You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

259 lines
11 KiB

# -*- coding: utf-8 -*-
"""Integration tests for backtest service and repository.
These tests run against a temporary SQLite DB (same approach as other tests)
and validate idempotency/force semantics, result field correctness,
summary creation, and query methods.
"""
import os
import tempfile
import unittest
from datetime import date, datetime
from src.config import Config
from src.core.backtest_engine import OVERALL_SENTINEL_CODE
from src.services.backtest_service import BacktestService
from src.storage import AnalysisHistory, BacktestResult, BacktestSummary, DatabaseManager, StockDaily
class BacktestServiceTestCase(unittest.TestCase):
def setUp(self) -> None:
self._temp_dir = tempfile.TemporaryDirectory()
self._db_path = os.path.join(self._temp_dir.name, "test_backtest_service.db")
os.environ["DATABASE_PATH"] = self._db_path
os.environ["BACKTEST_EVAL_WINDOW_DAYS"] = "3"
Config._instance = None
DatabaseManager.reset_instance()
self.db = DatabaseManager.get_instance()
# Ensure analysis is old enough for default min_age_days=14
old_created_at = datetime(2024, 1, 1, 0, 0, 0)
with self.db.get_session() as session:
session.add(
AnalysisHistory(
query_id="q1",
code="600519",
name="贵州茅台",
report_type="simple",
sentiment_score=80,
operation_advice="买入",
trend_prediction="看多",
analysis_summary="test",
stop_loss=95.0,
take_profit=110.0,
created_at=old_created_at,
context_snapshot='{"enhanced_context": {"date": "2024-01-01"}}',
)
)
# Analysis day close
session.add(
StockDaily(
code="600519",
date=date(2024, 1, 1),
open=100.0,
high=101.0,
low=99.0,
close=100.0,
)
)
# Forward bars (3 days) that hit take-profit on day1
session.add_all(
[
StockDaily(code="600519", date=date(2024, 1, 2), high=111.0, low=100.0, close=105.0),
StockDaily(code="600519", date=date(2024, 1, 3), high=108.0, low=103.0, close=106.0),
StockDaily(code="600519", date=date(2024, 1, 4), high=109.0, low=104.0, close=107.0),
]
)
session.commit()
def tearDown(self) -> None:
DatabaseManager.reset_instance()
self._temp_dir.cleanup()
def _count_results(self) -> int:
with self.db.get_session() as session:
return session.query(BacktestResult).count()
def test_force_semantics(self) -> None:
service = BacktestService(self.db)
stats1 = service.run_backtest(code="600519", force=False, eval_window_days=3, min_age_days=0, limit=10)
self.assertEqual(stats1["saved"], 1)
self.assertEqual(self._count_results(), 1)
# Non-force should be idempotent
stats2 = service.run_backtest(code="600519", force=False, eval_window_days=3, min_age_days=0, limit=10)
self.assertEqual(stats2["saved"], 0)
self.assertEqual(self._count_results(), 1)
# Force should replace existing result without unique constraint errors
stats3 = service.run_backtest(code="600519", force=True, eval_window_days=3, min_age_days=0, limit=10)
self.assertEqual(stats3["saved"], 1)
self.assertEqual(self._count_results(), 1)
def _run_and_get_result(self) -> BacktestResult:
"""Helper: run backtest and return the single BacktestResult row."""
service = BacktestService(self.db)
service.run_backtest(code="600519", force=False, eval_window_days=3, min_age_days=0, limit=10)
with self.db.get_session() as session:
return session.query(BacktestResult).one()
def test_result_fields_correct(self) -> None:
"""Verify BacktestResult row contains correct evaluation values."""
result = self._run_and_get_result()
self.assertEqual(result.eval_status, "completed")
self.assertEqual(result.code, "600519")
self.assertEqual(result.analysis_date, date(2024, 1, 1))
self.assertEqual(result.operation_advice, "买入")
self.assertEqual(result.position_recommendation, "long")
self.assertEqual(result.direction_expected, "up")
# Prices
self.assertAlmostEqual(result.start_price, 100.0)
self.assertAlmostEqual(result.end_close, 107.0)
self.assertAlmostEqual(result.stock_return_pct, 7.0)
# Direction & outcome
self.assertEqual(result.outcome, "win")
self.assertTrue(result.direction_correct)
# Target hits -- day2 high=111 >= take_profit=110
self.assertTrue(result.hit_take_profit)
self.assertFalse(result.hit_stop_loss)
self.assertEqual(result.first_hit, "take_profit")
self.assertEqual(result.first_hit_trading_days, 1)
self.assertEqual(result.first_hit_date, date(2024, 1, 2))
# Simulated execution
self.assertAlmostEqual(result.simulated_entry_price, 100.0)
self.assertAlmostEqual(result.simulated_exit_price, 110.0)
self.assertEqual(result.simulated_exit_reason, "take_profit")
self.assertAlmostEqual(result.simulated_return_pct, 10.0)
def test_summaries_created_after_run(self) -> None:
"""Verify both overall and per-stock BacktestSummary rows are created."""
service = BacktestService(self.db)
service.run_backtest(code="600519", force=False, eval_window_days=3, min_age_days=0, limit=10)
with self.db.get_session() as session:
# Overall summary uses sentinel code
overall = session.query(BacktestSummary).filter(
BacktestSummary.scope == "overall",
BacktestSummary.code == OVERALL_SENTINEL_CODE,
).first()
self.assertIsNotNone(overall)
self.assertEqual(overall.total_evaluations, 1)
self.assertEqual(overall.completed_count, 1)
self.assertEqual(overall.win_count, 1)
self.assertEqual(overall.loss_count, 0)
self.assertAlmostEqual(overall.win_rate_pct, 100.0)
# Stock-level summary
stock = session.query(BacktestSummary).filter(
BacktestSummary.scope == "stock",
BacktestSummary.code == "600519",
).first()
self.assertIsNotNone(stock)
self.assertEqual(stock.total_evaluations, 1)
self.assertEqual(stock.completed_count, 1)
self.assertEqual(stock.win_count, 1)
def test_get_summary_overall_returns_sentinel_as_none(self) -> None:
"""Verify get_summary translates __overall__ sentinel back to None."""
service = BacktestService(self.db)
service.run_backtest(code="600519", force=False, eval_window_days=3, min_age_days=0, limit=10)
summary = service.get_summary(scope="overall", code=None)
self.assertIsNotNone(summary)
self.assertIsNone(summary["code"])
self.assertEqual(summary["scope"], "overall")
self.assertEqual(summary["win_count"], 1)
def test_get_recent_evaluations(self) -> None:
"""Verify get_recent_evaluations returns correct paginated results."""
service = BacktestService(self.db)
service.run_backtest(code="600519", force=False, eval_window_days=3, min_age_days=0, limit=10)
data = service.get_recent_evaluations(code="600519", limit=10, page=1)
self.assertEqual(data["total"], 1)
self.assertEqual(data["page"], 1)
self.assertEqual(data["limit"], 10)
self.assertEqual(len(data["items"]), 1)
item = data["items"][0]
self.assertEqual(item["code"], "600519")
self.assertEqual(item["outcome"], "win")
self.assertEqual(item["direction_expected"], "up")
self.assertTrue(item["direction_correct"])
def test_multi_stock_summaries(self) -> None:
"""Verify separate summaries for multiple stocks + correct overall aggregate."""
old_created_at = datetime(2024, 1, 1, 0, 0, 0)
with self.db.get_session() as session:
# Second stock with sell advice -- price drops (win for cash/down)
session.add(
AnalysisHistory(
query_id="q2",
code="000001",
name="平安银行",
report_type="simple",
sentiment_score=30,
operation_advice="卖出",
trend_prediction="看空",
analysis_summary="test2",
stop_loss=None,
take_profit=None,
created_at=old_created_at,
context_snapshot='{"enhanced_context": {"date": "2024-01-01"}}',
)
)
session.add(
StockDaily(code="000001", date=date(2024, 1, 1), open=10.0, high=10.2, low=9.8, close=10.0)
)
session.add_all([
StockDaily(code="000001", date=date(2024, 1, 2), high=10.0, low=9.5, close=9.6),
StockDaily(code="000001", date=date(2024, 1, 3), high=9.7, low=9.3, close=9.4),
StockDaily(code="000001", date=date(2024, 1, 4), high=9.5, low=9.0, close=9.1),
])
session.commit()
service = BacktestService(self.db)
stats = service.run_backtest(code=None, force=False, eval_window_days=3, min_age_days=0, limit=10)
self.assertEqual(stats["saved"], 2)
self.assertEqual(stats["completed"], 2)
with self.db.get_session() as session:
# Each stock has its own summary
s1 = session.query(BacktestSummary).filter(
BacktestSummary.scope == "stock", BacktestSummary.code == "600519"
).first()
s2 = session.query(BacktestSummary).filter(
BacktestSummary.scope == "stock", BacktestSummary.code == "000001"
).first()
self.assertIsNotNone(s1)
self.assertIsNotNone(s2)
self.assertEqual(s1.win_count, 1)
self.assertEqual(s2.win_count, 1)
# Overall aggregates both
overall = session.query(BacktestSummary).filter(
BacktestSummary.scope == "overall",
BacktestSummary.code == OVERALL_SENTINEL_CODE,
).first()
self.assertIsNotNone(overall)
self.assertEqual(overall.total_evaluations, 2)
self.assertEqual(overall.completed_count, 2)
self.assertEqual(overall.win_count, 2)
if __name__ == "__main__":
unittest.main()