You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
164 lines
5.0 KiB
164 lines
5.0 KiB
|
2 months ago
|
"""
|
||
|
|
告警服务
|
||
|
|
"""
|
||
|
|
import logging
|
||
|
|
from datetime import datetime
|
||
|
|
from typing import List, Optional
|
||
|
|
|
||
|
|
from sqlalchemy.orm import Session
|
||
|
|
|
||
|
|
from app.models import Alert
|
||
|
|
from app.db.init_db import SQLiteSessionLocal
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class AlertService:
|
||
|
|
"""告警服务"""
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def create_alert(
|
||
|
|
user_id: int,
|
||
|
|
symbol: str,
|
||
|
|
condition_type: str,
|
||
|
|
condition_value: float,
|
||
|
|
alert_type: str = "price"
|
||
|
|
) -> Alert:
|
||
|
|
"""创建告警"""
|
||
|
|
with SQLiteSessionLocal() as db:
|
||
|
|
alert = Alert(
|
||
|
|
user_id=user_id,
|
||
|
|
symbol=symbol,
|
||
|
|
condition_type=condition_type,
|
||
|
|
condition_value=condition_value,
|
||
|
|
alert_type=alert_type,
|
||
|
|
status="active"
|
||
|
|
)
|
||
|
|
db.add(alert)
|
||
|
|
db.commit()
|
||
|
|
db.refresh(alert)
|
||
|
|
return alert
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def get_user_alerts(user_id: int, status: Optional[str] = None) -> List[Alert]:
|
||
|
|
"""获取用户告警列表"""
|
||
|
|
with SQLiteSessionLocal() as db:
|
||
|
|
query = db.query(Alert).filter(Alert.user_id == user_id)
|
||
|
|
if status:
|
||
|
|
query = query.filter(Alert.status == status)
|
||
|
|
return query.order_by(Alert.created_at.desc()).all()
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def get_alert_by_id(alert_id: int, user_id: int) -> Optional[Alert]:
|
||
|
|
"""根据 ID 获取告警"""
|
||
|
|
with SQLiteSessionLocal() as db:
|
||
|
|
return db.query(Alert).filter(
|
||
|
|
Alert.id == alert_id,
|
||
|
|
Alert.user_id == user_id
|
||
|
|
).first()
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def update_alert(
|
||
|
|
alert_id: int,
|
||
|
|
user_id: int,
|
||
|
|
condition_value: Optional[float] = None,
|
||
|
|
status: Optional[str] = None
|
||
|
|
) -> Optional[Alert]:
|
||
|
|
"""更新告警"""
|
||
|
|
with SQLiteSessionLocal() as db:
|
||
|
|
alert = db.query(Alert).filter(
|
||
|
|
Alert.id == alert_id,
|
||
|
|
Alert.user_id == user_id
|
||
|
|
).first()
|
||
|
|
|
||
|
|
if not alert:
|
||
|
|
return None
|
||
|
|
|
||
|
|
if condition_value is not None:
|
||
|
|
alert.condition_value = condition_value
|
||
|
|
if status is not None:
|
||
|
|
alert.status = status
|
||
|
|
if status == "active" and alert.triggered_at:
|
||
|
|
alert.triggered_at = None
|
||
|
|
|
||
|
|
alert.updated_at = datetime.utcnow()
|
||
|
|
db.commit()
|
||
|
|
db.refresh(alert)
|
||
|
|
return alert
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def delete_alert(alert_id: int, user_id: int) -> bool:
|
||
|
|
"""删除告警"""
|
||
|
|
with SQLiteSessionLocal() as db:
|
||
|
|
alert = db.query(Alert).filter(
|
||
|
|
Alert.id == alert_id,
|
||
|
|
Alert.user_id == user_id
|
||
|
|
).first()
|
||
|
|
|
||
|
|
if not alert:
|
||
|
|
return False
|
||
|
|
|
||
|
|
db.delete(alert)
|
||
|
|
db.commit()
|
||
|
|
return True
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def trigger_alert(alert_id: int) -> Optional[Alert]:
|
||
|
|
"""触发告警"""
|
||
|
|
with SQLiteSessionLocal() as db:
|
||
|
|
alert = db.query(Alert).filter(Alert.id == alert_id).first()
|
||
|
|
if not alert:
|
||
|
|
return None
|
||
|
|
|
||
|
|
alert.status = "triggered"
|
||
|
|
alert.triggered_at = datetime.utcnow()
|
||
|
|
alert.updated_at = datetime.utcnow()
|
||
|
|
db.commit()
|
||
|
|
db.refresh(alert)
|
||
|
|
return alert
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def check_price_alerts(symbol: str, price: float) -> List[Alert]:
|
||
|
|
"""
|
||
|
|
检查价格告警
|
||
|
|
|
||
|
|
Args:
|
||
|
|
symbol: 品种代码
|
||
|
|
price: 当前价格
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List[Alert]: 被触发的告警列表
|
||
|
|
"""
|
||
|
|
with SQLiteSessionLocal() as db:
|
||
|
|
alerts = db.query(Alert).filter(
|
||
|
|
Alert.symbol == symbol,
|
||
|
|
Alert.status == "active",
|
||
|
|
Alert.alert_type == "price"
|
||
|
|
).all()
|
||
|
|
|
||
|
|
triggered = []
|
||
|
|
for alert in alerts:
|
||
|
|
should_trigger = False
|
||
|
|
|
||
|
|
if alert.condition_type == "greater_than" and price >= alert.condition_value:
|
||
|
|
should_trigger = True
|
||
|
|
elif alert.condition_type == "less_than" and price <= alert.condition_value:
|
||
|
|
should_trigger = True
|
||
|
|
elif alert.condition_type == "equals" and abs(price - float(alert.condition_value)) < 0.001:
|
||
|
|
should_trigger = True
|
||
|
|
|
||
|
|
if should_trigger:
|
||
|
|
# 更新告警状态为 triggered
|
||
|
|
alert.status = "triggered"
|
||
|
|
alert.triggered_at = datetime.utcnow()
|
||
|
|
alert.updated_at = datetime.utcnow()
|
||
|
|
triggered.append(alert)
|
||
|
|
|
||
|
|
if triggered:
|
||
|
|
db.commit()
|
||
|
|
# 刷新对象状态
|
||
|
|
for alert in triggered:
|
||
|
|
db.refresh(alert)
|
||
|
|
|
||
|
|
return triggered
|