You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
162 lines
6.1 KiB
162 lines
6.1 KiB
|
2 months ago
|
import asyncio
|
||
|
|
import json
|
||
|
|
import time
|
||
|
|
from typing import Dict, Set
|
||
|
|
from fastapi import WebSocket, WebSocketDisconnect
|
||
|
|
from loguru import logger
|
||
|
|
|
||
|
|
from models import SubscribeTopic, WSMessage
|
||
|
|
from adapters import DataAdapterFactory
|
||
|
|
from config import settings
|
||
|
|
|
||
|
|
|
||
|
|
class ConnectionManager:
|
||
|
|
def __init__(self):
|
||
|
|
self.active_connections: Dict[WebSocket, Set[SubscribeTopic]] = {}
|
||
|
|
self._broadcast_task = None
|
||
|
|
|
||
|
|
async def connect(self, websocket: WebSocket):
|
||
|
|
await websocket.accept()
|
||
|
|
self.active_connections[websocket] = set()
|
||
|
|
logger.info(f"WebSocket connected, total connections: {len(self.active_connections)}")
|
||
|
|
|
||
|
|
def disconnect(self, websocket: WebSocket):
|
||
|
|
if websocket in self.active_connections:
|
||
|
|
del self.active_connections[websocket]
|
||
|
|
logger.info(f"WebSocket disconnected, total connections: {len(self.active_connections)}")
|
||
|
|
|
||
|
|
def subscribe(self, websocket: WebSocket, topic: SubscribeTopic):
|
||
|
|
if websocket in self.active_connections:
|
||
|
|
self.active_connections[websocket].add(topic)
|
||
|
|
logger.info(f"Subscribed to topic: {topic}")
|
||
|
|
|
||
|
|
def unsubscribe(self, websocket: WebSocket, topic: SubscribeTopic):
|
||
|
|
if websocket in self.active_connections:
|
||
|
|
self.active_connections[websocket].discard(topic)
|
||
|
|
logger.info(f"Unsubscribed from topic: {topic}")
|
||
|
|
|
||
|
|
async def send_personal_message(self, message: dict, websocket: WebSocket):
|
||
|
|
try:
|
||
|
|
await websocket.send_json(message)
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Failed to send message: {e}")
|
||
|
|
|
||
|
|
async def broadcast_to_topic(self, topic: SubscribeTopic, message: dict):
|
||
|
|
disconnected = []
|
||
|
|
for websocket, topics in self.active_connections.items():
|
||
|
|
if topic in topics:
|
||
|
|
try:
|
||
|
|
await websocket.send_json(message)
|
||
|
|
except Exception:
|
||
|
|
disconnected.append(websocket)
|
||
|
|
|
||
|
|
for ws in disconnected:
|
||
|
|
self.disconnect(ws)
|
||
|
|
|
||
|
|
async def start_broadcast_loop(self):
|
||
|
|
if self._broadcast_task is None:
|
||
|
|
self._broadcast_task = asyncio.create_task(self._broadcast_data())
|
||
|
|
|
||
|
|
async def stop_broadcast_loop(self):
|
||
|
|
if self._broadcast_task:
|
||
|
|
self._broadcast_task.cancel()
|
||
|
|
try:
|
||
|
|
await self._broadcast_task
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
pass
|
||
|
|
self._broadcast_task = None
|
||
|
|
|
||
|
|
async def _broadcast_data(self):
|
||
|
|
adapter = DataAdapterFactory.get_default_adapter()
|
||
|
|
|
||
|
|
while True:
|
||
|
|
try:
|
||
|
|
await asyncio.sleep(5)
|
||
|
|
|
||
|
|
if not self.active_connections:
|
||
|
|
continue
|
||
|
|
|
||
|
|
for websocket, topics in self.active_connections.items():
|
||
|
|
for topic in topics:
|
||
|
|
try:
|
||
|
|
data = await self._fetch_topic_data(adapter, topic)
|
||
|
|
message = WSMessage(
|
||
|
|
topic=topic,
|
||
|
|
data=data,
|
||
|
|
timestamp=int(time.time() * 1000),
|
||
|
|
)
|
||
|
|
await websocket.send_json(message.model_dump())
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Failed to fetch data for topic {topic}: {e}")
|
||
|
|
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
break
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Broadcast loop error: {e}")
|
||
|
|
await asyncio.sleep(1)
|
||
|
|
|
||
|
|
async def _fetch_topic_data(self, adapter, topic: SubscribeTopic) -> dict:
|
||
|
|
if topic == SubscribeTopic.MARKET_OVERVIEW:
|
||
|
|
data = await adapter.fetch_market_overview()
|
||
|
|
return data.model_dump()
|
||
|
|
elif topic == SubscribeTopic.SENTIMENT:
|
||
|
|
data = await adapter.fetch_sentiment()
|
||
|
|
return data.model_dump()
|
||
|
|
elif topic == SubscribeTopic.MOMENTUM:
|
||
|
|
data = await adapter.fetch_momentum_data()
|
||
|
|
return [item.model_dump() for item in data]
|
||
|
|
elif topic == SubscribeTopic.NEWS:
|
||
|
|
data = await adapter.fetch_hot_news(5)
|
||
|
|
return [item.model_dump() for item in data]
|
||
|
|
elif topic == SubscribeTopic.STOCK_PRICE:
|
||
|
|
data = await adapter.fetch_hot_stocks(10)
|
||
|
|
return [item.model_dump() for item in data]
|
||
|
|
return {}
|
||
|
|
|
||
|
|
|
||
|
|
manager = ConnectionManager()
|
||
|
|
|
||
|
|
|
||
|
|
async def websocket_handler(websocket: WebSocket):
|
||
|
|
await manager.connect(websocket)
|
||
|
|
|
||
|
|
try:
|
||
|
|
while True:
|
||
|
|
data = await websocket.receive_text()
|
||
|
|
try:
|
||
|
|
message = json.loads(data)
|
||
|
|
action = message.get("action")
|
||
|
|
topic_str = message.get("topic")
|
||
|
|
|
||
|
|
if topic_str:
|
||
|
|
try:
|
||
|
|
topic = SubscribeTopic(topic_str)
|
||
|
|
except ValueError:
|
||
|
|
await websocket.send_json({
|
||
|
|
"error": f"Invalid topic: {topic_str}"
|
||
|
|
})
|
||
|
|
continue
|
||
|
|
|
||
|
|
if action == "subscribe":
|
||
|
|
manager.subscribe(websocket, topic)
|
||
|
|
await websocket.send_json({
|
||
|
|
"success": True,
|
||
|
|
"message": f"Subscribed to {topic}"
|
||
|
|
})
|
||
|
|
elif action == "unsubscribe":
|
||
|
|
manager.unsubscribe(websocket, topic)
|
||
|
|
await websocket.send_json({
|
||
|
|
"success": True,
|
||
|
|
"message": f"Unsubscribed from {topic}"
|
||
|
|
})
|
||
|
|
|
||
|
|
except json.JSONDecodeError:
|
||
|
|
await websocket.send_json({
|
||
|
|
"error": "Invalid JSON message"
|
||
|
|
})
|
||
|
|
|
||
|
|
except WebSocketDisconnect:
|
||
|
|
manager.disconnect(websocket)
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"WebSocket error: {e}")
|
||
|
|
manager.disconnect(websocket)
|