You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

114 lines
3.6 KiB

# -*- coding: utf-8 -*-
"""
===================================
A股自选股智能分析系统 - 分析历史存储单元测试
===================================
职责
1. 验证分析历史保存逻辑
2. 验证上下文快照保存开关
"""
import os
import tempfile
import unittest
from src.config import Config
from src.storage import DatabaseManager, AnalysisHistory
from src.analyzer import AnalysisResult
class AnalysisHistoryTestCase(unittest.TestCase):
"""分析历史存储测试"""
def setUp(self) -> None:
"""为每个用例初始化独立数据库"""
self._temp_dir = tempfile.TemporaryDirectory()
self._db_path = os.path.join(self._temp_dir.name, "test_analysis_history.db")
os.environ["DATABASE_PATH"] = self._db_path
Config._instance = None
DatabaseManager.reset_instance()
self.db = DatabaseManager.get_instance()
def tearDown(self) -> None:
"""清理资源"""
DatabaseManager.reset_instance()
self._temp_dir.cleanup()
def _build_result(self) -> AnalysisResult:
"""构造分析结果"""
return AnalysisResult(
code="600519",
name="贵州茅台",
sentiment_score=78,
trend_prediction="看多",
operation_advice="持有",
analysis_summary="基本面稳健,短期震荡",
)
def test_save_analysis_history_with_snapshot(self) -> None:
"""保存历史记录并写入上下文快照"""
result = self._build_result()
result.dashboard = {
"battle_plan": {
"sniper_points": {
"ideal_buy": "理想买入点125.5元",
"secondary_buy": "120",
"stop_loss": "止损位110元",
"take_profit": "目标位150.0元",
}
}
}
context_snapshot = {"enhanced_context": {"code": "600519"}}
saved = self.db.save_analysis_history(
result=result,
query_id="query_001",
report_type="simple",
news_content="新闻摘要",
context_snapshot=context_snapshot,
save_snapshot=True
)
self.assertEqual(saved, 1)
history = self.db.get_analysis_history(code="600519", days=7, limit=10)
self.assertEqual(len(history), 1)
with self.db.get_session() as session:
row = session.query(AnalysisHistory).first()
if row is None:
self.fail("未找到保存的历史记录")
self.assertEqual(row.query_id, "query_001")
self.assertIsNotNone(row.context_snapshot)
self.assertEqual(row.ideal_buy, 125.5)
self.assertEqual(row.secondary_buy, 120.0)
self.assertEqual(row.stop_loss, 110.0)
self.assertEqual(row.take_profit, 150.0)
def test_save_analysis_history_without_snapshot(self) -> None:
"""关闭快照保存时不写入 context_snapshot"""
result = self._build_result()
saved = self.db.save_analysis_history(
result=result,
query_id="query_002",
report_type="simple",
news_content="新闻摘要",
context_snapshot={"foo": "bar"},
save_snapshot=False
)
self.assertEqual(saved, 1)
with self.db.get_session() as session:
row = session.query(AnalysisHistory).first()
if row is None:
self.fail("未找到保存的历史记录")
self.assertIsNone(row.context_snapshot)
if __name__ == "__main__":
unittest.main()