|
|
"""
|
|
|
缓存服务 - SQLite 数据库操作
|
|
|
"""
|
|
|
import json
|
|
|
import logging
|
|
|
from datetime import datetime, timedelta
|
|
|
from typing import Dict, List, Optional
|
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
from app.models import MarketData, ScheduledTask, SymbolTimestamp
|
|
|
from app.config import CACHE_TTL_SECONDS
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
# ===== 市场数据缓存 =====
|
|
|
|
|
|
def is_cache_valid(
|
|
|
db: Session,
|
|
|
symbol: str,
|
|
|
data_type: str,
|
|
|
period: str,
|
|
|
ttl_seconds: int = CACHE_TTL_SECONDS,
|
|
|
) -> bool:
|
|
|
"""检查指定品种+周期的缓存是否在有效期内"""
|
|
|
record = db.query(MarketData).filter_by(
|
|
|
symbol=symbol,
|
|
|
data_type=data_type,
|
|
|
period=period,
|
|
|
).first()
|
|
|
if not record:
|
|
|
return False
|
|
|
age = (datetime.now() - record.fetched_at).total_seconds()
|
|
|
return age < ttl_seconds
|
|
|
|
|
|
|
|
|
def check_cache_status(
|
|
|
db: Session,
|
|
|
symbol: str,
|
|
|
data_type: str,
|
|
|
periods: List[str],
|
|
|
ttl_seconds: int = CACHE_TTL_SECONDS,
|
|
|
) -> dict:
|
|
|
"""
|
|
|
检查一组周期的缓存状态。
|
|
|
Returns:
|
|
|
{
|
|
|
"all_valid": bool, # 所有周期都有有效缓存
|
|
|
"valid_periods": [...],
|
|
|
"missing_periods": [...],
|
|
|
}
|
|
|
"""
|
|
|
valid = []
|
|
|
missing = []
|
|
|
for p in periods:
|
|
|
if is_cache_valid(db, symbol, data_type, p, ttl_seconds):
|
|
|
valid.append(p)
|
|
|
else:
|
|
|
missing.append(p)
|
|
|
return {
|
|
|
"all_valid": len(missing) == 0,
|
|
|
"valid_periods": valid,
|
|
|
"missing_periods": missing,
|
|
|
}
|
|
|
|
|
|
|
|
|
def save_market_data(db: Session, symbol: str, data: Dict) -> MarketData:
|
|
|
"""
|
|
|
保存采集结果到缓存,并同步更新合约时间戳。
|
|
|
|
|
|
Args:
|
|
|
symbol: 品种代码
|
|
|
data: 采集脚本返回的完整数据
|
|
|
|
|
|
Returns:
|
|
|
保存的 MarketData 记录
|
|
|
"""
|
|
|
now = datetime.now()
|
|
|
|
|
|
# 按 period 拆分存储(每个周期一条记录)
|
|
|
for period, candles in data.get("timeframes", {}).items():
|
|
|
record = db.query(MarketData).filter_by(
|
|
|
symbol=symbol,
|
|
|
data_type=data.get("type", "futures"),
|
|
|
period=period,
|
|
|
).first()
|
|
|
|
|
|
candles_json = json.dumps(candles, ensure_ascii=False)
|
|
|
|
|
|
if record:
|
|
|
record.candles_json = candles_json
|
|
|
record.current_price = data.get("current_price")
|
|
|
record.fetched_at = now
|
|
|
record.candle_count = len(candles)
|
|
|
else:
|
|
|
record = MarketData(
|
|
|
symbol=symbol,
|
|
|
data_type=data.get("type", "futures"),
|
|
|
period=period,
|
|
|
candles_json=candles_json,
|
|
|
current_price=data.get("current_price"),
|
|
|
fetched_at=now,
|
|
|
candle_count=len(candles),
|
|
|
)
|
|
|
db.add(record)
|
|
|
|
|
|
# 更新合约时间戳
|
|
|
update_symbol_timestamp(db, symbol, data.get("type", "futures"), now)
|
|
|
|
|
|
db.commit()
|
|
|
logger.info(f"缓存已更新: {symbol}, {len(data.get('timeframes', {}))} 个周期")
|
|
|
|
|
|
# 返回最新的一条作为代表
|
|
|
return db.query(MarketData).filter_by(
|
|
|
symbol=symbol,
|
|
|
data_type=data.get("type", "futures"),
|
|
|
).order_by(MarketData.fetched_at.desc()).first()
|
|
|
|
|
|
|
|
|
def update_symbol_timestamp(db: Session, symbol: str, data_type: str, refresh_time: datetime) -> None:
|
|
|
"""更新或创建合约时间戳记录"""
|
|
|
timestamp_record = db.query(SymbolTimestamp).filter_by(
|
|
|
symbol=symbol,
|
|
|
data_type=data_type
|
|
|
).first()
|
|
|
|
|
|
if timestamp_record:
|
|
|
timestamp_record.last_refresh_at = refresh_time
|
|
|
timestamp_record.refresh_count += 1
|
|
|
else:
|
|
|
timestamp_record = SymbolTimestamp(
|
|
|
symbol=symbol,
|
|
|
data_type=data_type,
|
|
|
last_refresh_at=refresh_time,
|
|
|
refresh_count=1
|
|
|
)
|
|
|
db.add(timestamp_record)
|
|
|
|
|
|
db.commit()
|
|
|
|
|
|
|
|
|
def get_symbol_timestamp(db: Session, symbol: str, data_type: str = "futures") -> Optional[datetime]:
|
|
|
"""获取合约最后刷新时间"""
|
|
|
record = db.query(SymbolTimestamp).filter_by(
|
|
|
symbol=symbol,
|
|
|
data_type=data_type
|
|
|
).first()
|
|
|
return record.last_refresh_at if record else None
|
|
|
|
|
|
|
|
|
def needs_refresh(db: Session, symbol: str, data_type: str = "futures", threshold_seconds: int = 300) -> bool:
|
|
|
"""
|
|
|
检查合约是否需要刷新(数据是否超过阈值时间)
|
|
|
|
|
|
Args:
|
|
|
db: 数据库会话
|
|
|
symbol: 品种代码
|
|
|
data_type: 数据类型
|
|
|
threshold_seconds: 阈值时间(秒),默认300秒(5分钟)
|
|
|
|
|
|
Returns:
|
|
|
True 表示需要刷新,False 表示数据仍然新鲜
|
|
|
"""
|
|
|
last_refresh = get_symbol_timestamp(db, symbol, data_type)
|
|
|
if last_refresh is None:
|
|
|
return True # 从未刷新过,需要刷新
|
|
|
|
|
|
age = (datetime.now() - last_refresh).total_seconds()
|
|
|
return age > threshold_seconds
|
|
|
|
|
|
|
|
|
def get_latest_cached(
|
|
|
db: Session,
|
|
|
symbol: str,
|
|
|
data_type: str = "futures",
|
|
|
period: Optional[str] = None,
|
|
|
) -> List[MarketData]:
|
|
|
"""获取最新缓存数据"""
|
|
|
query = db.query(MarketData).filter_by(symbol=symbol, data_type=data_type)
|
|
|
if period:
|
|
|
query = query.filter_by(period=period)
|
|
|
return query.order_by(MarketData.fetched_at.desc()).all()
|
|
|
|
|
|
|
|
|
def get_cached_data(
|
|
|
db: Session,
|
|
|
symbol: str,
|
|
|
data_type: str = "futures",
|
|
|
periods: Optional[List[str]] = None,
|
|
|
) -> Optional[Dict]:
|
|
|
"""
|
|
|
从缓存中获取完整的多周期数据。
|
|
|
|
|
|
Returns:
|
|
|
与采集脚本相同格式的数据,或 None
|
|
|
"""
|
|
|
query = db.query(MarketData).filter_by(symbol=symbol, data_type=data_type)
|
|
|
if periods:
|
|
|
query = query.filter(MarketData.period.in_(periods))
|
|
|
|
|
|
records = query.all()
|
|
|
if not records:
|
|
|
return None
|
|
|
|
|
|
# 检查缓存是否过期
|
|
|
now = datetime.now()
|
|
|
newest = max(r.fetched_at for r in records)
|
|
|
is_fresh = (now - newest).total_seconds() < CACHE_TTL_SECONDS
|
|
|
|
|
|
timeframes = {}
|
|
|
current_price = None
|
|
|
for r in records:
|
|
|
timeframes[r.period] = json.loads(r.candles_json)
|
|
|
if current_price is None:
|
|
|
current_price = r.current_price
|
|
|
|
|
|
return {
|
|
|
"symbol": symbol,
|
|
|
"type": data_type,
|
|
|
"current_price": current_price,
|
|
|
"timestamp": newest.isoformat(),
|
|
|
"timeframes": timeframes,
|
|
|
"is_fresh": is_fresh,
|
|
|
"fetched_at": newest.isoformat(),
|
|
|
}
|
|
|
|
|
|
|
|
|
# ===== 定时任务管理 =====
|
|
|
|
|
|
def create_task(
|
|
|
db: Session,
|
|
|
symbol: str,
|
|
|
data_type: str,
|
|
|
periods: str,
|
|
|
interval_seconds: int,
|
|
|
task_type: str = "interval",
|
|
|
run_time: Optional[str] = None,
|
|
|
) -> ScheduledTask:
|
|
|
"""创建定时任务配置"""
|
|
|
existing = db.query(ScheduledTask).filter_by(
|
|
|
symbol=symbol, data_type=data_type
|
|
|
).first()
|
|
|
if existing:
|
|
|
existing.periods = periods
|
|
|
existing.interval_seconds = interval_seconds
|
|
|
existing.task_type = task_type
|
|
|
existing.run_time = run_time
|
|
|
existing.enabled = True
|
|
|
existing.updated_at = datetime.now()
|
|
|
db.commit()
|
|
|
db.refresh(existing)
|
|
|
return existing
|
|
|
|
|
|
task = ScheduledTask(
|
|
|
symbol=symbol,
|
|
|
data_type=data_type,
|
|
|
periods=periods,
|
|
|
interval_seconds=interval_seconds,
|
|
|
task_type=task_type,
|
|
|
run_time=run_time,
|
|
|
enabled=True,
|
|
|
)
|
|
|
db.add(task)
|
|
|
db.commit()
|
|
|
db.refresh(task)
|
|
|
return task
|
|
|
|
|
|
|
|
|
def list_tasks(db: Session) -> List[ScheduledTask]:
|
|
|
"""列出所有任务"""
|
|
|
return db.query(ScheduledTask).order_by(ScheduledTask.created_at.desc()).all()
|
|
|
|
|
|
|
|
|
def get_task(db: Session, task_id: int) -> Optional[ScheduledTask]:
|
|
|
"""获取单个任务"""
|
|
|
return db.query(ScheduledTask).filter_by(id=task_id).first()
|
|
|
|
|
|
|
|
|
def disable_task(db: Session, task_id: int) -> Optional[ScheduledTask]:
|
|
|
"""禁用任务"""
|
|
|
task = db.query(ScheduledTask).filter_by(id=task_id).first()
|
|
|
if task:
|
|
|
task.enabled = False
|
|
|
task.updated_at = datetime.now()
|
|
|
db.commit()
|
|
|
db.refresh(task)
|
|
|
return task
|
|
|
|
|
|
|
|
|
def enable_task(db: Session, task_id: int) -> Optional[ScheduledTask]:
|
|
|
"""启用任务"""
|
|
|
task = db.query(ScheduledTask).filter_by(id=task_id).first()
|
|
|
if task:
|
|
|
task.enabled = True
|
|
|
task.updated_at = datetime.now()
|
|
|
db.commit()
|
|
|
db.refresh(task)
|
|
|
return task
|
|
|
|
|
|
|
|
|
def delete_task(db: Session, task_id: int) -> bool:
|
|
|
"""删除任务"""
|
|
|
task = db.query(ScheduledTask).filter_by(id=task_id).first()
|
|
|
if task:
|
|
|
db.delete(task)
|
|
|
db.commit()
|
|
|
return True
|
|
|
return False
|
|
|
|
|
|
|
|
|
def update_task_status(
|
|
|
db: Session, task_id: int, status: str
|
|
|
) -> None:
|
|
|
"""更新任务执行状态"""
|
|
|
task = db.query(ScheduledTask).filter_by(id=task_id).first()
|
|
|
if task:
|
|
|
task.last_run = datetime.now()
|
|
|
task.last_status = status
|
|
|
db.commit()
|