""" 缓存服务 - SQLite 数据库操作 """ import json import logging from datetime import datetime, timedelta from typing import Dict, List, Optional from sqlalchemy.orm import Session from app.models import MarketData, ScheduledTask, SymbolTimestamp from app.config import CACHE_TTL_SECONDS logger = logging.getLogger(__name__) # ===== 市场数据缓存 ===== def is_cache_valid( db: Session, symbol: str, data_type: str, period: str, ttl_seconds: int = CACHE_TTL_SECONDS, ) -> bool: """检查指定品种+周期的缓存是否在有效期内""" record = db.query(MarketData).filter_by( symbol=symbol, data_type=data_type, period=period, ).first() if not record: return False age = (datetime.now() - record.fetched_at).total_seconds() return age < ttl_seconds def check_cache_status( db: Session, symbol: str, data_type: str, periods: List[str], ttl_seconds: int = CACHE_TTL_SECONDS, ) -> dict: """ 检查一组周期的缓存状态。 Returns: { "all_valid": bool, # 所有周期都有有效缓存 "valid_periods": [...], "missing_periods": [...], } """ valid = [] missing = [] for p in periods: if is_cache_valid(db, symbol, data_type, p, ttl_seconds): valid.append(p) else: missing.append(p) return { "all_valid": len(missing) == 0, "valid_periods": valid, "missing_periods": missing, } def save_market_data(db: Session, symbol: str, data: Dict) -> MarketData: """ 保存采集结果到缓存,并同步更新合约时间戳。 Args: symbol: 品种代码 data: 采集脚本返回的完整数据 Returns: 保存的 MarketData 记录 """ now = datetime.now() # 按 period 拆分存储(每个周期一条记录) for period, candles in data.get("timeframes", {}).items(): record = db.query(MarketData).filter_by( symbol=symbol, data_type=data.get("type", "futures"), period=period, ).first() candles_json = json.dumps(candles, ensure_ascii=False) if record: record.candles_json = candles_json record.current_price = data.get("current_price") record.fetched_at = now record.candle_count = len(candles) else: record = MarketData( symbol=symbol, data_type=data.get("type", "futures"), period=period, candles_json=candles_json, current_price=data.get("current_price"), fetched_at=now, candle_count=len(candles), ) db.add(record) # 更新合约时间戳 update_symbol_timestamp(db, symbol, data.get("type", "futures"), now) db.commit() logger.info(f"缓存已更新: {symbol}, {len(data.get('timeframes', {}))} 个周期") # 返回最新的一条作为代表 return db.query(MarketData).filter_by( symbol=symbol, data_type=data.get("type", "futures"), ).order_by(MarketData.fetched_at.desc()).first() def update_symbol_timestamp(db: Session, symbol: str, data_type: str, refresh_time: datetime) -> None: """更新或创建合约时间戳记录""" timestamp_record = db.query(SymbolTimestamp).filter_by( symbol=symbol, data_type=data_type ).first() if timestamp_record: timestamp_record.last_refresh_at = refresh_time timestamp_record.refresh_count += 1 else: timestamp_record = SymbolTimestamp( symbol=symbol, data_type=data_type, last_refresh_at=refresh_time, refresh_count=1 ) db.add(timestamp_record) db.commit() def get_symbol_timestamp(db: Session, symbol: str, data_type: str = "futures") -> Optional[datetime]: """获取合约最后刷新时间""" record = db.query(SymbolTimestamp).filter_by( symbol=symbol, data_type=data_type ).first() return record.last_refresh_at if record else None def needs_refresh(db: Session, symbol: str, data_type: str = "futures", threshold_seconds: int = 300) -> bool: """ 检查合约是否需要刷新(数据是否超过阈值时间) Args: db: 数据库会话 symbol: 品种代码 data_type: 数据类型 threshold_seconds: 阈值时间(秒),默认300秒(5分钟) Returns: True 表示需要刷新,False 表示数据仍然新鲜 """ last_refresh = get_symbol_timestamp(db, symbol, data_type) if last_refresh is None: return True # 从未刷新过,需要刷新 age = (datetime.now() - last_refresh).total_seconds() return age > threshold_seconds def get_latest_cached( db: Session, symbol: str, data_type: str = "futures", period: Optional[str] = None, ) -> List[MarketData]: """获取最新缓存数据""" query = db.query(MarketData).filter_by(symbol=symbol, data_type=data_type) if period: query = query.filter_by(period=period) return query.order_by(MarketData.fetched_at.desc()).all() def get_cached_data( db: Session, symbol: str, data_type: str = "futures", periods: Optional[List[str]] = None, ) -> Optional[Dict]: """ 从缓存中获取完整的多周期数据。 Returns: 与采集脚本相同格式的数据,或 None """ query = db.query(MarketData).filter_by(symbol=symbol, data_type=data_type) if periods: query = query.filter(MarketData.period.in_(periods)) records = query.all() if not records: return None # 检查缓存是否过期 now = datetime.now() newest = max(r.fetched_at for r in records) is_fresh = (now - newest).total_seconds() < CACHE_TTL_SECONDS timeframes = {} current_price = None for r in records: timeframes[r.period] = json.loads(r.candles_json) if current_price is None: current_price = r.current_price return { "symbol": symbol, "type": data_type, "current_price": current_price, "timestamp": newest.isoformat(), "timeframes": timeframes, "is_fresh": is_fresh, "fetched_at": newest.isoformat(), } # ===== 定时任务管理 ===== def create_task( db: Session, symbol: str, data_type: str, periods: str, interval_seconds: int, task_type: str = "interval", run_time: Optional[str] = None, ) -> ScheduledTask: """创建定时任务配置""" existing = db.query(ScheduledTask).filter_by( symbol=symbol, data_type=data_type ).first() if existing: existing.periods = periods existing.interval_seconds = interval_seconds existing.task_type = task_type existing.run_time = run_time existing.enabled = True existing.updated_at = datetime.now() db.commit() db.refresh(existing) return existing task = ScheduledTask( symbol=symbol, data_type=data_type, periods=periods, interval_seconds=interval_seconds, task_type=task_type, run_time=run_time, enabled=True, ) db.add(task) db.commit() db.refresh(task) return task def list_tasks(db: Session) -> List[ScheduledTask]: """列出所有任务""" return db.query(ScheduledTask).order_by(ScheduledTask.created_at.desc()).all() def get_task(db: Session, task_id: int) -> Optional[ScheduledTask]: """获取单个任务""" return db.query(ScheduledTask).filter_by(id=task_id).first() def disable_task(db: Session, task_id: int) -> Optional[ScheduledTask]: """禁用任务""" task = db.query(ScheduledTask).filter_by(id=task_id).first() if task: task.enabled = False task.updated_at = datetime.now() db.commit() db.refresh(task) return task def enable_task(db: Session, task_id: int) -> Optional[ScheduledTask]: """启用任务""" task = db.query(ScheduledTask).filter_by(id=task_id).first() if task: task.enabled = True task.updated_at = datetime.now() db.commit() db.refresh(task) return task def delete_task(db: Session, task_id: int) -> bool: """删除任务""" task = db.query(ScheduledTask).filter_by(id=task_id).first() if task: db.delete(task) db.commit() return True return False def update_task_status( db: Session, task_id: int, status: str ) -> None: """更新任务执行状态""" task = db.query(ScheduledTask).filter_by(id=task_id).first() if task: task.last_run = datetime.now() task.last_status = status db.commit()