"""WebSocket服务 - 对应Go的internal/websocket/server.go""" import asyncio import json from datetime import datetime from typing import Dict, Set, Optional from dataclasses import dataclass, field from fastapi import WebSocket, WebSocketDisconnect from app.core.logger import info, error @dataclass class WSClient: """WebSocket客户端""" id: str websocket: WebSocket subscriptions: Set[str] = field(default_factory=set) async def send(self, message: dict): """发送消息""" try: await self.websocket.send_json(message) except Exception as e: error(f"Failed to send message to client {self.id}: {e}") class WebSocketManager: """WebSocket连接管理器""" def __init__(self): self.clients: Dict[str, WSClient] = {} self.subscriptions: Dict[str, Set[str]] = {} # symbol -> set of client_ids self.max_symbols_per_client = 100 self.lock = asyncio.Lock() async def connect(self, websocket: WebSocket, client_id: str) -> WSClient: """建立连接""" await websocket.accept() client = WSClient(id=client_id, websocket=websocket) async with self.lock: self.clients[client_id] = client info(f"WebSocket client connected: {client_id}, total: {len(self.clients)}") return client async def disconnect(self, client_id: str): """断开连接""" async with self.lock: if client_id in self.clients: client = self.clients.pop(client_id) # 清理订阅 for symbol in client.subscriptions: if symbol in self.subscriptions: self.subscriptions[symbol].discard(client_id) if not self.subscriptions[symbol]: del self.subscriptions[symbol] info(f"WebSocket client disconnected: {client_id}, total: {len(self.clients)}") async def subscribe(self, client_id: str, symbols: list) -> bool: """订阅标的""" async with self.lock: if client_id not in self.clients: return False client = self.clients[client_id] # 检查订阅数量限制 if len(client.subscriptions) + len(symbols) > self.max_symbols_per_client: return False for symbol in symbols: client.subscriptions.add(symbol) if symbol not in self.subscriptions: self.subscriptions[symbol] = set() self.subscriptions[symbol].add(client_id) return True async def unsubscribe(self, client_id: str, symbols: list): """取消订阅""" async with self.lock: if client_id not in self.clients: return client = self.clients[client_id] for symbol in symbols: client.subscriptions.discard(symbol) if symbol in self.subscriptions: self.subscriptions[symbol].discard(client_id) if not self.subscriptions[symbol]: del self.subscriptions[symbol] async def broadcast_to_symbol(self, symbol: str, message: dict): """向订阅了某标的的所有客户端广播""" client_ids = set() async with self.lock: if symbol in self.subscriptions: client_ids = self.subscriptions[symbol].copy() # 在锁外发送消息 for client_id in client_ids: if client_id in self.clients: try: await self.clients[client_id].send(message) except Exception as e: error(f"Failed to broadcast to {client_id}: {e}") def get_stats(self) -> dict: """获取统计信息""" return { "total_clients": len(self.clients), "total_subscriptions": len(self.subscriptions) } # 全局WebSocket管理器实例 ws_manager = WebSocketManager() class WebSocketServer: """WebSocket服务器""" def __init__(self): self.manager = ws_manager async def handle(self, websocket: WebSocket, client_id: str): """处理WebSocket连接""" client = await self.manager.connect(websocket, client_id) try: while True: # 接收消息 data = await websocket.receive_text() try: msg = json.loads(data) action = msg.get("action") symbols = msg.get("symbols", []) if action == "subscribe": success = await self.manager.subscribe(client_id, symbols) if success: await client.send({ "type": "ack", "action": "subscribe", "symbols": symbols, "ts": datetime.now().isoformat() }) else: await client.send({ "type": "error", "code": 1003, "message": "Too many subscriptions or subscription failed", "ts": datetime.now().isoformat() }) elif action == "unsubscribe": await self.manager.unsubscribe(client_id, symbols) await client.send({ "type": "ack", "action": "unsubscribe", "symbols": symbols, "ts": datetime.now().isoformat() }) else: await client.send({ "type": "error", "code": 1001, "message": "Unknown action", "ts": datetime.now().isoformat() }) except json.JSONDecodeError: await client.send({ "type": "error", "code": 1000, "message": "Invalid message format", "ts": datetime.now().isoformat() }) except WebSocketDisconnect: await self.manager.disconnect(client_id) except Exception as e: error(f"WebSocket error for client {client_id}: {e}") await self.manager.disconnect(client_id) async def send_heartbeat(self): """发送心跳(可由定时任务调用)""" message = { "type": "heartbeat", "ts": datetime.now().isoformat() } # 向所有客户端发送心跳 clients_copy = list(self.manager.clients.values()) for client in clients_copy: try: await client.send(message) except Exception: pass