You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

896 lines
34 KiB

"""
缓存管理服务
"""
import logging
from typing import List, Dict, Optional
from datetime import date, datetime
from sqlalchemy.orm import Session
from sqlalchemy import and_, func
from app.models.cache import CacheTask, CacheTaskDetail
from app.models.stock import StockKlineDaily
from app.models.future import FutureKlineDaily
from app.services.base_data_service import BaseDataService
from app.services.stock_service import StockService
from app.services.future_service import FutureService
from app.services.sdk_manager import sdk_manager
from app.utils.date_utils import parse_date, format_date, get_market_from_code
from app.config import settings
from app.core.redis_client import redis_client
logger = logging.getLogger(__name__)
class CacheService:
"""缓存服务"""
CODE_LIST_CACHE_EXPIRE = 12 * 60 * 60
def __init__(self, db: Session):
self.db = db
self.base_service = BaseDataService(db)
self.stock_service = StockService(db)
self.future_service = FutureService(db)
def get_all_codes(self, security_type: str, contract_type: str = "all") -> List[str]:
"""
获取所有代码列表带Redis缓存
Args:
security_type: 证券类型 (stock, future)
contract_type: 合约类型 (all, main) - 仅对期货有效
Returns:
代码列表
"""
cache_key = f"code_list:{security_type}:{contract_type}"
cached_codes = redis_client.get(cache_key)
if cached_codes is not None:
logger.info(f"从Redis缓存获取代码列表: {security_type}/{contract_type}, 共{len(cached_codes)}")
return cached_codes
adapter = sdk_manager.get_default_connection()
if not adapter:
raise RuntimeError("SDK连接失败")
codes = []
if security_type == "stock":
codes = adapter.get_code_list("EXTRA_STOCK_A")
elif security_type == "future":
if contract_type == "main":
main_contracts = adapter.get_all_main_contracts()
codes = list(main_contracts.values())
else:
codes = adapter.get_code_list("EXTRA_FUTURE")
if codes:
redis_client.set(cache_key, codes, expire=self.CODE_LIST_CACHE_EXPIRE)
logger.info(f"代码列表已缓存到Redis: {security_type}/{contract_type}, 共{len(codes)}个, 有效期12小时")
return codes
def get_future_varieties(self) -> List[str]:
"""获取期货品种列表"""
adapter = sdk_manager.get_default_connection()
if not adapter:
raise RuntimeError("SDK连接失败")
return adapter.get_future_varieties()
def detect_all_missing_data(
self,
security_type: str,
period_type: str,
start_date: date,
end_date: date,
contract_type: str = "all"
) -> Dict:
"""
一键检测所有数据的缺失情况
Args:
security_type: 证券类型 (stock, future)
period_type: 周期类型 (daily, min1, etc.)
start_date: 开始日期
end_date: 结束日期
contract_type: 合约类型 (all, main) - 仅对期货有效
Returns:
检测结果字典
"""
# 获取所有代码
code_list = self.get_all_codes(security_type, contract_type)
if not code_list:
raise ValueError(f"无法获取{security_type}代码列表")
logger.info(f"获取到{len(code_list)}{security_type}代码")
# 创建检测任务
task = CacheTask(
task_name=f"一键检测所有数据 - {security_type} - {contract_type} - {len(code_list)}个代码",
task_type="detect_all_missing",
security_type=security_type,
period_type=period_type,
start_date=start_date,
end_date=end_date,
code_list=",".join(code_list[:100]) + "...",
status="running",
total_count=len(code_list),
started_at=datetime.utcnow()
)
self.db.add(task)
self.db.commit()
self.db.refresh(task)
try:
# 获取交易日历
market = "CFE" if security_type == "future" else "SH"
trading_days = self.base_service.get_trading_calendar(market, start_date, end_date)
expected_count = len(trading_days)
missing_codes = []
complete_codes = []
error_count = 0
# 统计每个交易日的缺失情况
daily_stats = {}
for td in trading_days:
daily_stats[format_date(td)] = {
"expected": len(code_list),
"actual": 0,
"missing": 0
}
for i, code in enumerate(code_list):
try:
# 查询实际数据量
if security_type == "stock" and period_type == "daily":
records = self.db.query(StockKlineDaily).filter(
and_(
StockKlineDaily.code == code,
StockKlineDaily.trade_date >= start_date,
StockKlineDaily.trade_date <= end_date
)
).all()
actual_count = len(records)
# 更新每日统计
for r in records:
date_key = format_date(r.trade_date)
if date_key in daily_stats:
daily_stats[date_key]["actual"] += 1
elif security_type == "future" and period_type == "daily":
records = self.db.query(FutureKlineDaily).filter(
and_(
FutureKlineDaily.code == code,
FutureKlineDaily.trade_date >= start_date,
FutureKlineDaily.trade_date <= end_date
)
).all()
actual_count = len(records)
for r in records:
date_key = format_date(r.trade_date)
if date_key in daily_stats:
daily_stats[date_key]["actual"] += 1
else:
actual_count = 0
# 判断是否缺失
is_missing = actual_count < expected_count
if is_missing:
missing_codes.append({
"code": code,
"actual_count": actual_count,
"expected_count": expected_count,
"missing_count": expected_count - actual_count,
"missing_ratio": (expected_count - actual_count) / expected_count if expected_count > 0 else 0
})
detail = CacheTaskDetail(
task_id=task.id,
code=code,
trade_date=start_date,
expected_count=expected_count,
actual_count=actual_count,
is_missing=True,
status="pending"
)
self.db.add(detail)
else:
complete_codes.append(code)
except Exception as e:
logger.error(f"检测{code}缺失数据失败: {str(e)}")
error_count += 1
detail = CacheTaskDetail(
task_id=task.id,
code=code,
trade_date=start_date,
status="failed",
error_message=str(e)
)
self.db.add(detail)
# 每100个代码更新一次进度
if (i + 1) % 100 == 0 or i == len(code_list) - 1:
task.success_count = len(missing_codes) + len(complete_codes)
task.error_count = error_count
task.progress = min(100, int((i + 1) / len(code_list) * 100))
self.db.commit()
# 计算每日缺失数
for date_key in daily_stats:
daily_stats[date_key]["missing"] = daily_stats[date_key]["expected"] - daily_stats[date_key]["actual"]
# 保存缺失代码列表到任务记录
missing_code_list = [m["code"] for m in missing_codes]
task.code_list = ",".join(missing_code_list[:500]) if missing_code_list else ""
task.status = "completed"
task.success_count = len(complete_codes)
task.error_count = error_count
task.completed_at = datetime.utcnow()
self.db.commit()
logger.info(f"检测完成: 完整{len(complete_codes)}个, 缺失{len(missing_codes)}个, 错误{error_count}")
return {
"task_id": task.id,
"task_name": task.task_name,
"status": task.status,
"progress": float(task.progress),
"total_count": task.total_count,
"complete_count": len(complete_codes),
"missing_count": len(missing_codes),
"error_count": error_count,
"expected_days": expected_count,
"start_date": format_date(start_date),
"end_date": format_date(end_date),
"security_type": security_type,
"period_type": period_type,
"daily_stats": daily_stats,
"missing_codes": missing_codes[:100],
"missing_code_list": missing_code_list
}
except Exception as e:
task.status = "failed"
task.error_message = str(e)
task.completed_at = datetime.utcnow()
self.db.commit()
logger.error(f"一键检测缺失数据任务失败: {str(e)}")
return {
"task_id": task.id,
"task_name": task.task_name,
"status": task.status,
"error_message": str(e)
}
def _create_cache_task(
self,
security_type: str,
period_type: str,
start_date: date,
end_date: date,
contract_type: str = "all"
) -> CacheTask:
"""创建缓存任务记录"""
code_list = self.get_all_codes(security_type, contract_type)
if not code_list:
raise ValueError(f"无法获取{security_type}代码列表")
task = CacheTask(
task_name=f"一键缓存所有数据 - {security_type} - {contract_type} - {len(code_list)}个代码",
task_type="cache_all_data",
security_type=security_type,
period_type=period_type,
start_date=start_date,
end_date=end_date,
code_list=",".join(code_list[:100]) + "...",
status="pending",
total_count=len(code_list),
started_at=datetime.utcnow()
)
self.db.add(task)
self.db.commit()
self.db.refresh(task)
return task
def _create_fill_missing_task(
self,
security_type: str,
period_type: str,
start_date: date,
end_date: date,
missing_codes: List[str]
) -> CacheTask:
"""创建补齐缺失数据任务记录"""
if not missing_codes:
raise ValueError("没有缺失代码需要补齐")
task = CacheTask(
task_name=f"一键补齐缺失数据 - {security_type} - {len(missing_codes)}个代码",
task_type="fill_missing_data",
security_type=security_type,
period_type=period_type,
start_date=start_date,
end_date=end_date,
code_list=",".join(missing_codes[:500]),
status="pending",
total_count=len(missing_codes),
started_at=datetime.utcnow()
)
self.db.add(task)
self.db.commit()
self.db.refresh(task)
return task
def _execute_fill_missing_task(
self,
task_id: int,
security_type: str,
period_type: str,
start_date: date,
end_date: date,
missing_codes: List[str]
):
"""执行补齐缺失数据任务"""
task = self.db.query(CacheTask).filter(CacheTask.id == task_id).first()
if not task:
return
task.status = "running"
self.db.commit()
try:
success_count = 0
error_count = 0
for i, code in enumerate(missing_codes):
try:
if security_type == "stock":
self.stock_service.get_kline([code], start_date, end_date, period_type)
elif security_type == "future":
self.future_service.get_kline([code], start_date, end_date, period_type)
success_count += 1
except Exception as e:
logger.error(f"补齐{code}数据失败: {str(e)}")
error_count += 1
if (i + 1) % 10 == 0 or i == len(missing_codes) - 1:
task.success_count = success_count
task.error_count = error_count
task.progress = min(100, int((i + 1) / len(missing_codes) * 100))
self.db.commit()
task.status = "completed"
task.success_count = success_count
task.error_count = error_count
task.completed_at = datetime.utcnow()
self.db.commit()
except Exception as e:
task.status = "failed"
task.error_message = str(e)
task.completed_at = datetime.utcnow()
self.db.commit()
def _execute_cache_task(
self,
task_id: int,
security_type: str,
period_type: str,
start_date: date,
end_date: date,
contract_type: str = "all"
):
"""执行缓存任务"""
task = self.db.query(CacheTask).filter(CacheTask.id == task_id).first()
if not task:
return
task.status = "running"
self.db.commit()
code_list = self.get_all_codes(security_type, contract_type)
try:
market = "CFE" if security_type == "future" else "SH"
trading_days = self.base_service.get_trading_calendar(market, start_date, end_date)
expected_count = len(trading_days)
success_count = 0
skipped_count = 0
error_count = 0
for i, code in enumerate(code_list):
try:
if security_type == "stock" and period_type == "daily":
actual_count = self.db.query(StockKlineDaily).filter(
and_(
StockKlineDaily.code == code,
StockKlineDaily.trade_date >= start_date,
StockKlineDaily.trade_date <= end_date
)
).count()
elif security_type == "future" and period_type == "daily":
actual_count = self.db.query(FutureKlineDaily).filter(
and_(
FutureKlineDaily.code == code,
FutureKlineDaily.trade_date >= start_date,
FutureKlineDaily.trade_date <= end_date
)
).count()
else:
actual_count = 0
if actual_count >= expected_count:
skipped_count += 1
continue
if security_type == "stock":
self.stock_service.get_kline([code], start_date, end_date, period_type)
elif security_type == "future":
self.future_service.get_kline([code], start_date, end_date, period_type)
success_count += 1
except Exception as e:
logger.error(f"缓存{code}数据失败: {str(e)}")
error_count += 1
if (i + 1) % 10 == 0 or i == len(code_list) - 1:
task.success_count = success_count
task.error_count = error_count
task.progress = min(100, int((i + 1) / len(code_list) * 100))
self.db.commit()
task.status = "completed"
task.success_count = success_count
task.error_count = error_count
task.completed_at = datetime.utcnow()
self.db.commit()
except Exception as e:
task.status = "failed"
task.error_message = str(e)
task.completed_at = datetime.utcnow()
self.db.commit()
def cache_all_missing_data(
self,
security_type: str,
period_type: str,
start_date: date,
end_date: date,
contract_type: str = "all"
) -> CacheTask:
"""
一键缓存所有缺失数据
Args:
security_type: 证券类型
period_type: 周期类型
start_date: 开始日期
end_date: 结束日期
contract_type: 合约类型 (all, main) - 仅对期货有效
Returns:
缓存任务对象
"""
# 获取所有代码
code_list = self.get_all_codes(security_type, contract_type)
if not code_list:
raise ValueError(f"无法获取{security_type}代码列表")
logger.info(f"获取到{len(code_list)}{security_type}代码,开始缓存")
# 创建缓存任务
task = CacheTask(
task_name=f"一键缓存所有数据 - {security_type} - {len(code_list)}个代码",
task_type="cache_all_data",
security_type=security_type,
period_type=period_type,
start_date=start_date,
end_date=end_date,
code_list=",".join(code_list[:100]) + "...",
status="running",
total_count=len(code_list),
started_at=datetime.utcnow()
)
self.db.add(task)
self.db.commit()
self.db.refresh(task)
try:
# 获取交易日历
market = "CFE" if security_type == "future" else "SH"
trading_days = self.base_service.get_trading_calendar(market, start_date, end_date)
expected_count = len(trading_days)
success_count = 0
skipped_count = 0
error_count = 0
for i, code in enumerate(code_list):
try:
# 先检查是否已有完整数据
if security_type == "stock" and period_type == "daily":
actual_count = self.db.query(StockKlineDaily).filter(
and_(
StockKlineDaily.code == code,
StockKlineDaily.trade_date >= start_date,
StockKlineDaily.trade_date <= end_date
)
).count()
elif security_type == "future" and period_type == "daily":
actual_count = self.db.query(FutureKlineDaily).filter(
and_(
FutureKlineDaily.code == code,
FutureKlineDaily.trade_date >= start_date,
FutureKlineDaily.trade_date <= end_date
)
).count()
else:
actual_count = 0
# 如果数据完整,跳过
if actual_count >= expected_count:
skipped_count += 1
detail = CacheTaskDetail(
task_id=task.id,
code=code,
trade_date=start_date,
expected_count=expected_count,
actual_count=actual_count,
is_missing=False,
status="skipped"
)
self.db.add(detail)
continue
# 获取数据(会自动缓存)
if security_type == "stock":
self.stock_service.get_kline([code], start_date, end_date, period_type)
elif security_type == "future":
self.future_service.get_kline([code], start_date, end_date, period_type)
success_count += 1
detail = CacheTaskDetail(
task_id=task.id,
code=code,
trade_date=start_date,
expected_count=expected_count,
actual_count=actual_count,
is_missing=True,
status="success",
processed_at=datetime.utcnow()
)
self.db.add(detail)
except Exception as e:
logger.error(f"缓存{code}数据失败: {str(e)}")
error_count += 1
detail = CacheTaskDetail(
task_id=task.id,
code=code,
trade_date=start_date,
status="failed",
error_message=str(e)
)
self.db.add(detail)
# 每50个代码更新一次进度
if (i + 1) % 50 == 0 or i == len(code_list) - 1:
task.success_count = success_count
task.error_count = error_count
task.progress = min(100, int((i + 1) / len(code_list) * 100))
self.db.commit()
logger.info(f"进度: {i + 1}/{len(code_list)}, 成功: {success_count}, 跳过: {skipped_count}, 错误: {error_count}")
task.status = "completed"
task.success_count = success_count
task.error_count = error_count
task.completed_at = datetime.utcnow()
self.db.commit()
logger.info(f"缓存完成: 成功{success_count}个, 跳过{skipped_count}个, 错误{error_count}")
except Exception as e:
task.status = "failed"
task.error_message = str(e)
task.completed_at = datetime.utcnow()
self.db.commit()
logger.error(f"一键缓存数据任务失败: {str(e)}")
return task
def detect_missing_data(
self,
security_type: str,
period_type: str,
start_date: date,
end_date: date,
code_list: List[str]
) -> CacheTask:
"""
检测缺失数据
Args:
security_type: 证券类型 (stock, future)
period_type: 周期类型 (daily, min1, etc.)
start_date: 开始日期
end_date: 结束日期
code_list: 代码列表
Returns:
缓存任务对象
"""
# 创建检测任务
task = CacheTask(
task_name=f"检测缺失数据 - {security_type} - {len(code_list)}个代码",
task_type="detect_missing",
security_type=security_type,
period_type=period_type,
start_date=start_date,
end_date=end_date,
code_list=",".join(code_list),
status="running",
total_count=len(code_list),
started_at=datetime.utcnow()
)
self.db.add(task)
self.db.commit()
self.db.refresh(task)
try:
# 获取交易日历
market = "CFE" if security_type == "future" else "SH"
trading_days = self.base_service.get_trading_calendar(market, start_date, end_date)
expected_count = len(trading_days)
success_count = 0
error_count = 0
for code in code_list:
try:
# 查询实际数据量
if security_type == "stock" and period_type == "daily":
actual_count = self.db.query(StockKlineDaily).filter(
and_(
StockKlineDaily.code == code,
StockKlineDaily.trade_date >= start_date,
StockKlineDaily.trade_date <= end_date
)
).count()
elif security_type == "future" and period_type == "daily":
actual_count = self.db.query(FutureKlineDaily).filter(
and_(
FutureKlineDaily.code == code,
FutureKlineDaily.trade_date >= start_date,
FutureKlineDaily.trade_date <= end_date
)
).count()
else:
actual_count = 0
# 计算缺失率
missing_ratio = 0
if expected_count > 0:
missing_ratio = (expected_count - actual_count) / expected_count
is_missing = missing_ratio > settings.CACHE_MISSING_THRESHOLD
# 创建任务详情
detail = CacheTaskDetail(
task_id=task.id,
code=code,
trade_date=start_date,
expected_count=expected_count,
actual_count=actual_count,
is_missing=is_missing,
status="pending" if is_missing else "skipped"
)
self.db.add(detail)
if is_missing:
success_count += 1
except Exception as e:
logger.error(f"检测{code}缺失数据失败: {str(e)}")
error_count += 1
detail = CacheTaskDetail(
task_id=task.id,
code=code,
trade_date=start_date,
status="failed",
error_message=str(e)
)
self.db.add(detail)
# 更新进度
task.success_count = success_count
task.error_count = error_count
task.progress = min(100, int((success_count + error_count) / len(code_list) * 100))
self.db.commit()
task.status = "completed"
task.completed_at = datetime.utcnow()
self.db.commit()
except Exception as e:
task.status = "failed"
task.error_message = str(e)
task.completed_at = datetime.utcnow()
self.db.commit()
logger.error(f"检测缺失数据任务失败: {str(e)}")
return task
def batch_cache_data(
self,
security_type: str,
period_type: str,
start_date: date,
end_date: date,
code_list: List[str]
) -> CacheTask:
"""
批量缓存数据
Args:
security_type: 证券类型
period_type: 周期类型
start_date: 开始日期
end_date: 结束日期
code_list: 代码列表
Returns:
缓存任务对象
"""
# 创建缓存任务
task = CacheTask(
task_name=f"批量缓存数据 - {security_type} - {len(code_list)}个代码",
task_type="cache_data",
security_type=security_type,
period_type=period_type,
start_date=start_date,
end_date=end_date,
code_list=",".join(code_list),
status="running",
total_count=len(code_list),
started_at=datetime.utcnow()
)
self.db.add(task)
self.db.commit()
self.db.refresh(task)
try:
success_count = 0
error_count = 0
for code in code_list:
try:
# 获取数据(会自动缓存)
if security_type == "stock":
self.stock_service.get_kline([code], start_date, end_date, period_type)
elif security_type == "future":
self.future_service.get_kline([code], start_date, end_date, period_type)
success_count += 1
# 创建任务详情
detail = CacheTaskDetail(
task_id=task.id,
code=code,
trade_date=start_date,
status="success",
processed_at=datetime.utcnow()
)
self.db.add(detail)
except Exception as e:
logger.error(f"缓存{code}数据失败: {str(e)}")
error_count += 1
detail = CacheTaskDetail(
task_id=task.id,
code=code,
trade_date=start_date,
status="failed",
error_message=str(e)
)
self.db.add(detail)
# 更新进度
task.success_count = success_count
task.error_count = error_count
task.progress = min(100, int((success_count + error_count) / len(code_list) * 100))
self.db.commit()
task.status = "completed"
task.completed_at = datetime.utcnow()
self.db.commit()
except Exception as e:
task.status = "failed"
task.error_message = str(e)
task.completed_at = datetime.utcnow()
self.db.commit()
logger.error(f"批量缓存数据任务失败: {str(e)}")
return task
def get_tasks(
self,
page: int = 1,
page_size: int = 20
) -> Dict:
"""获取缓存任务列表"""
query = self.db.query(CacheTask).order_by(CacheTask.created_at.desc())
total = query.count()
tasks = query.offset((page - 1) * page_size).limit(page_size).all()
return {
"items": tasks,
"total": total,
"page": page,
"page_size": page_size,
"total_pages": (total + page_size - 1) // page_size
}
def get_task(self, task_id: int) -> Optional[CacheTask]:
"""获取任务详情"""
return self.db.query(CacheTask).filter(CacheTask.id == task_id).first()
def get_task_details(self, task_id: int) -> List[CacheTaskDetail]:
"""获取任务详情列表"""
return self.db.query(CacheTaskDetail).filter(
CacheTaskDetail.task_id == task_id
).all()
def cancel_task(self, task_id: int) -> bool:
"""取消任务"""
task = self.db.query(CacheTask).filter(CacheTask.id == task_id).first()
if task and task.status == "running":
task.status = "cancelled"
task.completed_at = datetime.utcnow()
self.db.commit()
return True
return False
def get_cache_status(self, code: str, security_type: str, period_type: str) -> Dict:
"""获取代码缓存状态"""
if security_type == "stock":
return self.stock_service.get_cache_status(code, period_type)
elif security_type == "future":
return self.future_service.get_cache_status(code, period_type)
else:
return {
"code": code,
"security_type": security_type,
"period_type": period_type,
"record_count": 0,
"min_date": None,
"max_date": None
}