""" 实时行情 API 路由 """ import json import logging from typing import Annotated, List, Optional from datetime import datetime from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect, Query, HTTPException from sqlalchemy.orm import Session from app.schemas import ( RealtimeQuoteItem, SubscribeRequest, UnsubscribeRequest, ResponseData ) from app.services.realtime_service import realtime_service from app.services.auth_service import decode_token from app.db.init_db import get_sqlite_db router = APIRouter() logger = logging.getLogger(__name__) # WebSocket 连接数限制配置 MAX_CONNECTIONS_PER_USER = 5 MAX_TOTAL_CONNECTIONS = 100 @router.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): """ WebSocket 连接 - 实时行情推送 连接后发送订阅消息: {"action": "subscribe", "symbols": ["IF2406", "IC2406"]} 取消订阅: {"action": "unsubscribe", "symbols": ["IF2406"]} 心跳: {"action": "ping"} """ await websocket.accept() subscribed_symbols = set() user_id: Optional[int] = None token: Optional[str] = None try: # 尝试从连接参数或第一个消息中获取 Token 进行认证 # 先等待认证消息 data = await websocket.receive_text() message = json.loads(data) # 检查是否是认证消息 if message.get("action") == "auth": token = message.get("token") if token: try: payload = decode_token(token) user_id = payload.get("sub") logger.info(f"WebSocket user {user_id} authenticated") except Exception as e: logger.warning(f"WebSocket auth failed: {e}") await websocket.send_text(json.dumps({ "type": "error", "message": "认证失败" })) await websocket.close(code=4001, reason="认证失败") return else: # 如果没有认证,使用匿名连接(限制更严格) logger.info("WebSocket connected without authentication") # 检查总连接数限制 total_connections = realtime_service.get_total_connections() if total_connections >= MAX_TOTAL_CONNECTIONS: await websocket.send_text(json.dumps({ "type": "error", "message": "服务器连接数已达上限" })) await websocket.close(code=4002, reason="连接数超限") return # 检查单用户连接数限制 if user_id: user_connections = realtime_service.get_user_connections(user_id) if len(user_connections) >= MAX_CONNECTIONS_PER_USER: await websocket.send_text(json.dumps({ "type": "error", "message": f"单用户最大连接数限制 ({MAX_CONNECTIONS_PER_USER})" })) await websocket.close(code=4003, reason="连接数超限") return # 注册连接 realtime_service.register_connection(websocket, user_id) # 处理第一条消息(如果不是认证消息) if message.get("action") != "auth": action = message.get("action") symbols = message.get("symbols", []) if action == "subscribe": for symbol in symbols: await realtime_service.subscribe_symbol(symbol, websocket) subscribed_symbols.add(symbol) await websocket.send_text(json.dumps({ "type": "subscribed", "symbols": list(subscribed_symbols) })) for symbol in symbols: quote = await realtime_service.get_latest_quote(symbol) if quote: await websocket.send_text(json.dumps({ "type": "quote", "symbol": symbol, "data": quote })) # 主消息循环 while True: data = await websocket.receive_text() message = json.loads(data) action = message.get("action") symbols = message.get("symbols", []) if action == "subscribe": for symbol in symbols: await realtime_service.subscribe_symbol(symbol, websocket) subscribed_symbols.add(symbol) await websocket.send_text(json.dumps({ "type": "subscribed", "symbols": list(subscribed_symbols) })) for symbol in symbols: quote = await realtime_service.get_latest_quote(symbol) if quote: await websocket.send_text(json.dumps({ "type": "quote", "symbol": symbol, "data": quote })) elif action == "unsubscribe": for symbol in symbols: await realtime_service.unsubscribe_symbol(symbol, websocket) subscribed_symbols.discard(symbol) await websocket.send_text(json.dumps({ "type": "unsubscribed", "symbols": symbols })) elif action == "ping": await websocket.send_text(json.dumps({ "type": "pong", "timestamp": datetime.utcnow().isoformat() })) elif action == "auth": # 重新认证 token = message.get("token") if token: try: payload = decode_token(token) user_id = payload.get("sub") logger.info(f"WebSocket user {user_id} re-authenticated") await websocket.send_text(json.dumps({ "type": "authenticated", "user_id": user_id })) except Exception as e: await websocket.send_text(json.dumps({ "type": "error", "message": f"认证失败:{str(e)}" })) else: await websocket.send_text(json.dumps({ "type": "error", "message": f"Unknown action: {action}" })) except WebSocketDisconnect: logger.info(f"WebSocket client disconnected (user={user_id})") except Exception as e: logger.error(f"WebSocket error: {e}") finally: # 清理订阅和连接 for symbol in subscribed_symbols: await realtime_service.unsubscribe_symbol(symbol, websocket) if user_id: realtime_service.unregister_connection(websocket, user_id) else: realtime_service.unregister_anonymous_connection(websocket) @router.get("/quote", response_model=ResponseData) async def get_latest_quote( symbol: Annotated[str, Query(description="品种代码")] ): """获取最新行情""" quote = await realtime_service.get_latest_quote(symbol) if not quote: return ResponseData( code=404, message="Quote not found", data=None ) return ResponseData( code=0, message="success", data=quote ) @router.get("/quotes", response_model=ResponseData) async def get_multiple_quotes( symbols: Annotated[str, Query(description="品种代码列表,逗号分隔")] ): """获取多个品种的最新行情""" symbol_list = [s.strip() for s in symbols.split(",")] quotes = {} for symbol in symbol_list: quote = await realtime_service.get_latest_quote(symbol) if quote: quotes[symbol] = quote return ResponseData( code=0, message="success", data={"quotes": quotes, "count": len(quotes)} ) @router.get("/subscriptions", response_model=ResponseData) async def get_active_subscriptions(): """获取活跃订阅统计""" stats = realtime_service.get_active_subscriptions() return ResponseData( code=0, message="success", data={ "subscriptions": stats, "total_symbols": len(stats), "total_connections": sum(stats.values()) } )