You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

209 lines
6.5 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
数据接口 - 批量获取 / 获取最新缓存
"""
import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
from sqlalchemy.orm import Session
from app.database import get_db
from app.schemas import (
BatchFetchRequest,
BatchFetchResponse,
LatestDataResponse,
CandleItem,
TimeframeData,
SymbolDataResponse,
)
from app.services.collector import fetch_symbol_data, fetch_batch
from app.services.cache import (
save_market_data,
get_cached_data,
get_latest_cached,
check_cache_status,
)
from app.config import CACHE_TTL_SECONDS
from datetime import datetime
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/data", tags=["数据"])
@router.post("/batch-fetch", response_model=BatchFetchResponse)
def batch_fetch(req: BatchFetchRequest, db: Session = Depends(get_db)):
"""
批量获取指定品种、指定周期的数据。
智能缓存:已存在且有效的数据不重复请求。
"""
symbols = req.symbols
periods = req.periods
data_type = req.data_type
success = []
failed = []
details = {}
for sym in symbols:
status = check_cache_status(db, sym, data_type, periods)
if status["all_valid"]:
logger.info(f"[{sym}] 缓存全部命中,跳过采集")
cached = get_cached_data(db, sym, data_type, periods)
timeframes = []
for p, candles in cached["timeframes"].items():
timeframes.append(TimeframeData(
period=p,
candles=[CandleItem(**c) for c in candles],
candle_count=len(candles),
fetched_at=cached.get("timestamp", ""),
))
details[sym] = SymbolDataResponse(
symbol=sym,
data_type=data_type,
current_price=cached.get("current_price"),
timeframes=timeframes,
source="cache",
)
success.append(sym)
continue
need_fetch = status["missing_periods"]
logger.info(f"[{sym}] 缓存部分缺失,需要采集: {need_fetch}")
result = fetch_symbol_data(sym, data_type, need_fetch)
if result.get("timeframes"):
save_market_data(db, sym, result)
success.append(sym)
all_timeframes = {}
if status["valid_periods"]:
existing = get_cached_data(db, sym, data_type, status["valid_periods"])
if existing:
all_timeframes.update(existing["timeframes"])
all_timeframes.update(result["timeframes"])
timeframes = []
for p in periods:
candles = all_timeframes.get(p, [])
if candles:
timeframes.append(TimeframeData(
period=p,
candles=[CandleItem(**c) for c in candles],
candle_count=len(candles),
fetched_at=result.get("timestamp", ""),
))
details[sym] = SymbolDataResponse(
symbol=sym,
data_type=data_type,
current_price=result.get("current_price"),
timeframes=timeframes,
source="live+cache",
)
else:
failed.append(sym)
details[sym] = {"error": result.get("error", "未知错误")}
return BatchFetchResponse(
success=success,
failed=failed,
details=details,
)
@router.get("/latest/{symbol}", response_model=SymbolDataResponse)
def get_latest(
symbol: str,
data_type: str = "futures",
period: Optional[str] = None,
db: Session = Depends(get_db),
):
"""
从缓存获取最新数据。
可指定单个 period不指定则返回所有已缓存周期。
"""
cached = get_cached_data(db, symbol, data_type, [period] if period else None)
if not cached:
raise HTTPException(status_code=404, detail=f"未找到 {symbol} 的缓存数据")
timeframes = []
for p, candles in cached["timeframes"].items():
# 转换数据格式: time -> datetime
normalized_candles = []
for c in candles:
candle_dict = dict(c)
if 'time' in candle_dict and 'datetime' not in candle_dict:
candle_dict['datetime'] = candle_dict.pop('time')
normalized_candles.append(candle_dict)
timeframes.append(TimeframeData(
period=p,
candles=[CandleItem(**c) for c in normalized_candles],
candle_count=len(normalized_candles),
fetched_at=cached.get("timestamp", ""),
))
return SymbolDataResponse(
symbol=symbol,
data_type=data_type,
current_price=cached.get("current_price"),
timeframes=timeframes,
source="cache" if cached.get("is_fresh", False) else "cache_stale",
)
@router.get("/latest/{symbol}/{period}")
def get_latest_by_period(
symbol: str,
period: str,
data_type: str = "futures",
db: Session = Depends(get_db),
):
"""
获取缓存中指定品种+周期的最新数据。
返回单个周期的 K 线。
"""
cached = get_cached_data(db, symbol, data_type, [period])
if not cached:
raise HTTPException(status_code=404, detail=f"未找到 {symbol} {period} 的缓存")
candles = cached["timeframes"].get(period, [])
return {
"symbol": symbol,
"period": period,
"data_type": data_type,
"candles": candles,
"candle_count": len(candles),
"current_price": cached.get("current_price"),
"fetched_at": cached.get("timestamp"),
"is_fresh": cached.get("is_fresh", False),
}
@router.get("/cache-status/{symbol}")
def cache_status(symbol: str, db: Session = Depends(get_db)):
"""查看品种的缓存状态"""
records = get_latest_cached(db, symbol)
if not records:
return {"symbol": symbol, "cached_periods": [], "status": "no_data"}
now = datetime.now()
periods_info = []
for r in records:
age_seconds = (now - r.fetched_at).total_seconds()
periods_info.append({
"period": r.period,
"candle_count": r.candle_count,
"fetched_at": r.fetched_at.isoformat(),
"age_seconds": round(age_seconds, 0),
"is_fresh": age_seconds < CACHE_TTL_SECONDS,
})
return {
"symbol": symbol,
"cached_periods": periods_info,
"status": "ok",
}