You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1255 lines
48 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
缓存管理服务
"""
import asyncio
import logging
from typing import List, Dict, Optional
from datetime import date, datetime
from sqlalchemy.orm import Session
from sqlalchemy import and_, func
from app.models.cache import CacheTask, CacheTaskDetail
from app.models.stock import StockKlineDaily
from app.models.future import FutureKlineDaily
from app.services.base_data_service import BaseDataService
from app.services.stock_service import StockService
from app.services.future_service import FutureService
from app.services.sdk_manager import sdk_manager
from app.utils.date_utils import parse_date, format_date, get_market_from_code
from app.config import settings
from app.core.redis_client import redis_client
from app.core.progress_manager import progress_manager
logger = logging.getLogger(__name__)
class CacheService:
"""缓存服务"""
CODE_LIST_CACHE_EXPIRE = 12 * 60 * 60
def __init__(self, db: Session):
self.db = db
self.base_service = BaseDataService(db)
self.stock_service = StockService(db)
self.future_service = FutureService(db)
def get_all_codes(self, security_type: str, contract_type: str = "all") -> List[str]:
"""
获取所有代码列表带Redis缓存
Args:
security_type: 证券类型 (stock, future)
contract_type: 合约类型 (all, main) - 仅对期货有效
Returns:
代码列表
"""
cache_key = f"code_list:{security_type}:{contract_type}"
cached_codes = redis_client.get(cache_key)
if cached_codes is not None:
logger.info(f"从Redis缓存获取代码列表: {security_type}/{contract_type}, 共{len(cached_codes)}")
return cached_codes
adapter = sdk_manager.get_default_connection()
if not adapter:
raise RuntimeError("SDK连接失败")
codes = []
if security_type == "stock":
codes = adapter.get_code_list("EXTRA_STOCK_A")
elif security_type == "future":
if contract_type == "main":
main_contracts = adapter.get_all_main_contracts()
codes = list(main_contracts.values())
else:
codes = adapter.get_code_list("EXTRA_FUTURE")
if codes:
redis_client.set(cache_key, codes, expire=self.CODE_LIST_CACHE_EXPIRE)
logger.info(f"代码列表已缓存到Redis: {security_type}/{contract_type}, 共{len(codes)}个, 有效期12小时")
return codes
def get_future_varieties(self) -> List[str]:
"""获取期货品种列表带Redis缓存"""
cache_key = "future_varieties"
cached = redis_client.get(cache_key)
if cached is not None:
logger.info(f"从Redis缓存获取期货品种列表, 共{len(cached)}")
return cached
adapter = sdk_manager.get_default_connection()
if not adapter:
raise RuntimeError("SDK连接失败")
varieties = adapter.get_future_varieties()
if varieties:
redis_client.set(cache_key, varieties, expire=24 * 60 * 60)
logger.info(f"期货品种列表已缓存到Redis, 共{len(varieties)}个, 有效期24小时")
return varieties
def get_trading_calendar_cached(self, market: str, start_date: date = None, end_date: date = None) -> List[date]:
"""获取交易日历带Redis缓存"""
cache_key = f"trading_calendar:{market}"
cached = redis_client.get(cache_key)
if cached is not None:
logger.info(f"从Redis缓存获取交易日历: {market}")
from app.utils.date_utils import int_to_date
all_dates = [int_to_date(d) for d in cached]
if start_date:
all_dates = [d for d in all_dates if d >= start_date]
if end_date:
all_dates = [d for d in all_dates if d <= end_date]
return all_dates
trading_days = self.base_service.get_trading_calendar(market)
from app.utils.date_utils import date_to_int
all_dates_int = [date_to_int(d) for d in trading_days]
redis_client.set(cache_key, all_dates_int, expire=365 * 24 * 60 * 60)
logger.info(f"交易日历已缓存到Redis: {market}, 有效期1年")
return trading_days
def get_main_contract_cached(self, variety: str) -> Optional[str]:
"""获取主力合约带Redis缓存"""
cache_key = f"main_contract:{variety}"
cached = redis_client.get(cache_key)
if cached is not None:
logger.info(f"从Redis缓存获取主力合约: {variety} -> {cached}")
return cached
adapter = sdk_manager.get_default_connection()
if not adapter:
raise RuntimeError("SDK连接失败")
main_contract = adapter.get_main_contract(variety)
if main_contract:
redis_client.set(cache_key, main_contract, expire=1 * 60 * 60)
logger.info(f"主力合约已缓存到Redis: {variety} -> {main_contract}, 有效期1小时")
return main_contract
def get_all_main_contracts_cached(self) -> Dict[str, str]:
"""获取所有主力合约带Redis缓存"""
cache_key = "all_main_contracts"
cached = redis_client.get(cache_key)
if cached is not None:
logger.info(f"从Redis缓存获取所有主力合约, 共{len(cached)}")
return cached
adapter = sdk_manager.get_default_connection()
if not adapter:
raise RuntimeError("SDK连接失败")
main_contracts = adapter.get_all_main_contracts()
if main_contracts:
redis_client.set(cache_key, main_contracts, expire=1 * 60 * 60)
logger.info(f"所有主力合约已缓存到Redis, 共{len(main_contracts)}个, 有效期1小时")
return main_contracts
def detect_all_missing_data(
self,
security_type: str,
period_type: str,
start_date: date,
end_date: date,
contract_type: str = "all",
task_id: str = None
) -> Dict:
"""
一键检测所有数据的缺失情况
Args:
security_type: 证券类型 (stock, future)
period_type: 周期类型 (daily, min1, etc.)
start_date: 开始日期
end_date: 结束日期
contract_type: 合约类型 (all, main) - 仅对期货有效
task_id: WebSocket任务ID
Returns:
检测结果字典
"""
code_list = self.get_all_codes(security_type, contract_type)
if not code_list:
raise ValueError(f"无法获取{security_type}代码列表")
logger.info(f"获取到{len(code_list)}{security_type}代码")
ws_task_id = task_id or f"detect_{security_type}_{datetime.utcnow().strftime('%Y%m%d%H%M%S')}"
def push_progress(progress, status, **kwargs):
try:
loop = asyncio.get_event_loop()
if loop.is_running():
asyncio.create_task(progress_manager.update_progress(ws_task_id, {
"progress": progress,
"status": status,
"total_count": len(code_list),
**kwargs
}))
except RuntimeError:
pass
push_progress(0, "starting", message="开始检测...")
task = CacheTask(
task_name=f"一键检测所有数据 - {security_type} - {contract_type} - {len(code_list)}个代码",
task_type="detect_all_missing",
security_type=security_type,
period_type=period_type,
start_date=start_date,
end_date=end_date,
code_list=",".join(code_list[:100]) + "...",
status="running",
total_count=len(code_list),
started_at=datetime.utcnow()
)
self.db.add(task)
self.db.commit()
self.db.refresh(task)
push_progress(5, "running", message="获取交易日历...")
try:
market = "CFE" if security_type == "future" else "SH"
trading_days = self.base_service.get_trading_calendar(market, start_date, end_date)
expected_count = len(trading_days)
missing_codes = []
complete_codes = []
error_count = 0
daily_stats = {}
for td in trading_days:
daily_stats[format_date(td)] = {
"expected": len(code_list),
"actual": 0,
"missing": 0
}
push_progress(10, "running", message="查询数据库统计...")
if security_type == "stock" and period_type == "daily":
from sqlalchemy import func
code_count_query = self.db.query(
StockKlineDaily.code,
func.count(StockKlineDaily.id).label('count')
).filter(
and_(
StockKlineDaily.trade_date >= start_date,
StockKlineDaily.trade_date <= end_date
)
).group_by(StockKlineDaily.code).all()
code_counts = {r.code: r.count for r in code_count_query}
date_count_query = self.db.query(
func.date(StockKlineDaily.trade_date).label('trade_date'),
func.count(StockKlineDaily.id).label('count')
).filter(
and_(
StockKlineDaily.trade_date >= start_date,
StockKlineDaily.trade_date <= end_date
)
).group_by(func.date(StockKlineDaily.trade_date)).all()
for r in date_count_query:
date_key = format_date(r.trade_date) if hasattr(r.trade_date, 'strftime') else str(r.trade_date)
if date_key in daily_stats:
daily_stats[date_key]["actual"] = r.count
push_progress(20, "running", message="分析数据完整性...")
for i, code in enumerate(code_list):
actual_count = code_counts.get(code, 0)
is_missing = actual_count < expected_count
if is_missing:
missing_codes.append({
"code": code,
"actual_count": actual_count,
"expected_count": expected_count,
"missing_count": expected_count - actual_count,
"missing_ratio": (expected_count - actual_count) / expected_count if expected_count > 0 else 0
})
detail = CacheTaskDetail(
task_id=task.id,
code=code,
trade_date=start_date,
expected_count=expected_count,
actual_count=actual_count,
is_missing=True,
status="pending"
)
self.db.add(detail)
else:
complete_codes.append(code)
if (i + 1) % 500 == 0 or i == len(code_list) - 1:
task.success_count = len(missing_codes) + len(complete_codes)
task.error_count = error_count
task.progress = min(100, int((i + 1) / len(code_list) * 100))
self.db.commit()
push_progress(
20 + int((i + 1) / len(code_list) * 70),
"running",
processed=len(missing_codes) + len(complete_codes),
missing=len(missing_codes),
complete=len(complete_codes)
)
elif security_type == "future" and period_type == "daily":
from sqlalchemy import func
code_count_query = self.db.query(
FutureKlineDaily.code,
func.count(FutureKlineDaily.id).label('count')
).filter(
and_(
FutureKlineDaily.trade_date >= start_date,
FutureKlineDaily.trade_date <= end_date
)
).group_by(FutureKlineDaily.code).all()
code_counts = {r.code: r.count for r in code_count_query}
date_count_query = self.db.query(
func.date(FutureKlineDaily.trade_date).label('trade_date'),
func.count(FutureKlineDaily.id).label('count')
).filter(
and_(
FutureKlineDaily.trade_date >= start_date,
FutureKlineDaily.trade_date <= end_date
)
).group_by(func.date(FutureKlineDaily.trade_date)).all()
for r in date_count_query:
date_key = format_date(r.trade_date) if hasattr(r.trade_date, 'strftime') else str(r.trade_date)
if date_key in daily_stats:
daily_stats[date_key]["actual"] = r.count
push_progress(20, "running", message="分析数据完整性...")
for i, code in enumerate(code_list):
actual_count = code_counts.get(code, 0)
is_missing = actual_count < expected_count
if is_missing:
missing_codes.append({
"code": code,
"actual_count": actual_count,
"expected_count": expected_count,
"missing_count": expected_count - actual_count,
"missing_ratio": (expected_count - actual_count) / expected_count if expected_count > 0 else 0
})
detail = CacheTaskDetail(
task_id=task.id,
code=code,
trade_date=start_date,
expected_count=expected_count,
actual_count=actual_count,
is_missing=True,
status="pending"
)
self.db.add(detail)
else:
complete_codes.append(code)
if (i + 1) % 500 == 0 or i == len(code_list) - 1:
task.success_count = len(missing_codes) + len(complete_codes)
task.error_count = error_count
task.progress = min(100, int((i + 1) / len(code_list) * 100))
self.db.commit()
push_progress(
20 + int((i + 1) / len(code_list) * 70),
"running",
processed=len(missing_codes) + len(complete_codes),
missing=len(missing_codes),
complete=len(complete_codes)
)
else:
for i, code in enumerate(code_list):
actual_count = 0
is_missing = True
missing_codes.append({
"code": code,
"actual_count": 0,
"expected_count": expected_count,
"missing_count": expected_count,
"missing_ratio": 1.0
})
detail = CacheTaskDetail(
task_id=task.id,
code=code,
trade_date=start_date,
expected_count=expected_count,
actual_count=0,
is_missing=True,
status="pending"
)
self.db.add(detail)
if (i + 1) % 500 == 0 or i == len(code_list) - 1:
task.success_count = len(missing_codes)
task.error_count = error_count
task.progress = min(100, int((i + 1) / len(code_list) * 100))
self.db.commit()
push_progress(
20 + int((i + 1) / len(code_list) * 70),
"running",
processed=len(missing_codes),
missing=len(missing_codes)
)
for date_key in daily_stats:
daily_stats[date_key]["missing"] = daily_stats[date_key]["expected"] - daily_stats[date_key]["actual"]
missing_code_list = [m["code"] for m in missing_codes]
task.code_list = ",".join(missing_code_list[:500]) if missing_code_list else ""
task.status = "completed"
task.success_count = len(complete_codes)
task.error_count = error_count
task.completed_at = datetime.utcnow()
self.db.commit()
push_progress(100, "completed",
message="检测完成",
complete_count=len(complete_codes),
missing_count=len(missing_codes),
error_count=error_count
)
logger.info(f"检测完成: 完整{len(complete_codes)}个, 缺失{len(missing_codes)}个, 错误{error_count}")
return {
"task_id": task.id,
"ws_task_id": ws_task_id,
"task_name": task.task_name,
"status": task.status,
"progress": float(task.progress),
"total_count": task.total_count,
"complete_count": len(complete_codes),
"missing_count": len(missing_codes),
"error_count": error_count,
"expected_days": expected_count,
"start_date": format_date(start_date),
"end_date": format_date(end_date),
"security_type": security_type,
"period_type": period_type,
"daily_stats": daily_stats,
"missing_codes": missing_codes[:100],
"missing_code_list": missing_code_list
}
except Exception as e:
task.status = "failed"
task.error_message = str(e)
task.completed_at = datetime.utcnow()
self.db.commit()
logger.error(f"一键检测缺失数据任务失败: {str(e)}")
push_progress(100, "failed", error=str(e))
return {
"task_id": task.id,
"ws_task_id": ws_task_id,
"task_name": task.task_name,
"status": task.status,
"error_message": str(e)
}
def _create_cache_task(
self,
security_type: str,
period_type: str,
start_date: date,
end_date: date,
contract_type: str = "all"
) -> CacheTask:
"""创建缓存任务记录"""
code_list = self.get_all_codes(security_type, contract_type)
if not code_list:
raise ValueError(f"无法获取{security_type}代码列表")
task = CacheTask(
task_name=f"一键缓存所有数据 - {security_type} - {contract_type} - {len(code_list)}个代码",
task_type="cache_all_data",
security_type=security_type,
period_type=period_type,
start_date=start_date,
end_date=end_date,
code_list=",".join(code_list[:100]) + "...",
status="pending",
total_count=len(code_list),
started_at=datetime.utcnow()
)
self.db.add(task)
self.db.commit()
self.db.refresh(task)
return task
def _create_fill_missing_task(
self,
security_type: str,
period_type: str,
start_date: date,
end_date: date,
missing_codes: List[str]
) -> CacheTask:
"""创建补齐缺失数据任务记录"""
if not missing_codes:
raise ValueError("没有缺失代码需要补齐")
task = CacheTask(
task_name=f"一键补齐缺失数据 - {security_type} - {len(missing_codes)}个代码",
task_type="fill_missing_data",
security_type=security_type,
period_type=period_type,
start_date=start_date,
end_date=end_date,
code_list=",".join(missing_codes[:500]),
status="pending",
total_count=len(missing_codes),
started_at=datetime.utcnow()
)
self.db.add(task)
self.db.commit()
self.db.refresh(task)
return task
def _execute_fill_missing_task(
self,
task_id: int,
security_type: str,
period_type: str,
start_date: date,
end_date: date,
missing_codes: List[str]
):
"""执行补齐缺失数据任务"""
task = self.db.query(CacheTask).filter(CacheTask.id == task_id).first()
if not task:
return
task.status = "running"
self.db.commit()
try:
success_count = 0
error_count = 0
for i, code in enumerate(missing_codes):
try:
if security_type == "stock":
self.stock_service.get_kline([code], start_date, end_date, period_type)
elif security_type == "future":
self.future_service.get_kline([code], start_date, end_date, period_type)
success_count += 1
except Exception as e:
logger.error(f"补齐{code}数据失败: {str(e)}")
error_count += 1
if (i + 1) % 10 == 0 or i == len(missing_codes) - 1:
task.success_count = success_count
task.error_count = error_count
task.progress = min(100, int((i + 1) / len(missing_codes) * 100))
self.db.commit()
task.status = "completed"
task.success_count = success_count
task.error_count = error_count
task.completed_at = datetime.utcnow()
self.db.commit()
except Exception as e:
task.status = "failed"
task.error_message = str(e)
task.completed_at = datetime.utcnow()
self.db.commit()
def _execute_cache_task(
self,
task_id: int,
security_type: str,
period_type: str,
start_date: date,
end_date: date,
contract_type: str = "all"
):
"""执行缓存任务"""
task = self.db.query(CacheTask).filter(CacheTask.id == task_id).first()
if not task:
return
task.status = "running"
self.db.commit()
code_list = self.get_all_codes(security_type, contract_type)
try:
market = "CFE" if security_type == "future" else "SH"
trading_days = self.base_service.get_trading_calendar(market, start_date, end_date)
expected_count = len(trading_days)
success_count = 0
skipped_count = 0
error_count = 0
for i, code in enumerate(code_list):
try:
if security_type == "stock" and period_type == "daily":
actual_count = self.db.query(StockKlineDaily).filter(
and_(
StockKlineDaily.code == code,
StockKlineDaily.trade_date >= start_date,
StockKlineDaily.trade_date <= end_date
)
).count()
elif security_type == "future" and period_type == "daily":
actual_count = self.db.query(FutureKlineDaily).filter(
and_(
FutureKlineDaily.code == code,
FutureKlineDaily.trade_date >= start_date,
FutureKlineDaily.trade_date <= end_date
)
).count()
else:
actual_count = 0
if actual_count >= expected_count:
skipped_count += 1
continue
if security_type == "stock":
self.stock_service.get_kline([code], start_date, end_date, period_type)
elif security_type == "future":
self.future_service.get_kline([code], start_date, end_date, period_type)
success_count += 1
except Exception as e:
logger.error(f"缓存{code}数据失败: {str(e)}")
error_count += 1
if (i + 1) % 10 == 0 or i == len(code_list) - 1:
task.success_count = success_count
task.error_count = error_count
task.progress = min(100, int((i + 1) / len(code_list) * 100))
self.db.commit()
task.status = "completed"
task.success_count = success_count
task.error_count = error_count
task.completed_at = datetime.utcnow()
self.db.commit()
except Exception as e:
task.status = "failed"
task.error_message = str(e)
task.completed_at = datetime.utcnow()
self.db.commit()
def cache_all_missing_data(
self,
security_type: str,
period_type: str,
start_date: date,
end_date: date,
contract_type: str = "all"
) -> CacheTask:
"""
一键缓存所有缺失数据
Args:
security_type: 证券类型
period_type: 周期类型
start_date: 开始日期
end_date: 结束日期
contract_type: 合约类型 (all, main) - 仅对期货有效
Returns:
缓存任务对象
"""
# 获取所有代码
code_list = self.get_all_codes(security_type, contract_type)
if not code_list:
raise ValueError(f"无法获取{security_type}代码列表")
logger.info(f"获取到{len(code_list)}{security_type}代码,开始缓存")
# 创建缓存任务
task = CacheTask(
task_name=f"一键缓存所有数据 - {security_type} - {len(code_list)}个代码",
task_type="cache_all_data",
security_type=security_type,
period_type=period_type,
start_date=start_date,
end_date=end_date,
code_list=",".join(code_list[:100]) + "...",
status="running",
total_count=len(code_list),
started_at=datetime.utcnow()
)
self.db.add(task)
self.db.commit()
self.db.refresh(task)
try:
# 获取交易日历
market = "CFE" if security_type == "future" else "SH"
trading_days = self.base_service.get_trading_calendar(market, start_date, end_date)
expected_count = len(trading_days)
success_count = 0
skipped_count = 0
error_count = 0
for i, code in enumerate(code_list):
try:
# 先检查是否已有完整数据
if security_type == "stock" and period_type == "daily":
actual_count = self.db.query(StockKlineDaily).filter(
and_(
StockKlineDaily.code == code,
StockKlineDaily.trade_date >= start_date,
StockKlineDaily.trade_date <= end_date
)
).count()
elif security_type == "future" and period_type == "daily":
actual_count = self.db.query(FutureKlineDaily).filter(
and_(
FutureKlineDaily.code == code,
FutureKlineDaily.trade_date >= start_date,
FutureKlineDaily.trade_date <= end_date
)
).count()
else:
actual_count = 0
# 如果数据完整,跳过
if actual_count >= expected_count:
skipped_count += 1
detail = CacheTaskDetail(
task_id=task.id,
code=code,
trade_date=start_date,
expected_count=expected_count,
actual_count=actual_count,
is_missing=False,
status="skipped"
)
self.db.add(detail)
continue
# 获取数据(会自动缓存)
if security_type == "stock":
self.stock_service.get_kline([code], start_date, end_date, period_type)
elif security_type == "future":
self.future_service.get_kline([code], start_date, end_date, period_type)
success_count += 1
detail = CacheTaskDetail(
task_id=task.id,
code=code,
trade_date=start_date,
expected_count=expected_count,
actual_count=actual_count,
is_missing=True,
status="success",
processed_at=datetime.utcnow()
)
self.db.add(detail)
except Exception as e:
logger.error(f"缓存{code}数据失败: {str(e)}")
error_count += 1
detail = CacheTaskDetail(
task_id=task.id,
code=code,
trade_date=start_date,
status="failed",
error_message=str(e)
)
self.db.add(detail)
# 每50个代码更新一次进度
if (i + 1) % 50 == 0 or i == len(code_list) - 1:
task.success_count = success_count
task.error_count = error_count
task.progress = min(100, int((i + 1) / len(code_list) * 100))
self.db.commit()
logger.info(f"进度: {i + 1}/{len(code_list)}, 成功: {success_count}, 跳过: {skipped_count}, 错误: {error_count}")
task.status = "completed"
task.success_count = success_count
task.error_count = error_count
task.completed_at = datetime.utcnow()
self.db.commit()
logger.info(f"缓存完成: 成功{success_count}个, 跳过{skipped_count}个, 错误{error_count}")
except Exception as e:
task.status = "failed"
task.error_message = str(e)
task.completed_at = datetime.utcnow()
self.db.commit()
logger.error(f"一键缓存数据任务失败: {str(e)}")
return task
def detect_missing_data(
self,
security_type: str,
period_type: str,
start_date: date,
end_date: date,
code_list: List[str]
) -> CacheTask:
"""
检测缺失数据
Args:
security_type: 证券类型 (stock, future)
period_type: 周期类型 (daily, min1, etc.)
start_date: 开始日期
end_date: 结束日期
code_list: 代码列表
Returns:
缓存任务对象
"""
# 创建检测任务
task = CacheTask(
task_name=f"检测缺失数据 - {security_type} - {len(code_list)}个代码",
task_type="detect_missing",
security_type=security_type,
period_type=period_type,
start_date=start_date,
end_date=end_date,
code_list=",".join(code_list),
status="running",
total_count=len(code_list),
started_at=datetime.utcnow()
)
self.db.add(task)
self.db.commit()
self.db.refresh(task)
try:
# 获取交易日历
market = "CFE" if security_type == "future" else "SH"
trading_days = self.base_service.get_trading_calendar(market, start_date, end_date)
expected_count = len(trading_days)
success_count = 0
error_count = 0
for code in code_list:
try:
# 查询实际数据量
if security_type == "stock" and period_type == "daily":
actual_count = self.db.query(StockKlineDaily).filter(
and_(
StockKlineDaily.code == code,
StockKlineDaily.trade_date >= start_date,
StockKlineDaily.trade_date <= end_date
)
).count()
elif security_type == "future" and period_type == "daily":
actual_count = self.db.query(FutureKlineDaily).filter(
and_(
FutureKlineDaily.code == code,
FutureKlineDaily.trade_date >= start_date,
FutureKlineDaily.trade_date <= end_date
)
).count()
else:
actual_count = 0
# 计算缺失率
missing_ratio = 0
if expected_count > 0:
missing_ratio = (expected_count - actual_count) / expected_count
is_missing = missing_ratio > settings.CACHE_MISSING_THRESHOLD
# 创建任务详情
detail = CacheTaskDetail(
task_id=task.id,
code=code,
trade_date=start_date,
expected_count=expected_count,
actual_count=actual_count,
is_missing=is_missing,
status="pending" if is_missing else "skipped"
)
self.db.add(detail)
if is_missing:
success_count += 1
except Exception as e:
logger.error(f"检测{code}缺失数据失败: {str(e)}")
error_count += 1
detail = CacheTaskDetail(
task_id=task.id,
code=code,
trade_date=start_date,
status="failed",
error_message=str(e)
)
self.db.add(detail)
# 更新进度
task.success_count = success_count
task.error_count = error_count
task.progress = min(100, int((success_count + error_count) / len(code_list) * 100))
self.db.commit()
task.status = "completed"
task.completed_at = datetime.utcnow()
self.db.commit()
except Exception as e:
task.status = "failed"
task.error_message = str(e)
task.completed_at = datetime.utcnow()
self.db.commit()
logger.error(f"检测缺失数据任务失败: {str(e)}")
return task
def batch_cache_data(
self,
security_type: str,
period_type: str,
start_date: date,
end_date: date,
code_list: List[str]
) -> CacheTask:
"""
批量缓存数据
Args:
security_type: 证券类型
period_type: 周期类型
start_date: 开始日期
end_date: 结束日期
code_list: 代码列表
Returns:
缓存任务对象
"""
# 创建缓存任务
task = CacheTask(
task_name=f"批量缓存数据 - {security_type} - {len(code_list)}个代码",
task_type="cache_data",
security_type=security_type,
period_type=period_type,
start_date=start_date,
end_date=end_date,
code_list=",".join(code_list),
status="running",
total_count=len(code_list),
started_at=datetime.utcnow()
)
self.db.add(task)
self.db.commit()
self.db.refresh(task)
try:
success_count = 0
error_count = 0
for code in code_list:
try:
# 获取数据(会自动缓存)
if security_type == "stock":
self.stock_service.get_kline([code], start_date, end_date, period_type)
elif security_type == "future":
self.future_service.get_kline([code], start_date, end_date, period_type)
success_count += 1
# 创建任务详情
detail = CacheTaskDetail(
task_id=task.id,
code=code,
trade_date=start_date,
status="success",
processed_at=datetime.utcnow()
)
self.db.add(detail)
except Exception as e:
logger.error(f"缓存{code}数据失败: {str(e)}")
error_count += 1
detail = CacheTaskDetail(
task_id=task.id,
code=code,
trade_date=start_date,
status="failed",
error_message=str(e)
)
self.db.add(detail)
# 更新进度
task.success_count = success_count
task.error_count = error_count
task.progress = min(100, int((success_count + error_count) / len(code_list) * 100))
self.db.commit()
task.status = "completed"
task.completed_at = datetime.utcnow()
self.db.commit()
except Exception as e:
task.status = "failed"
task.error_message = str(e)
task.completed_at = datetime.utcnow()
self.db.commit()
logger.error(f"批量缓存数据任务失败: {str(e)}")
return task
def get_tasks(
self,
page: int = 1,
page_size: int = 20
) -> Dict:
"""获取缓存任务列表"""
query = self.db.query(CacheTask).order_by(CacheTask.created_at.desc())
total = query.count()
tasks = query.offset((page - 1) * page_size).limit(page_size).all()
return {
"items": tasks,
"total": total,
"page": page,
"page_size": page_size,
"total_pages": (total + page_size - 1) // page_size
}
def get_task(self, task_id: int) -> Optional[CacheTask]:
"""获取任务详情"""
return self.db.query(CacheTask).filter(CacheTask.id == task_id).first()
def get_task_details(self, task_id: int) -> List[CacheTaskDetail]:
"""获取任务详情列表"""
return self.db.query(CacheTaskDetail).filter(
CacheTaskDetail.task_id == task_id
).all()
def cancel_task(self, task_id: int) -> bool:
"""取消任务"""
task = self.db.query(CacheTask).filter(CacheTask.id == task_id).first()
if task and task.status == "running":
task.status = "cancelled"
task.completed_at = datetime.utcnow()
self.db.commit()
return True
return False
def get_cache_status(self, code: str, security_type: str, period_type: str) -> Dict:
"""获取代码缓存状态"""
if security_type == "stock":
return self.stock_service.get_cache_status(code, period_type)
elif security_type == "future":
return self.future_service.get_cache_status(code, period_type)
else:
return {
"code": code,
"security_type": security_type,
"period_type": period_type,
"record_count": 0,
"min_date": None,
"max_date": None
}
def get_missing_dates_for_code(
self,
code: str,
security_type: str,
period_type: str,
start_date: date,
end_date: date
) -> Dict:
"""
获取单个代码的缺失交易日详情
Args:
code: 证券代码
security_type: 证券类型 (stock, future)
period_type: 周期类型 (daily, min1, etc.)
start_date: 开始日期
end_date: 结束日期
Returns:
缺失交易日详情
"""
market = get_market_from_code(code)
trading_days = self.base_service.get_trading_calendar(market, start_date, end_date)
expected_dates = set(trading_days)
actual_dates = set()
if security_type == "stock" and period_type == "daily":
records = self.db.query(StockKlineDaily.trade_date).filter(
and_(
StockKlineDaily.code == code,
StockKlineDaily.trade_date >= start_date,
StockKlineDaily.trade_date <= end_date
)
).all()
actual_dates = set(r.trade_date for r in records)
elif security_type == "future" and period_type == "daily":
records = self.db.query(FutureKlineDaily.trade_date).filter(
and_(
FutureKlineDaily.code == code,
FutureKlineDaily.trade_date >= start_date,
FutureKlineDaily.trade_date <= end_date
)
).all()
actual_dates = set(r.trade_date for r in records)
missing_dates = sorted(list(expected_dates - actual_dates))
missing_dates_list = []
for d in missing_dates:
missing_dates_list.append({
"date": format_date(d),
"date_obj": d.isoformat()
})
return {
"code": code,
"security_type": security_type,
"period_type": period_type,
"start_date": format_date(start_date),
"end_date": format_date(end_date),
"expected_count": len(expected_dates),
"actual_count": len(actual_dates),
"missing_count": len(missing_dates),
"missing_dates": missing_dates_list
}
def fill_single_date_data(
self,
code: str,
security_type: str,
period_type: str,
trade_date: date
):
"""
补齐单个代码的单个交易日数据
Args:
code: 证券代码
security_type: 证券类型
period_type: 周期类型
trade_date: 交易日
"""
logger.info(f"补齐单日数据: {code} - {format_date(trade_date)}")
try:
if security_type == "stock":
self.stock_service.get_kline([code], trade_date, trade_date, period_type)
elif security_type == "future":
self.future_service.get_kline([code], trade_date, trade_date, period_type)
logger.info(f"补齐单日数据成功: {code} - {format_date(trade_date)}")
except Exception as e:
logger.error(f"补齐单日数据失败: {code} - {format_date(trade_date)}, 错误: {str(e)}")
raise
def fill_all_dates_for_code(
self,
code: str,
security_type: str,
period_type: str,
start_date: date,
end_date: date,
missing_dates: List[str] = None
):
"""
补齐单个代码的所有缺失交易日数据
Args:
code: 证券代码
security_type: 证券类型
period_type: 周期类型
start_date: 开始日期
end_date: 结束日期
missing_dates: 缺失日期列表(可选,如果不提供则自动检测)
"""
logger.info(f"补齐所有数据: {code} - {format_date(start_date)}{format_date(end_date)}")
if missing_dates:
from app.utils.date_utils import parse_date
dates_to_fill = [parse_date(d) for d in missing_dates]
else:
result = self.get_missing_dates_for_code(code, security_type, period_type, start_date, end_date)
dates_to_fill = [parse_date(d["date"]) for d in result["missing_dates"]]
if not dates_to_fill:
logger.info(f"没有缺失数据需要补齐: {code}")
return
success_count = 0
error_count = 0
for trade_date in dates_to_fill:
try:
if security_type == "stock":
self.stock_service.get_kline([code], trade_date, trade_date, period_type)
elif security_type == "future":
self.future_service.get_kline([code], trade_date, trade_date, period_type)
success_count += 1
except Exception as e:
logger.error(f"补齐{code} - {format_date(trade_date)}失败: {str(e)}")
error_count += 1
logger.info(f"补齐完成: {code}, 成功{success_count}个, 失败{error_count}")