|
|
|
|
|
"""
|
|
|
|
|
|
实时数据服务
|
|
|
|
|
|
"""
|
|
|
|
|
|
import asyncio
|
|
|
|
|
|
import logging
|
|
|
|
|
|
from typing import Dict, Set, List, Callable, Optional
|
|
|
|
|
|
from datetime import datetime, timedelta
|
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
from fastapi import WebSocket
|
|
|
|
|
|
|
|
|
|
|
|
from app.models.realtime import RealtimeSnapshot
|
|
|
|
|
|
from app.services.base_data_service import BaseDataService
|
|
|
|
|
|
from app.config import settings
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RealtimeManager:
|
|
|
|
|
|
"""实时数据管理器(单例)"""
|
|
|
|
|
|
_instance = None
|
|
|
|
|
|
|
|
|
|
|
|
def __new__(cls):
|
|
|
|
|
|
if cls._instance is None:
|
|
|
|
|
|
cls._instance = super().__new__(cls)
|
|
|
|
|
|
cls._instance._initialized = False
|
|
|
|
|
|
return cls._instance
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
if self._initialized:
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
self.subscribers: Dict[str, Set[WebSocket]] = {}
|
|
|
|
|
|
self.code_callbacks: Dict[str, List[Callable]] = {}
|
|
|
|
|
|
self._adapter = None
|
|
|
|
|
|
self._initialized = True
|
|
|
|
|
|
self._lock = asyncio.Lock()
|
|
|
|
|
|
|
|
|
|
|
|
async def subscribe(self, websocket: WebSocket, codes: List[str]):
|
|
|
|
|
|
"""
|
|
|
|
|
|
客户端订阅实时数据
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
websocket: WebSocket连接
|
|
|
|
|
|
codes: 代码列表
|
|
|
|
|
|
"""
|
|
|
|
|
|
await websocket.accept()
|
|
|
|
|
|
|
|
|
|
|
|
async with self._lock:
|
|
|
|
|
|
for code in codes:
|
|
|
|
|
|
if code not in self.subscribers:
|
|
|
|
|
|
self.subscribers[code] = set()
|
|
|
|
|
|
# 启动SDK订阅
|
|
|
|
|
|
await self._start_sdk_subscription(code)
|
|
|
|
|
|
|
|
|
|
|
|
self.subscribers[code].add(websocket)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"WebSocket {id(websocket)} 订阅了: {codes}")
|
|
|
|
|
|
|
|
|
|
|
|
async def unsubscribe(self, websocket: WebSocket, codes: List[str] = None):
|
|
|
|
|
|
"""
|
|
|
|
|
|
取消订阅
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
websocket: WebSocket连接
|
|
|
|
|
|
codes: 代码列表,None表示取消所有
|
|
|
|
|
|
"""
|
|
|
|
|
|
async with self._lock:
|
|
|
|
|
|
codes_to_remove = codes if codes else list(self.subscribers.keys())
|
|
|
|
|
|
|
|
|
|
|
|
for code in codes_to_remove:
|
|
|
|
|
|
if code in self.subscribers:
|
|
|
|
|
|
self.subscribers[code].discard(websocket)
|
|
|
|
|
|
|
|
|
|
|
|
# 如果没有订阅者了,取消SDK订阅
|
|
|
|
|
|
if not self.subscribers[code]:
|
|
|
|
|
|
del self.subscribers[code]
|
|
|
|
|
|
await self._stop_sdk_subscription(code)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"WebSocket {id(websocket)} 取消订阅")
|
|
|
|
|
|
|
|
|
|
|
|
async def _start_sdk_subscription(self, code: str):
|
|
|
|
|
|
"""启动SDK订阅"""
|
|
|
|
|
|
# 这里需要实现实际的SDK订阅逻辑
|
|
|
|
|
|
# 由于SDK的实时订阅是同步的回调,需要在后台线程中运行
|
|
|
|
|
|
logger.info(f"开始SDK订阅: {code}")
|
|
|
|
|
|
|
|
|
|
|
|
async def _stop_sdk_subscription(self, code: str):
|
|
|
|
|
|
"""停止SDK订阅"""
|
|
|
|
|
|
logger.info(f"停止SDK订阅: {code}")
|
|
|
|
|
|
|
|
|
|
|
|
def on_sdk_data(self, code: str, data: dict):
|
|
|
|
|
|
"""
|
|
|
|
|
|
SDK数据回调
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
code: 代码
|
|
|
|
|
|
data: 数据字典
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 保存到数据库
|
|
|
|
|
|
# self._save_snapshot(code, data)
|
|
|
|
|
|
|
|
|
|
|
|
# 推送给所有订阅者
|
|
|
|
|
|
if code in self.subscribers:
|
|
|
|
|
|
message = {
|
|
|
|
|
|
"type": "snapshot",
|
|
|
|
|
|
"code": code,
|
|
|
|
|
|
"data": data,
|
|
|
|
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 异步推送
|
|
|
|
|
|
for ws in self.subscribers[code]:
|
|
|
|
|
|
asyncio.create_task(self._send_to_ws(ws, message))
|
|
|
|
|
|
|
|
|
|
|
|
async def _send_to_ws(self, websocket: WebSocket, message: dict):
|
|
|
|
|
|
"""发送消息到WebSocket"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
await websocket.send_json(message)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"发送WebSocket消息失败: {str(e)}")
|
|
|
|
|
|
# 从订阅列表中移除
|
|
|
|
|
|
await self.unsubscribe(websocket)
|
|
|
|
|
|
|
|
|
|
|
|
def _save_snapshot(self, db: Session, code: str, data: dict):
|
|
|
|
|
|
"""保存快照到数据库"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
expires_at = datetime.utcnow() + timedelta(days=settings.CACHE_AUTO_CLEANUP_DAYS)
|
|
|
|
|
|
|
|
|
|
|
|
snapshot = RealtimeSnapshot(
|
|
|
|
|
|
code=code,
|
|
|
|
|
|
security_type=data.get("security_type", "stock"),
|
|
|
|
|
|
trade_time=datetime.fromisoformat(data.get("trade_time", datetime.utcnow().isoformat())),
|
|
|
|
|
|
pre_close=data.get("pre_close"),
|
|
|
|
|
|
last=data.get("last"),
|
|
|
|
|
|
open=data.get("open"),
|
|
|
|
|
|
high=data.get("high"),
|
|
|
|
|
|
low=data.get("low"),
|
|
|
|
|
|
close=data.get("close"),
|
|
|
|
|
|
volume=data.get("volume"),
|
|
|
|
|
|
amount=data.get("amount"),
|
|
|
|
|
|
expires_at=expires_at
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
db.add(snapshot)
|
|
|
|
|
|
db.commit()
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"保存快照失败: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 全局实时数据管理器实例
|
|
|
|
|
|
realtime_manager = RealtimeManager()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RealtimeService:
|
|
|
|
|
|
"""实时数据服务"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, db: Session):
|
|
|
|
|
|
self.db = db
|
|
|
|
|
|
self.base_service = BaseDataService(db)
|
|
|
|
|
|
self.manager = realtime_manager
|
|
|
|
|
|
|
|
|
|
|
|
def get_latest_snapshot(self, codes: List[str]) -> Dict[str, dict]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
获取最新快照数据
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
codes: 代码列表
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
快照数据字典
|
|
|
|
|
|
"""
|
|
|
|
|
|
result = {}
|
|
|
|
|
|
|
|
|
|
|
|
for code in codes:
|
|
|
|
|
|
# 查询最新的快照
|
|
|
|
|
|
snapshot = self.db.query(RealtimeSnapshot).filter(
|
|
|
|
|
|
RealtimeSnapshot.code == code
|
|
|
|
|
|
).order_by(RealtimeSnapshot.trade_time.desc()).first()
|
|
|
|
|
|
|
|
|
|
|
|
if snapshot:
|
|
|
|
|
|
result[code] = {
|
|
|
|
|
|
"code": snapshot.code,
|
|
|
|
|
|
"trade_time": snapshot.trade_time.isoformat(),
|
|
|
|
|
|
"pre_close": float(snapshot.pre_close) if snapshot.pre_close else None,
|
|
|
|
|
|
"last": float(snapshot.last) if snapshot.last else None,
|
|
|
|
|
|
"open": float(snapshot.open) if snapshot.open else None,
|
|
|
|
|
|
"high": float(snapshot.high) if snapshot.high else None,
|
|
|
|
|
|
"low": float(snapshot.low) if snapshot.low else None,
|
|
|
|
|
|
"close": float(snapshot.close) if snapshot.close else None,
|
|
|
|
|
|
"volume": int(snapshot.volume) if snapshot.volume else None,
|
|
|
|
|
|
"amount": float(snapshot.amount) if snapshot.amount else None
|
|
|
|
|
|
}
|
|
|
|
|
|
else:
|
|
|
|
|
|
result[code] = None
|
|
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
async def subscribe_websocket(self, websocket: WebSocket, codes: List[str]):
|
|
|
|
|
|
"""订阅WebSocket"""
|
|
|
|
|
|
await self.manager.subscribe(websocket, codes)
|
|
|
|
|
|
|
|
|
|
|
|
async def unsubscribe_websocket(self, websocket: WebSocket, codes: List[str] = None):
|
|
|
|
|
|
"""取消WebSocket订阅"""
|
|
|
|
|
|
await self.manager.unsubscribe(websocket, codes)
|