|
|
# -*- coding: utf-8 -*-
|
|
|
"""
|
|
|
===================================
|
|
|
A股自选股智能分析系统 - 分析历史存储单元测试
|
|
|
===================================
|
|
|
|
|
|
职责:
|
|
|
1. 验证分析历史保存逻辑
|
|
|
2. 验证上下文快照保存开关
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import tempfile
|
|
|
import unittest
|
|
|
|
|
|
from src.config import Config
|
|
|
from src.storage import DatabaseManager, AnalysisHistory
|
|
|
from src.analyzer import AnalysisResult
|
|
|
|
|
|
|
|
|
class AnalysisHistoryTestCase(unittest.TestCase):
|
|
|
"""分析历史存储测试"""
|
|
|
|
|
|
def setUp(self) -> None:
|
|
|
"""为每个用例初始化独立数据库"""
|
|
|
self._temp_dir = tempfile.TemporaryDirectory()
|
|
|
self._db_path = os.path.join(self._temp_dir.name, "test_analysis_history.db")
|
|
|
os.environ["DATABASE_PATH"] = self._db_path
|
|
|
|
|
|
Config._instance = None
|
|
|
DatabaseManager.reset_instance()
|
|
|
self.db = DatabaseManager.get_instance()
|
|
|
|
|
|
def tearDown(self) -> None:
|
|
|
"""清理资源"""
|
|
|
DatabaseManager.reset_instance()
|
|
|
self._temp_dir.cleanup()
|
|
|
|
|
|
def _build_result(self) -> AnalysisResult:
|
|
|
"""构造分析结果"""
|
|
|
return AnalysisResult(
|
|
|
code="600519",
|
|
|
name="贵州茅台",
|
|
|
sentiment_score=78,
|
|
|
trend_prediction="看多",
|
|
|
operation_advice="持有",
|
|
|
analysis_summary="基本面稳健,短期震荡",
|
|
|
)
|
|
|
|
|
|
def test_save_analysis_history_with_snapshot(self) -> None:
|
|
|
"""保存历史记录并写入上下文快照"""
|
|
|
result = self._build_result()
|
|
|
result.dashboard = {
|
|
|
"battle_plan": {
|
|
|
"sniper_points": {
|
|
|
"ideal_buy": "理想买入点:125.5元",
|
|
|
"secondary_buy": "120",
|
|
|
"stop_loss": "止损位:110元",
|
|
|
"take_profit": "目标位:150.0元",
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
context_snapshot = {"enhanced_context": {"code": "600519"}}
|
|
|
|
|
|
saved = self.db.save_analysis_history(
|
|
|
result=result,
|
|
|
query_id="query_001",
|
|
|
report_type="simple",
|
|
|
news_content="新闻摘要",
|
|
|
context_snapshot=context_snapshot,
|
|
|
save_snapshot=True
|
|
|
)
|
|
|
|
|
|
self.assertEqual(saved, 1)
|
|
|
|
|
|
history = self.db.get_analysis_history(code="600519", days=7, limit=10)
|
|
|
self.assertEqual(len(history), 1)
|
|
|
|
|
|
with self.db.get_session() as session:
|
|
|
row = session.query(AnalysisHistory).first()
|
|
|
if row is None:
|
|
|
self.fail("未找到保存的历史记录")
|
|
|
self.assertEqual(row.query_id, "query_001")
|
|
|
self.assertIsNotNone(row.context_snapshot)
|
|
|
self.assertEqual(row.ideal_buy, 125.5)
|
|
|
self.assertEqual(row.secondary_buy, 120.0)
|
|
|
self.assertEqual(row.stop_loss, 110.0)
|
|
|
self.assertEqual(row.take_profit, 150.0)
|
|
|
|
|
|
def test_save_analysis_history_without_snapshot(self) -> None:
|
|
|
"""关闭快照保存时不写入 context_snapshot"""
|
|
|
result = self._build_result()
|
|
|
|
|
|
saved = self.db.save_analysis_history(
|
|
|
result=result,
|
|
|
query_id="query_002",
|
|
|
report_type="simple",
|
|
|
news_content="新闻摘要",
|
|
|
context_snapshot={"foo": "bar"},
|
|
|
save_snapshot=False
|
|
|
)
|
|
|
|
|
|
self.assertEqual(saved, 1)
|
|
|
|
|
|
with self.db.get_session() as session:
|
|
|
row = session.query(AnalysisHistory).first()
|
|
|
if row is None:
|
|
|
self.fail("未找到保存的历史记录")
|
|
|
self.assertIsNone(row.context_snapshot)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
unittest.main()
|