""" 数据订阅服务 """ import logging from datetime import datetime from typing import List, Optional from sqlalchemy.orm import Session from app.models import Subscription from app.db.init_db import SQLiteSessionLocal logger = logging.getLogger(__name__) class SubscriptionService: """数据订阅服务""" @staticmethod def create_subscription( user_id: int, symbol: str, period: Optional[str] = None, subscription_type: str = "kline" ) -> tuple[Subscription, bool]: """ 创建订阅 Args: user_id: 用户 ID symbol: 品种代码 period: 周期 subscription_type: 订阅类型 Returns: tuple[Subscription, bool]: (订阅对象,是否为新创建) Raises: ValueError: 当重复订阅时抛出异常 """ with SQLiteSessionLocal() as db: # 检查是否已存在活跃订阅 existing = db.query(Subscription).filter( Subscription.user_id == user_id, Subscription.symbol == symbol, Subscription.period == period, Subscription.subscription_type == subscription_type, Subscription.is_active == True ).first() if existing: # 重复订阅检查:如果已存在活跃订阅,抛出异常 raise ValueError(f"您已订阅 {symbol} ({period}),请勿重复订阅") # 检查是否存在非活跃订阅,可以重新激活 inactive = db.query(Subscription).filter( Subscription.user_id == user_id, Subscription.symbol == symbol, Subscription.period == period, Subscription.subscription_type == subscription_type, Subscription.is_active == False ).first() if inactive: inactive.is_active = True inactive.created_at = datetime.utcnow() db.commit() db.refresh(inactive) return inactive, False subscription = Subscription( user_id=user_id, symbol=symbol, period=period, subscription_type=subscription_type ) db.add(subscription) db.commit() db.refresh(subscription) return subscription, True @staticmethod def get_user_subscriptions( user_id: int, subscription_type: Optional[str] = None ) -> List[Subscription]: """获取用户订阅列表""" with SQLiteSessionLocal() as db: query = db.query(Subscription).filter( Subscription.user_id == user_id, Subscription.is_active == True ) if subscription_type: query = query.filter(Subscription.subscription_type == subscription_type) return query.order_by(Subscription.created_at.desc()).all() @staticmethod def get_subscription_by_id(subscription_id: int, user_id: int) -> Optional[Subscription]: """根据 ID 获取订阅""" with SQLiteSessionLocal() as db: return db.query(Subscription).filter( Subscription.id == subscription_id, Subscription.user_id == user_id ).first() @staticmethod def cancel_subscription(subscription_id: int, user_id: int) -> bool: """取消订阅""" with SQLiteSessionLocal() as db: subscription = db.query(Subscription).filter( Subscription.id == subscription_id, Subscription.user_id == user_id ).first() if not subscription: return False subscription.is_active = False db.commit() return True @staticmethod def get_subscribers_for_symbol(symbol: str, subscription_type: str = "kline") -> List[int]: """获取订阅某品种的用户 ID 列表""" with SQLiteSessionLocal() as db: subscriptions = db.query(Subscription).filter( Subscription.symbol == symbol, Subscription.subscription_type == subscription_type, Subscription.is_active == True ).all() return [s.user_id for s in subscriptions]