You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

389 lines
11 KiB

"""
缓存管理路由
"""
from typing import List
from fastapi import APIRouter, Depends, Query, BackgroundTasks
from sqlalchemy.orm import Session
from app.db.session import get_db
from app.schemas.base import ResponseModel, PaginatedResponse
from app.schemas.cache import (
DetectMissingRequest, DetectMissingResponse,
BatchCacheRequest, CacheTaskResponse, CacheStatusResponse,
AllDataRequest
)
from app.services.cache_service import CacheService
from app.core.security import get_current_user
from app.models.user import User
from app.utils.date_utils import parse_date, format_date
router = APIRouter()
def run_cache_task(
security_type: str,
period_type: str,
start_date_str: str,
end_date_str: str,
contract_type: str,
task_id: int
):
"""后台执行缓存任务"""
from app.db.session import SessionLocal
from app.utils.date_utils import parse_date
db = SessionLocal()
try:
service = CacheService(db)
start_date = parse_date(start_date_str)
end_date = parse_date(end_date_str)
service._execute_cache_task(
task_id,
security_type,
period_type,
start_date,
end_date,
contract_type
)
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.error(f"后台缓存任务失败: {str(e)}")
finally:
db.close()
def run_fill_missing_task(
security_type: str,
period_type: str,
start_date_str: str,
end_date_str: str,
missing_codes: List[str],
task_id: int
):
"""后台执行补齐缺失数据任务"""
from app.db.session import SessionLocal
from app.utils.date_utils import parse_date
db = SessionLocal()
try:
service = CacheService(db)
start_date = parse_date(start_date_str)
end_date = parse_date(end_date_str)
service._execute_fill_missing_task(
task_id,
security_type,
period_type,
start_date,
end_date,
missing_codes
)
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.error(f"后台补齐任务失败: {str(e)}")
finally:
db.close()
@router.post("/fill-missing", response_model=ResponseModel)
async def fill_missing_data(
request: AllDataRequest,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""一键补齐缺失数据(异步执行,只处理缺失代码)"""
service = CacheService(db)
start = parse_date(request.start_date)
end = parse_date(request.end_date)
# 先执行检测获取缺失代码列表
result = service.detect_all_missing_data(
request.security_type,
request.period_type,
start,
end,
request.contract_type
)
missing_codes = result.get("missing_code_list", [])
if not missing_codes:
return ResponseModel(data={
"task_id": None,
"message": "没有缺失数据需要补齐",
"missing_count": 0
})
# 创建补齐任务记录
task = service._create_fill_missing_task(
request.security_type,
request.period_type,
start,
end,
missing_codes
)
# 在后台执行补齐任务
background_tasks.add_task(
run_fill_missing_task,
request.security_type,
request.period_type,
request.start_date,
request.end_date,
missing_codes,
task.id
)
return ResponseModel(data={
"task_id": task.id,
"task_name": task.task_name,
"status": task.status,
"total_count": task.total_count,
"progress": float(task.progress),
"missing_count": len(missing_codes)
})
@router.post("/cache-all-missing", response_model=ResponseModel)
async def cache_all_missing_data(
request: AllDataRequest,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""一键缓存所有缺失数据(异步执行)"""
service = CacheService(db)
start = parse_date(request.start_date)
end = parse_date(request.end_date)
# 创建任务记录
task = service._create_cache_task(
request.security_type,
request.period_type,
start,
end,
request.contract_type
)
# 在后台执行缓存任务
background_tasks.add_task(
run_cache_task,
request.security_type,
request.period_type,
request.start_date,
request.end_date,
request.contract_type,
task.id
)
return ResponseModel(data={
"task_id": task.id,
"task_name": task.task_name,
"status": task.status,
"total_count": task.total_count,
"progress": float(task.progress)
})
@router.post("/detect-all-missing", response_model=ResponseModel)
async def detect_all_missing_data(
request: AllDataRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""一键检测所有数据的缺失情况"""
service = CacheService(db)
start = parse_date(request.start_date)
end = parse_date(request.end_date)
result = service.detect_all_missing_data(
request.security_type,
request.period_type,
start,
end,
request.contract_type
)
return ResponseModel(data=result)
@router.post("/detect-missing", response_model=ResponseModel)
async def detect_missing_data(
request: DetectMissingRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""检测缺失数据"""
service = CacheService(db)
start = parse_date(request.start_date)
end = parse_date(request.end_date)
task = service.detect_missing_data(
request.security_type,
request.period_type,
start,
end,
request.code_list
)
# 获取缺失详情
details = service.get_task_details(task.id)
missing_codes = [d for d in details if d.is_missing]
missing_info = []
for code in request.code_list:
code_details = [d for d in details if d.code == code and d.is_missing]
if code_details:
missing_info.append({
"code": code,
"missing_dates": [{
"date": format_date(d.trade_date),
"expected": d.expected_count,
"actual": d.actual_count,
"missing_ratio": (d.expected_count - d.actual_count) / d.expected_count if d.expected_count > 0 else 0
} for d in code_details]
})
return ResponseModel(data={
"task_id": task.id,
"total_codes": len(request.code_list),
"missing_codes": missing_info
})
@router.post("/batch-cache", response_model=ResponseModel[CacheTaskResponse])
async def batch_cache_data(
request: BatchCacheRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""批量缓存数据"""
service = CacheService(db)
start = parse_date(request.start_date)
end = parse_date(request.end_date)
task = service.batch_cache_data(
request.security_type,
request.period_type,
start,
end,
request.code_list
)
return ResponseModel(data=CacheTaskResponse.model_validate(task))
@router.get("/tasks", response_model=ResponseModel)
async def get_cache_tasks(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取缓存任务列表"""
service = CacheService(db)
result = service.get_tasks(page, page_size)
return ResponseModel(data={
"items": [CacheTaskResponse.model_validate(t) for t in result["items"]],
"total": result["total"],
"page": result["page"],
"page_size": result["page_size"],
"total_pages": result["total_pages"]
})
@router.get("/tasks/{task_id}", response_model=ResponseModel)
async def get_cache_task(
task_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取缓存任务详情"""
service = CacheService(db)
task = service.get_task(task_id)
if not task:
return ResponseModel(code=404, message="任务不存在")
details = service.get_task_details(task_id)
return ResponseModel(data={
"task": CacheTaskResponse.model_validate(task),
"details": [{
"id": d.id,
"code": d.code,
"trade_date": d.trade_date.isoformat() if d.trade_date else None,
"expected_count": d.expected_count,
"actual_count": d.actual_count,
"is_missing": bool(d.is_missing),
"status": d.status,
"error_message": d.error_message
} for d in details]
})
@router.delete("/tasks/{task_id}", response_model=ResponseModel)
async def cancel_cache_task(
task_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""取消缓存任务"""
service = CacheService(db)
success = service.cancel_task(task_id)
if success:
return ResponseModel(message="任务已取消")
else:
return ResponseModel(code=400, message="任务不存在或已完成")
@router.get("/status/{code}", response_model=ResponseModel)
async def get_cache_status(
code: str,
security_type: str = Query("stock", description="证券类型: stock, future"),
period_type: str = Query("daily", description="周期类型: daily, min1, min5, etc."),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取代码缓存状态"""
service = CacheService(db)
status = service.get_cache_status(code, security_type, period_type)
return ResponseModel(data=status)
@router.get("/future-varieties", response_model=ResponseModel)
async def get_future_varieties(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取期货品种列表"""
service = CacheService(db)
varieties = service.get_future_varieties()
return ResponseModel(data={"varieties": varieties})
@router.get("/main-contracts", response_model=ResponseModel)
async def get_main_contracts(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取所有品种的主力合约"""
from app.services.sdk_manager import sdk_manager
adapter = sdk_manager.get_default_connection()
if not adapter:
return ResponseModel(code=500, message="SDK连接失败")
main_contracts = adapter.get_all_main_contracts()
return ResponseModel(data={"main_contracts": main_contracts})
from app.utils.date_utils import format_date