|
|
|
|
|
"""
|
|
|
|
|
|
实时行情服务
|
|
|
|
|
|
"""
|
|
|
|
|
|
import json
|
|
|
|
|
|
import logging
|
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
from typing import Dict, Set, Optional
|
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
|
|
|
|
|
|
|
import redis.asyncio as redis
|
|
|
|
|
|
|
|
|
|
|
|
from app.config import settings
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
# WebSocket 连接限流配置
|
|
|
|
|
|
MAX_CONNECTIONS_PER_USER = 5 # 每个用户最大连接数
|
|
|
|
|
|
MAX_CONNECTIONS_PER_SYMBOL = 100 # 每个品种最大连接数
|
|
|
|
|
|
MAX_TOTAL_CONNECTIONS = 100 # 总连接数限制
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RealtimeService:
|
|
|
|
|
|
"""实时行情服务"""
|
|
|
|
|
|
|
|
|
|
|
|
# WebSocket 连接管理
|
|
|
|
|
|
_active_connections: Dict[str, Set] = {} # symbol -> set of websockets
|
|
|
|
|
|
_user_connections: Dict[int, Set] = defaultdict(set) # user_id -> set of websockets
|
|
|
|
|
|
_websocket_user_map: Dict[int, int] = {} # websocket id -> user_id
|
|
|
|
|
|
_anonymous_connections: Set = set() # 匿名连接集合
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
self.redis: Optional[redis.Redis] = None
|
|
|
|
|
|
|
|
|
|
|
|
async def connect_redis(self):
|
|
|
|
|
|
"""连接 Redis"""
|
|
|
|
|
|
self.redis = redis.from_url(settings.REDIS_URL, decode_responses=True)
|
|
|
|
|
|
logger.info("Redis connected for realtime service")
|
|
|
|
|
|
|
|
|
|
|
|
async def disconnect_redis(self):
|
|
|
|
|
|
"""断开 Redis 连接"""
|
|
|
|
|
|
if self.redis:
|
|
|
|
|
|
await self.redis.close()
|
|
|
|
|
|
|
|
|
|
|
|
def register_connection(self, websocket, user_id: Optional[int] = None):
|
|
|
|
|
|
"""注册 WebSocket 连接"""
|
|
|
|
|
|
if user_id is not None:
|
|
|
|
|
|
self._websocket_user_map[id(websocket)] = user_id
|
|
|
|
|
|
self._user_connections[user_id].add(websocket)
|
|
|
|
|
|
else:
|
|
|
|
|
|
self._anonymous_connections.add(websocket)
|
|
|
|
|
|
logger.info(f"WebSocket registered (user={user_id})")
|
|
|
|
|
|
|
|
|
|
|
|
def unregister_connection(self, websocket, user_id: int):
|
|
|
|
|
|
"""注销用户 WebSocket 连接"""
|
|
|
|
|
|
ws_id = id(websocket)
|
|
|
|
|
|
if ws_id in self._websocket_user_map:
|
|
|
|
|
|
del self._websocket_user_map[ws_id]
|
|
|
|
|
|
self._user_connections[user_id].discard(websocket)
|
|
|
|
|
|
# 从所有品种订阅中移除
|
|
|
|
|
|
for symbol in list(self._active_connections.keys()):
|
|
|
|
|
|
self._active_connections[symbol].discard(websocket)
|
|
|
|
|
|
logger.info(f"WebSocket unregistered (user={user_id})")
|
|
|
|
|
|
|
|
|
|
|
|
def unregister_anonymous_connection(self, websocket):
|
|
|
|
|
|
"""注销匿名 WebSocket 连接"""
|
|
|
|
|
|
self._anonymous_connections.discard(websocket)
|
|
|
|
|
|
# 从所有品种订阅中移除
|
|
|
|
|
|
for symbol in list(self._active_connections.keys()):
|
|
|
|
|
|
self._active_connections[symbol].discard(websocket)
|
|
|
|
|
|
logger.info("Anonymous WebSocket unregistered")
|
|
|
|
|
|
|
|
|
|
|
|
def get_total_connections(self) -> int:
|
|
|
|
|
|
"""获取总连接数"""
|
|
|
|
|
|
total_user_connections = sum(len(conns) for conns in self._user_connections.values())
|
|
|
|
|
|
return total_user_connections + len(self._anonymous_connections)
|
|
|
|
|
|
|
|
|
|
|
|
def get_user_connections(self, user_id: int) -> Set:
|
|
|
|
|
|
"""获取用户的连接集合"""
|
|
|
|
|
|
return self._user_connections.get(user_id, set())
|
|
|
|
|
|
|
|
|
|
|
|
async def subscribe_symbol(self, symbol: str, websocket, user_id: Optional[int] = None) -> bool:
|
|
|
|
|
|
"""
|
|
|
|
|
|
订阅品种行情
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
symbol: 品种代码
|
|
|
|
|
|
websocket: WebSocket 连接对象
|
|
|
|
|
|
user_id: 用户 ID(用于限流)
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
bool: 订阅是否成功
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 检查品种连接数限制
|
|
|
|
|
|
if symbol in self._active_connections:
|
|
|
|
|
|
if len(self._active_connections[symbol]) >= MAX_CONNECTIONS_PER_SYMBOL:
|
|
|
|
|
|
logger.warning(f"Symbol {symbol} reached max connections limit ({MAX_CONNECTIONS_PER_SYMBOL})")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
# 检查用户连接数限制
|
|
|
|
|
|
if user_id is not None:
|
|
|
|
|
|
if len(self._user_connections[user_id]) >= MAX_CONNECTIONS_PER_USER:
|
|
|
|
|
|
logger.warning(f"User {user_id} reached max connections limit ({MAX_CONNECTIONS_PER_USER})")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
if symbol not in self._active_connections:
|
|
|
|
|
|
self._active_connections[symbol] = set()
|
|
|
|
|
|
|
|
|
|
|
|
self._active_connections[symbol].add(websocket)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"Client subscribed to {symbol}, total: {len(self._active_connections[symbol])}")
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
async def unsubscribe_symbol(self, symbol: str, websocket, user_id: Optional[int] = None):
|
|
|
|
|
|
"""取消订阅品种行情"""
|
|
|
|
|
|
if symbol in self._active_connections:
|
|
|
|
|
|
self._active_connections[symbol].discard(websocket)
|
|
|
|
|
|
if not self._active_connections[symbol]:
|
|
|
|
|
|
del self._active_connections[symbol]
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"Client unsubscribed from {symbol}")
|
|
|
|
|
|
|
|
|
|
|
|
async def broadcast_quote(self, symbol: str, quote: dict):
|
|
|
|
|
|
"""广播行情数据给所有订阅者"""
|
|
|
|
|
|
if symbol in self._active_connections:
|
|
|
|
|
|
message = json.dumps({
|
|
|
|
|
|
"type": "quote",
|
|
|
|
|
|
"symbol": symbol,
|
|
|
|
|
|
"data": quote,
|
|
|
|
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
disconnected = set()
|
|
|
|
|
|
for websocket in self._active_connections[symbol]:
|
|
|
|
|
|
try:
|
|
|
|
|
|
await websocket.send_text(message)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"Failed to send to websocket: {e}")
|
|
|
|
|
|
disconnected.add(websocket)
|
|
|
|
|
|
|
|
|
|
|
|
# 清理断开的连接
|
|
|
|
|
|
for ws in disconnected:
|
|
|
|
|
|
self._active_connections[symbol].discard(ws)
|
|
|
|
|
|
|
|
|
|
|
|
async def get_latest_quote(self, symbol: str) -> Optional[dict]:
|
|
|
|
|
|
"""从 Redis 获取最新行情"""
|
|
|
|
|
|
if not self.redis:
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
data = await self.redis.get(f"quote:{symbol}")
|
|
|
|
|
|
if data:
|
|
|
|
|
|
return json.loads(data)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"Failed to get quote from Redis: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
async def update_quote(self, symbol: str, quote: dict):
|
|
|
|
|
|
"""更新行情数据到 Redis"""
|
|
|
|
|
|
if not self.redis:
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
quote["timestamp"] = datetime.utcnow().isoformat()
|
|
|
|
|
|
await self.redis.set(
|
|
|
|
|
|
f"quote:{symbol}",
|
|
|
|
|
|
json.dumps(quote),
|
|
|
|
|
|
ex=300 # 5 分钟过期
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 发布到 Redis Pub/Sub
|
|
|
|
|
|
await self.redis.publish(
|
|
|
|
|
|
f"quotes:{symbol}",
|
|
|
|
|
|
json.dumps(quote)
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 广播给 WebSocket 客户端
|
|
|
|
|
|
await self.broadcast_quote(symbol, quote)
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"Failed to update quote: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
def get_active_subscriptions(self) -> Dict[str, int]:
|
|
|
|
|
|
"""获取活跃订阅统计"""
|
|
|
|
|
|
return {
|
|
|
|
|
|
symbol: len(connections)
|
|
|
|
|
|
for symbol, connections in self._active_connections.items()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 全局实时行情服务实例
|
|
|
|
|
|
realtime_service = RealtimeService()
|