You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

114 lines
3.6 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
"""
===================================
A股自选股智能分析系统 - 分析历史存储单元测试
===================================
职责:
1. 验证分析历史保存逻辑
2. 验证上下文快照保存开关
"""
import os
import tempfile
import unittest
from src.config import Config
from src.storage import DatabaseManager, AnalysisHistory
from src.analyzer import AnalysisResult
class AnalysisHistoryTestCase(unittest.TestCase):
"""分析历史存储测试"""
def setUp(self) -> None:
"""为每个用例初始化独立数据库"""
self._temp_dir = tempfile.TemporaryDirectory()
self._db_path = os.path.join(self._temp_dir.name, "test_analysis_history.db")
os.environ["DATABASE_PATH"] = self._db_path
Config._instance = None
DatabaseManager.reset_instance()
self.db = DatabaseManager.get_instance()
def tearDown(self) -> None:
"""清理资源"""
DatabaseManager.reset_instance()
self._temp_dir.cleanup()
def _build_result(self) -> AnalysisResult:
"""构造分析结果"""
return AnalysisResult(
code="600519",
name="贵州茅台",
sentiment_score=78,
trend_prediction="看多",
operation_advice="持有",
analysis_summary="基本面稳健,短期震荡",
)
def test_save_analysis_history_with_snapshot(self) -> None:
"""保存历史记录并写入上下文快照"""
result = self._build_result()
result.dashboard = {
"battle_plan": {
"sniper_points": {
"ideal_buy": "理想买入点125.5元",
"secondary_buy": "120",
"stop_loss": "止损位110元",
"take_profit": "目标位150.0元",
}
}
}
context_snapshot = {"enhanced_context": {"code": "600519"}}
saved = self.db.save_analysis_history(
result=result,
query_id="query_001",
report_type="simple",
news_content="新闻摘要",
context_snapshot=context_snapshot,
save_snapshot=True
)
self.assertEqual(saved, 1)
history = self.db.get_analysis_history(code="600519", days=7, limit=10)
self.assertEqual(len(history), 1)
with self.db.get_session() as session:
row = session.query(AnalysisHistory).first()
if row is None:
self.fail("未找到保存的历史记录")
self.assertEqual(row.query_id, "query_001")
self.assertIsNotNone(row.context_snapshot)
self.assertEqual(row.ideal_buy, 125.5)
self.assertEqual(row.secondary_buy, 120.0)
self.assertEqual(row.stop_loss, 110.0)
self.assertEqual(row.take_profit, 150.0)
def test_save_analysis_history_without_snapshot(self) -> None:
"""关闭快照保存时不写入 context_snapshot"""
result = self._build_result()
saved = self.db.save_analysis_history(
result=result,
query_id="query_002",
report_type="simple",
news_content="新闻摘要",
context_snapshot={"foo": "bar"},
save_snapshot=False
)
self.assertEqual(saved, 1)
with self.db.get_session() as session:
row = session.query(AnalysisHistory).first()
if row is None:
self.fail("未找到保存的历史记录")
self.assertIsNone(row.context_snapshot)
if __name__ == "__main__":
unittest.main()