""" 配置管理接口 - 品种配置文件上传、批量获取、批量任务创建 """ import json import logging import shutil from pathlib import Path from typing import Optional from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Body from fastapi.responses import JSONResponse from sqlalchemy.orm import Session from pydantic import BaseModel from app.database import get_db from app.services.collector import fetch_symbol_data from app.services.cache import save_market_data, check_cache_status, get_cached_data, create_task from app.services.scheduler import add_job from app.schemas import CandleItem, TimeframeData, SymbolDataResponse logger = logging.getLogger(__name__) router = APIRouter(prefix="/config", tags=["品种配置"]) class BatchFetchRequest(BaseModel): """批量获取请求体""" periods: Optional[str] = None data_type: str = "futures" selected_symbols: Optional[str] = None # 逗号分隔的合约代码 # 配置文件存储路径 CONFIG_DIR = Path(__file__).resolve().parent.parent.parent / "config" CONFIG_FILE = CONFIG_DIR / "symbols_config.json" def _ensure_config_dir(): CONFIG_DIR.mkdir(parents=True, exist_ok=True) @router.get("") def get_config(): """获取当前品种配置""" _ensure_config_dir() if not CONFIG_FILE.exists(): return {"futures": {}, "stock": {}} with open(CONFIG_FILE, "r", encoding="utf-8") as f: return json.load(f) @router.post("/upload") def upload_config( file: Optional[UploadFile] = File(None), json_config: Optional[dict] = Body(None, embed=False), ): """ 上传品种配置文件(JSON格式)。 格式示例: { "futures": {"沪银": "AG2606", "沪金": "AU2606"}, "stock": {"平安银行": "000001"} } """ _ensure_config_dir() try: if file: content = file.file.read() data = json.loads(content) elif json_config: data = json_config else: raise HTTPException(status_code=400, detail="请提供配置文件或JSON数据") if not isinstance(data, dict): raise HTTPException(status_code=400, detail="配置文件必须是 JSON 对象") with open(CONFIG_FILE, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=4) futures_count = len(data.get("futures", {})) stock_count = len(data.get("stock", {})) return { "message": "配置文件上传成功", "futures_symbols": futures_count, "stock_symbols": stock_count, "symbols": data, } except json.JSONDecodeError: raise HTTPException(status_code=400, detail="无效的 JSON 格式") @router.post("/batch-fetch-all") def batch_fetch_all( request: BatchFetchRequest, db: Session = Depends(get_db), ): """ 根据配置文件批量获取所有品种数据。 智能缓存:已存在且有效的数据不重复请求。 """ periods = request.periods data_type = request.data_type selected_symbols = request.selected_symbols _ensure_config_dir() if not CONFIG_FILE.exists(): raise HTTPException(status_code=400, detail="请先上传品种配置文件") with open(CONFIG_FILE, "r", encoding="utf-8") as f: config = json.load(f) symbols_dict = config.get(data_type, {}) if not symbols_dict: raise HTTPException(status_code=400, detail=f"配置中没有 {data_type} 类型的品种") # 如果指定了selected_symbols,只获取这些合约 if selected_symbols: # 解析逗号分隔的合约代码 symbol_list = [s.strip() for s in selected_symbols.split(",") if s.strip()] symbols_dict = {name: code for name, code in symbols_dict.items() if code in symbol_list} if not symbols_dict: raise HTTPException(status_code=400, detail="选定的合约不在配置中") period_list = [p.strip() for p in periods.split(",")] if periods else ["5min", "15min", "30min", "60min", "daily"] results = { "total": len(symbols_dict), "success": [], "failed": [], "cached": [], # 命中缓存的 "details": {}, } for name, symbol in symbols_dict.items(): logger.info(f"处理品种: {name} ({symbol})") # 检查缓存 status = check_cache_status(db, symbol, data_type, period_list) if status["all_valid"]: results["cached"].append({"name": name, "symbol": symbol}) cached = get_cached_data(db, symbol, data_type, period_list) timeframes = [] for p, candles in cached["timeframes"].items(): # 转换数据格式: time -> datetime normalized_candles = [] for c in candles: candle_dict = dict(c) if 'time' in candle_dict and 'datetime' not in candle_dict: candle_dict['datetime'] = candle_dict.pop('time') normalized_candles.append(candle_dict) timeframes.append(TimeframeData( period=p, candles=[CandleItem(**c) for c in normalized_candles], candle_count=len(normalized_candles), fetched_at=cached.get("timestamp", ""), )) results["details"][symbol] = SymbolDataResponse( symbol=symbol, data_type=data_type, current_price=cached.get("current_price"), timeframes=timeframes, source="cache", ) results["success"].append({"name": name, "symbol": symbol}) continue # 需要采集 need_fetch = status["missing_periods"] logger.info(f"需要采集的周期: {need_fetch}") result = fetch_symbol_data(symbol, data_type, need_fetch) if result.get("timeframes"): logger.info(f"采集到 {len(result['timeframes'])} 个周期的数据,开始保存") save_market_data(db, symbol, result) # 合并缓存和新数据 all_timeframes = {} if status["valid_periods"]: existing = get_cached_data(db, symbol, data_type, status["valid_periods"]) if existing: all_timeframes.update(existing["timeframes"]) all_timeframes.update(result["timeframes"]) timeframes = [] for p in period_list: candles = all_timeframes.get(p, []) if candles: # 转换数据格式: time -> datetime normalized_candles = [] for c in candles: candle_dict = dict(c) if 'time' in candle_dict and 'datetime' not in candle_dict: candle_dict['datetime'] = candle_dict.pop('time') normalized_candles.append(candle_dict) timeframes.append(TimeframeData( period=p, candles=[CandleItem(**c) for c in normalized_candles], candle_count=len(normalized_candles), fetched_at=result.get("timestamp", ""), )) source = "live+cache" if status["valid_periods"] else "live" results["details"][symbol] = SymbolDataResponse( symbol=symbol, data_type=data_type, current_price=result.get("current_price"), timeframes=timeframes, source=source, ) results["success"].append({"name": name, "symbol": symbol}) logger.info(f"采集成功: {symbol}") else: error_msg = result.get("error", "未知错误") logger.error(f"采集失败: {symbol}, 错误: {error_msg}") results["failed"].append({ "name": name, "symbol": symbol, "error": error_msg, }) return results @router.post("/batch-tasks") def batch_create_tasks( periods: Optional[str] = None, interval_seconds: int = 300, data_type: str = "futures", db: Session = Depends(get_db), ): """ 根据配置文件为所有品种批量创建定时任务。 """ _ensure_config_dir() if not CONFIG_FILE.exists(): raise HTTPException(status_code=400, detail="请先上传品种配置文件") with open(CONFIG_FILE, "r", encoding="utf-8") as f: config = json.load(f) symbols_dict = config.get(data_type, {}) if not symbols_dict: raise HTTPException(status_code=400, detail=f"配置中没有 {data_type} 类型的品种") period_list = [p.strip() for p in periods.split(",")] if periods else ["5min", "15min", "30min", "60min", "daily"] results = {"total": len(symbols_dict), "created": [], "failed": []} for name, symbol in symbols_dict.items(): try: task = create_task( db=db, symbol=symbol, data_type=data_type, periods=period_list, interval_seconds=interval_seconds, ) job_id = add_job(task.id, task.interval_seconds) task.job_id = job_id db.commit() db.refresh(task) results["created"].append({ "name": name, "symbol": symbol, "task_id": task.id, "job_id": job_id, "interval": interval_seconds, }) except Exception as e: results["failed"].append({ "name": name, "symbol": symbol, "error": str(e), }) return results