You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
259 lines
11 KiB
259 lines
11 KiB
|
2 months ago
|
# -*- coding: utf-8 -*-
|
||
|
|
"""Integration tests for backtest service and repository.
|
||
|
|
|
||
|
|
These tests run against a temporary SQLite DB (same approach as other tests)
|
||
|
|
and validate idempotency/force semantics, result field correctness,
|
||
|
|
summary creation, and query methods.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import os
|
||
|
|
import tempfile
|
||
|
|
import unittest
|
||
|
|
from datetime import date, datetime
|
||
|
|
|
||
|
|
from src.config import Config
|
||
|
|
from src.core.backtest_engine import OVERALL_SENTINEL_CODE
|
||
|
|
from src.services.backtest_service import BacktestService
|
||
|
|
from src.storage import AnalysisHistory, BacktestResult, BacktestSummary, DatabaseManager, StockDaily
|
||
|
|
|
||
|
|
|
||
|
|
class BacktestServiceTestCase(unittest.TestCase):
|
||
|
|
def setUp(self) -> None:
|
||
|
|
self._temp_dir = tempfile.TemporaryDirectory()
|
||
|
|
self._db_path = os.path.join(self._temp_dir.name, "test_backtest_service.db")
|
||
|
|
os.environ["DATABASE_PATH"] = self._db_path
|
||
|
|
os.environ["BACKTEST_EVAL_WINDOW_DAYS"] = "3"
|
||
|
|
|
||
|
|
Config._instance = None
|
||
|
|
DatabaseManager.reset_instance()
|
||
|
|
self.db = DatabaseManager.get_instance()
|
||
|
|
|
||
|
|
# Ensure analysis is old enough for default min_age_days=14
|
||
|
|
old_created_at = datetime(2024, 1, 1, 0, 0, 0)
|
||
|
|
|
||
|
|
with self.db.get_session() as session:
|
||
|
|
session.add(
|
||
|
|
AnalysisHistory(
|
||
|
|
query_id="q1",
|
||
|
|
code="600519",
|
||
|
|
name="贵州茅台",
|
||
|
|
report_type="simple",
|
||
|
|
sentiment_score=80,
|
||
|
|
operation_advice="买入",
|
||
|
|
trend_prediction="看多",
|
||
|
|
analysis_summary="test",
|
||
|
|
stop_loss=95.0,
|
||
|
|
take_profit=110.0,
|
||
|
|
created_at=old_created_at,
|
||
|
|
context_snapshot='{"enhanced_context": {"date": "2024-01-01"}}',
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
# Analysis day close
|
||
|
|
session.add(
|
||
|
|
StockDaily(
|
||
|
|
code="600519",
|
||
|
|
date=date(2024, 1, 1),
|
||
|
|
open=100.0,
|
||
|
|
high=101.0,
|
||
|
|
low=99.0,
|
||
|
|
close=100.0,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
# Forward bars (3 days) that hit take-profit on day1
|
||
|
|
session.add_all(
|
||
|
|
[
|
||
|
|
StockDaily(code="600519", date=date(2024, 1, 2), high=111.0, low=100.0, close=105.0),
|
||
|
|
StockDaily(code="600519", date=date(2024, 1, 3), high=108.0, low=103.0, close=106.0),
|
||
|
|
StockDaily(code="600519", date=date(2024, 1, 4), high=109.0, low=104.0, close=107.0),
|
||
|
|
]
|
||
|
|
)
|
||
|
|
session.commit()
|
||
|
|
|
||
|
|
def tearDown(self) -> None:
|
||
|
|
DatabaseManager.reset_instance()
|
||
|
|
self._temp_dir.cleanup()
|
||
|
|
|
||
|
|
def _count_results(self) -> int:
|
||
|
|
with self.db.get_session() as session:
|
||
|
|
return session.query(BacktestResult).count()
|
||
|
|
|
||
|
|
def test_force_semantics(self) -> None:
|
||
|
|
service = BacktestService(self.db)
|
||
|
|
|
||
|
|
stats1 = service.run_backtest(code="600519", force=False, eval_window_days=3, min_age_days=0, limit=10)
|
||
|
|
self.assertEqual(stats1["saved"], 1)
|
||
|
|
self.assertEqual(self._count_results(), 1)
|
||
|
|
|
||
|
|
# Non-force should be idempotent
|
||
|
|
stats2 = service.run_backtest(code="600519", force=False, eval_window_days=3, min_age_days=0, limit=10)
|
||
|
|
self.assertEqual(stats2["saved"], 0)
|
||
|
|
self.assertEqual(self._count_results(), 1)
|
||
|
|
|
||
|
|
# Force should replace existing result without unique constraint errors
|
||
|
|
stats3 = service.run_backtest(code="600519", force=True, eval_window_days=3, min_age_days=0, limit=10)
|
||
|
|
self.assertEqual(stats3["saved"], 1)
|
||
|
|
self.assertEqual(self._count_results(), 1)
|
||
|
|
|
||
|
|
def _run_and_get_result(self) -> BacktestResult:
|
||
|
|
"""Helper: run backtest and return the single BacktestResult row."""
|
||
|
|
service = BacktestService(self.db)
|
||
|
|
service.run_backtest(code="600519", force=False, eval_window_days=3, min_age_days=0, limit=10)
|
||
|
|
with self.db.get_session() as session:
|
||
|
|
return session.query(BacktestResult).one()
|
||
|
|
|
||
|
|
def test_result_fields_correct(self) -> None:
|
||
|
|
"""Verify BacktestResult row contains correct evaluation values."""
|
||
|
|
result = self._run_and_get_result()
|
||
|
|
|
||
|
|
self.assertEqual(result.eval_status, "completed")
|
||
|
|
self.assertEqual(result.code, "600519")
|
||
|
|
self.assertEqual(result.analysis_date, date(2024, 1, 1))
|
||
|
|
self.assertEqual(result.operation_advice, "买入")
|
||
|
|
self.assertEqual(result.position_recommendation, "long")
|
||
|
|
self.assertEqual(result.direction_expected, "up")
|
||
|
|
|
||
|
|
# Prices
|
||
|
|
self.assertAlmostEqual(result.start_price, 100.0)
|
||
|
|
self.assertAlmostEqual(result.end_close, 107.0)
|
||
|
|
self.assertAlmostEqual(result.stock_return_pct, 7.0)
|
||
|
|
|
||
|
|
# Direction & outcome
|
||
|
|
self.assertEqual(result.outcome, "win")
|
||
|
|
self.assertTrue(result.direction_correct)
|
||
|
|
|
||
|
|
# Target hits -- day2 high=111 >= take_profit=110
|
||
|
|
self.assertTrue(result.hit_take_profit)
|
||
|
|
self.assertFalse(result.hit_stop_loss)
|
||
|
|
self.assertEqual(result.first_hit, "take_profit")
|
||
|
|
self.assertEqual(result.first_hit_trading_days, 1)
|
||
|
|
self.assertEqual(result.first_hit_date, date(2024, 1, 2))
|
||
|
|
|
||
|
|
# Simulated execution
|
||
|
|
self.assertAlmostEqual(result.simulated_entry_price, 100.0)
|
||
|
|
self.assertAlmostEqual(result.simulated_exit_price, 110.0)
|
||
|
|
self.assertEqual(result.simulated_exit_reason, "take_profit")
|
||
|
|
self.assertAlmostEqual(result.simulated_return_pct, 10.0)
|
||
|
|
|
||
|
|
def test_summaries_created_after_run(self) -> None:
|
||
|
|
"""Verify both overall and per-stock BacktestSummary rows are created."""
|
||
|
|
service = BacktestService(self.db)
|
||
|
|
service.run_backtest(code="600519", force=False, eval_window_days=3, min_age_days=0, limit=10)
|
||
|
|
|
||
|
|
with self.db.get_session() as session:
|
||
|
|
# Overall summary uses sentinel code
|
||
|
|
overall = session.query(BacktestSummary).filter(
|
||
|
|
BacktestSummary.scope == "overall",
|
||
|
|
BacktestSummary.code == OVERALL_SENTINEL_CODE,
|
||
|
|
).first()
|
||
|
|
self.assertIsNotNone(overall)
|
||
|
|
self.assertEqual(overall.total_evaluations, 1)
|
||
|
|
self.assertEqual(overall.completed_count, 1)
|
||
|
|
self.assertEqual(overall.win_count, 1)
|
||
|
|
self.assertEqual(overall.loss_count, 0)
|
||
|
|
self.assertAlmostEqual(overall.win_rate_pct, 100.0)
|
||
|
|
|
||
|
|
# Stock-level summary
|
||
|
|
stock = session.query(BacktestSummary).filter(
|
||
|
|
BacktestSummary.scope == "stock",
|
||
|
|
BacktestSummary.code == "600519",
|
||
|
|
).first()
|
||
|
|
self.assertIsNotNone(stock)
|
||
|
|
self.assertEqual(stock.total_evaluations, 1)
|
||
|
|
self.assertEqual(stock.completed_count, 1)
|
||
|
|
self.assertEqual(stock.win_count, 1)
|
||
|
|
|
||
|
|
def test_get_summary_overall_returns_sentinel_as_none(self) -> None:
|
||
|
|
"""Verify get_summary translates __overall__ sentinel back to None."""
|
||
|
|
service = BacktestService(self.db)
|
||
|
|
service.run_backtest(code="600519", force=False, eval_window_days=3, min_age_days=0, limit=10)
|
||
|
|
|
||
|
|
summary = service.get_summary(scope="overall", code=None)
|
||
|
|
self.assertIsNotNone(summary)
|
||
|
|
self.assertIsNone(summary["code"])
|
||
|
|
self.assertEqual(summary["scope"], "overall")
|
||
|
|
self.assertEqual(summary["win_count"], 1)
|
||
|
|
|
||
|
|
def test_get_recent_evaluations(self) -> None:
|
||
|
|
"""Verify get_recent_evaluations returns correct paginated results."""
|
||
|
|
service = BacktestService(self.db)
|
||
|
|
service.run_backtest(code="600519", force=False, eval_window_days=3, min_age_days=0, limit=10)
|
||
|
|
|
||
|
|
data = service.get_recent_evaluations(code="600519", limit=10, page=1)
|
||
|
|
self.assertEqual(data["total"], 1)
|
||
|
|
self.assertEqual(data["page"], 1)
|
||
|
|
self.assertEqual(data["limit"], 10)
|
||
|
|
self.assertEqual(len(data["items"]), 1)
|
||
|
|
|
||
|
|
item = data["items"][0]
|
||
|
|
self.assertEqual(item["code"], "600519")
|
||
|
|
self.assertEqual(item["outcome"], "win")
|
||
|
|
self.assertEqual(item["direction_expected"], "up")
|
||
|
|
self.assertTrue(item["direction_correct"])
|
||
|
|
|
||
|
|
def test_multi_stock_summaries(self) -> None:
|
||
|
|
"""Verify separate summaries for multiple stocks + correct overall aggregate."""
|
||
|
|
old_created_at = datetime(2024, 1, 1, 0, 0, 0)
|
||
|
|
|
||
|
|
with self.db.get_session() as session:
|
||
|
|
# Second stock with sell advice -- price drops (win for cash/down)
|
||
|
|
session.add(
|
||
|
|
AnalysisHistory(
|
||
|
|
query_id="q2",
|
||
|
|
code="000001",
|
||
|
|
name="平安银行",
|
||
|
|
report_type="simple",
|
||
|
|
sentiment_score=30,
|
||
|
|
operation_advice="卖出",
|
||
|
|
trend_prediction="看空",
|
||
|
|
analysis_summary="test2",
|
||
|
|
stop_loss=None,
|
||
|
|
take_profit=None,
|
||
|
|
created_at=old_created_at,
|
||
|
|
context_snapshot='{"enhanced_context": {"date": "2024-01-01"}}',
|
||
|
|
)
|
||
|
|
)
|
||
|
|
session.add(
|
||
|
|
StockDaily(code="000001", date=date(2024, 1, 1), open=10.0, high=10.2, low=9.8, close=10.0)
|
||
|
|
)
|
||
|
|
session.add_all([
|
||
|
|
StockDaily(code="000001", date=date(2024, 1, 2), high=10.0, low=9.5, close=9.6),
|
||
|
|
StockDaily(code="000001", date=date(2024, 1, 3), high=9.7, low=9.3, close=9.4),
|
||
|
|
StockDaily(code="000001", date=date(2024, 1, 4), high=9.5, low=9.0, close=9.1),
|
||
|
|
])
|
||
|
|
session.commit()
|
||
|
|
|
||
|
|
service = BacktestService(self.db)
|
||
|
|
stats = service.run_backtest(code=None, force=False, eval_window_days=3, min_age_days=0, limit=10)
|
||
|
|
self.assertEqual(stats["saved"], 2)
|
||
|
|
self.assertEqual(stats["completed"], 2)
|
||
|
|
|
||
|
|
with self.db.get_session() as session:
|
||
|
|
# Each stock has its own summary
|
||
|
|
s1 = session.query(BacktestSummary).filter(
|
||
|
|
BacktestSummary.scope == "stock", BacktestSummary.code == "600519"
|
||
|
|
).first()
|
||
|
|
s2 = session.query(BacktestSummary).filter(
|
||
|
|
BacktestSummary.scope == "stock", BacktestSummary.code == "000001"
|
||
|
|
).first()
|
||
|
|
self.assertIsNotNone(s1)
|
||
|
|
self.assertIsNotNone(s2)
|
||
|
|
self.assertEqual(s1.win_count, 1)
|
||
|
|
self.assertEqual(s2.win_count, 1)
|
||
|
|
|
||
|
|
# Overall aggregates both
|
||
|
|
overall = session.query(BacktestSummary).filter(
|
||
|
|
BacktestSummary.scope == "overall",
|
||
|
|
BacktestSummary.code == OVERALL_SENTINEL_CODE,
|
||
|
|
).first()
|
||
|
|
self.assertIsNotNone(overall)
|
||
|
|
self.assertEqual(overall.total_evaluations, 2)
|
||
|
|
self.assertEqual(overall.completed_count, 2)
|
||
|
|
self.assertEqual(overall.win_count, 2)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
unittest.main()
|
||
|
|
|