You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

162 lines
6.1 KiB

import asyncio
import json
import time
from typing import Dict, Set
from fastapi import WebSocket, WebSocketDisconnect
from loguru import logger
from models import SubscribeTopic, WSMessage
from adapters import DataAdapterFactory
from config import settings
class ConnectionManager:
def __init__(self):
self.active_connections: Dict[WebSocket, Set[SubscribeTopic]] = {}
self._broadcast_task = None
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections[websocket] = set()
logger.info(f"WebSocket connected, total connections: {len(self.active_connections)}")
def disconnect(self, websocket: WebSocket):
if websocket in self.active_connections:
del self.active_connections[websocket]
logger.info(f"WebSocket disconnected, total connections: {len(self.active_connections)}")
def subscribe(self, websocket: WebSocket, topic: SubscribeTopic):
if websocket in self.active_connections:
self.active_connections[websocket].add(topic)
logger.info(f"Subscribed to topic: {topic}")
def unsubscribe(self, websocket: WebSocket, topic: SubscribeTopic):
if websocket in self.active_connections:
self.active_connections[websocket].discard(topic)
logger.info(f"Unsubscribed from topic: {topic}")
async def send_personal_message(self, message: dict, websocket: WebSocket):
try:
await websocket.send_json(message)
except Exception as e:
logger.error(f"Failed to send message: {e}")
async def broadcast_to_topic(self, topic: SubscribeTopic, message: dict):
disconnected = []
for websocket, topics in self.active_connections.items():
if topic in topics:
try:
await websocket.send_json(message)
except Exception:
disconnected.append(websocket)
for ws in disconnected:
self.disconnect(ws)
async def start_broadcast_loop(self):
if self._broadcast_task is None:
self._broadcast_task = asyncio.create_task(self._broadcast_data())
async def stop_broadcast_loop(self):
if self._broadcast_task:
self._broadcast_task.cancel()
try:
await self._broadcast_task
except asyncio.CancelledError:
pass
self._broadcast_task = None
async def _broadcast_data(self):
adapter = DataAdapterFactory.get_default_adapter()
while True:
try:
await asyncio.sleep(5)
if not self.active_connections:
continue
for websocket, topics in self.active_connections.items():
for topic in topics:
try:
data = await self._fetch_topic_data(adapter, topic)
message = WSMessage(
topic=topic,
data=data,
timestamp=int(time.time() * 1000),
)
await websocket.send_json(message.model_dump())
except Exception as e:
logger.error(f"Failed to fetch data for topic {topic}: {e}")
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Broadcast loop error: {e}")
await asyncio.sleep(1)
async def _fetch_topic_data(self, adapter, topic: SubscribeTopic) -> dict:
if topic == SubscribeTopic.MARKET_OVERVIEW:
data = await adapter.fetch_market_overview()
return data.model_dump()
elif topic == SubscribeTopic.SENTIMENT:
data = await adapter.fetch_sentiment()
return data.model_dump()
elif topic == SubscribeTopic.MOMENTUM:
data = await adapter.fetch_momentum_data()
return [item.model_dump() for item in data]
elif topic == SubscribeTopic.NEWS:
data = await adapter.fetch_hot_news(5)
return [item.model_dump() for item in data]
elif topic == SubscribeTopic.STOCK_PRICE:
data = await adapter.fetch_hot_stocks(10)
return [item.model_dump() for item in data]
return {}
manager = ConnectionManager()
async def websocket_handler(websocket: WebSocket):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
try:
message = json.loads(data)
action = message.get("action")
topic_str = message.get("topic")
if topic_str:
try:
topic = SubscribeTopic(topic_str)
except ValueError:
await websocket.send_json({
"error": f"Invalid topic: {topic_str}"
})
continue
if action == "subscribe":
manager.subscribe(websocket, topic)
await websocket.send_json({
"success": True,
"message": f"Subscribed to {topic}"
})
elif action == "unsubscribe":
manager.unsubscribe(websocket, topic)
await websocket.send_json({
"success": True,
"message": f"Unsubscribed from {topic}"
})
except json.JSONDecodeError:
await websocket.send_json({
"error": "Invalid JSON message"
})
except WebSocketDisconnect:
manager.disconnect(websocket)
except Exception as e:
logger.error(f"WebSocket error: {e}")
manager.disconnect(websocket)