You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

383 lines
12 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# backend/app/websocket/connection_manager.py
"""
WebSocket 连接管理器
支持 1000+ 并发连接,心跳机制,订阅管理
"""
import asyncio
import json
import time
from typing import Dict, Set, Optional, List
from datetime import datetime
from fastapi import WebSocket, WebSocketDisconnect
from collections import defaultdict
import uuid
class ConnectionManager:
"""
WebSocket 连接管理器
功能:
- 连接管理(存储、断开清理)
- 认证验证
- 心跳机制30秒间隔
- 订阅管理(订阅/取消订阅)
- 消息推送(广播、定向推送)
性能优化:
- 异步 IO
- 连接池管理
- 消息序列化优化
"""
def __init__(self):
# user_id -> Set[WebSocket] (支持多连接)
self.active_connections: Dict[int, Set[WebSocket]] = defaultdict(set)
# WebSocket -> user_id (反向映射)
self.connection_users: Dict[WebSocket, int] = {}
# user_id -> Set[symbols] (订阅管理)
self.subscriptions: Dict[int, Set[str]] = defaultdict(set)
# symbol -> Set[user_id] (反向映射,用于广播)
self.symbol_subscribers: Dict[str, Set[int]] = defaultdict(set)
# WebSocket -> connection_id (连接标识)
self.connection_ids: Dict[WebSocket, str] = {}
# 心跳时间记录 (WebSocket -> last_heartbeat)
self.heartbeat_times: Dict[WebSocket, float] = {}
# 连接统计
self.total_connections = 0
self.total_messages_sent = 0
# 心跳超时时间(秒)
self.heartbeat_timeout = 90
# 锁(用于并发安全)
self._lock = asyncio.Lock()
async def connect(self, websocket: WebSocket, user_id: int, client_ip: str = None):
"""
建立连接
Args:
websocket: WebSocket 连接对象
user_id: 用户 ID
client_ip: 客户端 IP
Returns:
str: connection_id
"""
async with self._lock:
# 接受连接
await websocket.accept()
# 生成连接 ID
connection_id = str(uuid.uuid4())
# 存储连接信息
self.active_connections[user_id].add(websocket)
self.connection_users[websocket] = user_id
self.connection_ids[websocket] = connection_id
self.heartbeat_times[websocket] = time.time()
# 更新统计
self.total_connections += 1
# 发送连接成功消息
await self.send_to_connection(websocket, {
"type": "system",
"event": "connected",
"connection_id": connection_id,
"time": datetime.now().isoformat(),
"message": "WebSocket 连接成功"
})
return connection_id
async def disconnect(self, websocket: WebSocket):
"""
断开连接
Args:
websocket: WebSocket 连接对象
"""
async with self._lock:
user_id = self.connection_users.get(websocket)
if user_id is None:
return
# 清理订阅
subscribed_symbols = self.subscriptions.get(user_id, set())
for symbol in subscribed_symbols:
self.symbol_subscribers[symbol].discard(user_id)
# 清理连接
self.active_connections[user_id].discard(websocket)
if not self.active_connections[user_id]:
del self.active_connections[user_id]
if user_id in self.subscriptions:
del self.subscriptions[user_id]
# 清理反向映射
del self.connection_users[websocket]
del self.connection_ids[websocket]
del self.heartbeat_times[websocket]
# 更新统计
self.total_connections -= 1
async def subscribe(self, websocket: WebSocket, symbols: List[str]):
"""
订阅品种
Args:
websocket: WebSocket 连接对象
symbols: 品种代码列表
"""
async with self._lock:
user_id = self.connection_users.get(websocket)
if user_id is None:
return
# 添加订阅
for symbol in symbols:
self.subscriptions[user_id].add(symbol)
self.symbol_subscribers[symbol].add(user_id)
# 发送订阅确认
await self.send_to_connection(websocket, {
"type": "system",
"event": "subscribed",
"symbols": symbols,
"time": datetime.now().isoformat(),
"message": f"已订阅 {len(symbols)} 个品种"
})
async def unsubscribe(self, websocket: WebSocket, symbols: List[str]):
"""
取消订阅
Args:
websocket: WebSocket 连接对象
symbols: 品种代码列表
"""
async with self._lock:
user_id = self.connection_users.get(websocket)
if user_id is None:
return
# 取消订阅
for symbol in symbols:
self.subscriptions[user_id].discard(symbol)
self.symbol_subscribers[symbol].discard(user_id)
# 发送取消确认
await self.send_to_connection(websocket, {
"type": "system",
"event": "unsubscribed",
"symbols": symbols,
"time": datetime.now().isoformat(),
"message": f"已取消订阅 {len(symbols)} 个品种"
})
async def send_to_connection(self, websocket: WebSocket, message: dict):
"""
向单个连接发送消息
Args:
websocket: WebSocket 连接对象
message: 消息内容
"""
try:
await websocket.send_json(message)
self.total_messages_sent += 1
except Exception as e:
# 连接已断开,清理
await self.disconnect(websocket)
async def send_to_user(self, user_id: int, message: dict):
"""
向用户的所有连接发送消息
Args:
user_id: 用户 ID
message: 消息内容
"""
connections = self.active_connections.get(user_id, set())
for websocket in list(connections): # 使用 list 防止迭代时修改
await self.send_to_connection(websocket, message)
async def broadcast_to_symbol(self, symbol: str, message: dict):
"""
向订阅该品种的所有用户广播
Args:
symbol: 品种代码
message: 消息内容
"""
subscribers = self.symbol_subscribers.get(symbol, set())
for user_id in list(subscribers):
await self.send_to_user(user_id, message)
async def broadcast(self, message: dict):
"""
向所有连接广播
Args:
message: 消息内容
"""
for user_id, connections in self.active_connections.items():
for websocket in list(connections):
await self.send_to_connection(websocket, message)
async def handle_heartbeat(self, websocket: WebSocket):
"""
处理心跳
Args:
websocket: WebSocket 连接对象
"""
self.heartbeat_times[websocket] = time.time()
await self.send_to_connection(websocket, {
"type": "system",
"event": "heartbeat",
"time": datetime.now().isoformat()
})
async def check_heartbeat_timeout(self):
"""
检查心跳超时
超过 heartbeat_timeout 秒无心跳的连接将被断开
"""
current_time = time.time()
timeout_connections = []
async with self._lock:
for websocket, last_time in self.heartbeat_times.items():
if current_time - last_time > self.heartbeat_timeout:
timeout_connections.append(websocket)
# 断开超时连接
for websocket in timeout_connections:
try:
await websocket.close(code=4003, reason="心跳超时")
except:
pass
await self.disconnect(websocket)
def get_connection_count(self) -> int:
"""获取当前连接数"""
return self.total_connections
def get_user_count(self) -> int:
"""获取当前用户数"""
return len(self.active_connections)
def get_subscription_count(self) -> int:
"""获取订阅总数"""
return sum(len(symbols) for symbols in self.symbol_subscribers.values())
def get_statistics(self) -> dict:
"""获取统计信息"""
return {
"total_connections": self.total_connections,
"active_users": self.get_user_count(),
"total_subscriptions": self.get_subscription_count(),
"total_messages_sent": self.total_messages_sent,
"symbol_subscribers": {
symbol: len(users)
for symbol, users in self.symbol_subscribers.items()
}
}
def get_user_subscriptions(self, user_id: int) -> List[str]:
"""获取用户订阅列表"""
return list(self.subscriptions.get(user_id, set()))
# 全局连接管理器实例
connection_manager = ConnectionManager()
# ============== WebSocket 路由处理 ==============
async def websocket_handler(websocket: WebSocket, user_id: int):
"""
WebSocket 消息处理
Args:
websocket: WebSocket 连接对象
user_id: 用户 ID
"""
# 建立连接
connection_id = await connection_manager.connect(websocket, user_id)
try:
while True:
# 接收消息
data = await websocket.receive_json()
# 处理消息
action = data.get("action")
if action == "subscribe":
# 订阅品种
symbols = data.get("symbols", [])
if symbols:
await connection_manager.subscribe(websocket, symbols)
elif action == "unsubscribe":
# 取消订阅
symbols = data.get("symbols", [])
if symbols:
await connection_manager.unsubscribe(websocket, symbols)
elif action == "heartbeat":
# 心跳
await connection_manager.handle_heartbeat(websocket)
elif action == "query":
# 查询订阅
subscriptions = connection_manager.get_user_subscriptions(user_id)
await connection_manager.send_to_connection(websocket, {
"type": "system",
"event": "query_result",
"subscriptions": subscriptions,
"time": datetime.now().isoformat()
})
else:
# 未知操作
await connection_manager.send_to_connection(websocket, {
"type": "system",
"event": "error",
"message": f"未知操作: {action}",
"time": datetime.now().isoformat()
})
except WebSocketDisconnect:
# 客户端断开
await connection_manager.disconnect(websocket)
except Exception as e:
# 其他异常
await connection_manager.disconnect(websocket)
# ============== 心跳检查任务 ==============
async def heartbeat_checker():
"""
心跳检查后台任务
每 30 秒检查一次心跳超时
"""
while True:
await asyncio.sleep(30)
await connection_manager.check_heartbeat_timeout()