You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
105 lines
3.6 KiB
105 lines
3.6 KiB
"""
|
|
进度管理器 - WebSocket实时进度推送
|
|
"""
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Dict, Set, Optional
|
|
from datetime import datetime
|
|
from fastapi import WebSocket
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ProgressManager:
|
|
"""进度管理器"""
|
|
|
|
_instance = None
|
|
_lock = asyncio.Lock()
|
|
|
|
def __new__(cls):
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
cls._instance._connections: Dict[str, Set[WebSocket]] = {}
|
|
cls._instance._progress_data: Dict[str, Dict] = {}
|
|
return cls._instance
|
|
|
|
async def connect(self, websocket: WebSocket, task_id: str):
|
|
"""连接WebSocket"""
|
|
await websocket.accept()
|
|
if task_id not in self._connections:
|
|
self._connections[task_id] = set()
|
|
self._connections[task_id].add(websocket)
|
|
|
|
if task_id in self._progress_data:
|
|
await websocket.send_json(self._progress_data[task_id])
|
|
|
|
logger.info(f"WebSocket连接: task_id={task_id}")
|
|
|
|
async def disconnect(self, websocket: WebSocket, task_id: str):
|
|
"""断开WebSocket连接"""
|
|
if task_id in self._connections:
|
|
self._connections[task_id].discard(websocket)
|
|
if not self._connections[task_id]:
|
|
del self._connections[task_id]
|
|
logger.info(f"WebSocket断开: task_id={task_id}")
|
|
|
|
async def update_progress(self, task_id: str, progress_data: Dict):
|
|
"""更新进度并推送"""
|
|
progress_data["timestamp"] = datetime.utcnow().isoformat()
|
|
self._progress_data[task_id] = progress_data
|
|
|
|
if task_id in self._connections:
|
|
disconnected = set()
|
|
for websocket in self._connections[task_id]:
|
|
try:
|
|
await websocket.send_json(progress_data)
|
|
except Exception as e:
|
|
logger.warning(f"WebSocket发送失败: {e}")
|
|
disconnected.add(websocket)
|
|
|
|
for ws in disconnected:
|
|
self._connections[task_id].discard(ws)
|
|
|
|
async def complete_task(self, task_id: str, result: Dict):
|
|
"""完成任务"""
|
|
result["status"] = "completed"
|
|
result["timestamp"] = datetime.utcnow().isoformat()
|
|
self._progress_data[task_id] = result
|
|
|
|
if task_id in self._connections:
|
|
for websocket in self._connections[task_id]:
|
|
try:
|
|
await websocket.send_json(result)
|
|
except Exception as e:
|
|
logger.warning(f"WebSocket发送失败: {e}")
|
|
|
|
async def fail_task(self, task_id: str, error: str):
|
|
"""任务失败"""
|
|
result = {
|
|
"status": "failed",
|
|
"error": error,
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
self._progress_data[task_id] = result
|
|
|
|
if task_id in self._connections:
|
|
for websocket in self._connections[task_id]:
|
|
try:
|
|
await websocket.send_json(result)
|
|
except Exception as e:
|
|
logger.warning(f"WebSocket发送失败: {e}")
|
|
|
|
def get_progress(self, task_id: str) -> Optional[Dict]:
|
|
"""获取进度数据"""
|
|
return self._progress_data.get(task_id)
|
|
|
|
def clear_task(self, task_id: str):
|
|
"""清除任务数据"""
|
|
if task_id in self._progress_data:
|
|
del self._progress_data[task_id]
|
|
if task_id in self._connections:
|
|
del self._connections[task_id]
|
|
|
|
|
|
progress_manager = ProgressManager() |