""" 实时行情服务 """ import json import logging from datetime import datetime from typing import Dict, Set, Optional from collections import defaultdict import redis.asyncio as redis from app.config import settings logger = logging.getLogger(__name__) # WebSocket 连接限流配置 MAX_CONNECTIONS_PER_USER = 5 # 每个用户最大连接数 MAX_CONNECTIONS_PER_SYMBOL = 100 # 每个品种最大连接数 MAX_TOTAL_CONNECTIONS = 100 # 总连接数限制 class RealtimeService: """实时行情服务""" # WebSocket 连接管理 _active_connections: Dict[str, Set] = {} # symbol -> set of websockets _user_connections: Dict[int, Set] = defaultdict(set) # user_id -> set of websockets _websocket_user_map: Dict[int, int] = {} # websocket id -> user_id _anonymous_connections: Set = set() # 匿名连接集合 def __init__(self): self.redis: Optional[redis.Redis] = None async def connect_redis(self): """连接 Redis""" self.redis = redis.from_url(settings.REDIS_URL, decode_responses=True) logger.info("Redis connected for realtime service") async def disconnect_redis(self): """断开 Redis 连接""" if self.redis: await self.redis.close() def register_connection(self, websocket, user_id: Optional[int] = None): """注册 WebSocket 连接""" if user_id is not None: self._websocket_user_map[id(websocket)] = user_id self._user_connections[user_id].add(websocket) else: self._anonymous_connections.add(websocket) logger.info(f"WebSocket registered (user={user_id})") def unregister_connection(self, websocket, user_id: int): """注销用户 WebSocket 连接""" ws_id = id(websocket) if ws_id in self._websocket_user_map: del self._websocket_user_map[ws_id] self._user_connections[user_id].discard(websocket) # 从所有品种订阅中移除 for symbol in list(self._active_connections.keys()): self._active_connections[symbol].discard(websocket) logger.info(f"WebSocket unregistered (user={user_id})") def unregister_anonymous_connection(self, websocket): """注销匿名 WebSocket 连接""" self._anonymous_connections.discard(websocket) # 从所有品种订阅中移除 for symbol in list(self._active_connections.keys()): self._active_connections[symbol].discard(websocket) logger.info("Anonymous WebSocket unregistered") def get_total_connections(self) -> int: """获取总连接数""" total_user_connections = sum(len(conns) for conns in self._user_connections.values()) return total_user_connections + len(self._anonymous_connections) def get_user_connections(self, user_id: int) -> Set: """获取用户的连接集合""" return self._user_connections.get(user_id, set()) async def subscribe_symbol(self, symbol: str, websocket, user_id: Optional[int] = None) -> bool: """ 订阅品种行情 Args: symbol: 品种代码 websocket: WebSocket 连接对象 user_id: 用户 ID(用于限流) Returns: bool: 订阅是否成功 """ # 检查品种连接数限制 if symbol in self._active_connections: if len(self._active_connections[symbol]) >= MAX_CONNECTIONS_PER_SYMBOL: logger.warning(f"Symbol {symbol} reached max connections limit ({MAX_CONNECTIONS_PER_SYMBOL})") return False # 检查用户连接数限制 if user_id is not None: if len(self._user_connections[user_id]) >= MAX_CONNECTIONS_PER_USER: logger.warning(f"User {user_id} reached max connections limit ({MAX_CONNECTIONS_PER_USER})") return False if symbol not in self._active_connections: self._active_connections[symbol] = set() self._active_connections[symbol].add(websocket) logger.info(f"Client subscribed to {symbol}, total: {len(self._active_connections[symbol])}") return True async def unsubscribe_symbol(self, symbol: str, websocket, user_id: Optional[int] = None): """取消订阅品种行情""" if symbol in self._active_connections: self._active_connections[symbol].discard(websocket) if not self._active_connections[symbol]: del self._active_connections[symbol] logger.info(f"Client unsubscribed from {symbol}") async def broadcast_quote(self, symbol: str, quote: dict): """广播行情数据给所有订阅者""" if symbol in self._active_connections: message = json.dumps({ "type": "quote", "symbol": symbol, "data": quote, "timestamp": datetime.utcnow().isoformat() }) disconnected = set() for websocket in self._active_connections[symbol]: try: await websocket.send_text(message) except Exception as e: logger.error(f"Failed to send to websocket: {e}") disconnected.add(websocket) # 清理断开的连接 for ws in disconnected: self._active_connections[symbol].discard(ws) async def get_latest_quote(self, symbol: str) -> Optional[dict]: """从 Redis 获取最新行情""" if not self.redis: return None try: data = await self.redis.get(f"quote:{symbol}") if data: return json.loads(data) except Exception as e: logger.error(f"Failed to get quote from Redis: {e}") return None async def update_quote(self, symbol: str, quote: dict): """更新行情数据到 Redis""" if not self.redis: return try: quote["timestamp"] = datetime.utcnow().isoformat() await self.redis.set( f"quote:{symbol}", json.dumps(quote), ex=300 # 5 分钟过期 ) # 发布到 Redis Pub/Sub await self.redis.publish( f"quotes:{symbol}", json.dumps(quote) ) # 广播给 WebSocket 客户端 await self.broadcast_quote(symbol, quote) except Exception as e: logger.error(f"Failed to update quote: {e}") def get_active_subscriptions(self) -> Dict[str, int]: """获取活跃订阅统计""" return { symbol: len(connections) for symbol, connections in self._active_connections.items() } # 全局实时行情服务实例 realtime_service = RealtimeService()