|
|
|
|
|
# backend/app/services/alert_engine.py
|
|
|
|
|
|
"""
|
|
|
|
|
|
智能告警引擎
|
|
|
|
|
|
支持实时数据 + 规则 → 布尔值判断,并行计算,100+ 规则/用户
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import asyncio
|
|
|
|
|
|
import json
|
|
|
|
|
|
from typing import Dict, List, Optional
|
|
|
|
|
|
from datetime import datetime, time as dt_time
|
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
from sqlalchemy import and_
|
|
|
|
|
|
from app.models.alert import AlertRule, AlertHistory, AlertType, AlertOperator
|
|
|
|
|
|
from app.services.cache_service import cache_service
|
|
|
|
|
|
import logging
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AlertEngine:
|
|
|
|
|
|
"""
|
|
|
|
|
|
告警引擎
|
|
|
|
|
|
|
|
|
|
|
|
功能:
|
|
|
|
|
|
- 规则缓存(内存缓存用户规则)
|
|
|
|
|
|
- 规则计算(实时数据 + 规则 → 布尔值)
|
|
|
|
|
|
- 并行计算(asyncio.gather)
|
|
|
|
|
|
- 触发判断(是否满足告警条件)
|
|
|
|
|
|
|
|
|
|
|
|
性能优化:
|
|
|
|
|
|
- 规则内存缓存
|
|
|
|
|
|
- 并行计算
|
|
|
|
|
|
- symbol 索引
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
# user_id -> List[AlertRule] (规则缓存)
|
|
|
|
|
|
self.rules_cache: Dict[int, List[AlertRule]] = {}
|
|
|
|
|
|
|
|
|
|
|
|
# symbol -> List[AlertRule] (symbol 索引)
|
|
|
|
|
|
self.symbol_rules: Dict[str, List[AlertRule]] = {}
|
|
|
|
|
|
|
|
|
|
|
|
# 全局规则(symbol 为空)
|
|
|
|
|
|
self.global_rules: List[AlertRule] = []
|
|
|
|
|
|
|
|
|
|
|
|
# 上次触发时间缓存 (rule_id -> last_triggered_time)
|
|
|
|
|
|
self.trigger_times: Dict[int, datetime] = {}
|
|
|
|
|
|
|
|
|
|
|
|
# 计算统计
|
|
|
|
|
|
self.total_checks = 0
|
|
|
|
|
|
self.total_triggers = 0
|
|
|
|
|
|
|
|
|
|
|
|
async def load_all_rules(self, db: Session):
|
|
|
|
|
|
"""
|
|
|
|
|
|
加载所有启用规则
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
db: 数据库会话
|
|
|
|
|
|
"""
|
|
|
|
|
|
rules = db.query(AlertRule).filter(AlertRule.enabled == True).all()
|
|
|
|
|
|
|
|
|
|
|
|
# 清空缓存
|
|
|
|
|
|
self.rules_cache.clear()
|
|
|
|
|
|
self.symbol_rules.clear()
|
|
|
|
|
|
self.global_rules.clear()
|
|
|
|
|
|
|
|
|
|
|
|
# 按用户和品种分组
|
|
|
|
|
|
for rule in rules:
|
|
|
|
|
|
# 按用户分组
|
|
|
|
|
|
if rule.user_id not in self.rules_cache:
|
|
|
|
|
|
self.rules_cache[rule.user_id] = []
|
|
|
|
|
|
self.rules_cache[rule.user_id].append(rule)
|
|
|
|
|
|
|
|
|
|
|
|
# 按品种分组
|
|
|
|
|
|
if rule.symbol:
|
|
|
|
|
|
if rule.symbol not in self.symbol_rules:
|
|
|
|
|
|
self.symbol_rules[rule.symbol] = []
|
|
|
|
|
|
self.symbol_rules[rule.symbol].append(rule)
|
|
|
|
|
|
else:
|
|
|
|
|
|
self.global_rules.append(rule)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"✅ AlertEngine 加载 {len(rules)} 条规则")
|
|
|
|
|
|
|
|
|
|
|
|
async def load_user_rules(self, db: Session, user_id: int):
|
|
|
|
|
|
"""
|
|
|
|
|
|
加载用户规则
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
db: 数据库会话
|
|
|
|
|
|
user_id: 用户 ID
|
|
|
|
|
|
"""
|
|
|
|
|
|
rules = db.query(AlertRule).filter(
|
|
|
|
|
|
and_(
|
|
|
|
|
|
AlertRule.user_id == user_id,
|
|
|
|
|
|
AlertRule.enabled == True
|
|
|
|
|
|
)
|
|
|
|
|
|
).all()
|
|
|
|
|
|
|
|
|
|
|
|
self.rules_cache[user_id] = rules
|
|
|
|
|
|
|
|
|
|
|
|
# 更新 symbol 索引
|
|
|
|
|
|
for rule in rules:
|
|
|
|
|
|
if rule.symbol:
|
|
|
|
|
|
if rule.symbol not in self.symbol_rules:
|
|
|
|
|
|
self.symbol_rules[rule.symbol] = []
|
|
|
|
|
|
if rule not in self.symbol_rules[rule.symbol]:
|
|
|
|
|
|
self.symbol_rules[rule.symbol].append(rule)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"✅ AlertEngine 加载用户 {user_id} 的 {len(rules)} 条规则")
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_condition(self, condition: dict, current_value: float) -> bool:
|
|
|
|
|
|
"""
|
|
|
|
|
|
计算单个条件
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
condition: 条件字典 {"field": "price", "operator": "gt", "value": 3900}
|
|
|
|
|
|
current_value: 当前值
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
bool: 是否满足条件
|
|
|
|
|
|
"""
|
|
|
|
|
|
operator = condition.get("operator")
|
|
|
|
|
|
threshold = condition.get("value", 0)
|
|
|
|
|
|
|
|
|
|
|
|
operators = {
|
|
|
|
|
|
"gt": lambda a, b: a > b,
|
|
|
|
|
|
"lt": lambda a, b: a < b,
|
|
|
|
|
|
"ge": lambda a, b: a >= b,
|
|
|
|
|
|
"le": lambda a, b: a <= b,
|
|
|
|
|
|
"eq": lambda a, b: abs(a - b) < 0.0001, # 浮点数相等判断
|
|
|
|
|
|
"ne": lambda a, b: abs(a - b) >= 0.0001,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
op_func = operators.get(operator)
|
|
|
|
|
|
if not op_func:
|
|
|
|
|
|
logger.warning(f"⚠️ AlertEngine 未知的操作符: {operator}")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
return op_func(current_value, threshold)
|
|
|
|
|
|
|
|
|
|
|
|
def _get_field_value(self, data: dict, field: str) -> float:
|
|
|
|
|
|
"""
|
|
|
|
|
|
从数据中获取字段值
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
data: 数据字典
|
|
|
|
|
|
field: 字段名
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
float: 字段值
|
|
|
|
|
|
"""
|
|
|
|
|
|
field_mapping = {
|
|
|
|
|
|
"price": "price",
|
|
|
|
|
|
"change": "change",
|
|
|
|
|
|
"change_percent": "change_percent",
|
|
|
|
|
|
"volume": "volume",
|
|
|
|
|
|
"high": "high",
|
|
|
|
|
|
"low": "low",
|
|
|
|
|
|
"open": "open",
|
|
|
|
|
|
"close": "close",
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
data_field = field_mapping.get(field, field)
|
|
|
|
|
|
return float(data.get(data_field, 0))
|
|
|
|
|
|
|
|
|
|
|
|
async def check_alert(self, db: Session, symbol: str, data: dict) -> List[dict]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
检查品种的所有告警
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
db: 数据库会话
|
|
|
|
|
|
symbol: 品种代码
|
|
|
|
|
|
data: 行情数据
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
List[dict]: 触发的告警列表
|
|
|
|
|
|
"""
|
|
|
|
|
|
triggered_alerts = []
|
|
|
|
|
|
|
|
|
|
|
|
# 获取该品种的所有规则
|
|
|
|
|
|
rules = self.symbol_rules.get(symbol, [])
|
|
|
|
|
|
|
|
|
|
|
|
# 加上全局规则
|
|
|
|
|
|
all_rules = rules + self.global_rules
|
|
|
|
|
|
|
|
|
|
|
|
# 并行检查所有规则
|
|
|
|
|
|
check_tasks = []
|
|
|
|
|
|
for rule in all_rules:
|
|
|
|
|
|
check_tasks.append(self._check_single_rule(db, rule, data))
|
|
|
|
|
|
|
|
|
|
|
|
# 并行执行
|
|
|
|
|
|
if check_tasks:
|
|
|
|
|
|
results = await asyncio.gather(*check_tasks)
|
|
|
|
|
|
triggered_alerts = [r for r in results if r is not None]
|
|
|
|
|
|
|
|
|
|
|
|
# 更新统计
|
|
|
|
|
|
self.total_checks += len(all_rules)
|
|
|
|
|
|
self.total_triggers += len(triggered_alerts)
|
|
|
|
|
|
|
|
|
|
|
|
return triggered_alerts
|
|
|
|
|
|
|
|
|
|
|
|
async def _check_single_rule(self, db: Session, rule: AlertRule, data: dict) -> Optional[dict]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
检查单个规则
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
db: 数据库会话
|
|
|
|
|
|
rule: 告警规则
|
|
|
|
|
|
data: 行情数据
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
Optional[dict]: 触发的告警信息
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 1. 检查生效时间
|
|
|
|
|
|
now = datetime.now()
|
|
|
|
|
|
current_time = now.time()
|
|
|
|
|
|
|
|
|
|
|
|
if rule.start_time and rule.end_time:
|
|
|
|
|
|
if not (rule.start_time <= current_time <= rule.end_time):
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 检查重复间隔
|
|
|
|
|
|
if rule.repeat_interval > 0:
|
|
|
|
|
|
last_triggered = self.trigger_times.get(rule.id)
|
|
|
|
|
|
if last_triggered:
|
|
|
|
|
|
elapsed = (now - last_triggered).total_seconds()
|
|
|
|
|
|
if elapsed < rule.repeat_interval:
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
# 3. 获取当前值
|
|
|
|
|
|
condition = rule.condition
|
|
|
|
|
|
if isinstance(condition, str):
|
|
|
|
|
|
condition = json.loads(condition)
|
|
|
|
|
|
|
|
|
|
|
|
field = condition.get("field")
|
|
|
|
|
|
current_value = self._get_field_value(data, field)
|
|
|
|
|
|
|
|
|
|
|
|
# 4. 计算规则
|
|
|
|
|
|
if self.evaluate_condition(condition, current_value):
|
|
|
|
|
|
# 5. 触发告警
|
|
|
|
|
|
trigger_info = {
|
|
|
|
|
|
"rule": rule,
|
|
|
|
|
|
"rule_id": rule.id,
|
|
|
|
|
|
"user_id": rule.user_id,
|
|
|
|
|
|
"symbol": rule.symbol,
|
|
|
|
|
|
"name": rule.name,
|
|
|
|
|
|
"type": rule.type,
|
|
|
|
|
|
"trigger_value": current_value,
|
|
|
|
|
|
"trigger_condition": f"{field} {condition.get('operator')} {condition.get('value')}",
|
|
|
|
|
|
"trigger_time": now,
|
|
|
|
|
|
"channels": rule.channels if isinstance(rule.channels, list) else json.loads(rule.channels),
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 6. 记录触发时间
|
|
|
|
|
|
self.trigger_times[rule.id] = now
|
|
|
|
|
|
|
|
|
|
|
|
# 7. 创建历史记录
|
|
|
|
|
|
await self._create_history(db, rule, trigger_info)
|
|
|
|
|
|
|
|
|
|
|
|
return trigger_info
|
|
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"❌ AlertEngine 检查规则 {rule.id} 失败: {e}")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
async def _create_history(self, db: Session, rule: AlertRule, trigger_info: dict):
|
|
|
|
|
|
"""
|
|
|
|
|
|
创建告警历史
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
db: 数据库会话
|
|
|
|
|
|
rule: 告警规则
|
|
|
|
|
|
trigger_info: 触发信息
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
history = AlertHistory(
|
|
|
|
|
|
rule_id=rule.id,
|
|
|
|
|
|
user_id=rule.user_id,
|
|
|
|
|
|
symbol=rule.symbol,
|
|
|
|
|
|
trigger_value=trigger_info["trigger_value"],
|
|
|
|
|
|
trigger_condition=trigger_info["trigger_condition"],
|
|
|
|
|
|
notified=False,
|
|
|
|
|
|
trigger_time=datetime.now()
|
|
|
|
|
|
)
|
|
|
|
|
|
db.add(history)
|
|
|
|
|
|
|
|
|
|
|
|
# 更新规则触发次数
|
|
|
|
|
|
rule.trigger_count += 1
|
|
|
|
|
|
rule.last_triggered_at = datetime.now()
|
|
|
|
|
|
|
|
|
|
|
|
db.commit()
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"❌ AlertEngine 创建历史失败: {e}")
|
|
|
|
|
|
db.rollback()
|
|
|
|
|
|
|
|
|
|
|
|
def get_statistics(self) -> dict:
|
|
|
|
|
|
"""获取统计信息"""
|
|
|
|
|
|
return {
|
|
|
|
|
|
"total_checks": self.total_checks,
|
|
|
|
|
|
"total_triggers": self.total_triggers,
|
|
|
|
|
|
"cached_users": len(self.rules_cache),
|
|
|
|
|
|
"cached_rules": sum(len(rules) for rules in self.rules_cache.values()),
|
|
|
|
|
|
"symbol_count": len(self.symbol_rules),
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
async def check_all_symbols(self, db: Session) -> Dict[str, List[dict]]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
检查所有品种
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
db: 数据库会话
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
Dict[str, List[dict]]: 各品种触发的告警
|
|
|
|
|
|
"""
|
|
|
|
|
|
results = {}
|
|
|
|
|
|
|
|
|
|
|
|
for symbol in self.symbol_rules.keys():
|
|
|
|
|
|
# 从缓存获取最新行情
|
|
|
|
|
|
quote_data = await cache_service.get_latest_quote(symbol)
|
|
|
|
|
|
if quote_data:
|
|
|
|
|
|
triggered = await self.check_alert(db, symbol, quote_data)
|
|
|
|
|
|
if triggered:
|
|
|
|
|
|
results[symbol] = triggered
|
|
|
|
|
|
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 全局告警引擎实例
|
|
|
|
|
|
alert_engine = AlertEngine()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ============== 定时任务 ==============
|
|
|
|
|
|
|
|
|
|
|
|
async def alert_checker_task(db: Session):
|
|
|
|
|
|
"""
|
|
|
|
|
|
告警检查定时任务
|
|
|
|
|
|
|
|
|
|
|
|
每分钟执行一次
|
|
|
|
|
|
"""
|
|
|
|
|
|
logger.info("🔄 AlertEngine 开始检查...")
|
|
|
|
|
|
|
|
|
|
|
|
# 加载所有规则
|
|
|
|
|
|
await alert_engine.load_all_rules(db)
|
|
|
|
|
|
|
|
|
|
|
|
# 检查所有品种
|
|
|
|
|
|
results = await alert_engine.check_all_symbols(db)
|
|
|
|
|
|
|
|
|
|
|
|
# 统计触发数量
|
|
|
|
|
|
total_triggered = sum(len(alerts) for alerts in results.values())
|
|
|
|
|
|
logger.info(f"📊 AlertEngine 检查完成: 触发 {total_triggered} 条告警")
|
|
|
|
|
|
|
|
|
|
|
|
return results
|