diff --git a/backend/app/api/v1/cache.py b/backend/app/api/v1/cache.py index 75525f7..c08ca95 100644 --- a/backend/app/api/v1/cache.py +++ b/backend/app/api/v1/cache.py @@ -2,7 +2,7 @@ 缓存管理路由 """ from typing import List -from fastapi import APIRouter, Depends, Query +from fastapi import APIRouter, Depends, Query, BackgroundTasks from sqlalchemy.orm import Session from app.db.session import get_db @@ -20,17 +20,85 @@ from app.utils.date_utils import parse_date, format_date router = APIRouter() -@router.post("/detect-all-missing", response_model=ResponseModel) -async def detect_all_missing_data( +def run_cache_task( + security_type: str, + period_type: str, + start_date_str: str, + end_date_str: str, + contract_type: str, + task_id: int +): + """后台执行缓存任务""" + from app.db.session import SessionLocal + from app.utils.date_utils import parse_date + + db = SessionLocal() + try: + service = CacheService(db) + start_date = parse_date(start_date_str) + end_date = parse_date(end_date_str) + service._execute_cache_task( + task_id, + security_type, + period_type, + start_date, + end_date, + contract_type + ) + except Exception as e: + import logging + logger = logging.getLogger(__name__) + logger.error(f"后台缓存任务失败: {str(e)}") + finally: + db.close() + + +def run_fill_missing_task( + security_type: str, + period_type: str, + start_date_str: str, + end_date_str: str, + missing_codes: List[str], + task_id: int +): + """后台执行补齐缺失数据任务""" + from app.db.session import SessionLocal + from app.utils.date_utils import parse_date + + db = SessionLocal() + try: + service = CacheService(db) + start_date = parse_date(start_date_str) + end_date = parse_date(end_date_str) + service._execute_fill_missing_task( + task_id, + security_type, + period_type, + start_date, + end_date, + missing_codes + ) + except Exception as e: + import logging + logger = logging.getLogger(__name__) + logger.error(f"后台补齐任务失败: {str(e)}") + finally: + db.close() + + +@router.post("/fill-missing", response_model=ResponseModel) +async def fill_missing_data( request: AllDataRequest, + background_tasks: BackgroundTasks, db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): - """一键检测所有数据的缺失情况""" + """一键补齐缺失数据(异步执行,只处理缺失代码)""" service = CacheService(db) start = parse_date(request.start_date) end = parse_date(request.end_date) + # 先执行检测获取缺失代码列表 result = service.detect_all_missing_data( request.security_type, request.period_type, @@ -39,21 +107,59 @@ async def detect_all_missing_data( request.contract_type ) - return ResponseModel(data=result) + missing_codes = result.get("missing_code_list", []) + + if not missing_codes: + return ResponseModel(data={ + "task_id": None, + "message": "没有缺失数据需要补齐", + "missing_count": 0 + }) + + # 创建补齐任务记录 + task = service._create_fill_missing_task( + request.security_type, + request.period_type, + start, + end, + missing_codes + ) + + # 在后台执行补齐任务 + background_tasks.add_task( + run_fill_missing_task, + request.security_type, + request.period_type, + request.start_date, + request.end_date, + missing_codes, + task.id + ) + + return ResponseModel(data={ + "task_id": task.id, + "task_name": task.task_name, + "status": task.status, + "total_count": task.total_count, + "progress": float(task.progress), + "missing_count": len(missing_codes) + }) @router.post("/cache-all-missing", response_model=ResponseModel) async def cache_all_missing_data( request: AllDataRequest, + background_tasks: BackgroundTasks, db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): - """一键缓存所有缺失数据""" + """一键缓存所有缺失数据(异步执行)""" service = CacheService(db) start = parse_date(request.start_date) end = parse_date(request.end_date) - task = service.cache_all_missing_data( + # 创建任务记录 + task = service._create_cache_task( request.security_type, request.period_type, start, @@ -61,15 +167,48 @@ async def cache_all_missing_data( request.contract_type ) + # 在后台执行缓存任务 + background_tasks.add_task( + run_cache_task, + request.security_type, + request.period_type, + request.start_date, + request.end_date, + request.contract_type, + task.id + ) + return ResponseModel(data={ "task_id": task.id, "task_name": task.task_name, "status": task.status, "total_count": task.total_count, - "progress": task.progress + "progress": float(task.progress) }) +@router.post("/detect-all-missing", response_model=ResponseModel) +async def detect_all_missing_data( + request: AllDataRequest, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """一键检测所有数据的缺失情况""" + service = CacheService(db) + start = parse_date(request.start_date) + end = parse_date(request.end_date) + + result = service.detect_all_missing_data( + request.security_type, + request.period_type, + start, + end, + request.contract_type + ) + + return ResponseModel(data=result) + + @router.post("/detect-missing", response_model=ResponseModel) async def detect_missing_data( request: DetectMissingRequest, diff --git a/backend/app/core/redis_client.py b/backend/app/core/redis_client.py new file mode 100644 index 0000000..05753e2 --- /dev/null +++ b/backend/app/core/redis_client.py @@ -0,0 +1,92 @@ +""" +Redis客户端模块 +""" +import redis +import json +import logging +from typing import Optional, List, Any +from app.config import settings + +logger = logging.getLogger(__name__) + + +class RedisClient: + """Redis客户端""" + + _instance: Optional['RedisClient'] = None + _client: Optional[redis.Redis] = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + if self._client is None: + try: + self._client = redis.from_url( + settings.REDIS_URL, + decode_responses=True + ) + logger.info("Redis连接成功") + except Exception as e: + logger.warning(f"Redis连接失败: {str(e)}") + self._client = None + + def is_connected(self) -> bool: + """检查Redis是否连接""" + if self._client is None: + return False + try: + self._client.ping() + return True + except: + return False + + def get(self, key: str) -> Optional[Any]: + """获取缓存""" + if not self.is_connected(): + return None + try: + value = self._client.get(key) + if value: + return json.loads(value) + return None + except Exception as e: + logger.error(f"Redis获取失败: {str(e)}") + return None + + def set(self, key: str, value: Any, expire: int = None) -> bool: + """设置缓存""" + if not self.is_connected(): + return False + try: + self._client.set(key, json.dumps(value), ex=expire) + return True + except Exception as e: + logger.error(f"Redis设置失败: {str(e)}") + return False + + def delete(self, key: str) -> bool: + """删除缓存""" + if not self.is_connected(): + return False + try: + self._client.delete(key) + return True + except Exception as e: + logger.error(f"Redis删除失败: {str(e)}") + return False + + def exists(self, key: str) -> bool: + """检查键是否存在""" + if not self.is_connected(): + return False + try: + return self._client.exists(key) > 0 + except Exception as e: + logger.error(f"Redis检查失败: {str(e)}") + return False + + +redis_client = RedisClient() \ No newline at end of file diff --git a/backend/app/services/cache_service.py b/backend/app/services/cache_service.py index 71716c5..63f9d55 100644 --- a/backend/app/services/cache_service.py +++ b/backend/app/services/cache_service.py @@ -16,6 +16,7 @@ from app.services.future_service import FutureService from app.services.sdk_manager import sdk_manager from app.utils.date_utils import parse_date, format_date, get_market_from_code from app.config import settings +from app.core.redis_client import redis_client logger = logging.getLogger(__name__) @@ -23,6 +24,8 @@ logger = logging.getLogger(__name__) class CacheService: """缓存服务""" + CODE_LIST_CACHE_EXPIRE = 12 * 60 * 60 + def __init__(self, db: Session): self.db = db self.base_service = BaseDataService(db) @@ -31,7 +34,7 @@ class CacheService: def get_all_codes(self, security_type: str, contract_type: str = "all") -> List[str]: """ - 获取所有代码列表 + 获取所有代码列表(带Redis缓存) Args: security_type: 证券类型 (stock, future) @@ -40,21 +43,32 @@ class CacheService: Returns: 代码列表 """ + cache_key = f"code_list:{security_type}:{contract_type}" + + cached_codes = redis_client.get(cache_key) + if cached_codes is not None: + logger.info(f"从Redis缓存获取代码列表: {security_type}/{contract_type}, 共{len(cached_codes)}个") + return cached_codes + adapter = sdk_manager.get_default_connection() if not adapter: raise RuntimeError("SDK连接失败") + codes = [] if security_type == "stock": - return adapter.get_code_list("EXTRA_STOCK_A") + codes = adapter.get_code_list("EXTRA_STOCK_A") elif security_type == "future": if contract_type == "main": - # 只获取主力合约 main_contracts = adapter.get_all_main_contracts() - return list(main_contracts.values()) + codes = list(main_contracts.values()) else: - return adapter.get_code_list("EXTRA_FUTURE") - else: - return [] + codes = adapter.get_code_list("EXTRA_FUTURE") + + if codes: + redis_client.set(cache_key, codes, expire=self.CODE_LIST_CACHE_EXPIRE) + logger.info(f"代码列表已缓存到Redis: {security_type}/{contract_type}, 共{len(codes)}个, 有效期12小时") + + return codes def get_future_varieties(self) -> List[str]: """获取期货品种列表""" @@ -214,6 +228,10 @@ class CacheService: for date_key in daily_stats: daily_stats[date_key]["missing"] = daily_stats[date_key]["expected"] - daily_stats[date_key]["actual"] + # 保存缺失代码列表到任务记录 + missing_code_list = [m["code"] for m in missing_codes] + task.code_list = ",".join(missing_code_list[:500]) if missing_code_list else "" + task.status = "completed" task.success_count = len(complete_codes) task.error_count = error_count @@ -237,7 +255,8 @@ class CacheService: "security_type": security_type, "period_type": period_type, "daily_stats": daily_stats, - "missing_codes": missing_codes[:100] # 只返回前100个缺失代码 + "missing_codes": missing_codes[:100], + "missing_code_list": missing_code_list } except Exception as e: @@ -254,6 +273,202 @@ class CacheService: "error_message": str(e) } + def _create_cache_task( + self, + security_type: str, + period_type: str, + start_date: date, + end_date: date, + contract_type: str = "all" + ) -> CacheTask: + """创建缓存任务记录""" + code_list = self.get_all_codes(security_type, contract_type) + + if not code_list: + raise ValueError(f"无法获取{security_type}代码列表") + + task = CacheTask( + task_name=f"一键缓存所有数据 - {security_type} - {contract_type} - {len(code_list)}个代码", + task_type="cache_all_data", + security_type=security_type, + period_type=period_type, + start_date=start_date, + end_date=end_date, + code_list=",".join(code_list[:100]) + "...", + status="pending", + total_count=len(code_list), + started_at=datetime.utcnow() + ) + self.db.add(task) + self.db.commit() + self.db.refresh(task) + + return task + + def _create_fill_missing_task( + self, + security_type: str, + period_type: str, + start_date: date, + end_date: date, + missing_codes: List[str] + ) -> CacheTask: + """创建补齐缺失数据任务记录""" + if not missing_codes: + raise ValueError("没有缺失代码需要补齐") + + task = CacheTask( + task_name=f"一键补齐缺失数据 - {security_type} - {len(missing_codes)}个代码", + task_type="fill_missing_data", + security_type=security_type, + period_type=period_type, + start_date=start_date, + end_date=end_date, + code_list=",".join(missing_codes[:500]), + status="pending", + total_count=len(missing_codes), + started_at=datetime.utcnow() + ) + self.db.add(task) + self.db.commit() + self.db.refresh(task) + + return task + + def _execute_fill_missing_task( + self, + task_id: int, + security_type: str, + period_type: str, + start_date: date, + end_date: date, + missing_codes: List[str] + ): + """执行补齐缺失数据任务""" + task = self.db.query(CacheTask).filter(CacheTask.id == task_id).first() + if not task: + return + + task.status = "running" + self.db.commit() + + try: + success_count = 0 + error_count = 0 + + for i, code in enumerate(missing_codes): + try: + if security_type == "stock": + self.stock_service.get_kline([code], start_date, end_date, period_type) + elif security_type == "future": + self.future_service.get_kline([code], start_date, end_date, period_type) + + success_count += 1 + + except Exception as e: + logger.error(f"补齐{code}数据失败: {str(e)}") + error_count += 1 + + if (i + 1) % 10 == 0 or i == len(missing_codes) - 1: + task.success_count = success_count + task.error_count = error_count + task.progress = min(100, int((i + 1) / len(missing_codes) * 100)) + self.db.commit() + + task.status = "completed" + task.success_count = success_count + task.error_count = error_count + task.completed_at = datetime.utcnow() + self.db.commit() + + except Exception as e: + task.status = "failed" + task.error_message = str(e) + task.completed_at = datetime.utcnow() + self.db.commit() + + def _execute_cache_task( + self, + task_id: int, + security_type: str, + period_type: str, + start_date: date, + end_date: date, + contract_type: str = "all" + ): + """执行缓存任务""" + task = self.db.query(CacheTask).filter(CacheTask.id == task_id).first() + if not task: + return + + task.status = "running" + self.db.commit() + + code_list = self.get_all_codes(security_type, contract_type) + + try: + market = "CFE" if security_type == "future" else "SH" + trading_days = self.base_service.get_trading_calendar(market, start_date, end_date) + expected_count = len(trading_days) + + success_count = 0 + skipped_count = 0 + error_count = 0 + + for i, code in enumerate(code_list): + try: + if security_type == "stock" and period_type == "daily": + actual_count = self.db.query(StockKlineDaily).filter( + and_( + StockKlineDaily.code == code, + StockKlineDaily.trade_date >= start_date, + StockKlineDaily.trade_date <= end_date + ) + ).count() + elif security_type == "future" and period_type == "daily": + actual_count = self.db.query(FutureKlineDaily).filter( + and_( + FutureKlineDaily.code == code, + FutureKlineDaily.trade_date >= start_date, + FutureKlineDaily.trade_date <= end_date + ) + ).count() + else: + actual_count = 0 + + if actual_count >= expected_count: + skipped_count += 1 + continue + + if security_type == "stock": + self.stock_service.get_kline([code], start_date, end_date, period_type) + elif security_type == "future": + self.future_service.get_kline([code], start_date, end_date, period_type) + + success_count += 1 + + except Exception as e: + logger.error(f"缓存{code}数据失败: {str(e)}") + error_count += 1 + + if (i + 1) % 10 == 0 or i == len(code_list) - 1: + task.success_count = success_count + task.error_count = error_count + task.progress = min(100, int((i + 1) / len(code_list) * 100)) + self.db.commit() + + task.status = "completed" + task.success_count = success_count + task.error_count = error_count + task.completed_at = datetime.utcnow() + self.db.commit() + + except Exception as e: + task.status = "failed" + task.error_message = str(e) + task.completed_at = datetime.utcnow() + self.db.commit() + def cache_all_missing_data( self, security_type: str, diff --git a/frontend/src/api/cache.ts b/frontend/src/api/cache.ts index 60230e0..e83bde1 100644 --- a/frontend/src/api/cache.ts +++ b/frontend/src/api/cache.ts @@ -66,3 +66,13 @@ export const getFutureVarieties = () => { export const getMainContracts = () => { return request.get('/cache/main-contracts') } + +export const fillMissingData = (data: { + security_type: string + period_type: string + contract_type?: string + start_date: string + end_date: string +}) => { + return request.post('/cache/fill-missing', data) +} diff --git a/frontend/src/views/CacheManager/DetectMissing.vue b/frontend/src/views/CacheManager/DetectMissing.vue index c58e1c8..8ce8ef1 100644 --- a/frontend/src/views/CacheManager/DetectMissing.vue +++ b/frontend/src/views/CacheManager/DetectMissing.vue @@ -72,7 +72,17 @@ @@ -120,6 +130,22 @@ + + + + + + {{ cacheTask.task_name }} + + {{ cacheTask.status }} + + {{ cacheTask.total_count }} + {{ cacheTask.success_count || 0 }} + + +