|
|
"""
|
|
|
配置管理接口 - 品种配置文件上传、批量获取、批量任务创建
|
|
|
"""
|
|
|
import json
|
|
|
import logging
|
|
|
import shutil
|
|
|
from pathlib import Path
|
|
|
from typing import Optional
|
|
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Body
|
|
|
from fastapi.responses import JSONResponse
|
|
|
from sqlalchemy.orm import Session
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
from app.database import get_db
|
|
|
from app.services.collector import fetch_symbol_data
|
|
|
from app.services.cache import save_market_data, check_cache_status, get_cached_data, create_task
|
|
|
from app.services.scheduler import add_job
|
|
|
from app.schemas import CandleItem, TimeframeData, SymbolDataResponse
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
router = APIRouter(prefix="/config", tags=["品种配置"])
|
|
|
|
|
|
|
|
|
class BatchFetchRequest(BaseModel):
|
|
|
"""批量获取请求体"""
|
|
|
periods: Optional[str] = None
|
|
|
data_type: str = "futures"
|
|
|
selected_symbols: Optional[str] = None # 逗号分隔的合约代码
|
|
|
|
|
|
# 配置文件存储路径
|
|
|
CONFIG_DIR = Path(__file__).resolve().parent.parent.parent / "config"
|
|
|
CONFIG_FILE = CONFIG_DIR / "symbols_config.json"
|
|
|
|
|
|
|
|
|
def _ensure_config_dir():
|
|
|
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
@router.get("")
|
|
|
def get_config():
|
|
|
"""获取当前品种配置"""
|
|
|
_ensure_config_dir()
|
|
|
if not CONFIG_FILE.exists():
|
|
|
return {"futures": {}, "stock": {}}
|
|
|
with open(CONFIG_FILE, "r", encoding="utf-8") as f:
|
|
|
return json.load(f)
|
|
|
|
|
|
|
|
|
@router.post("/upload")
|
|
|
def upload_config(
|
|
|
file: Optional[UploadFile] = File(None),
|
|
|
json_config: Optional[dict] = Body(None, embed=False),
|
|
|
):
|
|
|
"""
|
|
|
上传品种配置文件(JSON格式)。
|
|
|
格式示例:
|
|
|
{
|
|
|
"futures": {"沪银": "AG2606", "沪金": "AU2606"},
|
|
|
"stock": {"平安银行": "000001"}
|
|
|
}
|
|
|
"""
|
|
|
_ensure_config_dir()
|
|
|
|
|
|
try:
|
|
|
if file:
|
|
|
content = file.file.read()
|
|
|
data = json.loads(content)
|
|
|
elif json_config:
|
|
|
data = json_config
|
|
|
else:
|
|
|
raise HTTPException(status_code=400, detail="请提供配置文件或JSON数据")
|
|
|
|
|
|
if not isinstance(data, dict):
|
|
|
raise HTTPException(status_code=400, detail="配置文件必须是 JSON 对象")
|
|
|
|
|
|
with open(CONFIG_FILE, "w", encoding="utf-8") as f:
|
|
|
json.dump(data, f, ensure_ascii=False, indent=4)
|
|
|
|
|
|
futures_count = len(data.get("futures", {}))
|
|
|
stock_count = len(data.get("stock", {}))
|
|
|
|
|
|
return {
|
|
|
"message": "配置文件上传成功",
|
|
|
"futures_symbols": futures_count,
|
|
|
"stock_symbols": stock_count,
|
|
|
"symbols": data,
|
|
|
}
|
|
|
except json.JSONDecodeError:
|
|
|
raise HTTPException(status_code=400, detail="无效的 JSON 格式")
|
|
|
|
|
|
|
|
|
@router.post("/batch-fetch-all")
|
|
|
def batch_fetch_all(
|
|
|
request: BatchFetchRequest,
|
|
|
db: Session = Depends(get_db),
|
|
|
):
|
|
|
"""
|
|
|
根据配置文件批量获取所有品种数据。
|
|
|
智能缓存:已存在且有效的数据不重复请求。
|
|
|
"""
|
|
|
periods = request.periods
|
|
|
data_type = request.data_type
|
|
|
selected_symbols = request.selected_symbols
|
|
|
_ensure_config_dir()
|
|
|
if not CONFIG_FILE.exists():
|
|
|
raise HTTPException(status_code=400, detail="请先上传品种配置文件")
|
|
|
|
|
|
with open(CONFIG_FILE, "r", encoding="utf-8") as f:
|
|
|
config = json.load(f)
|
|
|
|
|
|
symbols_dict = config.get(data_type, {})
|
|
|
if not symbols_dict:
|
|
|
raise HTTPException(status_code=400, detail=f"配置中没有 {data_type} 类型的品种")
|
|
|
|
|
|
# 如果指定了selected_symbols,只获取这些合约
|
|
|
if selected_symbols:
|
|
|
# 解析逗号分隔的合约代码
|
|
|
symbol_list = [s.strip() for s in selected_symbols.split(",") if s.strip()]
|
|
|
symbols_dict = {name: code for name, code in symbols_dict.items() if code in symbol_list}
|
|
|
if not symbols_dict:
|
|
|
raise HTTPException(status_code=400, detail="选定的合约不在配置中")
|
|
|
|
|
|
period_list = [p.strip() for p in periods.split(",")] if periods else ["5min", "15min", "30min", "60min", "daily"]
|
|
|
|
|
|
results = {
|
|
|
"total": len(symbols_dict),
|
|
|
"success": [],
|
|
|
"failed": [],
|
|
|
"cached": [], # 命中缓存的
|
|
|
"details": {},
|
|
|
}
|
|
|
|
|
|
for name, symbol in symbols_dict.items():
|
|
|
logger.info(f"处理品种: {name} ({symbol})")
|
|
|
|
|
|
# 检查缓存
|
|
|
status = check_cache_status(db, symbol, data_type, period_list)
|
|
|
|
|
|
if status["all_valid"]:
|
|
|
results["cached"].append({"name": name, "symbol": symbol})
|
|
|
cached = get_cached_data(db, symbol, data_type, period_list)
|
|
|
timeframes = []
|
|
|
for p, candles in cached["timeframes"].items():
|
|
|
# 转换数据格式: time -> datetime
|
|
|
normalized_candles = []
|
|
|
for c in candles:
|
|
|
candle_dict = dict(c)
|
|
|
if 'time' in candle_dict and 'datetime' not in candle_dict:
|
|
|
candle_dict['datetime'] = candle_dict.pop('time')
|
|
|
normalized_candles.append(candle_dict)
|
|
|
|
|
|
timeframes.append(TimeframeData(
|
|
|
period=p,
|
|
|
candles=[CandleItem(**c) for c in normalized_candles],
|
|
|
candle_count=len(normalized_candles),
|
|
|
fetched_at=cached.get("timestamp", ""),
|
|
|
))
|
|
|
results["details"][symbol] = SymbolDataResponse(
|
|
|
symbol=symbol,
|
|
|
data_type=data_type,
|
|
|
current_price=cached.get("current_price"),
|
|
|
timeframes=timeframes,
|
|
|
source="cache",
|
|
|
)
|
|
|
results["success"].append({"name": name, "symbol": symbol})
|
|
|
continue
|
|
|
|
|
|
# 需要采集
|
|
|
need_fetch = status["missing_periods"]
|
|
|
logger.info(f"需要采集的周期: {need_fetch}")
|
|
|
result = fetch_symbol_data(symbol, data_type, need_fetch)
|
|
|
|
|
|
if result.get("timeframes"):
|
|
|
logger.info(f"采集到 {len(result['timeframes'])} 个周期的数据,开始保存")
|
|
|
save_market_data(db, symbol, result)
|
|
|
|
|
|
# 合并缓存和新数据
|
|
|
all_timeframes = {}
|
|
|
if status["valid_periods"]:
|
|
|
existing = get_cached_data(db, symbol, data_type, status["valid_periods"])
|
|
|
if existing:
|
|
|
all_timeframes.update(existing["timeframes"])
|
|
|
all_timeframes.update(result["timeframes"])
|
|
|
|
|
|
timeframes = []
|
|
|
for p in period_list:
|
|
|
candles = all_timeframes.get(p, [])
|
|
|
if candles:
|
|
|
# 转换数据格式: time -> datetime
|
|
|
normalized_candles = []
|
|
|
for c in candles:
|
|
|
candle_dict = dict(c)
|
|
|
if 'time' in candle_dict and 'datetime' not in candle_dict:
|
|
|
candle_dict['datetime'] = candle_dict.pop('time')
|
|
|
normalized_candles.append(candle_dict)
|
|
|
|
|
|
timeframes.append(TimeframeData(
|
|
|
period=p,
|
|
|
candles=[CandleItem(**c) for c in normalized_candles],
|
|
|
candle_count=len(normalized_candles),
|
|
|
fetched_at=result.get("timestamp", ""),
|
|
|
))
|
|
|
|
|
|
source = "live+cache" if status["valid_periods"] else "live"
|
|
|
results["details"][symbol] = SymbolDataResponse(
|
|
|
symbol=symbol,
|
|
|
data_type=data_type,
|
|
|
current_price=result.get("current_price"),
|
|
|
timeframes=timeframes,
|
|
|
source=source,
|
|
|
)
|
|
|
results["success"].append({"name": name, "symbol": symbol})
|
|
|
logger.info(f"采集成功: {symbol}")
|
|
|
else:
|
|
|
error_msg = result.get("error", "未知错误")
|
|
|
logger.error(f"采集失败: {symbol}, 错误: {error_msg}")
|
|
|
results["failed"].append({
|
|
|
"name": name,
|
|
|
"symbol": symbol,
|
|
|
"error": error_msg,
|
|
|
})
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
@router.post("/batch-tasks")
|
|
|
def batch_create_tasks(
|
|
|
periods: Optional[str] = None,
|
|
|
interval_seconds: int = 300,
|
|
|
data_type: str = "futures",
|
|
|
db: Session = Depends(get_db),
|
|
|
):
|
|
|
"""
|
|
|
根据配置文件为所有品种批量创建定时任务。
|
|
|
"""
|
|
|
_ensure_config_dir()
|
|
|
if not CONFIG_FILE.exists():
|
|
|
raise HTTPException(status_code=400, detail="请先上传品种配置文件")
|
|
|
|
|
|
with open(CONFIG_FILE, "r", encoding="utf-8") as f:
|
|
|
config = json.load(f)
|
|
|
|
|
|
symbols_dict = config.get(data_type, {})
|
|
|
if not symbols_dict:
|
|
|
raise HTTPException(status_code=400, detail=f"配置中没有 {data_type} 类型的品种")
|
|
|
|
|
|
period_list = [p.strip() for p in periods.split(",")] if periods else ["5min", "15min", "30min", "60min", "daily"]
|
|
|
|
|
|
results = {"total": len(symbols_dict), "created": [], "failed": []}
|
|
|
|
|
|
for name, symbol in symbols_dict.items():
|
|
|
try:
|
|
|
task = create_task(
|
|
|
db=db,
|
|
|
symbol=symbol,
|
|
|
data_type=data_type,
|
|
|
periods=period_list,
|
|
|
interval_seconds=interval_seconds,
|
|
|
)
|
|
|
job_id = add_job(task.id, task.interval_seconds)
|
|
|
task.job_id = job_id
|
|
|
db.commit()
|
|
|
db.refresh(task)
|
|
|
|
|
|
results["created"].append({
|
|
|
"name": name,
|
|
|
"symbol": symbol,
|
|
|
"task_id": task.id,
|
|
|
"job_id": job_id,
|
|
|
"interval": interval_seconds,
|
|
|
})
|
|
|
except Exception as e:
|
|
|
results["failed"].append({
|
|
|
"name": name,
|
|
|
"symbol": symbol,
|
|
|
"error": str(e),
|
|
|
})
|
|
|
|
|
|
return results
|