You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

321 lines
8.7 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
缓存服务 - SQLite 数据库操作
"""
import json
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Optional
from sqlalchemy.orm import Session
from app.models import MarketData, ScheduledTask, SymbolTimestamp
from app.config import CACHE_TTL_SECONDS
logger = logging.getLogger(__name__)
# ===== 市场数据缓存 =====
def is_cache_valid(
db: Session,
symbol: str,
data_type: str,
period: str,
ttl_seconds: int = CACHE_TTL_SECONDS,
) -> bool:
"""检查指定品种+周期的缓存是否在有效期内"""
record = db.query(MarketData).filter_by(
symbol=symbol,
data_type=data_type,
period=period,
).first()
if not record:
return False
age = (datetime.now() - record.fetched_at).total_seconds()
return age < ttl_seconds
def check_cache_status(
db: Session,
symbol: str,
data_type: str,
periods: List[str],
ttl_seconds: int = CACHE_TTL_SECONDS,
) -> dict:
"""
检查一组周期的缓存状态。
Returns:
{
"all_valid": bool, # 所有周期都有有效缓存
"valid_periods": [...],
"missing_periods": [...],
}
"""
valid = []
missing = []
for p in periods:
if is_cache_valid(db, symbol, data_type, p, ttl_seconds):
valid.append(p)
else:
missing.append(p)
return {
"all_valid": len(missing) == 0,
"valid_periods": valid,
"missing_periods": missing,
}
def save_market_data(db: Session, symbol: str, data: Dict) -> MarketData:
"""
保存采集结果到缓存,并同步更新合约时间戳。
Args:
symbol: 品种代码
data: 采集脚本返回的完整数据
Returns:
保存的 MarketData 记录
"""
now = datetime.now()
# 按 period 拆分存储(每个周期一条记录)
for period, candles in data.get("timeframes", {}).items():
record = db.query(MarketData).filter_by(
symbol=symbol,
data_type=data.get("type", "futures"),
period=period,
).first()
candles_json = json.dumps(candles, ensure_ascii=False)
if record:
record.candles_json = candles_json
record.current_price = data.get("current_price")
record.fetched_at = now
record.candle_count = len(candles)
else:
record = MarketData(
symbol=symbol,
data_type=data.get("type", "futures"),
period=period,
candles_json=candles_json,
current_price=data.get("current_price"),
fetched_at=now,
candle_count=len(candles),
)
db.add(record)
# 更新合约时间戳
update_symbol_timestamp(db, symbol, data.get("type", "futures"), now)
db.commit()
logger.info(f"缓存已更新: {symbol}, {len(data.get('timeframes', {}))} 个周期")
# 返回最新的一条作为代表
return db.query(MarketData).filter_by(
symbol=symbol,
data_type=data.get("type", "futures"),
).order_by(MarketData.fetched_at.desc()).first()
def update_symbol_timestamp(db: Session, symbol: str, data_type: str, refresh_time: datetime) -> None:
"""更新或创建合约时间戳记录"""
timestamp_record = db.query(SymbolTimestamp).filter_by(
symbol=symbol,
data_type=data_type
).first()
if timestamp_record:
timestamp_record.last_refresh_at = refresh_time
timestamp_record.refresh_count += 1
else:
timestamp_record = SymbolTimestamp(
symbol=symbol,
data_type=data_type,
last_refresh_at=refresh_time,
refresh_count=1
)
db.add(timestamp_record)
db.commit()
def get_symbol_timestamp(db: Session, symbol: str, data_type: str = "futures") -> Optional[datetime]:
"""获取合约最后刷新时间"""
record = db.query(SymbolTimestamp).filter_by(
symbol=symbol,
data_type=data_type
).first()
return record.last_refresh_at if record else None
def needs_refresh(db: Session, symbol: str, data_type: str = "futures", threshold_seconds: int = 300) -> bool:
"""
检查合约是否需要刷新(数据是否超过阈值时间)
Args:
db: 数据库会话
symbol: 品种代码
data_type: 数据类型
threshold_seconds: 阈值时间默认300秒5分钟
Returns:
True 表示需要刷新False 表示数据仍然新鲜
"""
last_refresh = get_symbol_timestamp(db, symbol, data_type)
if last_refresh is None:
return True # 从未刷新过,需要刷新
age = (datetime.now() - last_refresh).total_seconds()
return age > threshold_seconds
def get_latest_cached(
db: Session,
symbol: str,
data_type: str = "futures",
period: Optional[str] = None,
) -> List[MarketData]:
"""获取最新缓存数据"""
query = db.query(MarketData).filter_by(symbol=symbol, data_type=data_type)
if period:
query = query.filter_by(period=period)
return query.order_by(MarketData.fetched_at.desc()).all()
def get_cached_data(
db: Session,
symbol: str,
data_type: str = "futures",
periods: Optional[List[str]] = None,
) -> Optional[Dict]:
"""
从缓存中获取完整的多周期数据。
Returns:
与采集脚本相同格式的数据,或 None
"""
query = db.query(MarketData).filter_by(symbol=symbol, data_type=data_type)
if periods:
query = query.filter(MarketData.period.in_(periods))
records = query.all()
if not records:
return None
# 检查缓存是否过期
now = datetime.now()
newest = max(r.fetched_at for r in records)
is_fresh = (now - newest).total_seconds() < CACHE_TTL_SECONDS
timeframes = {}
current_price = None
for r in records:
timeframes[r.period] = json.loads(r.candles_json)
if current_price is None:
current_price = r.current_price
return {
"symbol": symbol,
"type": data_type,
"current_price": current_price,
"timestamp": newest.isoformat(),
"timeframes": timeframes,
"is_fresh": is_fresh,
"fetched_at": newest.isoformat(),
}
# ===== 定时任务管理 =====
def create_task(
db: Session,
symbol: str,
data_type: str,
periods: str,
interval_seconds: int,
task_type: str = "interval",
run_time: Optional[str] = None,
) -> ScheduledTask:
"""创建定时任务配置"""
existing = db.query(ScheduledTask).filter_by(
symbol=symbol, data_type=data_type
).first()
if existing:
existing.periods = periods
existing.interval_seconds = interval_seconds
existing.task_type = task_type
existing.run_time = run_time
existing.enabled = True
existing.updated_at = datetime.now()
db.commit()
db.refresh(existing)
return existing
task = ScheduledTask(
symbol=symbol,
data_type=data_type,
periods=periods,
interval_seconds=interval_seconds,
task_type=task_type,
run_time=run_time,
enabled=True,
)
db.add(task)
db.commit()
db.refresh(task)
return task
def list_tasks(db: Session) -> List[ScheduledTask]:
"""列出所有任务"""
return db.query(ScheduledTask).order_by(ScheduledTask.created_at.desc()).all()
def get_task(db: Session, task_id: int) -> Optional[ScheduledTask]:
"""获取单个任务"""
return db.query(ScheduledTask).filter_by(id=task_id).first()
def disable_task(db: Session, task_id: int) -> Optional[ScheduledTask]:
"""禁用任务"""
task = db.query(ScheduledTask).filter_by(id=task_id).first()
if task:
task.enabled = False
task.updated_at = datetime.now()
db.commit()
db.refresh(task)
return task
def enable_task(db: Session, task_id: int) -> Optional[ScheduledTask]:
"""启用任务"""
task = db.query(ScheduledTask).filter_by(id=task_id).first()
if task:
task.enabled = True
task.updated_at = datetime.now()
db.commit()
db.refresh(task)
return task
def delete_task(db: Session, task_id: int) -> bool:
"""删除任务"""
task = db.query(ScheduledTask).filter_by(id=task_id).first()
if task:
db.delete(task)
db.commit()
return True
return False
def update_task_status(
db: Session, task_id: int, status: str
) -> None:
"""更新任务执行状态"""
task = db.query(ScheduledTask).filter_by(id=task_id).first()
if task:
task.last_run = datetime.now()
task.last_status = status
db.commit()