You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

105 lines
3.6 KiB

"""
进度管理器 - WebSocket实时进度推送
"""
import asyncio
import json
import logging
from typing import Dict, Set, Optional
from datetime import datetime
from fastapi import WebSocket
logger = logging.getLogger(__name__)
class ProgressManager:
"""进度管理器"""
_instance = None
_lock = asyncio.Lock()
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._connections: Dict[str, Set[WebSocket]] = {}
cls._instance._progress_data: Dict[str, Dict] = {}
return cls._instance
async def connect(self, websocket: WebSocket, task_id: str):
"""连接WebSocket"""
await websocket.accept()
if task_id not in self._connections:
self._connections[task_id] = set()
self._connections[task_id].add(websocket)
if task_id in self._progress_data:
await websocket.send_json(self._progress_data[task_id])
logger.info(f"WebSocket连接: task_id={task_id}")
async def disconnect(self, websocket: WebSocket, task_id: str):
"""断开WebSocket连接"""
if task_id in self._connections:
self._connections[task_id].discard(websocket)
if not self._connections[task_id]:
del self._connections[task_id]
logger.info(f"WebSocket断开: task_id={task_id}")
async def update_progress(self, task_id: str, progress_data: Dict):
"""更新进度并推送"""
progress_data["timestamp"] = datetime.utcnow().isoformat()
self._progress_data[task_id] = progress_data
if task_id in self._connections:
disconnected = set()
for websocket in self._connections[task_id]:
try:
await websocket.send_json(progress_data)
except Exception as e:
logger.warning(f"WebSocket发送失败: {e}")
disconnected.add(websocket)
for ws in disconnected:
self._connections[task_id].discard(ws)
async def complete_task(self, task_id: str, result: Dict):
"""完成任务"""
result["status"] = "completed"
result["timestamp"] = datetime.utcnow().isoformat()
self._progress_data[task_id] = result
if task_id in self._connections:
for websocket in self._connections[task_id]:
try:
await websocket.send_json(result)
except Exception as e:
logger.warning(f"WebSocket发送失败: {e}")
async def fail_task(self, task_id: str, error: str):
"""任务失败"""
result = {
"status": "failed",
"error": error,
"timestamp": datetime.utcnow().isoformat()
}
self._progress_data[task_id] = result
if task_id in self._connections:
for websocket in self._connections[task_id]:
try:
await websocket.send_json(result)
except Exception as e:
logger.warning(f"WebSocket发送失败: {e}")
def get_progress(self, task_id: str) -> Optional[Dict]:
"""获取进度数据"""
return self._progress_data.get(task_id)
def clear_task(self, task_id: str):
"""清除任务数据"""
if task_id in self._progress_data:
del self._progress_data[task_id]
if task_id in self._connections:
del self._connections[task_id]
progress_manager = ProgressManager()