# backend/app/services/alert_engine.py """ 智能告警引擎 支持实时数据 + 规则 → 布尔值判断,并行计算,100+ 规则/用户 """ import asyncio import json from typing import Dict, List, Optional from datetime import datetime, time as dt_time from sqlalchemy.orm import Session from sqlalchemy import and_ from app.models.alert import AlertRule, AlertHistory, AlertType, AlertOperator from app.services.cache_service import cache_service import logging logger = logging.getLogger(__name__) class AlertEngine: """ 告警引擎 功能: - 规则缓存(内存缓存用户规则) - 规则计算(实时数据 + 规则 → 布尔值) - 并行计算(asyncio.gather) - 触发判断(是否满足告警条件) 性能优化: - 规则内存缓存 - 并行计算 - symbol 索引 """ def __init__(self): # user_id -> List[AlertRule] (规则缓存) self.rules_cache: Dict[int, List[AlertRule]] = {} # symbol -> List[AlertRule] (symbol 索引) self.symbol_rules: Dict[str, List[AlertRule]] = {} # 全局规则(symbol 为空) self.global_rules: List[AlertRule] = [] # 上次触发时间缓存 (rule_id -> last_triggered_time) self.trigger_times: Dict[int, datetime] = {} # 计算统计 self.total_checks = 0 self.total_triggers = 0 async def load_all_rules(self, db: Session): """ 加载所有启用规则 Args: db: 数据库会话 """ rules = db.query(AlertRule).filter(AlertRule.enabled == True).all() # 清空缓存 self.rules_cache.clear() self.symbol_rules.clear() self.global_rules.clear() # 按用户和品种分组 for rule in rules: # 按用户分组 if rule.user_id not in self.rules_cache: self.rules_cache[rule.user_id] = [] self.rules_cache[rule.user_id].append(rule) # 按品种分组 if rule.symbol: if rule.symbol not in self.symbol_rules: self.symbol_rules[rule.symbol] = [] self.symbol_rules[rule.symbol].append(rule) else: self.global_rules.append(rule) logger.info(f"✅ AlertEngine 加载 {len(rules)} 条规则") async def load_user_rules(self, db: Session, user_id: int): """ 加载用户规则 Args: db: 数据库会话 user_id: 用户 ID """ rules = db.query(AlertRule).filter( and_( AlertRule.user_id == user_id, AlertRule.enabled == True ) ).all() self.rules_cache[user_id] = rules # 更新 symbol 索引 for rule in rules: if rule.symbol: if rule.symbol not in self.symbol_rules: self.symbol_rules[rule.symbol] = [] if rule not in self.symbol_rules[rule.symbol]: self.symbol_rules[rule.symbol].append(rule) logger.info(f"✅ AlertEngine 加载用户 {user_id} 的 {len(rules)} 条规则") def evaluate_condition(self, condition: dict, current_value: float) -> bool: """ 计算单个条件 Args: condition: 条件字典 {"field": "price", "operator": "gt", "value": 3900} current_value: 当前值 Returns: bool: 是否满足条件 """ operator = condition.get("operator") threshold = condition.get("value", 0) operators = { "gt": lambda a, b: a > b, "lt": lambda a, b: a < b, "ge": lambda a, b: a >= b, "le": lambda a, b: a <= b, "eq": lambda a, b: abs(a - b) < 0.0001, # 浮点数相等判断 "ne": lambda a, b: abs(a - b) >= 0.0001, } op_func = operators.get(operator) if not op_func: logger.warning(f"⚠️ AlertEngine 未知的操作符: {operator}") return False return op_func(current_value, threshold) def _get_field_value(self, data: dict, field: str) -> float: """ 从数据中获取字段值 Args: data: 数据字典 field: 字段名 Returns: float: 字段值 """ field_mapping = { "price": "price", "change": "change", "change_percent": "change_percent", "volume": "volume", "high": "high", "low": "low", "open": "open", "close": "close", } data_field = field_mapping.get(field, field) return float(data.get(data_field, 0)) async def check_alert(self, db: Session, symbol: str, data: dict) -> List[dict]: """ 检查品种的所有告警 Args: db: 数据库会话 symbol: 品种代码 data: 行情数据 Returns: List[dict]: 触发的告警列表 """ triggered_alerts = [] # 获取该品种的所有规则 rules = self.symbol_rules.get(symbol, []) # 加上全局规则 all_rules = rules + self.global_rules # 并行检查所有规则 check_tasks = [] for rule in all_rules: check_tasks.append(self._check_single_rule(db, rule, data)) # 并行执行 if check_tasks: results = await asyncio.gather(*check_tasks) triggered_alerts = [r for r in results if r is not None] # 更新统计 self.total_checks += len(all_rules) self.total_triggers += len(triggered_alerts) return triggered_alerts async def _check_single_rule(self, db: Session, rule: AlertRule, data: dict) -> Optional[dict]: """ 检查单个规则 Args: db: 数据库会话 rule: 告警规则 data: 行情数据 Returns: Optional[dict]: 触发的告警信息 """ try: # 1. 检查生效时间 now = datetime.now() current_time = now.time() if rule.start_time and rule.end_time: if not (rule.start_time <= current_time <= rule.end_time): return None # 2. 检查重复间隔 if rule.repeat_interval > 0: last_triggered = self.trigger_times.get(rule.id) if last_triggered: elapsed = (now - last_triggered).total_seconds() if elapsed < rule.repeat_interval: return None # 3. 获取当前值 condition = rule.condition if isinstance(condition, str): condition = json.loads(condition) field = condition.get("field") current_value = self._get_field_value(data, field) # 4. 计算规则 if self.evaluate_condition(condition, current_value): # 5. 触发告警 trigger_info = { "rule": rule, "rule_id": rule.id, "user_id": rule.user_id, "symbol": rule.symbol, "name": rule.name, "type": rule.type, "trigger_value": current_value, "trigger_condition": f"{field} {condition.get('operator')} {condition.get('value')}", "trigger_time": now, "channels": rule.channels if isinstance(rule.channels, list) else json.loads(rule.channels), } # 6. 记录触发时间 self.trigger_times[rule.id] = now # 7. 创建历史记录 await self._create_history(db, rule, trigger_info) return trigger_info return None except Exception as e: logger.error(f"❌ AlertEngine 检查规则 {rule.id} 失败: {e}") return None async def _create_history(self, db: Session, rule: AlertRule, trigger_info: dict): """ 创建告警历史 Args: db: 数据库会话 rule: 告警规则 trigger_info: 触发信息 """ try: history = AlertHistory( rule_id=rule.id, user_id=rule.user_id, symbol=rule.symbol, trigger_value=trigger_info["trigger_value"], trigger_condition=trigger_info["trigger_condition"], notified=False, trigger_time=datetime.now() ) db.add(history) # 更新规则触发次数 rule.trigger_count += 1 rule.last_triggered_at = datetime.now() db.commit() except Exception as e: logger.error(f"❌ AlertEngine 创建历史失败: {e}") db.rollback() def get_statistics(self) -> dict: """获取统计信息""" return { "total_checks": self.total_checks, "total_triggers": self.total_triggers, "cached_users": len(self.rules_cache), "cached_rules": sum(len(rules) for rules in self.rules_cache.values()), "symbol_count": len(self.symbol_rules), } async def check_all_symbols(self, db: Session) -> Dict[str, List[dict]]: """ 检查所有品种 Args: db: 数据库会话 Returns: Dict[str, List[dict]]: 各品种触发的告警 """ results = {} for symbol in self.symbol_rules.keys(): # 从缓存获取最新行情 quote_data = await cache_service.get_latest_quote(symbol) if quote_data: triggered = await self.check_alert(db, symbol, quote_data) if triggered: results[symbol] = triggered return results # 全局告警引擎实例 alert_engine = AlertEngine() # ============== 定时任务 ============== async def alert_checker_task(db: Session): """ 告警检查定时任务 每分钟执行一次 """ logger.info("🔄 AlertEngine 开始检查...") # 加载所有规则 await alert_engine.load_all_rules(db) # 检查所有品种 results = await alert_engine.check_all_symbols(db) # 统计触发数量 total_triggered = sum(len(alerts) for alerts in results.values()) logger.info(f"📊 AlertEngine 检查完成: 触发 {total_triggered} 条告警") return results