You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

357 lines
11 KiB

# backend/app/services/alert_engine.py
"""
智能告警引擎
支持实时数据 + 规则 布尔值判断并行计算100+ 规则/用户
"""
import asyncio
import json
from typing import Dict, List, Optional
from datetime import datetime, time as dt_time
from sqlalchemy.orm import Session
from sqlalchemy import and_
from app.models.alert import AlertRule, AlertHistory, AlertType, AlertOperator
from app.services.cache_service import cache_service
import logging
logger = logging.getLogger(__name__)
class AlertEngine:
"""
告警引擎
功能:
- 规则缓存内存缓存用户规则
- 规则计算实时数据 + 规则 布尔值
- 并行计算asyncio.gather
- 触发判断是否满足告警条件
性能优化:
- 规则内存缓存
- 并行计算
- symbol 索引
"""
def __init__(self):
# user_id -> List[AlertRule] (规则缓存)
self.rules_cache: Dict[int, List[AlertRule]] = {}
# symbol -> List[AlertRule] (symbol 索引)
self.symbol_rules: Dict[str, List[AlertRule]] = {}
# 全局规则symbol 为空)
self.global_rules: List[AlertRule] = []
# 上次触发时间缓存 (rule_id -> last_triggered_time)
self.trigger_times: Dict[int, datetime] = {}
# 计算统计
self.total_checks = 0
self.total_triggers = 0
async def load_all_rules(self, db: Session):
"""
加载所有启用规则
Args:
db: 数据库会话
"""
rules = db.query(AlertRule).filter(AlertRule.enabled == True).all()
# 清空缓存
self.rules_cache.clear()
self.symbol_rules.clear()
self.global_rules.clear()
# 按用户和品种分组
for rule in rules:
# 按用户分组
if rule.user_id not in self.rules_cache:
self.rules_cache[rule.user_id] = []
self.rules_cache[rule.user_id].append(rule)
# 按品种分组
if rule.symbol:
if rule.symbol not in self.symbol_rules:
self.symbol_rules[rule.symbol] = []
self.symbol_rules[rule.symbol].append(rule)
else:
self.global_rules.append(rule)
logger.info(f"✅ AlertEngine 加载 {len(rules)} 条规则")
async def load_user_rules(self, db: Session, user_id: int):
"""
加载用户规则
Args:
db: 数据库会话
user_id: 用户 ID
"""
rules = db.query(AlertRule).filter(
and_(
AlertRule.user_id == user_id,
AlertRule.enabled == True
)
).all()
self.rules_cache[user_id] = rules
# 更新 symbol 索引
for rule in rules:
if rule.symbol:
if rule.symbol not in self.symbol_rules:
self.symbol_rules[rule.symbol] = []
if rule not in self.symbol_rules[rule.symbol]:
self.symbol_rules[rule.symbol].append(rule)
logger.info(f"✅ AlertEngine 加载用户 {user_id}{len(rules)} 条规则")
def evaluate_condition(self, condition: dict, current_value: float) -> bool:
"""
计算单个条件
Args:
condition: 条件字典 {"field": "price", "operator": "gt", "value": 3900}
current_value: 当前值
Returns:
bool: 是否满足条件
"""
operator = condition.get("operator")
threshold = condition.get("value", 0)
operators = {
"gt": lambda a, b: a > b,
"lt": lambda a, b: a < b,
"ge": lambda a, b: a >= b,
"le": lambda a, b: a <= b,
"eq": lambda a, b: abs(a - b) < 0.0001, # 浮点数相等判断
"ne": lambda a, b: abs(a - b) >= 0.0001,
}
op_func = operators.get(operator)
if not op_func:
logger.warning(f"⚠️ AlertEngine 未知的操作符: {operator}")
return False
return op_func(current_value, threshold)
def _get_field_value(self, data: dict, field: str) -> float:
"""
从数据中获取字段值
Args:
data: 数据字典
field: 字段名
Returns:
float: 字段值
"""
field_mapping = {
"price": "price",
"change": "change",
"change_percent": "change_percent",
"volume": "volume",
"high": "high",
"low": "low",
"open": "open",
"close": "close",
}
data_field = field_mapping.get(field, field)
return float(data.get(data_field, 0))
async def check_alert(self, db: Session, symbol: str, data: dict) -> List[dict]:
"""
检查品种的所有告警
Args:
db: 数据库会话
symbol: 品种代码
data: 行情数据
Returns:
List[dict]: 触发的告警列表
"""
triggered_alerts = []
# 获取该品种的所有规则
rules = self.symbol_rules.get(symbol, [])
# 加上全局规则
all_rules = rules + self.global_rules
# 并行检查所有规则
check_tasks = []
for rule in all_rules:
check_tasks.append(self._check_single_rule(db, rule, data))
# 并行执行
if check_tasks:
results = await asyncio.gather(*check_tasks)
triggered_alerts = [r for r in results if r is not None]
# 更新统计
self.total_checks += len(all_rules)
self.total_triggers += len(triggered_alerts)
return triggered_alerts
async def _check_single_rule(self, db: Session, rule: AlertRule, data: dict) -> Optional[dict]:
"""
检查单个规则
Args:
db: 数据库会话
rule: 告警规则
data: 行情数据
Returns:
Optional[dict]: 触发的告警信息
"""
try:
# 1. 检查生效时间
now = datetime.now()
current_time = now.time()
if rule.start_time and rule.end_time:
if not (rule.start_time <= current_time <= rule.end_time):
return None
# 2. 检查重复间隔
if rule.repeat_interval > 0:
last_triggered = self.trigger_times.get(rule.id)
if last_triggered:
elapsed = (now - last_triggered).total_seconds()
if elapsed < rule.repeat_interval:
return None
# 3. 获取当前值
condition = rule.condition
if isinstance(condition, str):
condition = json.loads(condition)
field = condition.get("field")
current_value = self._get_field_value(data, field)
# 4. 计算规则
if self.evaluate_condition(condition, current_value):
# 5. 触发告警
trigger_info = {
"rule": rule,
"rule_id": rule.id,
"user_id": rule.user_id,
"symbol": rule.symbol,
"name": rule.name,
"type": rule.type,
"trigger_value": current_value,
"trigger_condition": f"{field} {condition.get('operator')} {condition.get('value')}",
"trigger_time": now,
"channels": rule.channels if isinstance(rule.channels, list) else json.loads(rule.channels),
}
# 6. 记录触发时间
self.trigger_times[rule.id] = now
# 7. 创建历史记录
await self._create_history(db, rule, trigger_info)
return trigger_info
return None
except Exception as e:
logger.error(f"❌ AlertEngine 检查规则 {rule.id} 失败: {e}")
return None
async def _create_history(self, db: Session, rule: AlertRule, trigger_info: dict):
"""
创建告警历史
Args:
db: 数据库会话
rule: 告警规则
trigger_info: 触发信息
"""
try:
history = AlertHistory(
rule_id=rule.id,
user_id=rule.user_id,
symbol=rule.symbol,
trigger_value=trigger_info["trigger_value"],
trigger_condition=trigger_info["trigger_condition"],
notified=False,
trigger_time=datetime.now()
)
db.add(history)
# 更新规则触发次数
rule.trigger_count += 1
rule.last_triggered_at = datetime.now()
db.commit()
except Exception as e:
logger.error(f"❌ AlertEngine 创建历史失败: {e}")
db.rollback()
def get_statistics(self) -> dict:
"""获取统计信息"""
return {
"total_checks": self.total_checks,
"total_triggers": self.total_triggers,
"cached_users": len(self.rules_cache),
"cached_rules": sum(len(rules) for rules in self.rules_cache.values()),
"symbol_count": len(self.symbol_rules),
}
async def check_all_symbols(self, db: Session) -> Dict[str, List[dict]]:
"""
检查所有品种
Args:
db: 数据库会话
Returns:
Dict[str, List[dict]]: 各品种触发的告警
"""
results = {}
for symbol in self.symbol_rules.keys():
# 从缓存获取最新行情
quote_data = await cache_service.get_latest_quote(symbol)
if quote_data:
triggered = await self.check_alert(db, symbol, quote_data)
if triggered:
results[symbol] = triggered
return results
# 全局告警引擎实例
alert_engine = AlertEngine()
# ============== 定时任务 ==============
async def alert_checker_task(db: Session):
"""
告警检查定时任务
每分钟执行一次
"""
logger.info("🔄 AlertEngine 开始检查...")
# 加载所有规则
await alert_engine.load_all_rules(db)
# 检查所有品种
results = await alert_engine.check_all_symbols(db)
# 统计触发数量
total_triggered = sum(len(alerts) for alerts in results.values())
logger.info(f"📊 AlertEngine 检查完成: 触发 {total_triggered} 条告警")
return results