You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
257 lines
8.7 KiB
257 lines
8.7 KiB
|
2 months ago
|
"""
|
||
|
|
实时行情 API 路由
|
||
|
|
"""
|
||
|
|
import json
|
||
|
|
import logging
|
||
|
|
from typing import Annotated, List, Optional
|
||
|
|
from datetime import datetime
|
||
|
|
|
||
|
|
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect, Query, HTTPException
|
||
|
|
from sqlalchemy.orm import Session
|
||
|
|
|
||
|
|
from app.schemas import (
|
||
|
|
RealtimeQuoteItem,
|
||
|
|
SubscribeRequest,
|
||
|
|
UnsubscribeRequest,
|
||
|
|
ResponseData
|
||
|
|
)
|
||
|
|
from app.services.realtime_service import realtime_service
|
||
|
|
from app.services.auth_service import decode_token
|
||
|
|
from app.db.init_db import get_sqlite_db
|
||
|
|
|
||
|
|
router = APIRouter()
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
# WebSocket 连接数限制配置
|
||
|
|
MAX_CONNECTIONS_PER_USER = 5
|
||
|
|
MAX_TOTAL_CONNECTIONS = 100
|
||
|
|
|
||
|
|
|
||
|
|
@router.websocket("/ws")
|
||
|
|
async def websocket_endpoint(websocket: WebSocket):
|
||
|
|
"""
|
||
|
|
WebSocket 连接 - 实时行情推送
|
||
|
|
|
||
|
|
连接后发送订阅消息:
|
||
|
|
{"action": "subscribe", "symbols": ["IF2406", "IC2406"]}
|
||
|
|
|
||
|
|
取消订阅:
|
||
|
|
{"action": "unsubscribe", "symbols": ["IF2406"]}
|
||
|
|
|
||
|
|
心跳:
|
||
|
|
{"action": "ping"}
|
||
|
|
"""
|
||
|
|
await websocket.accept()
|
||
|
|
subscribed_symbols = set()
|
||
|
|
user_id: Optional[int] = None
|
||
|
|
token: Optional[str] = None
|
||
|
|
|
||
|
|
try:
|
||
|
|
# 尝试从连接参数或第一个消息中获取 Token 进行认证
|
||
|
|
# 先等待认证消息
|
||
|
|
data = await websocket.receive_text()
|
||
|
|
message = json.loads(data)
|
||
|
|
|
||
|
|
# 检查是否是认证消息
|
||
|
|
if message.get("action") == "auth":
|
||
|
|
token = message.get("token")
|
||
|
|
if token:
|
||
|
|
try:
|
||
|
|
payload = decode_token(token)
|
||
|
|
user_id = payload.get("sub")
|
||
|
|
logger.info(f"WebSocket user {user_id} authenticated")
|
||
|
|
except Exception as e:
|
||
|
|
logger.warning(f"WebSocket auth failed: {e}")
|
||
|
|
await websocket.send_text(json.dumps({
|
||
|
|
"type": "error",
|
||
|
|
"message": "认证失败"
|
||
|
|
}))
|
||
|
|
await websocket.close(code=4001, reason="认证失败")
|
||
|
|
return
|
||
|
|
else:
|
||
|
|
# 如果没有认证,使用匿名连接(限制更严格)
|
||
|
|
logger.info("WebSocket connected without authentication")
|
||
|
|
|
||
|
|
# 检查总连接数限制
|
||
|
|
total_connections = realtime_service.get_total_connections()
|
||
|
|
if total_connections >= MAX_TOTAL_CONNECTIONS:
|
||
|
|
await websocket.send_text(json.dumps({
|
||
|
|
"type": "error",
|
||
|
|
"message": "服务器连接数已达上限"
|
||
|
|
}))
|
||
|
|
await websocket.close(code=4002, reason="连接数超限")
|
||
|
|
return
|
||
|
|
|
||
|
|
# 检查单用户连接数限制
|
||
|
|
if user_id:
|
||
|
|
user_connections = realtime_service.get_user_connections(user_id)
|
||
|
|
if len(user_connections) >= MAX_CONNECTIONS_PER_USER:
|
||
|
|
await websocket.send_text(json.dumps({
|
||
|
|
"type": "error",
|
||
|
|
"message": f"单用户最大连接数限制 ({MAX_CONNECTIONS_PER_USER})"
|
||
|
|
}))
|
||
|
|
await websocket.close(code=4003, reason="连接数超限")
|
||
|
|
return
|
||
|
|
|
||
|
|
# 注册连接
|
||
|
|
realtime_service.register_connection(websocket, user_id)
|
||
|
|
|
||
|
|
# 处理第一条消息(如果不是认证消息)
|
||
|
|
if message.get("action") != "auth":
|
||
|
|
action = message.get("action")
|
||
|
|
symbols = message.get("symbols", [])
|
||
|
|
|
||
|
|
if action == "subscribe":
|
||
|
|
for symbol in symbols:
|
||
|
|
await realtime_service.subscribe_symbol(symbol, websocket)
|
||
|
|
subscribed_symbols.add(symbol)
|
||
|
|
|
||
|
|
await websocket.send_text(json.dumps({
|
||
|
|
"type": "subscribed",
|
||
|
|
"symbols": list(subscribed_symbols)
|
||
|
|
}))
|
||
|
|
|
||
|
|
for symbol in symbols:
|
||
|
|
quote = await realtime_service.get_latest_quote(symbol)
|
||
|
|
if quote:
|
||
|
|
await websocket.send_text(json.dumps({
|
||
|
|
"type": "quote",
|
||
|
|
"symbol": symbol,
|
||
|
|
"data": quote
|
||
|
|
}))
|
||
|
|
|
||
|
|
# 主消息循环
|
||
|
|
while True:
|
||
|
|
data = await websocket.receive_text()
|
||
|
|
message = json.loads(data)
|
||
|
|
|
||
|
|
action = message.get("action")
|
||
|
|
symbols = message.get("symbols", [])
|
||
|
|
|
||
|
|
if action == "subscribe":
|
||
|
|
for symbol in symbols:
|
||
|
|
await realtime_service.subscribe_symbol(symbol, websocket)
|
||
|
|
subscribed_symbols.add(symbol)
|
||
|
|
|
||
|
|
await websocket.send_text(json.dumps({
|
||
|
|
"type": "subscribed",
|
||
|
|
"symbols": list(subscribed_symbols)
|
||
|
|
}))
|
||
|
|
|
||
|
|
for symbol in symbols:
|
||
|
|
quote = await realtime_service.get_latest_quote(symbol)
|
||
|
|
if quote:
|
||
|
|
await websocket.send_text(json.dumps({
|
||
|
|
"type": "quote",
|
||
|
|
"symbol": symbol,
|
||
|
|
"data": quote
|
||
|
|
}))
|
||
|
|
|
||
|
|
elif action == "unsubscribe":
|
||
|
|
for symbol in symbols:
|
||
|
|
await realtime_service.unsubscribe_symbol(symbol, websocket)
|
||
|
|
subscribed_symbols.discard(symbol)
|
||
|
|
|
||
|
|
await websocket.send_text(json.dumps({
|
||
|
|
"type": "unsubscribed",
|
||
|
|
"symbols": symbols
|
||
|
|
}))
|
||
|
|
|
||
|
|
elif action == "ping":
|
||
|
|
await websocket.send_text(json.dumps({
|
||
|
|
"type": "pong",
|
||
|
|
"timestamp": datetime.utcnow().isoformat()
|
||
|
|
}))
|
||
|
|
|
||
|
|
elif action == "auth":
|
||
|
|
# 重新认证
|
||
|
|
token = message.get("token")
|
||
|
|
if token:
|
||
|
|
try:
|
||
|
|
payload = decode_token(token)
|
||
|
|
user_id = payload.get("sub")
|
||
|
|
logger.info(f"WebSocket user {user_id} re-authenticated")
|
||
|
|
await websocket.send_text(json.dumps({
|
||
|
|
"type": "authenticated",
|
||
|
|
"user_id": user_id
|
||
|
|
}))
|
||
|
|
except Exception as e:
|
||
|
|
await websocket.send_text(json.dumps({
|
||
|
|
"type": "error",
|
||
|
|
"message": f"认证失败:{str(e)}"
|
||
|
|
}))
|
||
|
|
|
||
|
|
else:
|
||
|
|
await websocket.send_text(json.dumps({
|
||
|
|
"type": "error",
|
||
|
|
"message": f"Unknown action: {action}"
|
||
|
|
}))
|
||
|
|
|
||
|
|
except WebSocketDisconnect:
|
||
|
|
logger.info(f"WebSocket client disconnected (user={user_id})")
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"WebSocket error: {e}")
|
||
|
|
finally:
|
||
|
|
# 清理订阅和连接
|
||
|
|
for symbol in subscribed_symbols:
|
||
|
|
await realtime_service.unsubscribe_symbol(symbol, websocket)
|
||
|
|
if user_id:
|
||
|
|
realtime_service.unregister_connection(websocket, user_id)
|
||
|
|
else:
|
||
|
|
realtime_service.unregister_anonymous_connection(websocket)
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/quote", response_model=ResponseData)
|
||
|
|
async def get_latest_quote(
|
||
|
|
symbol: Annotated[str, Query(description="品种代码")]
|
||
|
|
):
|
||
|
|
"""获取最新行情"""
|
||
|
|
quote = await realtime_service.get_latest_quote(symbol)
|
||
|
|
if not quote:
|
||
|
|
return ResponseData(
|
||
|
|
code=404,
|
||
|
|
message="Quote not found",
|
||
|
|
data=None
|
||
|
|
)
|
||
|
|
|
||
|
|
return ResponseData(
|
||
|
|
code=0,
|
||
|
|
message="success",
|
||
|
|
data=quote
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/quotes", response_model=ResponseData)
|
||
|
|
async def get_multiple_quotes(
|
||
|
|
symbols: Annotated[str, Query(description="品种代码列表,逗号分隔")]
|
||
|
|
):
|
||
|
|
"""获取多个品种的最新行情"""
|
||
|
|
symbol_list = [s.strip() for s in symbols.split(",")]
|
||
|
|
quotes = {}
|
||
|
|
|
||
|
|
for symbol in symbol_list:
|
||
|
|
quote = await realtime_service.get_latest_quote(symbol)
|
||
|
|
if quote:
|
||
|
|
quotes[symbol] = quote
|
||
|
|
|
||
|
|
return ResponseData(
|
||
|
|
code=0,
|
||
|
|
message="success",
|
||
|
|
data={"quotes": quotes, "count": len(quotes)}
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/subscriptions", response_model=ResponseData)
|
||
|
|
async def get_active_subscriptions():
|
||
|
|
"""获取活跃订阅统计"""
|
||
|
|
stats = realtime_service.get_active_subscriptions()
|
||
|
|
return ResponseData(
|
||
|
|
code=0,
|
||
|
|
message="success",
|
||
|
|
data={
|
||
|
|
"subscriptions": stats,
|
||
|
|
"total_symbols": len(stats),
|
||
|
|
"total_connections": sum(stats.values())
|
||
|
|
}
|
||
|
|
)
|