You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
211 lines
7.3 KiB
211 lines
7.3 KiB
"""WebSocket服务 - 对应Go的internal/websocket/server.go"""
|
|
import asyncio
|
|
import json
|
|
from datetime import datetime
|
|
from typing import Dict, Set, Optional
|
|
from dataclasses import dataclass, field
|
|
|
|
from fastapi import WebSocket, WebSocketDisconnect
|
|
|
|
from app.core.logger import info, error
|
|
|
|
|
|
@dataclass
|
|
class WSClient:
|
|
"""WebSocket客户端"""
|
|
id: str
|
|
websocket: WebSocket
|
|
subscriptions: Set[str] = field(default_factory=set)
|
|
|
|
async def send(self, message: dict):
|
|
"""发送消息"""
|
|
try:
|
|
await self.websocket.send_json(message)
|
|
except Exception as e:
|
|
error(f"Failed to send message to client {self.id}: {e}")
|
|
|
|
|
|
class WebSocketManager:
|
|
"""WebSocket连接管理器"""
|
|
|
|
def __init__(self):
|
|
self.clients: Dict[str, WSClient] = {}
|
|
self.subscriptions: Dict[str, Set[str]] = {} # symbol -> set of client_ids
|
|
self.max_symbols_per_client = 100
|
|
self.lock = asyncio.Lock()
|
|
|
|
async def connect(self, websocket: WebSocket, client_id: str) -> WSClient:
|
|
"""建立连接"""
|
|
await websocket.accept()
|
|
|
|
client = WSClient(id=client_id, websocket=websocket)
|
|
|
|
async with self.lock:
|
|
self.clients[client_id] = client
|
|
|
|
info(f"WebSocket client connected: {client_id}, total: {len(self.clients)}")
|
|
return client
|
|
|
|
async def disconnect(self, client_id: str):
|
|
"""断开连接"""
|
|
async with self.lock:
|
|
if client_id in self.clients:
|
|
client = self.clients.pop(client_id)
|
|
|
|
# 清理订阅
|
|
for symbol in client.subscriptions:
|
|
if symbol in self.subscriptions:
|
|
self.subscriptions[symbol].discard(client_id)
|
|
if not self.subscriptions[symbol]:
|
|
del self.subscriptions[symbol]
|
|
|
|
info(f"WebSocket client disconnected: {client_id}, total: {len(self.clients)}")
|
|
|
|
async def subscribe(self, client_id: str, symbols: list) -> bool:
|
|
"""订阅标的"""
|
|
async with self.lock:
|
|
if client_id not in self.clients:
|
|
return False
|
|
|
|
client = self.clients[client_id]
|
|
|
|
# 检查订阅数量限制
|
|
if len(client.subscriptions) + len(symbols) > self.max_symbols_per_client:
|
|
return False
|
|
|
|
for symbol in symbols:
|
|
client.subscriptions.add(symbol)
|
|
|
|
if symbol not in self.subscriptions:
|
|
self.subscriptions[symbol] = set()
|
|
self.subscriptions[symbol].add(client_id)
|
|
|
|
return True
|
|
|
|
async def unsubscribe(self, client_id: str, symbols: list):
|
|
"""取消订阅"""
|
|
async with self.lock:
|
|
if client_id not in self.clients:
|
|
return
|
|
|
|
client = self.clients[client_id]
|
|
|
|
for symbol in symbols:
|
|
client.subscriptions.discard(symbol)
|
|
|
|
if symbol in self.subscriptions:
|
|
self.subscriptions[symbol].discard(client_id)
|
|
if not self.subscriptions[symbol]:
|
|
del self.subscriptions[symbol]
|
|
|
|
async def broadcast_to_symbol(self, symbol: str, message: dict):
|
|
"""向订阅了某标的的所有客户端广播"""
|
|
client_ids = set()
|
|
|
|
async with self.lock:
|
|
if symbol in self.subscriptions:
|
|
client_ids = self.subscriptions[symbol].copy()
|
|
|
|
# 在锁外发送消息
|
|
for client_id in client_ids:
|
|
if client_id in self.clients:
|
|
try:
|
|
await self.clients[client_id].send(message)
|
|
except Exception as e:
|
|
error(f"Failed to broadcast to {client_id}: {e}")
|
|
|
|
def get_stats(self) -> dict:
|
|
"""获取统计信息"""
|
|
return {
|
|
"total_clients": len(self.clients),
|
|
"total_subscriptions": len(self.subscriptions)
|
|
}
|
|
|
|
|
|
# 全局WebSocket管理器实例
|
|
ws_manager = WebSocketManager()
|
|
|
|
|
|
class WebSocketServer:
|
|
"""WebSocket服务器"""
|
|
|
|
def __init__(self):
|
|
self.manager = ws_manager
|
|
|
|
async def handle(self, websocket: WebSocket, client_id: str):
|
|
"""处理WebSocket连接"""
|
|
client = await self.manager.connect(websocket, client_id)
|
|
|
|
try:
|
|
while True:
|
|
# 接收消息
|
|
data = await websocket.receive_text()
|
|
|
|
try:
|
|
msg = json.loads(data)
|
|
action = msg.get("action")
|
|
symbols = msg.get("symbols", [])
|
|
|
|
if action == "subscribe":
|
|
success = await self.manager.subscribe(client_id, symbols)
|
|
if success:
|
|
await client.send({
|
|
"type": "ack",
|
|
"action": "subscribe",
|
|
"symbols": symbols,
|
|
"ts": datetime.now().isoformat()
|
|
})
|
|
else:
|
|
await client.send({
|
|
"type": "error",
|
|
"code": 1003,
|
|
"message": "Too many subscriptions or subscription failed",
|
|
"ts": datetime.now().isoformat()
|
|
})
|
|
|
|
elif action == "unsubscribe":
|
|
await self.manager.unsubscribe(client_id, symbols)
|
|
await client.send({
|
|
"type": "ack",
|
|
"action": "unsubscribe",
|
|
"symbols": symbols,
|
|
"ts": datetime.now().isoformat()
|
|
})
|
|
|
|
else:
|
|
await client.send({
|
|
"type": "error",
|
|
"code": 1001,
|
|
"message": "Unknown action",
|
|
"ts": datetime.now().isoformat()
|
|
})
|
|
|
|
except json.JSONDecodeError:
|
|
await client.send({
|
|
"type": "error",
|
|
"code": 1000,
|
|
"message": "Invalid message format",
|
|
"ts": datetime.now().isoformat()
|
|
})
|
|
|
|
except WebSocketDisconnect:
|
|
await self.manager.disconnect(client_id)
|
|
except Exception as e:
|
|
error(f"WebSocket error for client {client_id}: {e}")
|
|
await self.manager.disconnect(client_id)
|
|
|
|
async def send_heartbeat(self):
|
|
"""发送心跳(可由定时任务调用)"""
|
|
message = {
|
|
"type": "heartbeat",
|
|
"ts": datetime.now().isoformat()
|
|
}
|
|
|
|
# 向所有客户端发送心跳
|
|
clients_copy = list(self.manager.clients.values())
|
|
for client in clients_copy:
|
|
try:
|
|
await client.send(message)
|
|
except Exception:
|
|
pass
|