""" 告警服务 """ import logging from datetime import datetime from typing import List, Optional from sqlalchemy.orm import Session from app.models import Alert from app.db.init_db import SQLiteSessionLocal logger = logging.getLogger(__name__) class AlertService: """告警服务""" @staticmethod def create_alert( user_id: int, symbol: str, condition_type: str, condition_value: float, alert_type: str = "price" ) -> Alert: """创建告警""" with SQLiteSessionLocal() as db: alert = Alert( user_id=user_id, symbol=symbol, condition_type=condition_type, condition_value=condition_value, alert_type=alert_type, status="active" ) db.add(alert) db.commit() db.refresh(alert) return alert @staticmethod def get_user_alerts(user_id: int, status: Optional[str] = None) -> List[Alert]: """获取用户告警列表""" with SQLiteSessionLocal() as db: query = db.query(Alert).filter(Alert.user_id == user_id) if status: query = query.filter(Alert.status == status) return query.order_by(Alert.created_at.desc()).all() @staticmethod def get_alert_by_id(alert_id: int, user_id: int) -> Optional[Alert]: """根据 ID 获取告警""" with SQLiteSessionLocal() as db: return db.query(Alert).filter( Alert.id == alert_id, Alert.user_id == user_id ).first() @staticmethod def update_alert( alert_id: int, user_id: int, condition_value: Optional[float] = None, status: Optional[str] = None ) -> Optional[Alert]: """更新告警""" with SQLiteSessionLocal() as db: alert = db.query(Alert).filter( Alert.id == alert_id, Alert.user_id == user_id ).first() if not alert: return None if condition_value is not None: alert.condition_value = condition_value if status is not None: alert.status = status if status == "active" and alert.triggered_at: alert.triggered_at = None alert.updated_at = datetime.utcnow() db.commit() db.refresh(alert) return alert @staticmethod def delete_alert(alert_id: int, user_id: int) -> bool: """删除告警""" with SQLiteSessionLocal() as db: alert = db.query(Alert).filter( Alert.id == alert_id, Alert.user_id == user_id ).first() if not alert: return False db.delete(alert) db.commit() return True @staticmethod def trigger_alert(alert_id: int) -> Optional[Alert]: """触发告警""" with SQLiteSessionLocal() as db: alert = db.query(Alert).filter(Alert.id == alert_id).first() if not alert: return None alert.status = "triggered" alert.triggered_at = datetime.utcnow() alert.updated_at = datetime.utcnow() db.commit() db.refresh(alert) return alert @staticmethod def check_price_alerts(symbol: str, price: float) -> List[Alert]: """ 检查价格告警 Args: symbol: 品种代码 price: 当前价格 Returns: List[Alert]: 被触发的告警列表 """ with SQLiteSessionLocal() as db: alerts = db.query(Alert).filter( Alert.symbol == symbol, Alert.status == "active", Alert.alert_type == "price" ).all() triggered = [] for alert in alerts: should_trigger = False if alert.condition_type == "greater_than" and price >= alert.condition_value: should_trigger = True elif alert.condition_type == "less_than" and price <= alert.condition_value: should_trigger = True elif alert.condition_type == "equals" and abs(price - float(alert.condition_value)) < 0.001: should_trigger = True if should_trigger: # 更新告警状态为 triggered alert.status = "triggered" alert.triggered_at = datetime.utcnow() alert.updated_at = datetime.utcnow() triggered.append(alert) if triggered: db.commit() # 刷新对象状态 for alert in triggered: db.refresh(alert) return triggered