import asyncio import json import time from typing import Dict, Set from fastapi import WebSocket, WebSocketDisconnect from loguru import logger from models import SubscribeTopic, WSMessage from adapters import DataAdapterFactory from config import settings class ConnectionManager: def __init__(self): self.active_connections: Dict[WebSocket, Set[SubscribeTopic]] = {} self._broadcast_task = None async def connect(self, websocket: WebSocket): await websocket.accept() self.active_connections[websocket] = set() logger.info(f"WebSocket connected, total connections: {len(self.active_connections)}") def disconnect(self, websocket: WebSocket): if websocket in self.active_connections: del self.active_connections[websocket] logger.info(f"WebSocket disconnected, total connections: {len(self.active_connections)}") def subscribe(self, websocket: WebSocket, topic: SubscribeTopic): if websocket in self.active_connections: self.active_connections[websocket].add(topic) logger.info(f"Subscribed to topic: {topic}") def unsubscribe(self, websocket: WebSocket, topic: SubscribeTopic): if websocket in self.active_connections: self.active_connections[websocket].discard(topic) logger.info(f"Unsubscribed from topic: {topic}") async def send_personal_message(self, message: dict, websocket: WebSocket): try: await websocket.send_json(message) except Exception as e: logger.error(f"Failed to send message: {e}") async def broadcast_to_topic(self, topic: SubscribeTopic, message: dict): disconnected = [] for websocket, topics in self.active_connections.items(): if topic in topics: try: await websocket.send_json(message) except Exception: disconnected.append(websocket) for ws in disconnected: self.disconnect(ws) async def start_broadcast_loop(self): if self._broadcast_task is None: self._broadcast_task = asyncio.create_task(self._broadcast_data()) async def stop_broadcast_loop(self): if self._broadcast_task: self._broadcast_task.cancel() try: await self._broadcast_task except asyncio.CancelledError: pass self._broadcast_task = None async def _broadcast_data(self): adapter = DataAdapterFactory.get_default_adapter() while True: try: await asyncio.sleep(5) if not self.active_connections: continue for websocket, topics in self.active_connections.items(): for topic in topics: try: data = await self._fetch_topic_data(adapter, topic) message = WSMessage( topic=topic, data=data, timestamp=int(time.time() * 1000), ) await websocket.send_json(message.model_dump()) except Exception as e: logger.error(f"Failed to fetch data for topic {topic}: {e}") except asyncio.CancelledError: break except Exception as e: logger.error(f"Broadcast loop error: {e}") await asyncio.sleep(1) async def _fetch_topic_data(self, adapter, topic: SubscribeTopic) -> dict: if topic == SubscribeTopic.MARKET_OVERVIEW: data = await adapter.fetch_market_overview() return data.model_dump() elif topic == SubscribeTopic.SENTIMENT: data = await adapter.fetch_sentiment() return data.model_dump() elif topic == SubscribeTopic.MOMENTUM: data = await adapter.fetch_momentum_data() return [item.model_dump() for item in data] elif topic == SubscribeTopic.NEWS: data = await adapter.fetch_hot_news(5) return [item.model_dump() for item in data] elif topic == SubscribeTopic.STOCK_PRICE: data = await adapter.fetch_hot_stocks(10) return [item.model_dump() for item in data] return {} manager = ConnectionManager() async def websocket_handler(websocket: WebSocket): await manager.connect(websocket) try: while True: data = await websocket.receive_text() try: message = json.loads(data) action = message.get("action") topic_str = message.get("topic") if topic_str: try: topic = SubscribeTopic(topic_str) except ValueError: await websocket.send_json({ "error": f"Invalid topic: {topic_str}" }) continue if action == "subscribe": manager.subscribe(websocket, topic) await websocket.send_json({ "success": True, "message": f"Subscribed to {topic}" }) elif action == "unsubscribe": manager.unsubscribe(websocket, topic) await websocket.send_json({ "success": True, "message": f"Unsubscribed from {topic}" }) except json.JSONDecodeError: await websocket.send_json({ "error": "Invalid JSON message" }) except WebSocketDisconnect: manager.disconnect(websocket) except Exception as e: logger.error(f"WebSocket error: {e}") manager.disconnect(websocket)