|
|
|
|
|
# backend/app/websocket/connection_manager.py
|
|
|
|
|
|
"""
|
|
|
|
|
|
WebSocket 连接管理器
|
|
|
|
|
|
支持 1000+ 并发连接,心跳机制,订阅管理
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import asyncio
|
|
|
|
|
|
import json
|
|
|
|
|
|
import time
|
|
|
|
|
|
from typing import Dict, Set, Optional, List
|
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
from fastapi import WebSocket, WebSocketDisconnect
|
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
|
import uuid
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConnectionManager:
|
|
|
|
|
|
"""
|
|
|
|
|
|
WebSocket 连接管理器
|
|
|
|
|
|
|
|
|
|
|
|
功能:
|
|
|
|
|
|
- 连接管理(存储、断开清理)
|
|
|
|
|
|
- 认证验证
|
|
|
|
|
|
- 心跳机制(30秒间隔)
|
|
|
|
|
|
- 订阅管理(订阅/取消订阅)
|
|
|
|
|
|
- 消息推送(广播、定向推送)
|
|
|
|
|
|
|
|
|
|
|
|
性能优化:
|
|
|
|
|
|
- 异步 IO
|
|
|
|
|
|
- 连接池管理
|
|
|
|
|
|
- 消息序列化优化
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
# user_id -> Set[WebSocket] (支持多连接)
|
|
|
|
|
|
self.active_connections: Dict[int, Set[WebSocket]] = defaultdict(set)
|
|
|
|
|
|
|
|
|
|
|
|
# WebSocket -> user_id (反向映射)
|
|
|
|
|
|
self.connection_users: Dict[WebSocket, int] = {}
|
|
|
|
|
|
|
|
|
|
|
|
# user_id -> Set[symbols] (订阅管理)
|
|
|
|
|
|
self.subscriptions: Dict[int, Set[str]] = defaultdict(set)
|
|
|
|
|
|
|
|
|
|
|
|
# symbol -> Set[user_id] (反向映射,用于广播)
|
|
|
|
|
|
self.symbol_subscribers: Dict[str, Set[int]] = defaultdict(set)
|
|
|
|
|
|
|
|
|
|
|
|
# WebSocket -> connection_id (连接标识)
|
|
|
|
|
|
self.connection_ids: Dict[WebSocket, str] = {}
|
|
|
|
|
|
|
|
|
|
|
|
# 心跳时间记录 (WebSocket -> last_heartbeat)
|
|
|
|
|
|
self.heartbeat_times: Dict[WebSocket, float] = {}
|
|
|
|
|
|
|
|
|
|
|
|
# 连接统计
|
|
|
|
|
|
self.total_connections = 0
|
|
|
|
|
|
self.total_messages_sent = 0
|
|
|
|
|
|
|
|
|
|
|
|
# 心跳超时时间(秒)
|
|
|
|
|
|
self.heartbeat_timeout = 90
|
|
|
|
|
|
|
|
|
|
|
|
# 锁(用于并发安全)
|
|
|
|
|
|
self._lock = asyncio.Lock()
|
|
|
|
|
|
|
|
|
|
|
|
async def connect(self, websocket: WebSocket, user_id: int, client_ip: str = None):
|
|
|
|
|
|
"""
|
|
|
|
|
|
建立连接
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
websocket: WebSocket 连接对象
|
|
|
|
|
|
user_id: 用户 ID
|
|
|
|
|
|
client_ip: 客户端 IP
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
str: connection_id
|
|
|
|
|
|
"""
|
|
|
|
|
|
async with self._lock:
|
|
|
|
|
|
# 接受连接
|
|
|
|
|
|
await websocket.accept()
|
|
|
|
|
|
|
|
|
|
|
|
# 生成连接 ID
|
|
|
|
|
|
connection_id = str(uuid.uuid4())
|
|
|
|
|
|
|
|
|
|
|
|
# 存储连接信息
|
|
|
|
|
|
self.active_connections[user_id].add(websocket)
|
|
|
|
|
|
self.connection_users[websocket] = user_id
|
|
|
|
|
|
self.connection_ids[websocket] = connection_id
|
|
|
|
|
|
self.heartbeat_times[websocket] = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
# 更新统计
|
|
|
|
|
|
self.total_connections += 1
|
|
|
|
|
|
|
|
|
|
|
|
# 发送连接成功消息
|
|
|
|
|
|
await self.send_to_connection(websocket, {
|
|
|
|
|
|
"type": "system",
|
|
|
|
|
|
"event": "connected",
|
|
|
|
|
|
"connection_id": connection_id,
|
|
|
|
|
|
"time": datetime.now().isoformat(),
|
|
|
|
|
|
"message": "WebSocket 连接成功"
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
return connection_id
|
|
|
|
|
|
|
|
|
|
|
|
async def disconnect(self, websocket: WebSocket):
|
|
|
|
|
|
"""
|
|
|
|
|
|
断开连接
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
websocket: WebSocket 连接对象
|
|
|
|
|
|
"""
|
|
|
|
|
|
async with self._lock:
|
|
|
|
|
|
user_id = self.connection_users.get(websocket)
|
|
|
|
|
|
if user_id is None:
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# 清理订阅
|
|
|
|
|
|
subscribed_symbols = self.subscriptions.get(user_id, set())
|
|
|
|
|
|
for symbol in subscribed_symbols:
|
|
|
|
|
|
self.symbol_subscribers[symbol].discard(user_id)
|
|
|
|
|
|
|
|
|
|
|
|
# 清理连接
|
|
|
|
|
|
self.active_connections[user_id].discard(websocket)
|
|
|
|
|
|
if not self.active_connections[user_id]:
|
|
|
|
|
|
del self.active_connections[user_id]
|
|
|
|
|
|
if user_id in self.subscriptions:
|
|
|
|
|
|
del self.subscriptions[user_id]
|
|
|
|
|
|
|
|
|
|
|
|
# 清理反向映射
|
|
|
|
|
|
del self.connection_users[websocket]
|
|
|
|
|
|
del self.connection_ids[websocket]
|
|
|
|
|
|
del self.heartbeat_times[websocket]
|
|
|
|
|
|
|
|
|
|
|
|
# 更新统计
|
|
|
|
|
|
self.total_connections -= 1
|
|
|
|
|
|
|
|
|
|
|
|
async def subscribe(self, websocket: WebSocket, symbols: List[str]):
|
|
|
|
|
|
"""
|
|
|
|
|
|
订阅品种
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
websocket: WebSocket 连接对象
|
|
|
|
|
|
symbols: 品种代码列表
|
|
|
|
|
|
"""
|
|
|
|
|
|
async with self._lock:
|
|
|
|
|
|
user_id = self.connection_users.get(websocket)
|
|
|
|
|
|
if user_id is None:
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# 添加订阅
|
|
|
|
|
|
for symbol in symbols:
|
|
|
|
|
|
self.subscriptions[user_id].add(symbol)
|
|
|
|
|
|
self.symbol_subscribers[symbol].add(user_id)
|
|
|
|
|
|
|
|
|
|
|
|
# 发送订阅确认
|
|
|
|
|
|
await self.send_to_connection(websocket, {
|
|
|
|
|
|
"type": "system",
|
|
|
|
|
|
"event": "subscribed",
|
|
|
|
|
|
"symbols": symbols,
|
|
|
|
|
|
"time": datetime.now().isoformat(),
|
|
|
|
|
|
"message": f"已订阅 {len(symbols)} 个品种"
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
async def unsubscribe(self, websocket: WebSocket, symbols: List[str]):
|
|
|
|
|
|
"""
|
|
|
|
|
|
取消订阅
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
websocket: WebSocket 连接对象
|
|
|
|
|
|
symbols: 品种代码列表
|
|
|
|
|
|
"""
|
|
|
|
|
|
async with self._lock:
|
|
|
|
|
|
user_id = self.connection_users.get(websocket)
|
|
|
|
|
|
if user_id is None:
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# 取消订阅
|
|
|
|
|
|
for symbol in symbols:
|
|
|
|
|
|
self.subscriptions[user_id].discard(symbol)
|
|
|
|
|
|
self.symbol_subscribers[symbol].discard(user_id)
|
|
|
|
|
|
|
|
|
|
|
|
# 发送取消确认
|
|
|
|
|
|
await self.send_to_connection(websocket, {
|
|
|
|
|
|
"type": "system",
|
|
|
|
|
|
"event": "unsubscribed",
|
|
|
|
|
|
"symbols": symbols,
|
|
|
|
|
|
"time": datetime.now().isoformat(),
|
|
|
|
|
|
"message": f"已取消订阅 {len(symbols)} 个品种"
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
async def send_to_connection(self, websocket: WebSocket, message: dict):
|
|
|
|
|
|
"""
|
|
|
|
|
|
向单个连接发送消息
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
websocket: WebSocket 连接对象
|
|
|
|
|
|
message: 消息内容
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
await websocket.send_json(message)
|
|
|
|
|
|
self.total_messages_sent += 1
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
# 连接已断开,清理
|
|
|
|
|
|
await self.disconnect(websocket)
|
|
|
|
|
|
|
|
|
|
|
|
async def send_to_user(self, user_id: int, message: dict):
|
|
|
|
|
|
"""
|
|
|
|
|
|
向用户的所有连接发送消息
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
user_id: 用户 ID
|
|
|
|
|
|
message: 消息内容
|
|
|
|
|
|
"""
|
|
|
|
|
|
connections = self.active_connections.get(user_id, set())
|
|
|
|
|
|
for websocket in list(connections): # 使用 list 防止迭代时修改
|
|
|
|
|
|
await self.send_to_connection(websocket, message)
|
|
|
|
|
|
|
|
|
|
|
|
async def broadcast_to_symbol(self, symbol: str, message: dict):
|
|
|
|
|
|
"""
|
|
|
|
|
|
向订阅该品种的所有用户广播
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
symbol: 品种代码
|
|
|
|
|
|
message: 消息内容
|
|
|
|
|
|
"""
|
|
|
|
|
|
subscribers = self.symbol_subscribers.get(symbol, set())
|
|
|
|
|
|
for user_id in list(subscribers):
|
|
|
|
|
|
await self.send_to_user(user_id, message)
|
|
|
|
|
|
|
|
|
|
|
|
async def broadcast(self, message: dict):
|
|
|
|
|
|
"""
|
|
|
|
|
|
向所有连接广播
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
message: 消息内容
|
|
|
|
|
|
"""
|
|
|
|
|
|
for user_id, connections in self.active_connections.items():
|
|
|
|
|
|
for websocket in list(connections):
|
|
|
|
|
|
await self.send_to_connection(websocket, message)
|
|
|
|
|
|
|
|
|
|
|
|
async def handle_heartbeat(self, websocket: WebSocket):
|
|
|
|
|
|
"""
|
|
|
|
|
|
处理心跳
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
websocket: WebSocket 连接对象
|
|
|
|
|
|
"""
|
|
|
|
|
|
self.heartbeat_times[websocket] = time.time()
|
|
|
|
|
|
await self.send_to_connection(websocket, {
|
|
|
|
|
|
"type": "system",
|
|
|
|
|
|
"event": "heartbeat",
|
|
|
|
|
|
"time": datetime.now().isoformat()
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
async def check_heartbeat_timeout(self):
|
|
|
|
|
|
"""
|
|
|
|
|
|
检查心跳超时
|
|
|
|
|
|
|
|
|
|
|
|
超过 heartbeat_timeout 秒无心跳的连接将被断开
|
|
|
|
|
|
"""
|
|
|
|
|
|
current_time = time.time()
|
|
|
|
|
|
timeout_connections = []
|
|
|
|
|
|
|
|
|
|
|
|
async with self._lock:
|
|
|
|
|
|
for websocket, last_time in self.heartbeat_times.items():
|
|
|
|
|
|
if current_time - last_time > self.heartbeat_timeout:
|
|
|
|
|
|
timeout_connections.append(websocket)
|
|
|
|
|
|
|
|
|
|
|
|
# 断开超时连接
|
|
|
|
|
|
for websocket in timeout_connections:
|
|
|
|
|
|
try:
|
|
|
|
|
|
await websocket.close(code=4003, reason="心跳超时")
|
|
|
|
|
|
except:
|
|
|
|
|
|
pass
|
|
|
|
|
|
await self.disconnect(websocket)
|
|
|
|
|
|
|
|
|
|
|
|
def get_connection_count(self) -> int:
|
|
|
|
|
|
"""获取当前连接数"""
|
|
|
|
|
|
return self.total_connections
|
|
|
|
|
|
|
|
|
|
|
|
def get_user_count(self) -> int:
|
|
|
|
|
|
"""获取当前用户数"""
|
|
|
|
|
|
return len(self.active_connections)
|
|
|
|
|
|
|
|
|
|
|
|
def get_subscription_count(self) -> int:
|
|
|
|
|
|
"""获取订阅总数"""
|
|
|
|
|
|
return sum(len(symbols) for symbols in self.symbol_subscribers.values())
|
|
|
|
|
|
|
|
|
|
|
|
def get_statistics(self) -> dict:
|
|
|
|
|
|
"""获取统计信息"""
|
|
|
|
|
|
return {
|
|
|
|
|
|
"total_connections": self.total_connections,
|
|
|
|
|
|
"active_users": self.get_user_count(),
|
|
|
|
|
|
"total_subscriptions": self.get_subscription_count(),
|
|
|
|
|
|
"total_messages_sent": self.total_messages_sent,
|
|
|
|
|
|
"symbol_subscribers": {
|
|
|
|
|
|
symbol: len(users)
|
|
|
|
|
|
for symbol, users in self.symbol_subscribers.items()
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def get_user_subscriptions(self, user_id: int) -> List[str]:
|
|
|
|
|
|
"""获取用户订阅列表"""
|
|
|
|
|
|
return list(self.subscriptions.get(user_id, set()))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 全局连接管理器实例
|
|
|
|
|
|
connection_manager = ConnectionManager()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ============== WebSocket 路由处理 ==============
|
|
|
|
|
|
|
|
|
|
|
|
async def websocket_handler(websocket: WebSocket, user_id: int):
|
|
|
|
|
|
"""
|
|
|
|
|
|
WebSocket 消息处理
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
websocket: WebSocket 连接对象
|
|
|
|
|
|
user_id: 用户 ID
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 建立连接
|
|
|
|
|
|
connection_id = await connection_manager.connect(websocket, user_id)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
while True:
|
|
|
|
|
|
# 接收消息
|
|
|
|
|
|
data = await websocket.receive_json()
|
|
|
|
|
|
|
|
|
|
|
|
# 处理消息
|
|
|
|
|
|
action = data.get("action")
|
|
|
|
|
|
|
|
|
|
|
|
if action == "subscribe":
|
|
|
|
|
|
# 订阅品种
|
|
|
|
|
|
symbols = data.get("symbols", [])
|
|
|
|
|
|
if symbols:
|
|
|
|
|
|
await connection_manager.subscribe(websocket, symbols)
|
|
|
|
|
|
|
|
|
|
|
|
elif action == "unsubscribe":
|
|
|
|
|
|
# 取消订阅
|
|
|
|
|
|
symbols = data.get("symbols", [])
|
|
|
|
|
|
if symbols:
|
|
|
|
|
|
await connection_manager.unsubscribe(websocket, symbols)
|
|
|
|
|
|
|
|
|
|
|
|
elif action == "heartbeat":
|
|
|
|
|
|
# 心跳
|
|
|
|
|
|
await connection_manager.handle_heartbeat(websocket)
|
|
|
|
|
|
|
|
|
|
|
|
elif action == "query":
|
|
|
|
|
|
# 查询订阅
|
|
|
|
|
|
subscriptions = connection_manager.get_user_subscriptions(user_id)
|
|
|
|
|
|
await connection_manager.send_to_connection(websocket, {
|
|
|
|
|
|
"type": "system",
|
|
|
|
|
|
"event": "query_result",
|
|
|
|
|
|
"subscriptions": subscriptions,
|
|
|
|
|
|
"time": datetime.now().isoformat()
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 未知操作
|
|
|
|
|
|
await connection_manager.send_to_connection(websocket, {
|
|
|
|
|
|
"type": "system",
|
|
|
|
|
|
"event": "error",
|
|
|
|
|
|
"message": f"未知操作: {action}",
|
|
|
|
|
|
"time": datetime.now().isoformat()
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
except WebSocketDisconnect:
|
|
|
|
|
|
# 客户端断开
|
|
|
|
|
|
await connection_manager.disconnect(websocket)
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
# 其他异常
|
|
|
|
|
|
await connection_manager.disconnect(websocket)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ============== 心跳检查任务 ==============
|
|
|
|
|
|
|
|
|
|
|
|
async def heartbeat_checker():
|
|
|
|
|
|
"""
|
|
|
|
|
|
心跳检查后台任务
|
|
|
|
|
|
|
|
|
|
|
|
每 30 秒检查一次心跳超时
|
|
|
|
|
|
"""
|
|
|
|
|
|
while True:
|
|
|
|
|
|
await asyncio.sleep(30)
|
|
|
|
|
|
await connection_manager.check_heartbeat_timeout()
|