parent
7a2094b738
commit
4db22a6f77
@ -0,0 +1,33 @@
|
||||
"""
|
||||
WebSocket进度路由
|
||||
"""
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||
from app.core.progress_manager import progress_manager
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.websocket("/progress/{task_id}")
|
||||
async def websocket_progress(
|
||||
websocket: WebSocket,
|
||||
task_id: str
|
||||
):
|
||||
"""WebSocket进度推送"""
|
||||
await progress_manager.connect(websocket, task_id)
|
||||
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
elif data == "close":
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket断开连接: task_id={task_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket错误: {e}")
|
||||
finally:
|
||||
await progress_manager.disconnect(websocket, task_id)
|
||||
@ -0,0 +1,105 @@
|
||||
"""
|
||||
进度管理器 - WebSocket实时进度推送
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Set, Optional
|
||||
from datetime import datetime
|
||||
from fastapi import WebSocket
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProgressManager:
|
||||
"""进度管理器"""
|
||||
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._connections: Dict[str, Set[WebSocket]] = {}
|
||||
cls._instance._progress_data: Dict[str, Dict] = {}
|
||||
return cls._instance
|
||||
|
||||
async def connect(self, websocket: WebSocket, task_id: str):
|
||||
"""连接WebSocket"""
|
||||
await websocket.accept()
|
||||
if task_id not in self._connections:
|
||||
self._connections[task_id] = set()
|
||||
self._connections[task_id].add(websocket)
|
||||
|
||||
if task_id in self._progress_data:
|
||||
await websocket.send_json(self._progress_data[task_id])
|
||||
|
||||
logger.info(f"WebSocket连接: task_id={task_id}")
|
||||
|
||||
async def disconnect(self, websocket: WebSocket, task_id: str):
|
||||
"""断开WebSocket连接"""
|
||||
if task_id in self._connections:
|
||||
self._connections[task_id].discard(websocket)
|
||||
if not self._connections[task_id]:
|
||||
del self._connections[task_id]
|
||||
logger.info(f"WebSocket断开: task_id={task_id}")
|
||||
|
||||
async def update_progress(self, task_id: str, progress_data: Dict):
|
||||
"""更新进度并推送"""
|
||||
progress_data["timestamp"] = datetime.utcnow().isoformat()
|
||||
self._progress_data[task_id] = progress_data
|
||||
|
||||
if task_id in self._connections:
|
||||
disconnected = set()
|
||||
for websocket in self._connections[task_id]:
|
||||
try:
|
||||
await websocket.send_json(progress_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket发送失败: {e}")
|
||||
disconnected.add(websocket)
|
||||
|
||||
for ws in disconnected:
|
||||
self._connections[task_id].discard(ws)
|
||||
|
||||
async def complete_task(self, task_id: str, result: Dict):
|
||||
"""完成任务"""
|
||||
result["status"] = "completed"
|
||||
result["timestamp"] = datetime.utcnow().isoformat()
|
||||
self._progress_data[task_id] = result
|
||||
|
||||
if task_id in self._connections:
|
||||
for websocket in self._connections[task_id]:
|
||||
try:
|
||||
await websocket.send_json(result)
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket发送失败: {e}")
|
||||
|
||||
async def fail_task(self, task_id: str, error: str):
|
||||
"""任务失败"""
|
||||
result = {
|
||||
"status": "failed",
|
||||
"error": error,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
self._progress_data[task_id] = result
|
||||
|
||||
if task_id in self._connections:
|
||||
for websocket in self._connections[task_id]:
|
||||
try:
|
||||
await websocket.send_json(result)
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket发送失败: {e}")
|
||||
|
||||
def get_progress(self, task_id: str) -> Optional[Dict]:
|
||||
"""获取进度数据"""
|
||||
return self._progress_data.get(task_id)
|
||||
|
||||
def clear_task(self, task_id: str):
|
||||
"""清除任务数据"""
|
||||
if task_id in self._progress_data:
|
||||
del self._progress_data[task_id]
|
||||
if task_id in self._connections:
|
||||
del self._connections[task_id]
|
||||
|
||||
|
||||
progress_manager = ProgressManager()
|
||||
Loading…
Reference in new issue