You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

164 lines
5.0 KiB

"""
告警服务
"""
import logging
from datetime import datetime
from typing import List, Optional
from sqlalchemy.orm import Session
from app.models import Alert
from app.db.init_db import SQLiteSessionLocal
logger = logging.getLogger(__name__)
class AlertService:
"""告警服务"""
@staticmethod
def create_alert(
user_id: int,
symbol: str,
condition_type: str,
condition_value: float,
alert_type: str = "price"
) -> Alert:
"""创建告警"""
with SQLiteSessionLocal() as db:
alert = Alert(
user_id=user_id,
symbol=symbol,
condition_type=condition_type,
condition_value=condition_value,
alert_type=alert_type,
status="active"
)
db.add(alert)
db.commit()
db.refresh(alert)
return alert
@staticmethod
def get_user_alerts(user_id: int, status: Optional[str] = None) -> List[Alert]:
"""获取用户告警列表"""
with SQLiteSessionLocal() as db:
query = db.query(Alert).filter(Alert.user_id == user_id)
if status:
query = query.filter(Alert.status == status)
return query.order_by(Alert.created_at.desc()).all()
@staticmethod
def get_alert_by_id(alert_id: int, user_id: int) -> Optional[Alert]:
"""根据 ID 获取告警"""
with SQLiteSessionLocal() as db:
return db.query(Alert).filter(
Alert.id == alert_id,
Alert.user_id == user_id
).first()
@staticmethod
def update_alert(
alert_id: int,
user_id: int,
condition_value: Optional[float] = None,
status: Optional[str] = None
) -> Optional[Alert]:
"""更新告警"""
with SQLiteSessionLocal() as db:
alert = db.query(Alert).filter(
Alert.id == alert_id,
Alert.user_id == user_id
).first()
if not alert:
return None
if condition_value is not None:
alert.condition_value = condition_value
if status is not None:
alert.status = status
if status == "active" and alert.triggered_at:
alert.triggered_at = None
alert.updated_at = datetime.utcnow()
db.commit()
db.refresh(alert)
return alert
@staticmethod
def delete_alert(alert_id: int, user_id: int) -> bool:
"""删除告警"""
with SQLiteSessionLocal() as db:
alert = db.query(Alert).filter(
Alert.id == alert_id,
Alert.user_id == user_id
).first()
if not alert:
return False
db.delete(alert)
db.commit()
return True
@staticmethod
def trigger_alert(alert_id: int) -> Optional[Alert]:
"""触发告警"""
with SQLiteSessionLocal() as db:
alert = db.query(Alert).filter(Alert.id == alert_id).first()
if not alert:
return None
alert.status = "triggered"
alert.triggered_at = datetime.utcnow()
alert.updated_at = datetime.utcnow()
db.commit()
db.refresh(alert)
return alert
@staticmethod
def check_price_alerts(symbol: str, price: float) -> List[Alert]:
"""
检查价格告警
Args:
symbol: 品种代码
price: 当前价格
Returns:
List[Alert]: 被触发的告警列表
"""
with SQLiteSessionLocal() as db:
alerts = db.query(Alert).filter(
Alert.symbol == symbol,
Alert.status == "active",
Alert.alert_type == "price"
).all()
triggered = []
for alert in alerts:
should_trigger = False
if alert.condition_type == "greater_than" and price >= alert.condition_value:
should_trigger = True
elif alert.condition_type == "less_than" and price <= alert.condition_value:
should_trigger = True
elif alert.condition_type == "equals" and abs(price - float(alert.condition_value)) < 0.001:
should_trigger = True
if should_trigger:
# 更新告警状态为 triggered
alert.status = "triggered"
alert.triggered_at = datetime.utcnow()
alert.updated_at = datetime.utcnow()
triggered.append(alert)
if triggered:
db.commit()
# 刷新对象状态
for alert in triggered:
db.refresh(alert)
return triggered