You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

357 lines
11 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# backend/app/services/alert_engine.py
"""
智能告警引擎
支持实时数据 + 规则 → 布尔值判断并行计算100+ 规则/用户
"""
import asyncio
import json
from typing import Dict, List, Optional
from datetime import datetime, time as dt_time
from sqlalchemy.orm import Session
from sqlalchemy import and_
from app.models.alert import AlertRule, AlertHistory, AlertType, AlertOperator
from app.services.cache_service import cache_service
import logging
logger = logging.getLogger(__name__)
class AlertEngine:
"""
告警引擎
功能:
- 规则缓存(内存缓存用户规则)
- 规则计算(实时数据 + 规则 → 布尔值)
- 并行计算asyncio.gather
- 触发判断(是否满足告警条件)
性能优化:
- 规则内存缓存
- 并行计算
- symbol 索引
"""
def __init__(self):
# user_id -> List[AlertRule] (规则缓存)
self.rules_cache: Dict[int, List[AlertRule]] = {}
# symbol -> List[AlertRule] (symbol 索引)
self.symbol_rules: Dict[str, List[AlertRule]] = {}
# 全局规则symbol 为空)
self.global_rules: List[AlertRule] = []
# 上次触发时间缓存 (rule_id -> last_triggered_time)
self.trigger_times: Dict[int, datetime] = {}
# 计算统计
self.total_checks = 0
self.total_triggers = 0
async def load_all_rules(self, db: Session):
"""
加载所有启用规则
Args:
db: 数据库会话
"""
rules = db.query(AlertRule).filter(AlertRule.enabled == True).all()
# 清空缓存
self.rules_cache.clear()
self.symbol_rules.clear()
self.global_rules.clear()
# 按用户和品种分组
for rule in rules:
# 按用户分组
if rule.user_id not in self.rules_cache:
self.rules_cache[rule.user_id] = []
self.rules_cache[rule.user_id].append(rule)
# 按品种分组
if rule.symbol:
if rule.symbol not in self.symbol_rules:
self.symbol_rules[rule.symbol] = []
self.symbol_rules[rule.symbol].append(rule)
else:
self.global_rules.append(rule)
logger.info(f"✅ AlertEngine 加载 {len(rules)} 条规则")
async def load_user_rules(self, db: Session, user_id: int):
"""
加载用户规则
Args:
db: 数据库会话
user_id: 用户 ID
"""
rules = db.query(AlertRule).filter(
and_(
AlertRule.user_id == user_id,
AlertRule.enabled == True
)
).all()
self.rules_cache[user_id] = rules
# 更新 symbol 索引
for rule in rules:
if rule.symbol:
if rule.symbol not in self.symbol_rules:
self.symbol_rules[rule.symbol] = []
if rule not in self.symbol_rules[rule.symbol]:
self.symbol_rules[rule.symbol].append(rule)
logger.info(f"✅ AlertEngine 加载用户 {user_id}{len(rules)} 条规则")
def evaluate_condition(self, condition: dict, current_value: float) -> bool:
"""
计算单个条件
Args:
condition: 条件字典 {"field": "price", "operator": "gt", "value": 3900}
current_value: 当前值
Returns:
bool: 是否满足条件
"""
operator = condition.get("operator")
threshold = condition.get("value", 0)
operators = {
"gt": lambda a, b: a > b,
"lt": lambda a, b: a < b,
"ge": lambda a, b: a >= b,
"le": lambda a, b: a <= b,
"eq": lambda a, b: abs(a - b) < 0.0001, # 浮点数相等判断
"ne": lambda a, b: abs(a - b) >= 0.0001,
}
op_func = operators.get(operator)
if not op_func:
logger.warning(f"⚠️ AlertEngine 未知的操作符: {operator}")
return False
return op_func(current_value, threshold)
def _get_field_value(self, data: dict, field: str) -> float:
"""
从数据中获取字段值
Args:
data: 数据字典
field: 字段名
Returns:
float: 字段值
"""
field_mapping = {
"price": "price",
"change": "change",
"change_percent": "change_percent",
"volume": "volume",
"high": "high",
"low": "low",
"open": "open",
"close": "close",
}
data_field = field_mapping.get(field, field)
return float(data.get(data_field, 0))
async def check_alert(self, db: Session, symbol: str, data: dict) -> List[dict]:
"""
检查品种的所有告警
Args:
db: 数据库会话
symbol: 品种代码
data: 行情数据
Returns:
List[dict]: 触发的告警列表
"""
triggered_alerts = []
# 获取该品种的所有规则
rules = self.symbol_rules.get(symbol, [])
# 加上全局规则
all_rules = rules + self.global_rules
# 并行检查所有规则
check_tasks = []
for rule in all_rules:
check_tasks.append(self._check_single_rule(db, rule, data))
# 并行执行
if check_tasks:
results = await asyncio.gather(*check_tasks)
triggered_alerts = [r for r in results if r is not None]
# 更新统计
self.total_checks += len(all_rules)
self.total_triggers += len(triggered_alerts)
return triggered_alerts
async def _check_single_rule(self, db: Session, rule: AlertRule, data: dict) -> Optional[dict]:
"""
检查单个规则
Args:
db: 数据库会话
rule: 告警规则
data: 行情数据
Returns:
Optional[dict]: 触发的告警信息
"""
try:
# 1. 检查生效时间
now = datetime.now()
current_time = now.time()
if rule.start_time and rule.end_time:
if not (rule.start_time <= current_time <= rule.end_time):
return None
# 2. 检查重复间隔
if rule.repeat_interval > 0:
last_triggered = self.trigger_times.get(rule.id)
if last_triggered:
elapsed = (now - last_triggered).total_seconds()
if elapsed < rule.repeat_interval:
return None
# 3. 获取当前值
condition = rule.condition
if isinstance(condition, str):
condition = json.loads(condition)
field = condition.get("field")
current_value = self._get_field_value(data, field)
# 4. 计算规则
if self.evaluate_condition(condition, current_value):
# 5. 触发告警
trigger_info = {
"rule": rule,
"rule_id": rule.id,
"user_id": rule.user_id,
"symbol": rule.symbol,
"name": rule.name,
"type": rule.type,
"trigger_value": current_value,
"trigger_condition": f"{field} {condition.get('operator')} {condition.get('value')}",
"trigger_time": now,
"channels": rule.channels if isinstance(rule.channels, list) else json.loads(rule.channels),
}
# 6. 记录触发时间
self.trigger_times[rule.id] = now
# 7. 创建历史记录
await self._create_history(db, rule, trigger_info)
return trigger_info
return None
except Exception as e:
logger.error(f"❌ AlertEngine 检查规则 {rule.id} 失败: {e}")
return None
async def _create_history(self, db: Session, rule: AlertRule, trigger_info: dict):
"""
创建告警历史
Args:
db: 数据库会话
rule: 告警规则
trigger_info: 触发信息
"""
try:
history = AlertHistory(
rule_id=rule.id,
user_id=rule.user_id,
symbol=rule.symbol,
trigger_value=trigger_info["trigger_value"],
trigger_condition=trigger_info["trigger_condition"],
notified=False,
trigger_time=datetime.now()
)
db.add(history)
# 更新规则触发次数
rule.trigger_count += 1
rule.last_triggered_at = datetime.now()
db.commit()
except Exception as e:
logger.error(f"❌ AlertEngine 创建历史失败: {e}")
db.rollback()
def get_statistics(self) -> dict:
"""获取统计信息"""
return {
"total_checks": self.total_checks,
"total_triggers": self.total_triggers,
"cached_users": len(self.rules_cache),
"cached_rules": sum(len(rules) for rules in self.rules_cache.values()),
"symbol_count": len(self.symbol_rules),
}
async def check_all_symbols(self, db: Session) -> Dict[str, List[dict]]:
"""
检查所有品种
Args:
db: 数据库会话
Returns:
Dict[str, List[dict]]: 各品种触发的告警
"""
results = {}
for symbol in self.symbol_rules.keys():
# 从缓存获取最新行情
quote_data = await cache_service.get_latest_quote(symbol)
if quote_data:
triggered = await self.check_alert(db, symbol, quote_data)
if triggered:
results[symbol] = triggered
return results
# 全局告警引擎实例
alert_engine = AlertEngine()
# ============== 定时任务 ==============
async def alert_checker_task(db: Session):
"""
告警检查定时任务
每分钟执行一次
"""
logger.info("🔄 AlertEngine 开始检查...")
# 加载所有规则
await alert_engine.load_all_rules(db)
# 检查所有品种
results = await alert_engine.check_all_symbols(db)
# 统计触发数量
total_triggered = sum(len(alerts) for alerts in results.values())
logger.info(f"📊 AlertEngine 检查完成: 触发 {total_triggered} 条告警")
return results