# backend/app/services/push_service.py """ 推送服务 从 Redis Pub/Sub 接收数据,推送给 WebSocket 客户端 """ import asyncio import json import redis.asyncio as redis from typing import Dict, Optional from datetime import datetime from app.websocket.connection_manager import connection_manager from app.config import settings import logging logger = logging.getLogger(__name__) class PushService: """ 推送服务 功能: - 从 Redis Pub/Sub 接收行情更新 - 推送给订阅的 WebSocket 客户端 - 支持行情推送、K 线推送 性能优化: - 异步处理 - 批量推送 - 消息队列缓冲 """ def __init__(self): self.redis: Optional[redis.Redis] = None self.pubsub: Optional[redis.client.PubSub] = None self.running = False # 推送统计 self.total_pushes = 0 self.push_errors = 0 async def connect(self): """连接 Redis""" try: self.redis = redis.Redis( host=settings.REDIS_HOST, port=settings.REDIS_PORT, db=settings.REDIS_DB, decode_responses=True ) self.pubsub = self.redis.pubsub() logger.info(f"✅ PushService 连接 Redis 成功: {settings.REDIS_HOST}:{settings.REDIS_PORT}") except Exception as e: logger.error(f"❌ PushService 连接 Redis 失败: {e}") raise async def start(self): """启动推送服务""" if not self.redis: await self.connect() # 订阅 Redis 主题 await self.pubsub.psubscribe( "quote:*", # 行情推送 "kline:*", # K 线推送 "system:*", # 系统消息 "alert:*", # 告警消息 ) self.running = True logger.info("✅ PushService 启动成功") # 启动监听任务 asyncio.create_task(self._listen_and_push()) # 启动统计任务 asyncio.create_task(self._print_statistics()) async def stop(self): """停止推送服务""" self.running = False if self.pubsub: await self.pubsub.unsubscribe() if self.redis: await self.redis.close() logger.info("✅ PushService 已停止") async def _listen_and_push(self): """ 监听 Redis Pub/Sub 并推送 核心流程: 1. 接收 Redis 消息 2. 解析消息内容 3. 推送给订阅用户 """ logger.info("🔄 PushService 开始监听...") while self.running: try: # 获取消息 message = await self.pubsub.get_message(timeout=1) if message is None: continue if message["type"] not in ["pmessage", "message"]: continue # 解析消息 channel = message["channel"] if isinstance(channel, bytes): channel = channel.decode() data = message["data"] if isinstance(data, bytes): data = data.decode() # 处理消息 await self._handle_message(channel, data) except Exception as e: logger.error(f"❌ PushService 处理消息失败: {e}") self.push_errors += 1 async def _handle_message(self, channel: str, data: str): """ 处理单个消息 Args: channel: Redis 主题 data: 消息内容 """ try: # 解析数据 message_data = json.loads(data) # 解析主题 parts = channel.split(":") message_type = parts[0] # quote, kline, system, alert # 构造推送消息 push_message = { "type": message_type, "time": datetime.now().isoformat(), "data": message_data } # 根据类型处理 if message_type == "quote": # 行情推送 symbol = parts[1] if len(parts) > 1 else message_data.get("symbol") push_message["symbol"] = symbol await connection_manager.broadcast_to_symbol(symbol, push_message) elif message_type == "kline": # K 线推送 symbol = parts[1] if len(parts) > 1 else message_data.get("symbol") period = parts[2] if len(parts) > 2 else message_data.get("period") push_message["symbol"] = symbol push_message["period"] = period await connection_manager.broadcast_to_symbol(symbol, push_message) elif message_type == "system": # 系统消息(广播) await connection_manager.broadcast(push_message) elif message_type == "alert": # 告警消息(定向推送) user_id = message_data.get("user_id") if user_id: await connection_manager.send_to_user(user_id, push_message) # 更新统计 self.total_pushes += 1 except json.JSONDecodeError: logger.error(f"❌ PushService JSON 解析失败: {data}") except Exception as e: logger.error(f"❌ PushService 处理消息失败: {e}") self.push_errors += 1 async def publish_quote(self, symbol: str, quote_data: dict): """ 发布行情更新 Args: symbol: 品种代码 quote_data: 行情数据 """ channel = f"quote:{symbol}" await self.redis.publish(channel, json.dumps(quote_data)) async def publish_kline(self, symbol: str, period: str, kline_data: dict): """ 发布 K 线更新 Args: symbol: 品种代码 period: 周期 kline_data: K 线数据 """ channel = f"kline:{symbol}:{period}" await self.redis.publish(channel, json.dumps(kline_data)) async def publish_system(self, message: dict): """ 发布系统消息 Args: message: 系统消息 """ await self.redis.publish("system:broadcast", json.dumps(message)) async def publish_alert(self, user_id: int, alert_data: dict): """ 发布告警消息 Args: user_id: 用户 ID alert_data: 告警数据 """ alert_data["user_id"] = user_id await self.redis.publish("alert:trigger", json.dumps(alert_data)) def get_statistics(self) -> dict: """获取统计信息""" return { "total_pushes": self.total_pushes, "push_errors": self.push_errors, "running": self.running, "connection_stats": connection_manager.get_statistics() } async def _print_statistics(self): """ 定时打印统计信息 """ while self.running: await asyncio.sleep(60) stats = self.get_statistics() logger.info(f"📊 PushService 统计: 推送 {stats['total_pushes']} 次, 错误 {stats['push_errors']} 次") # 全局推送服务实例 push_service = PushService() # ============== 启动函数 ============== async def start_push_service(): """启动推送服务(应用启动时调用)""" await push_service.start() async def stop_push_service(): """停止推送服务(应用关闭时调用)""" await push_service.stop()