You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
161 lines
5.6 KiB
161 lines
5.6 KiB
|
2 months ago
|
# -*- coding: utf-8 -*-
|
||
|
|
"""Backtest endpoints."""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import logging
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||
|
|
|
||
|
|
from api.deps import get_database_manager
|
||
|
|
from api.v1.schemas.backtest import (
|
||
|
|
BacktestRunRequest,
|
||
|
|
BacktestRunResponse,
|
||
|
|
BacktestResultItem,
|
||
|
|
BacktestResultsResponse,
|
||
|
|
PerformanceMetrics,
|
||
|
|
)
|
||
|
|
from api.v1.schemas.common import ErrorResponse
|
||
|
|
from src.services.backtest_service import BacktestService
|
||
|
|
from src.storage import DatabaseManager
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
router = APIRouter()
|
||
|
|
|
||
|
|
|
||
|
|
@router.post(
|
||
|
|
"/run",
|
||
|
|
response_model=BacktestRunResponse,
|
||
|
|
responses={
|
||
|
|
200: {"description": "回测执行完成"},
|
||
|
|
500: {"description": "服务器错误", "model": ErrorResponse},
|
||
|
|
},
|
||
|
|
summary="触发回测",
|
||
|
|
description="对历史分析记录进行回测评估,并写入 backtest_results/backtest_summaries",
|
||
|
|
)
|
||
|
|
def run_backtest(
|
||
|
|
request: BacktestRunRequest,
|
||
|
|
db_manager: DatabaseManager = Depends(get_database_manager),
|
||
|
|
) -> BacktestRunResponse:
|
||
|
|
try:
|
||
|
|
service = BacktestService(db_manager)
|
||
|
|
stats = service.run_backtest(
|
||
|
|
code=request.code,
|
||
|
|
force=request.force,
|
||
|
|
eval_window_days=request.eval_window_days,
|
||
|
|
min_age_days=request.min_age_days,
|
||
|
|
limit=request.limit,
|
||
|
|
)
|
||
|
|
return BacktestRunResponse(**stats)
|
||
|
|
except Exception as exc:
|
||
|
|
logger.error(f"回测执行失败: {exc}", exc_info=True)
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=500,
|
||
|
|
detail={"error": "internal_error", "message": f"回测执行失败: {str(exc)}"},
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@router.get(
|
||
|
|
"/results",
|
||
|
|
response_model=BacktestResultsResponse,
|
||
|
|
responses={
|
||
|
|
200: {"description": "回测结果列表"},
|
||
|
|
500: {"description": "服务器错误", "model": ErrorResponse},
|
||
|
|
},
|
||
|
|
summary="获取回测结果",
|
||
|
|
description="分页获取回测结果,支持按股票代码过滤",
|
||
|
|
)
|
||
|
|
def get_backtest_results(
|
||
|
|
code: Optional[str] = Query(None, description="股票代码筛选"),
|
||
|
|
eval_window_days: Optional[int] = Query(None, ge=1, le=120, description="评估窗口过滤"),
|
||
|
|
page: int = Query(1, ge=1, description="页码"),
|
||
|
|
limit: int = Query(20, ge=1, le=200, description="每页数量"),
|
||
|
|
db_manager: DatabaseManager = Depends(get_database_manager),
|
||
|
|
) -> BacktestResultsResponse:
|
||
|
|
try:
|
||
|
|
service = BacktestService(db_manager)
|
||
|
|
data = service.get_recent_evaluations(code=code, eval_window_days=eval_window_days, limit=limit, page=page)
|
||
|
|
items = [BacktestResultItem(**item) for item in data.get("items", [])]
|
||
|
|
return BacktestResultsResponse(
|
||
|
|
total=int(data.get("total", 0)),
|
||
|
|
page=page,
|
||
|
|
limit=limit,
|
||
|
|
items=items,
|
||
|
|
)
|
||
|
|
except Exception as exc:
|
||
|
|
logger.error(f"查询回测结果失败: {exc}", exc_info=True)
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=500,
|
||
|
|
detail={"error": "internal_error", "message": f"查询回测结果失败: {str(exc)}"},
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@router.get(
|
||
|
|
"/performance",
|
||
|
|
response_model=PerformanceMetrics,
|
||
|
|
responses={
|
||
|
|
200: {"description": "整体回测表现"},
|
||
|
|
404: {"description": "无回测汇总", "model": ErrorResponse},
|
||
|
|
500: {"description": "服务器错误", "model": ErrorResponse},
|
||
|
|
},
|
||
|
|
summary="获取整体回测表现",
|
||
|
|
)
|
||
|
|
def get_overall_performance(
|
||
|
|
eval_window_days: Optional[int] = Query(None, ge=1, le=120, description="评估窗口过滤"),
|
||
|
|
db_manager: DatabaseManager = Depends(get_database_manager),
|
||
|
|
) -> PerformanceMetrics:
|
||
|
|
try:
|
||
|
|
service = BacktestService(db_manager)
|
||
|
|
summary = service.get_summary(scope="overall", code=None, eval_window_days=eval_window_days)
|
||
|
|
if summary is None:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=404,
|
||
|
|
detail={"error": "not_found", "message": "未找到整体回测汇总"},
|
||
|
|
)
|
||
|
|
return PerformanceMetrics(**summary)
|
||
|
|
except HTTPException:
|
||
|
|
raise
|
||
|
|
except Exception as exc:
|
||
|
|
logger.error(f"查询整体表现失败: {exc}", exc_info=True)
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=500,
|
||
|
|
detail={"error": "internal_error", "message": f"查询整体表现失败: {str(exc)}"},
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@router.get(
|
||
|
|
"/performance/{code}",
|
||
|
|
response_model=PerformanceMetrics,
|
||
|
|
responses={
|
||
|
|
200: {"description": "单股回测表现"},
|
||
|
|
404: {"description": "无回测汇总", "model": ErrorResponse},
|
||
|
|
500: {"description": "服务器错误", "model": ErrorResponse},
|
||
|
|
},
|
||
|
|
summary="获取单股回测表现",
|
||
|
|
)
|
||
|
|
def get_stock_performance(
|
||
|
|
code: str,
|
||
|
|
eval_window_days: Optional[int] = Query(None, ge=1, le=120, description="评估窗口过滤"),
|
||
|
|
db_manager: DatabaseManager = Depends(get_database_manager),
|
||
|
|
) -> PerformanceMetrics:
|
||
|
|
try:
|
||
|
|
service = BacktestService(db_manager)
|
||
|
|
summary = service.get_summary(scope="stock", code=code, eval_window_days=eval_window_days)
|
||
|
|
if summary is None:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=404,
|
||
|
|
detail={"error": "not_found", "message": f"未找到 {code} 的回测汇总"},
|
||
|
|
)
|
||
|
|
return PerformanceMetrics(**summary)
|
||
|
|
except HTTPException:
|
||
|
|
raise
|
||
|
|
except Exception as exc:
|
||
|
|
logger.error(f"查询单股表现失败: {exc}", exc_info=True)
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=500,
|
||
|
|
detail={"error": "internal_error", "message": f"查询单股表现失败: {str(exc)}"},
|
||
|
|
)
|
||
|
|
|