You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

259 lines
7.9 KiB

# backend/app/services/push_service.py
"""
推送服务
从 Redis Pub/Sub 接收数据,推送给 WebSocket 客户端
"""
import asyncio
import json
import redis.asyncio as redis
from typing import Dict, Optional
from datetime import datetime
from app.websocket.connection_manager import connection_manager
from app.config import settings
import logging
logger = logging.getLogger(__name__)
class PushService:
"""
推送服务
功能:
- 从 Redis Pub/Sub 接收行情更新
- 推送给订阅的 WebSocket 客户端
- 支持行情推送、K 线推送
性能优化:
- 异步处理
- 批量推送
- 消息队列缓冲
"""
def __init__(self):
self.redis: Optional[redis.Redis] = None
self.pubsub: Optional[redis.client.PubSub] = None
self.running = False
# 推送统计
self.total_pushes = 0
self.push_errors = 0
async def connect(self):
"""连接 Redis"""
try:
self.redis = redis.Redis(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
decode_responses=True
)
self.pubsub = self.redis.pubsub()
logger.info(f"✅ PushService 连接 Redis 成功: {settings.REDIS_HOST}:{settings.REDIS_PORT}")
except Exception as e:
logger.error(f"❌ PushService 连接 Redis 失败: {e}")
raise
async def start(self):
"""启动推送服务"""
if not self.redis:
await self.connect()
# 订阅 Redis 主题
await self.pubsub.psubscribe(
"quote:*", # 行情推送
"kline:*", # K 线推送
"system:*", # 系统消息
"alert:*", # 告警消息
)
self.running = True
logger.info("✅ PushService 启动成功")
# 启动监听任务
asyncio.create_task(self._listen_and_push())
# 启动统计任务
asyncio.create_task(self._print_statistics())
async def stop(self):
"""停止推送服务"""
self.running = False
if self.pubsub:
await self.pubsub.unsubscribe()
if self.redis:
await self.redis.close()
logger.info("✅ PushService 已停止")
async def _listen_and_push(self):
"""
监听 Redis Pub/Sub 并推送
核心流程:
1. 接收 Redis 消息
2. 解析消息内容
3. 推送给订阅用户
"""
logger.info("🔄 PushService 开始监听...")
while self.running:
try:
# 获取消息
message = await self.pubsub.get_message(timeout=1)
if message is None:
continue
if message["type"] not in ["pmessage", "message"]:
continue
# 解析消息
channel = message["channel"]
if isinstance(channel, bytes):
channel = channel.decode()
data = message["data"]
if isinstance(data, bytes):
data = data.decode()
# 处理消息
await self._handle_message(channel, data)
except Exception as e:
logger.error(f"❌ PushService 处理消息失败: {e}")
self.push_errors += 1
async def _handle_message(self, channel: str, data: str):
"""
处理单个消息
Args:
channel: Redis 主题
data: 消息内容
"""
try:
# 解析数据
message_data = json.loads(data)
# 解析主题
parts = channel.split(":")
message_type = parts[0] # quote, kline, system, alert
# 构造推送消息
push_message = {
"type": message_type,
"time": datetime.now().isoformat(),
"data": message_data
}
# 根据类型处理
if message_type == "quote":
# 行情推送
symbol = parts[1] if len(parts) > 1 else message_data.get("symbol")
push_message["symbol"] = symbol
await connection_manager.broadcast_to_symbol(symbol, push_message)
elif message_type == "kline":
# K 线推送
symbol = parts[1] if len(parts) > 1 else message_data.get("symbol")
period = parts[2] if len(parts) > 2 else message_data.get("period")
push_message["symbol"] = symbol
push_message["period"] = period
await connection_manager.broadcast_to_symbol(symbol, push_message)
elif message_type == "system":
# 系统消息(广播)
await connection_manager.broadcast(push_message)
elif message_type == "alert":
# 告警消息(定向推送)
user_id = message_data.get("user_id")
if user_id:
await connection_manager.send_to_user(user_id, push_message)
# 更新统计
self.total_pushes += 1
except json.JSONDecodeError:
logger.error(f"❌ PushService JSON 解析失败: {data}")
except Exception as e:
logger.error(f"❌ PushService 处理消息失败: {e}")
self.push_errors += 1
async def publish_quote(self, symbol: str, quote_data: dict):
"""
发布行情更新
Args:
symbol: 品种代码
quote_data: 行情数据
"""
channel = f"quote:{symbol}"
await self.redis.publish(channel, json.dumps(quote_data))
async def publish_kline(self, symbol: str, period: str, kline_data: dict):
"""
发布 K 线更新
Args:
symbol: 品种代码
period: 周期
kline_data: K 线数据
"""
channel = f"kline:{symbol}:{period}"
await self.redis.publish(channel, json.dumps(kline_data))
async def publish_system(self, message: dict):
"""
发布系统消息
Args:
message: 系统消息
"""
await self.redis.publish("system:broadcast", json.dumps(message))
async def publish_alert(self, user_id: int, alert_data: dict):
"""
发布告警消息
Args:
user_id: 用户 ID
alert_data: 告警数据
"""
alert_data["user_id"] = user_id
await self.redis.publish("alert:trigger", json.dumps(alert_data))
def get_statistics(self) -> dict:
"""获取统计信息"""
return {
"total_pushes": self.total_pushes,
"push_errors": self.push_errors,
"running": self.running,
"connection_stats": connection_manager.get_statistics()
}
async def _print_statistics(self):
"""
定时打印统计信息
"""
while self.running:
await asyncio.sleep(60)
stats = self.get_statistics()
logger.info(f"📊 PushService 统计: 推送 {stats['total_pushes']} 次, 错误 {stats['push_errors']}")
# 全局推送服务实例
push_service = PushService()
# ============== 启动函数 ==============
async def start_push_service():
"""启动推送服务(应用启动时调用)"""
await push_service.start()
async def stop_push_service():
"""停止推送服务(应用关闭时调用)"""
await push_service.stop()