You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

161 lines
5.6 KiB

# -*- coding: utf-8 -*-
"""Backtest endpoints."""
from __future__ import annotations
import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from api.deps import get_database_manager
from api.v1.schemas.backtest import (
BacktestRunRequest,
BacktestRunResponse,
BacktestResultItem,
BacktestResultsResponse,
PerformanceMetrics,
)
from api.v1.schemas.common import ErrorResponse
from src.services.backtest_service import BacktestService
from src.storage import DatabaseManager
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post(
"/run",
response_model=BacktestRunResponse,
responses={
200: {"description": "回测执行完成"},
500: {"description": "服务器错误", "model": ErrorResponse},
},
summary="触发回测",
description="对历史分析记录进行回测评估,并写入 backtest_results/backtest_summaries",
)
def run_backtest(
request: BacktestRunRequest,
db_manager: DatabaseManager = Depends(get_database_manager),
) -> BacktestRunResponse:
try:
service = BacktestService(db_manager)
stats = service.run_backtest(
code=request.code,
force=request.force,
eval_window_days=request.eval_window_days,
min_age_days=request.min_age_days,
limit=request.limit,
)
return BacktestRunResponse(**stats)
except Exception as exc:
logger.error(f"回测执行失败: {exc}", exc_info=True)
raise HTTPException(
status_code=500,
detail={"error": "internal_error", "message": f"回测执行失败: {str(exc)}"},
)
@router.get(
"/results",
response_model=BacktestResultsResponse,
responses={
200: {"description": "回测结果列表"},
500: {"description": "服务器错误", "model": ErrorResponse},
},
summary="获取回测结果",
description="分页获取回测结果,支持按股票代码过滤",
)
def get_backtest_results(
code: Optional[str] = Query(None, description="股票代码筛选"),
eval_window_days: Optional[int] = Query(None, ge=1, le=120, description="评估窗口过滤"),
page: int = Query(1, ge=1, description="页码"),
limit: int = Query(20, ge=1, le=200, description="每页数量"),
db_manager: DatabaseManager = Depends(get_database_manager),
) -> BacktestResultsResponse:
try:
service = BacktestService(db_manager)
data = service.get_recent_evaluations(code=code, eval_window_days=eval_window_days, limit=limit, page=page)
items = [BacktestResultItem(**item) for item in data.get("items", [])]
return BacktestResultsResponse(
total=int(data.get("total", 0)),
page=page,
limit=limit,
items=items,
)
except Exception as exc:
logger.error(f"查询回测结果失败: {exc}", exc_info=True)
raise HTTPException(
status_code=500,
detail={"error": "internal_error", "message": f"查询回测结果失败: {str(exc)}"},
)
@router.get(
"/performance",
response_model=PerformanceMetrics,
responses={
200: {"description": "整体回测表现"},
404: {"description": "无回测汇总", "model": ErrorResponse},
500: {"description": "服务器错误", "model": ErrorResponse},
},
summary="获取整体回测表现",
)
def get_overall_performance(
eval_window_days: Optional[int] = Query(None, ge=1, le=120, description="评估窗口过滤"),
db_manager: DatabaseManager = Depends(get_database_manager),
) -> PerformanceMetrics:
try:
service = BacktestService(db_manager)
summary = service.get_summary(scope="overall", code=None, eval_window_days=eval_window_days)
if summary is None:
raise HTTPException(
status_code=404,
detail={"error": "not_found", "message": "未找到整体回测汇总"},
)
return PerformanceMetrics(**summary)
except HTTPException:
raise
except Exception as exc:
logger.error(f"查询整体表现失败: {exc}", exc_info=True)
raise HTTPException(
status_code=500,
detail={"error": "internal_error", "message": f"查询整体表现失败: {str(exc)}"},
)
@router.get(
"/performance/{code}",
response_model=PerformanceMetrics,
responses={
200: {"description": "单股回测表现"},
404: {"description": "无回测汇总", "model": ErrorResponse},
500: {"description": "服务器错误", "model": ErrorResponse},
},
summary="获取单股回测表现",
)
def get_stock_performance(
code: str,
eval_window_days: Optional[int] = Query(None, ge=1, le=120, description="评估窗口过滤"),
db_manager: DatabaseManager = Depends(get_database_manager),
) -> PerformanceMetrics:
try:
service = BacktestService(db_manager)
summary = service.get_summary(scope="stock", code=code, eval_window_days=eval_window_days)
if summary is None:
raise HTTPException(
status_code=404,
detail={"error": "not_found", "message": f"未找到 {code} 的回测汇总"},
)
return PerformanceMetrics(**summary)
except HTTPException:
raise
except Exception as exc:
logger.error(f"查询单股表现失败: {exc}", exc_info=True)
raise HTTPException(
status_code=500,
detail={"error": "internal_error", "message": f"查询单股表现失败: {str(exc)}"},
)