You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
132 lines
4.4 KiB
132 lines
4.4 KiB
"""
|
|
数据订阅服务
|
|
"""
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import List, Optional
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.models import Subscription
|
|
from app.db.init_db import SQLiteSessionLocal
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SubscriptionService:
|
|
"""数据订阅服务"""
|
|
|
|
@staticmethod
|
|
def create_subscription(
|
|
user_id: int,
|
|
symbol: str,
|
|
period: Optional[str] = None,
|
|
subscription_type: str = "kline"
|
|
) -> tuple[Subscription, bool]:
|
|
"""
|
|
创建订阅
|
|
|
|
Args:
|
|
user_id: 用户 ID
|
|
symbol: 品种代码
|
|
period: 周期
|
|
subscription_type: 订阅类型
|
|
|
|
Returns:
|
|
tuple[Subscription, bool]: (订阅对象,是否为新创建)
|
|
|
|
Raises:
|
|
ValueError: 当重复订阅时抛出异常
|
|
"""
|
|
with SQLiteSessionLocal() as db:
|
|
# 检查是否已存在活跃订阅
|
|
existing = db.query(Subscription).filter(
|
|
Subscription.user_id == user_id,
|
|
Subscription.symbol == symbol,
|
|
Subscription.period == period,
|
|
Subscription.subscription_type == subscription_type,
|
|
Subscription.is_active == True
|
|
).first()
|
|
|
|
if existing:
|
|
# 重复订阅检查:如果已存在活跃订阅,抛出异常
|
|
raise ValueError(f"您已订阅 {symbol} ({period}),请勿重复订阅")
|
|
|
|
# 检查是否存在非活跃订阅,可以重新激活
|
|
inactive = db.query(Subscription).filter(
|
|
Subscription.user_id == user_id,
|
|
Subscription.symbol == symbol,
|
|
Subscription.period == period,
|
|
Subscription.subscription_type == subscription_type,
|
|
Subscription.is_active == False
|
|
).first()
|
|
|
|
if inactive:
|
|
inactive.is_active = True
|
|
inactive.created_at = datetime.utcnow()
|
|
db.commit()
|
|
db.refresh(inactive)
|
|
return inactive, False
|
|
|
|
subscription = Subscription(
|
|
user_id=user_id,
|
|
symbol=symbol,
|
|
period=period,
|
|
subscription_type=subscription_type
|
|
)
|
|
db.add(subscription)
|
|
db.commit()
|
|
db.refresh(subscription)
|
|
return subscription, True
|
|
|
|
@staticmethod
|
|
def get_user_subscriptions(
|
|
user_id: int,
|
|
subscription_type: Optional[str] = None
|
|
) -> List[Subscription]:
|
|
"""获取用户订阅列表"""
|
|
with SQLiteSessionLocal() as db:
|
|
query = db.query(Subscription).filter(
|
|
Subscription.user_id == user_id,
|
|
Subscription.is_active == True
|
|
)
|
|
if subscription_type:
|
|
query = query.filter(Subscription.subscription_type == subscription_type)
|
|
return query.order_by(Subscription.created_at.desc()).all()
|
|
|
|
@staticmethod
|
|
def get_subscription_by_id(subscription_id: int, user_id: int) -> Optional[Subscription]:
|
|
"""根据 ID 获取订阅"""
|
|
with SQLiteSessionLocal() as db:
|
|
return db.query(Subscription).filter(
|
|
Subscription.id == subscription_id,
|
|
Subscription.user_id == user_id
|
|
).first()
|
|
|
|
@staticmethod
|
|
def cancel_subscription(subscription_id: int, user_id: int) -> bool:
|
|
"""取消订阅"""
|
|
with SQLiteSessionLocal() as db:
|
|
subscription = db.query(Subscription).filter(
|
|
Subscription.id == subscription_id,
|
|
Subscription.user_id == user_id
|
|
).first()
|
|
|
|
if not subscription:
|
|
return False
|
|
|
|
subscription.is_active = False
|
|
db.commit()
|
|
return True
|
|
|
|
@staticmethod
|
|
def get_subscribers_for_symbol(symbol: str, subscription_type: str = "kline") -> List[int]:
|
|
"""获取订阅某品种的用户 ID 列表"""
|
|
with SQLiteSessionLocal() as db:
|
|
subscriptions = db.query(Subscription).filter(
|
|
Subscription.symbol == symbol,
|
|
Subscription.subscription_type == subscription_type,
|
|
Subscription.is_active == True
|
|
).all()
|
|
return [s.user_id for s in subscriptions]
|