You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

132 lines
4.4 KiB

"""
数据订阅服务
"""
import logging
from datetime import datetime
from typing import List, Optional
from sqlalchemy.orm import Session
from app.models import Subscription
from app.db.init_db import SQLiteSessionLocal
logger = logging.getLogger(__name__)
class SubscriptionService:
"""数据订阅服务"""
@staticmethod
def create_subscription(
user_id: int,
symbol: str,
period: Optional[str] = None,
subscription_type: str = "kline"
) -> tuple[Subscription, bool]:
"""
创建订阅
Args:
user_id: 用户 ID
symbol: 品种代码
period: 周期
subscription_type: 订阅类型
Returns:
tuple[Subscription, bool]: (订阅对象,是否为新创建)
Raises:
ValueError: 当重复订阅时抛出异常
"""
with SQLiteSessionLocal() as db:
# 检查是否已存在活跃订阅
existing = db.query(Subscription).filter(
Subscription.user_id == user_id,
Subscription.symbol == symbol,
Subscription.period == period,
Subscription.subscription_type == subscription_type,
Subscription.is_active == True
).first()
if existing:
# 重复订阅检查:如果已存在活跃订阅,抛出异常
raise ValueError(f"您已订阅 {symbol} ({period}),请勿重复订阅")
# 检查是否存在非活跃订阅,可以重新激活
inactive = db.query(Subscription).filter(
Subscription.user_id == user_id,
Subscription.symbol == symbol,
Subscription.period == period,
Subscription.subscription_type == subscription_type,
Subscription.is_active == False
).first()
if inactive:
inactive.is_active = True
inactive.created_at = datetime.utcnow()
db.commit()
db.refresh(inactive)
return inactive, False
subscription = Subscription(
user_id=user_id,
symbol=symbol,
period=period,
subscription_type=subscription_type
)
db.add(subscription)
db.commit()
db.refresh(subscription)
return subscription, True
@staticmethod
def get_user_subscriptions(
user_id: int,
subscription_type: Optional[str] = None
) -> List[Subscription]:
"""获取用户订阅列表"""
with SQLiteSessionLocal() as db:
query = db.query(Subscription).filter(
Subscription.user_id == user_id,
Subscription.is_active == True
)
if subscription_type:
query = query.filter(Subscription.subscription_type == subscription_type)
return query.order_by(Subscription.created_at.desc()).all()
@staticmethod
def get_subscription_by_id(subscription_id: int, user_id: int) -> Optional[Subscription]:
"""根据 ID 获取订阅"""
with SQLiteSessionLocal() as db:
return db.query(Subscription).filter(
Subscription.id == subscription_id,
Subscription.user_id == user_id
).first()
@staticmethod
def cancel_subscription(subscription_id: int, user_id: int) -> bool:
"""取消订阅"""
with SQLiteSessionLocal() as db:
subscription = db.query(Subscription).filter(
Subscription.id == subscription_id,
Subscription.user_id == user_id
).first()
if not subscription:
return False
subscription.is_active = False
db.commit()
return True
@staticmethod
def get_subscribers_for_symbol(symbol: str, subscription_type: str = "kline") -> List[int]:
"""获取订阅某品种的用户 ID 列表"""
with SQLiteSessionLocal() as db:
subscriptions = db.query(Subscription).filter(
Subscription.symbol == symbol,
Subscription.subscription_type == subscription_type,
Subscription.is_active == True
).all()
return [s.user_id for s in subscriptions]