""" 进度管理器 - WebSocket实时进度推送 """ import asyncio import json import logging from typing import Dict, Set, Optional from datetime import datetime from fastapi import WebSocket logger = logging.getLogger(__name__) class ProgressManager: """进度管理器""" _instance = None _lock = asyncio.Lock() def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._connections: Dict[str, Set[WebSocket]] = {} cls._instance._progress_data: Dict[str, Dict] = {} return cls._instance async def connect(self, websocket: WebSocket, task_id: str): """连接WebSocket""" await websocket.accept() if task_id not in self._connections: self._connections[task_id] = set() self._connections[task_id].add(websocket) if task_id in self._progress_data: await websocket.send_json(self._progress_data[task_id]) logger.info(f"WebSocket连接: task_id={task_id}") async def disconnect(self, websocket: WebSocket, task_id: str): """断开WebSocket连接""" if task_id in self._connections: self._connections[task_id].discard(websocket) if not self._connections[task_id]: del self._connections[task_id] logger.info(f"WebSocket断开: task_id={task_id}") async def update_progress(self, task_id: str, progress_data: Dict): """更新进度并推送""" progress_data["timestamp"] = datetime.utcnow().isoformat() self._progress_data[task_id] = progress_data if task_id in self._connections: disconnected = set() for websocket in self._connections[task_id]: try: await websocket.send_json(progress_data) except Exception as e: logger.warning(f"WebSocket发送失败: {e}") disconnected.add(websocket) for ws in disconnected: self._connections[task_id].discard(ws) async def complete_task(self, task_id: str, result: Dict): """完成任务""" result["status"] = "completed" result["timestamp"] = datetime.utcnow().isoformat() self._progress_data[task_id] = result if task_id in self._connections: for websocket in self._connections[task_id]: try: await websocket.send_json(result) except Exception as e: logger.warning(f"WebSocket发送失败: {e}") async def fail_task(self, task_id: str, error: str): """任务失败""" result = { "status": "failed", "error": error, "timestamp": datetime.utcnow().isoformat() } self._progress_data[task_id] = result if task_id in self._connections: for websocket in self._connections[task_id]: try: await websocket.send_json(result) except Exception as e: logger.warning(f"WebSocket发送失败: {e}") def get_progress(self, task_id: str) -> Optional[Dict]: """获取进度数据""" return self._progress_data.get(task_id) def clear_task(self, task_id: str): """清除任务数据""" if task_id in self._progress_data: del self._progress_data[task_id] if task_id in self._connections: del self._connections[task_id] progress_manager = ProgressManager()