You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

193 lines
6.9 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
实时行情服务
"""
import json
import logging
from datetime import datetime
from typing import Dict, Set, Optional
from collections import defaultdict
import redis.asyncio as redis
from app.config import settings
logger = logging.getLogger(__name__)
# WebSocket 连接限流配置
MAX_CONNECTIONS_PER_USER = 5 # 每个用户最大连接数
MAX_CONNECTIONS_PER_SYMBOL = 100 # 每个品种最大连接数
MAX_TOTAL_CONNECTIONS = 100 # 总连接数限制
class RealtimeService:
"""实时行情服务"""
# WebSocket 连接管理
_active_connections: Dict[str, Set] = {} # symbol -> set of websockets
_user_connections: Dict[int, Set] = defaultdict(set) # user_id -> set of websockets
_websocket_user_map: Dict[int, int] = {} # websocket id -> user_id
_anonymous_connections: Set = set() # 匿名连接集合
def __init__(self):
self.redis: Optional[redis.Redis] = None
async def connect_redis(self):
"""连接 Redis"""
self.redis = redis.from_url(settings.REDIS_URL, decode_responses=True)
logger.info("Redis connected for realtime service")
async def disconnect_redis(self):
"""断开 Redis 连接"""
if self.redis:
await self.redis.close()
def register_connection(self, websocket, user_id: Optional[int] = None):
"""注册 WebSocket 连接"""
if user_id is not None:
self._websocket_user_map[id(websocket)] = user_id
self._user_connections[user_id].add(websocket)
else:
self._anonymous_connections.add(websocket)
logger.info(f"WebSocket registered (user={user_id})")
def unregister_connection(self, websocket, user_id: int):
"""注销用户 WebSocket 连接"""
ws_id = id(websocket)
if ws_id in self._websocket_user_map:
del self._websocket_user_map[ws_id]
self._user_connections[user_id].discard(websocket)
# 从所有品种订阅中移除
for symbol in list(self._active_connections.keys()):
self._active_connections[symbol].discard(websocket)
logger.info(f"WebSocket unregistered (user={user_id})")
def unregister_anonymous_connection(self, websocket):
"""注销匿名 WebSocket 连接"""
self._anonymous_connections.discard(websocket)
# 从所有品种订阅中移除
for symbol in list(self._active_connections.keys()):
self._active_connections[symbol].discard(websocket)
logger.info("Anonymous WebSocket unregistered")
def get_total_connections(self) -> int:
"""获取总连接数"""
total_user_connections = sum(len(conns) for conns in self._user_connections.values())
return total_user_connections + len(self._anonymous_connections)
def get_user_connections(self, user_id: int) -> Set:
"""获取用户的连接集合"""
return self._user_connections.get(user_id, set())
async def subscribe_symbol(self, symbol: str, websocket, user_id: Optional[int] = None) -> bool:
"""
订阅品种行情
Args:
symbol: 品种代码
websocket: WebSocket 连接对象
user_id: 用户 ID用于限流
Returns:
bool: 订阅是否成功
"""
# 检查品种连接数限制
if symbol in self._active_connections:
if len(self._active_connections[symbol]) >= MAX_CONNECTIONS_PER_SYMBOL:
logger.warning(f"Symbol {symbol} reached max connections limit ({MAX_CONNECTIONS_PER_SYMBOL})")
return False
# 检查用户连接数限制
if user_id is not None:
if len(self._user_connections[user_id]) >= MAX_CONNECTIONS_PER_USER:
logger.warning(f"User {user_id} reached max connections limit ({MAX_CONNECTIONS_PER_USER})")
return False
if symbol not in self._active_connections:
self._active_connections[symbol] = set()
self._active_connections[symbol].add(websocket)
logger.info(f"Client subscribed to {symbol}, total: {len(self._active_connections[symbol])}")
return True
async def unsubscribe_symbol(self, symbol: str, websocket, user_id: Optional[int] = None):
"""取消订阅品种行情"""
if symbol in self._active_connections:
self._active_connections[symbol].discard(websocket)
if not self._active_connections[symbol]:
del self._active_connections[symbol]
logger.info(f"Client unsubscribed from {symbol}")
async def broadcast_quote(self, symbol: str, quote: dict):
"""广播行情数据给所有订阅者"""
if symbol in self._active_connections:
message = json.dumps({
"type": "quote",
"symbol": symbol,
"data": quote,
"timestamp": datetime.utcnow().isoformat()
})
disconnected = set()
for websocket in self._active_connections[symbol]:
try:
await websocket.send_text(message)
except Exception as e:
logger.error(f"Failed to send to websocket: {e}")
disconnected.add(websocket)
# 清理断开的连接
for ws in disconnected:
self._active_connections[symbol].discard(ws)
async def get_latest_quote(self, symbol: str) -> Optional[dict]:
"""从 Redis 获取最新行情"""
if not self.redis:
return None
try:
data = await self.redis.get(f"quote:{symbol}")
if data:
return json.loads(data)
except Exception as e:
logger.error(f"Failed to get quote from Redis: {e}")
return None
async def update_quote(self, symbol: str, quote: dict):
"""更新行情数据到 Redis"""
if not self.redis:
return
try:
quote["timestamp"] = datetime.utcnow().isoformat()
await self.redis.set(
f"quote:{symbol}",
json.dumps(quote),
ex=300 # 5 分钟过期
)
# 发布到 Redis Pub/Sub
await self.redis.publish(
f"quotes:{symbol}",
json.dumps(quote)
)
# 广播给 WebSocket 客户端
await self.broadcast_quote(symbol, quote)
except Exception as e:
logger.error(f"Failed to update quote: {e}")
def get_active_subscriptions(self) -> Dict[str, int]:
"""获取活跃订阅统计"""
return {
symbol: len(connections)
for symbol, connections in self._active_connections.items()
}
# 全局实时行情服务实例
realtime_service = RealtimeService()