You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

283 lines
9.8 KiB

"""
配置管理接口 - 品种配置文件上传批量获取批量任务创建
"""
import json
import logging
import shutil
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Body, Request
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from pydantic import BaseModel
from app.database import get_db
from app.services.collector import fetch_symbol_data
from app.services.cache import save_market_data, check_cache_status, get_cached_data, create_task
from app.services.scheduler import add_job
from app.schemas import CandleItem, TimeframeData, SymbolDataResponse
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/config", tags=["品种配置"])
class BatchFetchRequest(BaseModel):
"""批量获取请求体"""
periods: Optional[str] = None
data_type: str = "futures"
selected_symbols: Optional[str] = None # 逗号分隔的合约代码
# 配置文件存储路径
CONFIG_DIR = Path(__file__).resolve().parent.parent.parent / "config"
CONFIG_FILE = CONFIG_DIR / "symbols_config.json"
def _ensure_config_dir():
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
@router.get("")
def get_config():
"""获取当前品种配置"""
_ensure_config_dir()
if not CONFIG_FILE.exists():
return {"futures": {}, "stock": {}}
with open(CONFIG_FILE, "r", encoding="utf-8") as f:
return json.load(f)
@router.post("/upload")
async def upload_config(
file: Optional[UploadFile] = File(None),
request: Request = None,
):
"""
上传品种配置文件(JSON格式)
格式示例:
{
"futures": {"沪银": "AG2606", "沪金": "AU2606"},
"stock": {"平安银行": "000001"}
}
"""
_ensure_config_dir()
try:
if file:
content = await file.read()
data = json.loads(content)
else:
# 直接从请求体读取 JSON
body = await request.body()
if not body:
raise HTTPException(status_code=400, detail="请提供配置文件或JSON数据")
data = json.loads(body)
if not isinstance(data, dict):
raise HTTPException(status_code=400, detail="配置文件必须是 JSON 对象")
with open(CONFIG_FILE, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
futures_count = len(data.get("futures", {}))
stock_count = len(data.get("stock", {}))
return {
"message": "配置文件上传成功",
"futures_symbols": futures_count,
"stock_symbols": stock_count,
"symbols": data,
}
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="无效的 JSON 格式")
@router.post("/batch-fetch-all")
def batch_fetch_all(
request: BatchFetchRequest,
db: Session = Depends(get_db),
):
"""
根据配置文件批量获取所有品种数据
智能缓存:已存在且有效的数据不重复请求
"""
periods = request.periods
data_type = request.data_type
selected_symbols = request.selected_symbols
_ensure_config_dir()
if not CONFIG_FILE.exists():
raise HTTPException(status_code=400, detail="请先上传品种配置文件")
with open(CONFIG_FILE, "r", encoding="utf-8") as f:
config = json.load(f)
symbols_dict = config.get(data_type, {})
if not symbols_dict:
raise HTTPException(status_code=400, detail=f"配置中没有 {data_type} 类型的品种")
# 如果指定了selected_symbols只获取这些合约
if selected_symbols:
# 解析逗号分隔的合约代码
symbol_list = [s.strip() for s in selected_symbols.split(",") if s.strip()]
symbols_dict = {name: code for name, code in symbols_dict.items() if code in symbol_list}
if not symbols_dict:
raise HTTPException(status_code=400, detail="选定的合约不在配置中")
period_list = [p.strip() for p in periods.split(",")] if periods else ["5min", "15min", "30min", "60min", "daily"]
results = {
"total": len(symbols_dict),
"success": [],
"failed": [],
"cached": [], # 命中缓存的
"details": {},
}
for name, symbol in symbols_dict.items():
logger.info(f"处理品种: {name} ({symbol})")
# 检查缓存
status = check_cache_status(db, symbol, data_type, period_list)
if status["all_valid"]:
results["cached"].append({"name": name, "symbol": symbol})
cached = get_cached_data(db, symbol, data_type, period_list)
timeframes = []
for p, candles in cached["timeframes"].items():
# 转换数据格式: time -> datetime
normalized_candles = []
for c in candles:
candle_dict = dict(c)
if 'time' in candle_dict and 'datetime' not in candle_dict:
candle_dict['datetime'] = candle_dict.pop('time')
normalized_candles.append(candle_dict)
timeframes.append(TimeframeData(
period=p,
candles=[CandleItem(**c) for c in normalized_candles],
candle_count=len(normalized_candles),
fetched_at=cached.get("timestamp", ""),
))
results["details"][symbol] = SymbolDataResponse(
symbol=symbol,
data_type=data_type,
current_price=cached.get("current_price"),
timeframes=timeframes,
source="cache",
)
results["success"].append({"name": name, "symbol": symbol})
continue
# 需要采集
need_fetch = status["missing_periods"]
logger.info(f"需要采集的周期: {need_fetch}")
result = fetch_symbol_data(symbol, data_type, need_fetch)
if result.get("timeframes"):
logger.info(f"采集到 {len(result['timeframes'])} 个周期的数据,开始保存")
save_market_data(db, symbol, result)
# 合并缓存和新数据
all_timeframes = {}
if status["valid_periods"]:
existing = get_cached_data(db, symbol, data_type, status["valid_periods"])
if existing:
all_timeframes.update(existing["timeframes"])
all_timeframes.update(result["timeframes"])
timeframes = []
for p in period_list:
candles = all_timeframes.get(p, [])
if candles:
# 转换数据格式: time -> datetime
normalized_candles = []
for c in candles:
candle_dict = dict(c)
if 'time' in candle_dict and 'datetime' not in candle_dict:
candle_dict['datetime'] = candle_dict.pop('time')
normalized_candles.append(candle_dict)
timeframes.append(TimeframeData(
period=p,
candles=[CandleItem(**c) for c in normalized_candles],
candle_count=len(normalized_candles),
fetched_at=result.get("timestamp", ""),
))
source = "live+cache" if status["valid_periods"] else "live"
results["details"][symbol] = SymbolDataResponse(
symbol=symbol,
data_type=data_type,
current_price=result.get("current_price"),
timeframes=timeframes,
source=source,
)
results["success"].append({"name": name, "symbol": symbol})
logger.info(f"采集成功: {symbol}")
else:
error_msg = result.get("error", "未知错误")
logger.error(f"采集失败: {symbol}, 错误: {error_msg}")
results["failed"].append({
"name": name,
"symbol": symbol,
"error": error_msg,
})
return results
@router.post("/batch-tasks")
def batch_create_tasks(
periods: Optional[str] = None,
interval_seconds: int = 300,
data_type: str = "futures",
db: Session = Depends(get_db),
):
"""
根据配置文件为所有品种批量创建定时任务
"""
_ensure_config_dir()
if not CONFIG_FILE.exists():
raise HTTPException(status_code=400, detail="请先上传品种配置文件")
with open(CONFIG_FILE, "r", encoding="utf-8") as f:
config = json.load(f)
symbols_dict = config.get(data_type, {})
if not symbols_dict:
raise HTTPException(status_code=400, detail=f"配置中没有 {data_type} 类型的品种")
period_list = [p.strip() for p in periods.split(",")] if periods else ["5min", "15min", "30min", "60min", "daily"]
results = {"total": len(symbols_dict), "created": [], "failed": []}
for name, symbol in symbols_dict.items():
try:
task = create_task(
db=db,
symbol=symbol,
data_type=data_type,
periods=period_list,
interval_seconds=interval_seconds,
)
job_id = add_job(task.id, task.interval_seconds)
task.job_id = job_id
db.commit()
db.refresh(task)
results["created"].append({
"name": name,
"symbol": symbol,
"task_id": task.id,
"job_id": job_id,
"interval": interval_seconds,
})
except Exception as e:
results["failed"].append({
"name": name,
"symbol": symbol,
"error": str(e),
})
return results