You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

211 lines
7.3 KiB

"""WebSocket服务 - 对应Go的internal/websocket/server.go"""
import asyncio
import json
from datetime import datetime
from typing import Dict, Set, Optional
from dataclasses import dataclass, field
from fastapi import WebSocket, WebSocketDisconnect
from app.core.logger import info, error
@dataclass
class WSClient:
"""WebSocket客户端"""
id: str
websocket: WebSocket
subscriptions: Set[str] = field(default_factory=set)
async def send(self, message: dict):
"""发送消息"""
try:
await self.websocket.send_json(message)
except Exception as e:
error(f"Failed to send message to client {self.id}: {e}")
class WebSocketManager:
"""WebSocket连接管理器"""
def __init__(self):
self.clients: Dict[str, WSClient] = {}
self.subscriptions: Dict[str, Set[str]] = {} # symbol -> set of client_ids
self.max_symbols_per_client = 100
self.lock = asyncio.Lock()
async def connect(self, websocket: WebSocket, client_id: str) -> WSClient:
"""建立连接"""
await websocket.accept()
client = WSClient(id=client_id, websocket=websocket)
async with self.lock:
self.clients[client_id] = client
info(f"WebSocket client connected: {client_id}, total: {len(self.clients)}")
return client
async def disconnect(self, client_id: str):
"""断开连接"""
async with self.lock:
if client_id in self.clients:
client = self.clients.pop(client_id)
# 清理订阅
for symbol in client.subscriptions:
if symbol in self.subscriptions:
self.subscriptions[symbol].discard(client_id)
if not self.subscriptions[symbol]:
del self.subscriptions[symbol]
info(f"WebSocket client disconnected: {client_id}, total: {len(self.clients)}")
async def subscribe(self, client_id: str, symbols: list) -> bool:
"""订阅标的"""
async with self.lock:
if client_id not in self.clients:
return False
client = self.clients[client_id]
# 检查订阅数量限制
if len(client.subscriptions) + len(symbols) > self.max_symbols_per_client:
return False
for symbol in symbols:
client.subscriptions.add(symbol)
if symbol not in self.subscriptions:
self.subscriptions[symbol] = set()
self.subscriptions[symbol].add(client_id)
return True
async def unsubscribe(self, client_id: str, symbols: list):
"""取消订阅"""
async with self.lock:
if client_id not in self.clients:
return
client = self.clients[client_id]
for symbol in symbols:
client.subscriptions.discard(symbol)
if symbol in self.subscriptions:
self.subscriptions[symbol].discard(client_id)
if not self.subscriptions[symbol]:
del self.subscriptions[symbol]
async def broadcast_to_symbol(self, symbol: str, message: dict):
"""向订阅了某标的的所有客户端广播"""
client_ids = set()
async with self.lock:
if symbol in self.subscriptions:
client_ids = self.subscriptions[symbol].copy()
# 在锁外发送消息
for client_id in client_ids:
if client_id in self.clients:
try:
await self.clients[client_id].send(message)
except Exception as e:
error(f"Failed to broadcast to {client_id}: {e}")
def get_stats(self) -> dict:
"""获取统计信息"""
return {
"total_clients": len(self.clients),
"total_subscriptions": len(self.subscriptions)
}
# 全局WebSocket管理器实例
ws_manager = WebSocketManager()
class WebSocketServer:
"""WebSocket服务器"""
def __init__(self):
self.manager = ws_manager
async def handle(self, websocket: WebSocket, client_id: str):
"""处理WebSocket连接"""
client = await self.manager.connect(websocket, client_id)
try:
while True:
# 接收消息
data = await websocket.receive_text()
try:
msg = json.loads(data)
action = msg.get("action")
symbols = msg.get("symbols", [])
if action == "subscribe":
success = await self.manager.subscribe(client_id, symbols)
if success:
await client.send({
"type": "ack",
"action": "subscribe",
"symbols": symbols,
"ts": datetime.now().isoformat()
})
else:
await client.send({
"type": "error",
"code": 1003,
"message": "Too many subscriptions or subscription failed",
"ts": datetime.now().isoformat()
})
elif action == "unsubscribe":
await self.manager.unsubscribe(client_id, symbols)
await client.send({
"type": "ack",
"action": "unsubscribe",
"symbols": symbols,
"ts": datetime.now().isoformat()
})
else:
await client.send({
"type": "error",
"code": 1001,
"message": "Unknown action",
"ts": datetime.now().isoformat()
})
except json.JSONDecodeError:
await client.send({
"type": "error",
"code": 1000,
"message": "Invalid message format",
"ts": datetime.now().isoformat()
})
except WebSocketDisconnect:
await self.manager.disconnect(client_id)
except Exception as e:
error(f"WebSocket error for client {client_id}: {e}")
await self.manager.disconnect(client_id)
async def send_heartbeat(self):
"""发送心跳(可由定时任务调用)"""
message = {
"type": "heartbeat",
"ts": datetime.now().isoformat()
}
# 向所有客户端发送心跳
clients_copy = list(self.manager.clients.values())
for client in clients_copy:
try:
await client.send(message)
except Exception:
pass