diff --git a/FIXES_FUTURES_KLINES.md b/FIXES_FUTURES_KLINES.md new file mode 100644 index 0000000..1a9aff3 --- /dev/null +++ b/FIXES_FUTURES_KLINES.md @@ -0,0 +1,159 @@ +# 期货K线接口修复报告 + +## 问题描述 + +期货K线接口 `/v1/futures/klines/{symbol}` 返回的 `items` 为空列表。 + +## 根本原因 + +代码中**硬编码**了使用 `amazingdata` 适配器,但配置文件中配置的是 `custom` 适配器。导致: + +1. 配置文件中 `sources.futures.active = "custom"` +2. 但代码中尝试连接 `amazingdata` 适配器 +3. `_connect_adapter` 方法中尝试从 `file_config.sources.stock.list["amazingdata"]` 获取配置 +4. 配置中不存在 `amazingdata`,导致 `KeyError: 'amazingdata'` +5. 异常被捕获后返回空列表 + +## 修复内容 + +### 1. 修复硬编码适配器名称 + +**修改文件:** +- `app/services/futures_service.py` +- `app/services/stock_service.py` + +**修改内容:** +将以下代码: +```python +if not adapter: + loop.run_until_complete(adapter_service._connect_adapter("amazingdata")) +``` + +改为: +```python +if not adapter: + # 从配置获取当前激活的适配器名称 + from app.core.config import get_config + config = get_config() + active_source = config.sources.futures.active # 或 config.sources.stock.active + + info(f"Connecting to configured adapter: {active_source}") + loop.run_until_complete(adapter_service._connect_adapter(active_source)) +``` + +### 2. 修复适配器配置获取逻辑 + +**修改文件:** +- `app/services/adapter_service.py` + +**修改内容:** +将以下代码: +```python +if name == "amazingdata": + source_info = file_config.sources.stock.list["amazingdata"] + adapter_config = dict(source_info.config) if source_info else {} +else: + adapter_config = self.configs[name].get("config", {}) +``` + +改为: +```python +# 尝试从配置文件中获取适配器配置 +adapter_config = None + +# 1. 首先检查 stock 配置 +if name in file_config.sources.stock.list: + source_info = file_config.sources.stock.list[name] + adapter_config = dict(source_info.config) if source_info else {} + +# 2. 然后检查 futures 配置 +elif name in file_config.sources.futures.list: + source_info = file_config.sources.futures.list[name] + adapter_config = dict(source_info.config) if source_info else {} + +# 3. 使用默认配置 +else: + adapter_config = self.configs[name].get("config", {}) +``` + +### 3. 修复期货仓库的频率映射 + +**修改文件:** +- `app/repositories/futures_repository.py` + +**修改内容:** +添加了对其他频率的映射(虽然数据库只支持1分钟和日线,但避免KeyError): +```python +def _get_kline_model(self, freq: Frequency): + mapping = { + Frequency.FREQ_1M: FuturesKLine1M, + Frequency.FREQ_1D: FuturesKLine1D, + Frequency.FREQ_5M: FuturesKLine1D, # 默认使用日线 + Frequency.FREQ_15M: FuturesKLine1D, + Frequency.FREQ_30M: FuturesKLine1D, + Frequency.FREQ_60M: FuturesKLine1D, + Frequency.FREQ_1W: FuturesKLine1D, + Frequency.FREQ_1MONTH: FuturesKLine1D, + } + return mapping.get(freq, FuturesKLine1D) +``` + +## 当前状态 + +API现在可以正常返回响应: +```json +{ + "code": 0, + "message": "success", + "data": { + "symbol": "CU2504.SHFE", + "name": null, + "freq": "1d", + "adjust": "", + "count": 0, + "items": [] + } +} +``` + +`items` 为空是因为: +1. 数据库中没有数据 +2. 配置的 `custom` 适配器未注册(只注册了 `amazingdata`) + +## 使用建议 + +要使接口返回实际数据,需要: + +1. **配置 AmazingData 适配器:** + 修改 `config.json`: + ```json + { + "sources": { + "futures": { + "active": "amazingdata", + "list": { + "amazingdata": { + "type": "sdk", + "config": { + "username": "your_username", + "password": "your_password", + "host": "your_host", + "port": "8600" + } + } + } + } + } + } + ``` + +2. **安装 AmazingData SDK:** + ```bash + pip install AmazingData tgw + ``` + +3. **或者注册自定义适配器:** + 在 `AdapterService._register_builtin_adapters()` 中添加: + ```python + self.register_adapter("custom", lambda: YourCustomAdapter()) + ``` diff --git a/IMPROVEMENTS.md b/IMPROVEMENTS.md new file mode 100644 index 0000000..5e40f42 --- /dev/null +++ b/IMPROVEMENTS.md @@ -0,0 +1,276 @@ +# 系统完善报告 + +## 概述 + +本次系统完善共完成了6个主要功能的开发和改进。 + +## 已完成的功能 + +### 1. 股票复权计算功能 ✅ + +**文件修改:** +- `app/repositories/models.py` - 添加 `StockAdjustFactor` 复权系数表 +- `app/repositories/stock_repository.py` - 添加复权系数查询和保存方法 +- `app/services/stock_service.py` - 实现复权计算逻辑 + +**功能说明:** +- 支持前复权(qfq)和后复权(hfq)计算 +- 复权系数自动从数据源获取并缓存到数据库 +- 支持价格、成交量的复权调整 +- 保留原始复权系数在K线数据中 + +**技术实现:** +- 前复权:以最新价格为基准,历史价格按比例缩小 +- 后复权:以历史最早价格为基准,后续价格按比例放大 + +--- + +### 2. Prometheus指标暴露端点 ✅ + +**新增文件:** +- `app/core/metrics.py` - 指标收集模块 + +**文件修改:** +- `app/main.py` - 添加指标中间件和端点 +- `requirements.txt` - 添加 prometheus-client 依赖 + +**功能说明:** +- HTTP请求计数和持续时间监控 +- 活跃请求数跟踪 +- 数据库操作性能监控 +- 数据源健康状态监控 +- WebSocket连接数监控 +- 缓存命中率监控 + +**暴露端点:** +``` +GET /metrics - Prometheus格式的指标数据 +``` + +**指标列表:** +| 指标名 | 类型 | 说明 | +|--------|------|------| +| http_requests_total | Counter | HTTP请求总数 | +| http_request_duration_seconds | Histogram | HTTP请求持续时间 | +| http_requests_active | Gauge | 活跃请求数 | +| api_calls_total | Counter | API调用总数 | +| db_operation_duration_seconds | Histogram | 数据库操作持续时间 | +| data_source_status | Gauge | 数据源健康状态 | +| websocket_connections | Gauge | WebSocket连接数 | +| websocket_messages_total | Counter | WebSocket消息总数 | + +--- + +### 3. 应用层限流功能 ✅ + +**新增文件:** +- `app/core/rate_limiter.py` - 限流模块 + +**文件修改:** +- `app/main.py` - 添加限流中间件 + +**功能说明:** +- 支持三种限流算法:固定窗口、滑动窗口、令牌桶 +- 基于客户端IP + 路径的限流key +- 可配置的请求速率和突发容量 +- 自动清理过期数据 + +**默认配置:** +```python +RateLimitConfig( + requests_per_minute=120, # 每分钟120请求 + burst_size=20, # 突发20请求 + strategy="sliding_window" # 滑动窗口算法 +) +``` + +**响应头:** +``` +X-RateLimit-Limit: 120 +X-RateLimit-Remaining: 119 +X-RateLimit-Reset: 1700000000 +Retry-After: 60 # 限流时返回 +``` + +--- + +### 4. 监控告警通道 ✅ + +**新增文件:** +- `app/monitor/alert_channels.py` - 告警通道模块 + +**文件修改:** +- `app/monitor/__init__.py` - 导出告警类 +- `app/monitor/monitor.py` - 集成新的告警管理器 + +**支持的告警通道:** +| 通道 | 类型 | 说明 | +|------|------|------| +| LogAlertChannel | 日志 | 默认日志输出 | +| DingTalkAlertChannel | 钉钉 | 钉钉机器人webhook | +| EmailAlertChannel | 邮件 | SMTP邮件发送 | +| WebhookAlertChannel | Webhook | 自定义HTTP回调 | + +**功能特性:** +- 支持消息路由(按告警级别) +- 支持批量发送 +- Markdown格式的钉钉消息 +- HTML格式的邮件内容 +- 可扩展的架构 + +**使用示例:** +```python +from app.monitor import get_alert_manager + +# 发送告警 +await get_alert_manager().send_simple( + title="数据缺失告警", + content="股票000001.SZ数据缺失", + level="warning" +) +``` + +--- + +### 5. 修复已知问题 ✅ + +**修复内容:** +1. 添加缺失的 `Response` 导入到 `rate_limiter.py` +2. 修复 `app/monitor/__init__.py` 中已删除类的引用 +3. 更新 `requirements.txt` 添加 `prometheus-client` +4. 安装缺失的依赖包 + +--- + +### 6. 服务重启功能 ✅ + +**文件修改:** +- `app/api/admin_routes.py` - 实现重启逻辑 + +**功能说明:** +- 延迟2秒后重启,确保当前响应返回 +- 支持Windows和Linux/Mac系统 +- 在后台线程中执行,不阻塞API响应 + +**使用方式:** +```bash +POST /v1/admin/system/restart +``` + +**注意:** 生产环境建议使用Docker或systemd管理服务生命周期 + +--- + +## 配置文件更新建议 + +### 添加告警配置到 config.json: + +```json +{ + "alert": { + "log": { + "enabled": true + }, + "dingtalk": { + "enabled": false, + "webhook_url": "https://oapi.dingtalk.com/robot/send?access_token=xxx", + "secret": "your-secret", + "at_mobiles": ["13800138000"], + "at_all": false + }, + "email": { + "enabled": false, + "smtp_host": "smtp.example.com", + "smtp_port": 587, + "username": "alert@example.com", + "password": "your-password", + "from_addr": "alert@example.com", + "to_addrs": ["admin@example.com"], + "use_tls": true + }, + "routing": { + "info": ["log"], + "warning": ["log", "dingtalk"], + "error": ["log", "dingtalk", "email"], + "critical": ["log", "dingtalk", "email"] + } + } +} +``` + +--- + +## API端点更新 + +### 新增端点: + +| 端点 | 方法 | 说明 | +|------|------|------| +| `/metrics` | GET | Prometheus指标数据 | +| `/admin/system/restart` | POST | 重启服务 | + +### 限流保护端点: + +所有 `/v1/*` 端点(除 `/health`, `/metrics`, `/docs` 等外)都受到限流保护。 + +--- + +## 系统架构图 + +``` +┌─────────────────────────────────────────────────────────┐ +│ FastAPI Application │ +├─────────────────────────────────────────────────────────┤ +│ CORS Middleware │ +│ Metrics Middleware (Prometheus) │ +│ Rate Limit Middleware (120 req/min) │ +├─────────────────────────────────────────────────────────┤ +│ Routes: │ +│ /v1/stock/* - 股票接口 │ +│ /v1/futures/* - 期货接口 │ +│ /v1/admin/* - 管理接口 │ +│ /v1/stream - WebSocket │ +│ /metrics - 指标端点 │ +│ /admin - 管理后台UI │ +├─────────────────────────────────────────────────────────┤ +│ Services: │ +│ StockService - 复权计算 ✅ │ +│ FuturesService - 期货业务 │ +│ AdminService - 管理功能 │ +│ AdapterService - 数据源适配 │ +│ AlertManager - 告警管理 ✅ │ +├─────────────────────────────────────────────────────────┤ +│ Repositories: │ +│ StockRepository - 复权系数表 ✅ │ +│ FuturesRepository - 期货数据 │ +├─────────────────────────────────────────────────────────┤ +│ Data Sources: │ +│ AmazingDataAdapter - 星耀数智 │ +└─────────────────────────────────────────────────────────┘ +``` + +--- + +## 后续建议 + +1. **Prometheus集成** + - 部署Prometheus服务器抓取 `/metrics` 端点 + - 配置Grafana仪表板展示指标 + +2. **告警规则配置** + - 配置告警路由规则 + - 设置钉钉/邮件通道参数 + +3. **性能优化** + - 添加Redis缓存层 + - 实现数据库连接池监控 + +4. **安全性增强** + - 实现API Key验证逻辑 + - 添加请求签名验证 + +--- + +## 完成时间 + +2026-03-14 diff --git a/app/__pycache__/main.cpython-311.pyc b/app/__pycache__/main.cpython-311.pyc index 2a5ef01..1d43a2b 100644 Binary files a/app/__pycache__/main.cpython-311.pyc and b/app/__pycache__/main.cpython-311.pyc differ diff --git a/app/adapters/__pycache__/amazingdata_adapter.cpython-311.pyc b/app/adapters/__pycache__/amazingdata_adapter.cpython-311.pyc index 09aca4c..37ea79f 100644 Binary files a/app/adapters/__pycache__/amazingdata_adapter.cpython-311.pyc and b/app/adapters/__pycache__/amazingdata_adapter.cpython-311.pyc differ diff --git a/app/adapters/amazingdata_adapter.py b/app/adapters/amazingdata_adapter.py index a796bbb..f3e4bb0 100644 --- a/app/adapters/amazingdata_adapter.py +++ b/app/adapters/amazingdata_adapter.py @@ -160,6 +160,14 @@ class AmazingDataAdapter(DataSourceAdapter): def _do_login(self): """执行登录(同步方法)""" + # 检查配置是否完整 + if not self.config.username or not self.config.password or not self.config.host: + raise RuntimeError( + f"AmazingData 配置不完整: username={self.config.username}, " + f"host={self.config.host}, port={self.config.port}. " + f"请在 config.json 中配置正确的账号信息" + ) + print("[amazingdata_adapter]正在登录 AmazingData...") print(f"[amazingdata_adapter]登录用户: {self.config.username}") print(f"[amazingdata_adapter]登录地址: {self.config.host}:{self.config.port}") diff --git a/app/api/__pycache__/admin_routes.cpython-311.pyc b/app/api/__pycache__/admin_routes.cpython-311.pyc index e7f8d44..db793b8 100644 Binary files a/app/api/__pycache__/admin_routes.cpython-311.pyc and b/app/api/__pycache__/admin_routes.cpython-311.pyc differ diff --git a/app/api/admin_routes.py b/app/api/admin_routes.py index 48c43cf..39ffb3f 100644 --- a/app/api/admin_routes.py +++ b/app/api/admin_routes.py @@ -59,12 +59,51 @@ def reload_config( def restart_service( token: str = Depends(verify_admin_token) ): - """重启服务""" - # TODO: 实现服务重启逻辑 + """重启服务 + + 注意: 此方法通过创建子进程实现服务重启,适用于开发环境。 + 生产环境建议使用Docker或systemd管理服务生命周期。 + """ + import os + import sys + import subprocess + import threading + import time + + def delayed_restart(): + """延迟重启函数""" + time.sleep(2) # 等待当前响应返回 + + # 获取当前Python解释器和启动参数 + python = sys.executable + args = sys.argv[:] + + # 在Windows上使用start命令,在Linux上使用nohup + if os.name == 'nt': # Windows + subprocess.Popen( + ['start', 'python'] + args, + shell=True, + creationflags=subprocess.CREATE_NEW_CONSOLE + ) + else: # Linux/Mac + subprocess.Popen( + [python] + args, + stdout=open('/dev/null', 'w'), + stderr=open('/dev/null', 'w'), + start_new_session=True + ) + + # 退出当前进程 + os._exit(0) + + # 在后台线程中执行重启 + restart_thread = threading.Thread(target=delayed_restart, daemon=True) + restart_thread.start() + return Response( code=0, - message="重启命令已发送", - data={"status": "restarting"} + message="服务将在2秒后重启", + data={"status": "restarting", "delay_seconds": 2} ) diff --git a/app/core/__pycache__/metrics.cpython-311.pyc b/app/core/__pycache__/metrics.cpython-311.pyc new file mode 100644 index 0000000..01eaa79 Binary files /dev/null and b/app/core/__pycache__/metrics.cpython-311.pyc differ diff --git a/app/core/__pycache__/rate_limiter.cpython-311.pyc b/app/core/__pycache__/rate_limiter.cpython-311.pyc new file mode 100644 index 0000000..064311c Binary files /dev/null and b/app/core/__pycache__/rate_limiter.cpython-311.pyc differ diff --git a/app/core/metrics.py b/app/core/metrics.py new file mode 100644 index 0000000..1f3b5b3 --- /dev/null +++ b/app/core/metrics.py @@ -0,0 +1,221 @@ +"""Prometheus指标收集模块""" +from contextvars import ContextVar +from typing import Callable, Optional +import time + +from prometheus_client import Counter, Histogram, Gauge, Info, generate_latest, CONTENT_TYPE_LATEST +from fastapi import Request, Response +from starlette.middleware.base import BaseHTTPMiddleware + +# ============================================ +# 定义指标 +# ============================================ + +# HTTP请求计数器 +http_requests_total = Counter( + 'http_requests_total', + 'Total HTTP requests', + ['method', 'endpoint', 'status_code'] +) + +# HTTP请求持续时间 +http_request_duration_seconds = Histogram( + 'http_request_duration_seconds', + 'HTTP request duration in seconds', + ['method', 'endpoint'], + buckets=[.005, .01, .025, .05, .075, .1, .25, .5, .75, 1.0, 2.5, 5.0, 7.5, 10.0] +) + +# 活跃请求数 +http_requests_active = Gauge( + 'http_requests_active', + 'Active HTTP requests', + ['method'] +) + +# API调用计数器(按业务分类) +api_calls_total = Counter( + 'api_calls_total', + 'Total API calls by category', + ['category', 'operation'] +) + +# 数据库操作持续时间 +db_operation_duration_seconds = Histogram( + 'db_operation_duration_seconds', + 'Database operation duration', + ['operation', 'table'], + buckets=[.001, .005, .01, .025, .05, .1, .25, .5, 1.0] +) + +# 数据源状态 +data_source_status = Gauge( + 'data_source_status', + 'Data source health status (1=healthy, 0=unhealthy)', + ['source', 'asset_class'] +) + +# WebSocket连接数 +websocket_connections = Gauge( + 'websocket_connections', + 'Number of active WebSocket connections' +) + +# WebSocket消息计数器 +websocket_messages_total = Counter( + 'websocket_messages_total', + 'Total WebSocket messages', + ['direction'] # 'in' or 'out' +) + +# 缓存命中率 +cache_hit_ratio = Gauge( + 'cache_hit_ratio', + 'Cache hit ratio', + ['cache_type'] +) + +# 应用信息 +app_info = Info( + 'market_data_service', + 'Application information' +) + +# ============================================ +# 上下文变量 +# ============================================ + +# 用于存储请求开始时间 +request_start_time: ContextVar[Optional[float]] = ContextVar('request_start_time', default=None) + + +# ============================================ +# 指标收集函数 +# ============================================ + +def record_http_request(method: str, endpoint: str, status_code: int, duration: float): + """记录HTTP请求指标""" + http_requests_total.labels( + method=method, + endpoint=endpoint, + status_code=status_code + ).inc() + + http_request_duration_seconds.labels( + method=method, + endpoint=endpoint + ).observe(duration) + + +def record_api_call(category: str, operation: str): + """记录API调用""" + api_calls_total.labels( + category=category, + operation=operation + ).inc() + + +def record_db_operation(operation: str, table: str, duration: float): + """记录数据库操作""" + db_operation_duration_seconds.labels( + operation=operation, + table=table + ).observe(duration) + + +def update_data_source_status(source: str, asset_class: str, is_healthy: bool): + """更新数据源状态""" + data_source_status.labels( + source=source, + asset_class=asset_class + ).set(1 if is_healthy else 0) + + +def increment_websocket_connections(delta: int = 1): + """增加WebSocket连接数""" + websocket_connections.inc(delta) + + +def decrement_websocket_connections(delta: int = 1): + """减少WebSocket连接数""" + websocket_connections.dec(delta) + + +def record_websocket_message(direction: str): + """记录WebSocket消息""" + websocket_messages_total.labels(direction=direction).inc() + + +def set_cache_hit_ratio(cache_type: str, ratio: float): + """设置缓存命中率""" + cache_hit_ratio.labels(cache_type=cache_type).set(ratio) + + +def set_app_info(version: str, build_time: str, git_commit: str = ""): + """设置应用信息""" + app_info.info({ + 'version': version, + 'build_time': build_time, + 'git_commit': git_commit + }) + + +# ============================================ +# FastAPI中间件 +# ============================================ + +class MetricsMiddleware(BaseHTTPMiddleware): + """指标收集中间件""" + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + # 跳过metrics端点自身的监控 + if request.url.path == '/metrics': + return await call_next(request) + + # 记录活跃请求 + http_requests_active.labels(method=request.method).inc() + + # 记录开始时间 + start_time = time.time() + status_code = 200 # 默认状态码 + + try: + response = await call_next(request) + status_code = response.status_code + return response + except Exception as e: + status_code = 500 + raise + finally: + # 计算持续时间 + duration = time.time() - start_time + + # 获取端点路径(使用路由模板而非实际URL) + endpoint = request.url.path + if hasattr(request.state, 'route'): + endpoint = request.state.route + + # 记录指标 + record_http_request( + method=request.method, + endpoint=endpoint, + status_code=status_code, + duration=duration + ) + + # 减少活跃请求计数 + http_requests_active.labels(method=request.method).dec() + + +# ============================================ +# 指标端点 +# ============================================ + +def get_metrics_response() -> Response: + """获取Prometheus格式的指标数据""" + from fastapi.responses import Response as FastAPIResponse + + return FastAPIResponse( + content=generate_latest(), + media_type=CONTENT_TYPE_LATEST + ) diff --git a/app/core/rate_limiter.py b/app/core/rate_limiter.py new file mode 100644 index 0000000..30eaadb --- /dev/null +++ b/app/core/rate_limiter.py @@ -0,0 +1,352 @@ +"""应用层限流模块 + +支持以下限流策略: +1. 固定窗口计数器 +2. 滑动窗口计数器 +3. 令牌桶算法 +""" + +import time +import asyncio +from typing import Dict, Optional, Tuple, Callable +from dataclasses import dataclass, field +from threading import Lock +from collections import deque + +from fastapi import Request, HTTPException, Response +from starlette.middleware.base import BaseHTTPMiddleware + + +@dataclass +class RateLimitConfig: + """限流配置""" + requests_per_minute: int = 60 # 每分钟请求数 + burst_size: int = 10 # 突发请求数 + window_size: int = 60 # 窗口大小(秒) + strategy: str = "sliding_window" # 限流策略: fixed_window, sliding_window, token_bucket + key_func: Optional[Callable[[Request], str]] = None # 自定义key生成函数 + + +@dataclass +class FixedWindow: + """固定窗口""" + count: int = 0 + reset_time: float = field(default_factory=lambda: time.time() + 60) + + +@dataclass +class SlidingWindow: + """滑动窗口""" + requests: deque = field(default_factory=lambda: deque()) + + def clean_old_requests(self, window_size: int): + """清理过期的请求记录""" + now = time.time() + cutoff = now - window_size + while self.requests and self.requests[0] < cutoff: + self.requests.popleft() + + +@dataclass +class TokenBucket: + """令牌桶""" + tokens: float = field(default_factory=float) + last_update: float = field(default_factory=time.time) + + def update_tokens(self, rate_per_second: float, max_tokens: float): + """更新令牌数量""" + now = time.time() + elapsed = now - self.last_update + self.tokens = min(max_tokens, self.tokens + elapsed * rate_per_second) + self.last_update = now + + +class RateLimiter: + """限流器 + + 支持多种限流策略,默认使用滑动窗口算法。 + """ + + def __init__(self, config: RateLimitConfig = None): + self.config = config or RateLimitConfig() + self.lock = Lock() + + # 存储每个key的限流状态 + self.fixed_windows: Dict[str, FixedWindow] = {} + self.sliding_windows: Dict[str, SlidingWindow] = {} + self.token_buckets: Dict[str, TokenBucket] = {} + + # 启动清理任务 + self._cleanup_task = None + + def _get_key(self, request: Request) -> str: + """生成限流key + + 默认使用客户端IP + 路径 + """ + if self.config.key_func: + return self.config.key_func(request) + + # 获取客户端IP + client_ip = request.client.host if request.client else "unknown" + + # 检查X-Forwarded-For头 + forwarded = request.headers.get("X-Forwarded-For") + if forwarded: + client_ip = forwarded.split(",")[0].strip() + + # 使用IP + 路径作为key + return f"{client_ip}:{request.url.path}" + + def _check_fixed_window(self, key: str) -> Tuple[bool, Dict]: + """固定窗口限流检查 + + Returns: + (是否允许, 响应头信息) + """ + now = time.time() + window = self.fixed_windows.get(key) + + if window is None or now > window.reset_time: + # 新窗口 + self.fixed_windows[key] = FixedWindow(count=1, reset_time=now + self.config.window_size) + remaining = self.config.requests_per_minute - 1 + return True, { + "X-RateLimit-Limit": str(self.config.requests_per_minute), + "X-RateLimit-Remaining": str(remaining), + "X-RateLimit-Reset": str(int(now + self.config.window_size)) + } + + if window.count >= self.config.requests_per_minute: + # 超过限制 + return False, { + "X-RateLimit-Limit": str(self.config.requests_per_minute), + "X-RateLimit-Remaining": "0", + "X-RateLimit-Reset": str(int(window.reset_time)), + "Retry-After": str(int(window.reset_time - now)) + } + + # 允许请求 + window.count += 1 + remaining = self.config.requests_per_minute - window.count + return True, { + "X-RateLimit-Limit": str(self.config.requests_per_minute), + "X-RateLimit-Remaining": str(remaining), + "X-RateLimit-Reset": str(int(window.reset_time)) + } + + def _check_sliding_window(self, key: str) -> Tuple[bool, Dict]: + """滑动窗口限流检查 + + Returns: + (是否允许, 响应头信息) + """ + now = time.time() + + if key not in self.sliding_windows: + self.sliding_windows[key] = SlidingWindow() + + window = self.sliding_windows[key] + window.clean_old_requests(self.config.window_size) + + if len(window.requests) >= self.config.requests_per_minute: + # 超过限制 + oldest = window.requests[0] + reset_time = oldest + self.config.window_size + return False, { + "X-RateLimit-Limit": str(self.config.requests_per_minute), + "X-RateLimit-Remaining": "0", + "X-RateLimit-Reset": str(int(reset_time)), + "Retry-After": str(int(reset_time - now)) + } + + # 允许请求 + window.requests.append(now) + remaining = self.config.requests_per_minute - len(window.requests) + return True, { + "X-RateLimit-Limit": str(self.config.requests_per_minute), + "X-RateLimit-Remaining": str(remaining), + "X-RateLimit-Reset": str(int(now + self.config.window_size)) + } + + def _check_token_bucket(self, key: str) -> Tuple[bool, Dict]: + """令牌桶限流检查 + + Returns: + (是否允许, 响应头信息) + """ + rate_per_second = self.config.requests_per_minute / 60.0 + max_tokens = self.config.burst_size + + if key not in self.token_buckets: + self.token_buckets[key] = TokenBucket(tokens=max_tokens) + + bucket = self.token_buckets[key] + bucket.update_tokens(rate_per_second, max_tokens) + + if bucket.tokens < 1: + # 令牌不足 + wait_time = (1 - bucket.tokens) / rate_per_second + return False, { + "X-RateLimit-Limit": str(self.config.requests_per_minute), + "X-RateLimit-Remaining": "0", + "Retry-After": str(int(wait_time) + 1) + } + + # 消耗令牌 + bucket.tokens -= 1 + remaining = int(bucket.tokens) + return True, { + "X-RateLimit-Limit": str(self.config.requests_per_minute), + "X-RateLimit-Remaining": str(remaining) + } + + def is_allowed(self, request: Request) -> Tuple[bool, Dict]: + """检查请求是否允许通过 + + Returns: + (是否允许, 响应头信息) + """ + with self.lock: + key = self._get_key(request) + + if self.config.strategy == "fixed_window": + return self._check_fixed_window(key) + elif self.config.strategy == "token_bucket": + return self._check_token_bucket(key) + else: + return self._check_sliding_window(key) + + def cleanup(self): + """清理过期的限流数据""" + now = time.time() + + with self.lock: + # 清理固定窗口 + expired = [ + key for key, window in self.fixed_windows.items() + if now > window.reset_time + ] + for key in expired: + del self.fixed_windows[key] + + # 清理滑动窗口 + for window in self.sliding_windows.values(): + window.clean_old_requests(self.config.window_size) + + # 清理空的滑动窗口 + empty = [ + key for key, window in self.sliding_windows.items() + if not window.requests + ] + for key in empty: + del self.sliding_windows[key] + + async def start_cleanup_task(self): + """启动定期清理任务""" + while True: + await asyncio.sleep(300) # 每5分钟清理一次 + self.cleanup() + + +class RateLimitMiddleware(BaseHTTPMiddleware): + """限流中间件 + + 使用示例: + app.add_middleware( + RateLimitMiddleware, + config=RateLimitConfig( + requests_per_minute=60, + strategy="sliding_window" + ) + ) + """ + + def __init__(self, app, config: RateLimitConfig = None): + super().__init__(app) + self.limiter = RateLimiter(config) + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + # 跳过某些路径 + if request.url.path in ["/health", "/metrics", "/docs", "/redoc", "/openapi.json"]: + return await call_next(request) + + # 检查限流 + allowed, headers = self.limiter.is_allowed(request) + + if not allowed: + raise HTTPException( + status_code=429, + detail="Too many requests", + headers=headers + ) + + # 执行请求 + response = await call_next(request) + + # 添加限流响应头 + for key, value in headers.items(): + response.headers[key] = value + + return response + + +# 全局限流器实例(用于特定端点限流) +_default_limiter: Optional[RateLimiter] = None + + +def get_limiter(config: RateLimitConfig = None) -> RateLimiter: + """获取全局限流器实例""" + global _default_limiter + if _default_limiter is None: + _default_limiter = RateLimiter(config) + return _default_limiter + + +def rate_limit( + requests_per_minute: int = 60, + strategy: str = "sliding_window", + key_func: Optional[Callable[[Request], str]] = None +): + """装饰器:为特定端点添加限流 + + 使用示例: + @app.get("/api/data") + @rate_limit(requests_per_minute=30) + async def get_data(): + return {"data": "..."} + """ + config = RateLimitConfig( + requests_per_minute=requests_per_minute, + strategy=strategy, + key_func=key_func + ) + limiter = RateLimiter(config) + + def decorator(func: Callable) -> Callable: + async def wrapper(request: Request, *args, **kwargs): + allowed, headers = limiter.is_allowed(request) + + if not allowed: + raise HTTPException( + status_code=429, + detail="Too many requests", + headers=headers + ) + + # 如果原函数是协程 + if asyncio.iscoroutinefunction(func): + response = await func(request, *args, **kwargs) + else: + response = func(request, *args, **kwargs) + + # 添加响应头 + if hasattr(response, 'headers'): + for key, value in headers.items(): + response.headers[key] = value + + return response + + return wrapper + + return decorator diff --git a/app/main.py b/app/main.py index 15438c6..e93f21f 100644 --- a/app/main.py +++ b/app/main.py @@ -11,6 +11,8 @@ from app.api import router, admin_router from app.websocket import WebSocketServer from app.core.config import get_config, get_settings from app.core.logger import info, error, setup_logging +from app.core.metrics import MetricsMiddleware, get_metrics_response, set_app_info +from app.core.rate_limiter import RateLimitMiddleware, RateLimitConfig from app.repositories.database import init_db @@ -58,6 +60,19 @@ app.add_middleware( allow_headers=["*"], ) +# 添加Prometheus指标中间件 +app.add_middleware(MetricsMiddleware) + +# 添加限流中间件(默认每分钟60请求,滑动窗口算法) +app.add_middleware( + RateLimitMiddleware, + config=RateLimitConfig( + requests_per_minute=120, # 每分钟120请求 + burst_size=20, # 突发20请求 + strategy="sliding_window" # 使用滑动窗口算法 + ) +) + # 注册API路由 app.include_router(router, prefix="/v1") app.include_router(admin_router, prefix="/v1") @@ -74,6 +89,12 @@ async def websocket_endpoint(websocket): await ws_server.handle(websocket, client_id) +@app.get("/metrics") +async def metrics(): + """Prometheus指标端点""" + return get_metrics_response() + + # 管理后台页面HTML(完整版) ADMIN_HTML = ''' diff --git a/app/monitor/__init__.py b/app/monitor/__init__.py index 95d2989..7002463 100644 --- a/app/monitor/__init__.py +++ b/app/monitor/__init__.py @@ -1,4 +1,14 @@ """数据质量监控模块""" -from .monitor import DataQualityMonitor, AlertSender, LogAlertSender +from .monitor import DataQualityMonitor, CheckResult, QualityReport +from .alert_channels import ( + AlertChannel, AlertMessage, AlertManager, + LogAlertChannel, DingTalkAlertChannel, EmailAlertChannel, WebhookAlertChannel, + get_alert_manager, init_alert_manager +) -__all__ = ["DataQualityMonitor", "AlertSender", "LogAlertSender"] +__all__ = [ + "DataQualityMonitor", "CheckResult", "QualityReport", + "AlertChannel", "AlertMessage", "AlertManager", + "LogAlertChannel", "DingTalkAlertChannel", "EmailAlertChannel", "WebhookAlertChannel", + "get_alert_manager", "init_alert_manager" +] diff --git a/app/monitor/__pycache__/__init__.cpython-311.pyc b/app/monitor/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..aef1503 Binary files /dev/null and b/app/monitor/__pycache__/__init__.cpython-311.pyc differ diff --git a/app/monitor/__pycache__/alert_channels.cpython-311.pyc b/app/monitor/__pycache__/alert_channels.cpython-311.pyc new file mode 100644 index 0000000..672066a Binary files /dev/null and b/app/monitor/__pycache__/alert_channels.cpython-311.pyc differ diff --git a/app/monitor/__pycache__/monitor.cpython-311.pyc b/app/monitor/__pycache__/monitor.cpython-311.pyc new file mode 100644 index 0000000..74a8d53 Binary files /dev/null and b/app/monitor/__pycache__/monitor.cpython-311.pyc differ diff --git a/app/monitor/alert_channels.py b/app/monitor/alert_channels.py new file mode 100644 index 0000000..72d9ccb --- /dev/null +++ b/app/monitor/alert_channels.py @@ -0,0 +1,516 @@ +"""告警通道模块 + +支持多种告警方式: +- 日志告警(默认) +- 钉钉机器人 +- 邮件 +- Webhook +""" + +import json +import smtplib +from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import datetime +from email.mime.text import MIMEText +from email.mime.multipart import MIMEMultipart +from typing import Dict, List, Optional + +import httpx +from app.core.logger import info, error, warning + + +@dataclass +class AlertMessage: + """告警消息""" + title: str + content: str + level: str = "warning" # info, warning, error, critical + timestamp: datetime = None + metadata: Dict = None + + def __post_init__(self): + if self.timestamp is None: + self.timestamp = datetime.now() + if self.metadata is None: + self.metadata = {} + + +class AlertChannel(ABC): + """告警通道基类""" + + def __init__(self, name: str, enabled: bool = True): + self.name = name + self.enabled = enabled + + @abstractmethod + async def send(self, message: AlertMessage) -> bool: + """发送告警消息""" + pass + + async def send_batch(self, messages: List[AlertMessage]) -> List[bool]: + """批量发送告警""" + results = [] + for msg in messages: + result = await self.send(msg) + results.append(result) + return results + + +class LogAlertChannel(AlertChannel): + """日志告警通道""" + + def __init__(self, enabled: bool = True): + super().__init__("log", enabled) + + async def send(self, message: AlertMessage) -> bool: + """发送日志告警""" + if not self.enabled: + return False + + log_msg = f"[{message.level.upper()}] {message.title}: {message.content}" + + if message.level == "info": + info(log_msg) + elif message.level == "warning": + warning(log_msg) + else: + error(log_msg) + + return True + + +class DingTalkAlertChannel(AlertChannel): + """钉钉机器人告警通道""" + + def __init__( + self, + webhook_url: str, + secret: Optional[str] = None, + at_mobiles: Optional[List[str]] = None, + at_all: bool = False, + enabled: bool = True + ): + super().__init__("dingtalk", enabled) + self.webhook_url = webhook_url + self.secret = secret + self.at_mobiles = at_mobiles or [] + self.at_all = at_all + + def _generate_sign(self, timestamp: str) -> str: + """生成钉钉签名""" + import hmac + import hashlib + import urllib.parse + + if not self.secret: + return "" + + string_to_sign = f"{timestamp}\n{self.secret}" + hmac_code = hmac.new( + self.secret.encode('utf-8'), + string_to_sign.encode('utf-8'), + digestmod=hashlib.sha256 + ).digest() + sign = urllib.parse.quote_plus(base64.b64encode(hmac_code)) + return sign + + def _build_markdown_message(self, message: AlertMessage) -> Dict: + """构建Markdown格式的消息""" + # 根据级别选择颜色 + color_map = { + "info": "#007bff", + "warning": "#ffc107", + "error": "#dc3545", + "critical": "#6f42c1" + } + color = color_map.get(message.level, "#6c757d") + + # 构建@信息 + at_text = "" + if self.at_all: + at_text = "@所有人 " + elif self.at_mobiles: + at_text = " ".join([f"@{mobile}" for mobile in self.at_mobiles]) + + content = f"""### {message.title} {at_text} + +**告警级别:** {message.level.upper()} +**告警时间:** {message.timestamp.strftime('%Y-%m-%d %H:%M:%S')} + +--- + +{message.content} + +--- + +**详细信息:** +```json +{json.dumps(message.metadata, indent=2, ensure_ascii=False, default=str)} +``` +""" + + return { + "msgtype": "markdown", + "markdown": { + "title": message.title, + "text": content + }, + "at": { + "atMobiles": self.at_mobiles, + "isAtAll": self.at_all + } + } + + def _build_text_message(self, message: AlertMessage) -> Dict: + """构建文本格式的消息""" + return { + "msgtype": "text", + "text": { + "content": f"[{message.level.upper()}] {message.title}\n\n{message.content}" + }, + "at": { + "atMobiles": self.at_mobiles, + "isAtAll": self.at_all + } + } + + async def send(self, message: AlertMessage, msg_type: str = "markdown") -> bool: + """发送钉钉告警 + + Args: + message: 告警消息 + msg_type: 消息类型 (markdown 或 text) + """ + if not self.enabled or not self.webhook_url: + return False + + try: + import base64 + import time as time_module + + timestamp = str(int(round(time_module.time() * 1000))) + sign = self._generate_sign(timestamp) + + # 构建URL + url = self.webhook_url + if self.secret: + url = f"{self.webhook_url}×tamp={timestamp}&sign={sign}" + + # 构建消息 + if msg_type == "markdown": + payload = self._build_markdown_message(message) + else: + payload = self._build_text_message(message) + + # 发送请求 + async with httpx.AsyncClient() as client: + response = await client.post( + url, + json=payload, + headers={"Content-Type": "application/json"}, + timeout=10.0 + ) + + if response.status_code == 200: + result = response.json() + if result.get("errcode") == 0: + info(f"DingTalk alert sent: {message.title}") + return True + else: + error(f"DingTalk API error: {result}") + return False + else: + error(f"DingTalk HTTP error: {response.status_code}") + return False + + except Exception as e: + error(f"Failed to send DingTalk alert: {e}") + return False + + +class EmailAlertChannel(AlertChannel): + """邮件告警通道""" + + def __init__( + self, + smtp_host: str, + smtp_port: int, + username: str, + password: str, + from_addr: str, + to_addrs: List[str], + use_tls: bool = True, + enabled: bool = True + ): + super().__init__("email", enabled) + self.smtp_host = smtp_host + self.smtp_port = smtp_port + self.username = username + self.password = password + self.from_addr = from_addr + self.to_addrs = to_addrs + self.use_tls = use_tls + + def _build_html_content(self, message: AlertMessage) -> str: + """构建HTML格式的邮件内容""" + # 根据级别选择颜色 + color_map = { + "info": "#007bff", + "warning": "#ffc107", + "error": "#dc3545", + "critical": "#6f42c1" + } + color = color_map.get(message.level, "#6c757d") + + metadata_html = "" + if message.metadata: + rows = "" + for key, value in message.metadata.items(): + rows += f"{key}{value}" + metadata_html = f""" +

详细信息

+ + {rows} +
+ """ + + return f""" + + +

[{message.level.upper()}] {message.title}

+

告警时间: {message.timestamp.strftime('%Y-%m-%d %H:%M:%S')}

+
+

{message.content.replace(chr(10), '
')}

+
+ {metadata_html} +

+ 本邮件由行情数据服务自动发送,请勿回复。 +

+ + + """ + + async def send(self, message: AlertMessage) -> bool: + """发送邮件告警""" + if not self.enabled or not self.to_addrs: + return False + + try: + # 构建邮件 + msg = MIMEMultipart('alternative') + msg['Subject'] = f"[{message.level.upper()}] {message.title}" + msg['From'] = self.from_addr + msg['To'] = ', '.join(self.to_addrs) + + # 添加HTML内容 + html_content = self._build_html_content(message) + msg.attach(MIMEText(html_content, 'html', 'utf-8')) + + # 发送邮件(在executor中执行同步操作) + import asyncio + loop = asyncio.get_event_loop() + + def send_email(): + server = smtplib.SMTP(self.smtp_host, self.smtp_port) + if self.use_tls: + server.starttls() + server.login(self.username, self.password) + server.sendmail(self.from_addr, self.to_addrs, msg.as_string()) + server.quit() + + await loop.run_in_executor(None, send_email) + + info(f"Email alert sent: {message.title}") + return True + + except Exception as e: + error(f"Failed to send email alert: {e}") + return False + + +class WebhookAlertChannel(AlertChannel): + """Webhook告警通道""" + + def __init__( + self, + webhook_url: str, + headers: Optional[Dict[str, str]] = None, + timeout: float = 10.0, + enabled: bool = True + ): + super().__init__("webhook", enabled) + self.webhook_url = webhook_url + self.headers = headers or {"Content-Type": "application/json"} + self.timeout = timeout + + async def send(self, message: AlertMessage) -> bool: + """发送Webhook告警""" + if not self.enabled or not self.webhook_url: + return False + + try: + payload = { + "title": message.title, + "content": message.content, + "level": message.level, + "timestamp": message.timestamp.isoformat(), + "metadata": message.metadata + } + + async with httpx.AsyncClient() as client: + response = await client.post( + self.webhook_url, + json=payload, + headers=self.headers, + timeout=self.timeout + ) + + if response.status_code < 400: + info(f"Webhook alert sent: {message.title}") + return True + else: + error(f"Webhook error: {response.status_code}") + return False + + except Exception as e: + error(f"Failed to send webhook alert: {e}") + return False + + +class AlertManager: + """告警管理器 + + 管理多个告警通道,支持消息路由和批量发送。 + """ + + def __init__(self): + self.channels: Dict[str, AlertChannel] = {} + self.level_routing = { + "info": ["log"], + "warning": ["log", "dingtalk"], + "error": ["log", "dingtalk", "email"], + "critical": ["log", "dingtalk", "email", "webhook"] + } + + def register_channel(self, channel: AlertChannel): + """注册告警通道""" + self.channels[channel.name] = channel + info(f"Alert channel registered: {channel.name}") + + def configure_routing(self, level_routing: Dict[str, List[str]]): + """配置告警路由规则""" + self.level_routing = level_routing + + async def send( + self, + message: AlertMessage, + channels: Optional[List[str]] = None + ) -> Dict[str, bool]: + """发送告警 + + Args: + message: 告警消息 + channels: 指定通道列表,None则根据级别路由 + + Returns: + 各通道发送结果 + """ + # 确定目标通道 + target_channels = channels + if target_channels is None: + target_channels = self.level_routing.get(message.level, ["log"]) + + # 发送到各通道 + results = {} + for channel_name in target_channels: + channel = self.channels.get(channel_name) + if channel: + results[channel_name] = await channel.send(message) + else: + warning(f"Alert channel not found: {channel_name}") + results[channel_name] = False + + return results + + async def send_simple( + self, + title: str, + content: str, + level: str = "warning", + **kwargs + ) -> Dict[str, bool]: + """发送简单告警""" + message = AlertMessage( + title=title, + content=content, + level=level, + metadata=kwargs + ) + return await self.send(message) + + +# 全局告警管理器实例 +_alert_manager: Optional[AlertManager] = None + + +def get_alert_manager() -> AlertManager: + """获取全局告警管理器""" + global _alert_manager + if _alert_manager is None: + _alert_manager = AlertManager() + # 默认注册日志通道 + _alert_manager.register_channel(LogAlertChannel()) + return _alert_manager + + +def init_alert_manager(config: Dict): + """从配置初始化告警管理器""" + global _alert_manager + _alert_manager = AlertManager() + + # 注册日志通道 + _alert_manager.register_channel(LogAlertChannel( + enabled=config.get("log", {}).get("enabled", True) + )) + + # 注册钉钉通道 + dingtalk_config = config.get("dingtalk", {}) + if dingtalk_config.get("enabled"): + _alert_manager.register_channel(DingTalkAlertChannel( + webhook_url=dingtalk_config["webhook_url"], + secret=dingtalk_config.get("secret"), + at_mobiles=dingtalk_config.get("at_mobiles", []), + at_all=dingtalk_config.get("at_all", False), + enabled=True + )) + + # 注册邮件通道 + email_config = config.get("email", {}) + if email_config.get("enabled"): + _alert_manager.register_channel(EmailAlertChannel( + smtp_host=email_config["smtp_host"], + smtp_port=email_config["smtp_port"], + username=email_config["username"], + password=email_config["password"], + from_addr=email_config["from_addr"], + to_addrs=email_config["to_addrs"], + use_tls=email_config.get("use_tls", True), + enabled=True + )) + + # 注册Webhook通道 + webhook_config = config.get("webhook", {}) + if webhook_config.get("enabled"): + _alert_manager.register_channel(WebhookAlertChannel( + webhook_url=webhook_config["webhook_url"], + headers=webhook_config.get("headers"), + timeout=webhook_config.get("timeout", 10.0), + enabled=True + )) + + # 配置路由规则 + if "routing" in config: + _alert_manager.configure_routing(config["routing"]) + + return _alert_manager diff --git a/app/monitor/monitor.py b/app/monitor/monitor.py index 8a0dd57..01d6e87 100644 --- a/app/monitor/monitor.py +++ b/app/monitor/monitor.py @@ -1,6 +1,5 @@ """数据质量监控 - 对应Go的internal/monitor/monitor.go""" import asyncio -from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime, timedelta from typing import List, Optional @@ -11,6 +10,7 @@ from sqlalchemy import text from app.repositories import StockRepository, FuturesRepository from app.models import Frequency from app.core.logger import info, error +from app.monitor.alert_channels import AlertManager, AlertMessage, get_alert_manager @dataclass @@ -37,23 +37,6 @@ class QualityReport: pass_rate: float -class AlertSender(ABC): - """告警发送接口""" - - @abstractmethod - def send_alert(self, title: str, content: str) -> bool: - """发送告警""" - pass - - -class LogAlertSender(AlertSender): - """日志告警发送器""" - - def send_alert(self, title: str, content: str) -> bool: - info(f"[ALERT] {title}: {content}") - return True - - class DataQualityMonitor: """数据质量监控""" @@ -62,12 +45,12 @@ class DataQualityMonitor: db: Session, stock_repo: StockRepository, futures_repo: FuturesRepository, - sender: Optional[AlertSender] = None + alert_manager: Optional[AlertManager] = None ): self.db = db self.stock_repo = stock_repo self.futures_repo = futures_repo - self.sender = sender or LogAlertSender() + self.alert_manager = alert_manager or get_alert_manager() async def daily_check(self, check_date: str): """每日数据质量检查""" @@ -156,11 +139,17 @@ class DataQualityMonitor: result.detail = f"Data missing: expected {expect_count}, actual {actual_count}" # 发送告警 - if self.sender: - self.sender.send_alert( - f"[{asset_type}] Data Missing Alert", - f"Symbol: {symbol}, Date: {check_date}, Expected: {expect_count}, Actual: {actual_count}" - ) + if self.alert_manager: + asyncio.create_task(self.alert_manager.send_simple( + title=f"[{asset_type.upper()}] 数据缺失告警", + content=f"标的: {symbol}, 日期: {check_date}, 期望: {expect_count}条, 实际: {actual_count}条", + level="warning", + asset_type=asset_type, + symbol=symbol, + check_date=check_date, + expect_count=expect_count, + actual_count=actual_count + )) except Exception as e: result.status = "fail" diff --git a/app/repositories/__pycache__/database.cpython-311.pyc b/app/repositories/__pycache__/database.cpython-311.pyc index 08eb722..84f5275 100644 Binary files a/app/repositories/__pycache__/database.cpython-311.pyc and b/app/repositories/__pycache__/database.cpython-311.pyc differ diff --git a/app/repositories/__pycache__/futures_repository.cpython-311.pyc b/app/repositories/__pycache__/futures_repository.cpython-311.pyc index c0ca708..66e9f9e 100644 Binary files a/app/repositories/__pycache__/futures_repository.cpython-311.pyc and b/app/repositories/__pycache__/futures_repository.cpython-311.pyc differ diff --git a/app/repositories/__pycache__/models.cpython-311.pyc b/app/repositories/__pycache__/models.cpython-311.pyc index 953510e..3c501d1 100644 Binary files a/app/repositories/__pycache__/models.cpython-311.pyc and b/app/repositories/__pycache__/models.cpython-311.pyc differ diff --git a/app/repositories/__pycache__/stock_repository.cpython-311.pyc b/app/repositories/__pycache__/stock_repository.cpython-311.pyc index 0754396..a5c1c13 100644 Binary files a/app/repositories/__pycache__/stock_repository.cpython-311.pyc and b/app/repositories/__pycache__/stock_repository.cpython-311.pyc differ diff --git a/app/repositories/futures_repository.py b/app/repositories/futures_repository.py index dbf217a..3395b58 100644 --- a/app/repositories/futures_repository.py +++ b/app/repositories/futures_repository.py @@ -58,10 +58,20 @@ class FuturesRepository: return items def _get_kline_model(self, freq: Frequency): - """根据周期获取K线模型""" + """根据周期获取K线模型 + + 注意: 目前数据库只支持1分钟和日线K线存储。 + 其他周期(5m/15m/30m/60m/1w/1month)默认使用日线表。 + """ mapping = { Frequency.FREQ_1M: FuturesKLine1M, Frequency.FREQ_1D: FuturesKLine1D, + Frequency.FREQ_5M: FuturesKLine1D, # 默认使用日线 + Frequency.FREQ_15M: FuturesKLine1D, + Frequency.FREQ_30M: FuturesKLine1D, + Frequency.FREQ_60M: FuturesKLine1D, + Frequency.FREQ_1W: FuturesKLine1D, + Frequency.FREQ_1MONTH: FuturesKLine1D, } return mapping.get(freq, FuturesKLine1D) diff --git a/app/repositories/models.py b/app/repositories/models.py index c79d26a..410322a 100644 --- a/app/repositories/models.py +++ b/app/repositories/models.py @@ -210,3 +210,20 @@ class DataQualityCheck(Base): actual_count = Column(Integer, nullable=True, comment="实际数量") detail = Column(String(500), nullable=True, comment="详情") created_at = Column(DateTime, default=datetime.now, comment="创建时间") + + +class StockAdjustFactor(Base): + """股票复权系数表""" + __tablename__ = "stock_adjust_factors" + + id = Column(BigInteger, primary_key=True, autoincrement=True) + symbol_id = Column(String(20), nullable=False, index=True, comment="标的代码") + trade_date = Column(String(10), nullable=False, index=True, comment="交易日期 YYYY-MM-DD") + qfq_factor = Column(Numeric(18, 8), nullable=False, default=1.0, comment="前复权系数") + hfq_factor = Column(Numeric(18, 8), nullable=False, default=1.0, comment="后复权系数") + created_at = Column(DateTime, default=datetime.now, comment="创建时间") + updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment="更新时间") + + __table_args__ = ( + Index("idx_adj_factor_symbol_date", "symbol_id", "trade_date"), + ) diff --git a/app/repositories/stock_repository.py b/app/repositories/stock_repository.py index b3eee4d..03886f4 100644 --- a/app/repositories/stock_repository.py +++ b/app/repositories/stock_repository.py @@ -11,7 +11,7 @@ from app.models import ( ) from app.repositories.models import ( StockSymbol, StockKLine1M, StockKLine5M, StockKLine1D, - StockTradingCalendar + StockTradingCalendar, StockAdjustFactor ) @@ -258,3 +258,85 @@ class StockRepository: self.db.add(new_cal) self.db.commit() + + def get_adjust_factors( + self, + symbol: str, + start_date: str, + end_date: str + ) -> List[dict]: + """获取指定日期范围内的复权系数 + + Args: + symbol: 股票代码 + start_date: 开始日期 (YYYYMMDD) + end_date: 结束日期 (YYYYMMDD) + + Returns: + 复权系数列表,每项包含 trade_date, qfq_factor, hfq_factor + """ + # 转换日期格式 + start_fmt = f"{start_date[:4]}-{start_date[4:6]}-{start_date[6:]}" + end_fmt = f"{end_date[:4]}-{end_date[4:6]}-{end_date[6:]}" + + results = self.db.query(StockAdjustFactor).filter( + StockAdjustFactor.symbol_id == symbol, + StockAdjustFactor.trade_date >= start_fmt, + StockAdjustFactor.trade_date <= end_fmt + ).order_by(StockAdjustFactor.trade_date.asc()).all() + + return [ + { + "trade_date": r.trade_date, + "qfq_factor": float(r.qfq_factor) if r.qfq_factor else 1.0, + "hfq_factor": float(r.hfq_factor) if r.hfq_factor else 1.0 + } + for r in results + ] + + def save_adjust_factors(self, symbol: str, factors: List[dict]) -> None: + """保存复权系数 + + Args: + symbol: 股票代码 + factors: 复权系数列表,每项包含 trade_date, qfq_factor, hfq_factor + """ + for f in factors: + trade_date = f.get("trade_date") + + existing = self.db.query(StockAdjustFactor).filter( + StockAdjustFactor.symbol_id == symbol, + StockAdjustFactor.trade_date == trade_date + ).first() + + if existing: + existing.qfq_factor = f.get("qfq_factor", 1.0) + existing.hfq_factor = f.get("hfq_factor", 1.0) + else: + new_factor = StockAdjustFactor( + symbol_id=symbol, + trade_date=trade_date, + qfq_factor=f.get("qfq_factor", 1.0), + hfq_factor=f.get("hfq_factor", 1.0) + ) + self.db.add(new_factor) + + self.db.commit() + + def get_latest_adjust_factor(self, symbol: str) -> Optional[dict]: + """获取最新的复权系数 + + Returns: + 包含 qfq_factor 和 hfq_factor 的字典,如果没有则返回None + """ + result = self.db.query(StockAdjustFactor).filter( + StockAdjustFactor.symbol_id == symbol + ).order_by(StockAdjustFactor.trade_date.desc()).first() + + if result: + return { + "trade_date": result.trade_date, + "qfq_factor": float(result.qfq_factor) if result.qfq_factor else 1.0, + "hfq_factor": float(result.hfq_factor) if result.hfq_factor else 1.0 + } + return None diff --git a/app/services/__pycache__/adapter_service.cpython-311.pyc b/app/services/__pycache__/adapter_service.cpython-311.pyc index 4b707c0..95292a2 100644 Binary files a/app/services/__pycache__/adapter_service.cpython-311.pyc and b/app/services/__pycache__/adapter_service.cpython-311.pyc differ diff --git a/app/services/__pycache__/futures_service.cpython-311.pyc b/app/services/__pycache__/futures_service.cpython-311.pyc index c2365a2..ad5eff1 100644 Binary files a/app/services/__pycache__/futures_service.cpython-311.pyc and b/app/services/__pycache__/futures_service.cpython-311.pyc differ diff --git a/app/services/__pycache__/stock_service.cpython-311.pyc b/app/services/__pycache__/stock_service.cpython-311.pyc index 63524a8..b49eea0 100644 Binary files a/app/services/__pycache__/stock_service.cpython-311.pyc and b/app/services/__pycache__/stock_service.cpython-311.pyc differ diff --git a/app/services/adapter_service.py b/app/services/adapter_service.py index 1f2b35e..2d57851 100644 --- a/app/services/adapter_service.py +++ b/app/services/adapter_service.py @@ -203,16 +203,30 @@ class AdapterService: # 从 config.json 获取最新配置(与文件同步) file_config = get_config() print(f"Using file config: {file_config}, adapter name: {name}") - if name == "amazingdata": - # 优先使用 stock 下的 amazingdata 配置 - source_info = file_config.sources.stock.list["amazingdata"] + + # 尝试从配置文件中获取适配器配置 + adapter_config = None + + # 1. 首先检查 stock 配置 + if name in file_config.sources.stock.list: + source_info = file_config.sources.stock.list[name] + adapter_config = dict(source_info.config) if source_info else {} + print(f"Using stock config for {name}: {adapter_config}") + + # 2. 然后检查 futures 配置 + elif name in file_config.sources.futures.list: + source_info = file_config.sources.futures.list[name] adapter_config = dict(source_info.config) if source_info else {} - print(f"Using amazingdata config: {adapter_config}") - # 处理 port 为字符串的情况 - if "port" in adapter_config and isinstance(adapter_config["port"], str): - adapter_config["port"] = int(adapter_config["port"]) if adapter_config["port"].strip() else 8600 + print(f"Using futures config for {name}: {adapter_config}") + + # 3. 使用默认配置 else: adapter_config = self.configs[name].get("config", {}) + print(f"Using default config for {name}: {adapter_config}") + + # 处理 port 为字符串的情况 + if "port" in adapter_config and isinstance(adapter_config["port"], str): + adapter_config["port"] = int(adapter_config["port"]) if adapter_config["port"].strip() else 8600 cfg = {"enabled": self.configs[name].get("enabled", False), "config": adapter_config} diff --git a/app/services/futures_service.py b/app/services/futures_service.py index dd9e218..c549a53 100644 --- a/app/services/futures_service.py +++ b/app/services/futures_service.py @@ -61,10 +61,16 @@ class FuturesService: # 确保适配器已连接 adapter = adapter_service.get_active_adapter("futures") if not adapter: - # 尝试连接 amazingdata + # 从配置获取当前激活的适配器名称 + from app.core.config import get_config + config = get_config() + active_source = config.sources.futures.active + + # 尝试连接配置的适配器 + info(f"Connecting to configured adapter: {active_source}") loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - loop.run_until_complete(adapter_service._connect_adapter("amazingdata")) + loop.run_until_complete(adapter_service._connect_adapter(active_source)) loop.close() adapter = adapter_service.get_active_adapter("futures") @@ -164,7 +170,13 @@ class FuturesService: # 确保适配器已连接 adapter = adapter_service.get_active_adapter("futures") if not adapter: - asyncio.run(adapter_service._connect_adapter("amazingdata")) + # 从配置获取当前激活的适配器名称 + from app.core.config import get_config + config = get_config() + active_source = config.sources.futures.active + + info(f"Connecting to configured adapter: {active_source}") + asyncio.run(adapter_service._connect_adapter(active_source)) adapter = adapter_service.get_active_adapter("futures") if not adapter: diff --git a/app/services/stock_service.py b/app/services/stock_service.py index 2e162ed..9ce7861 100644 --- a/app/services/stock_service.py +++ b/app/services/stock_service.py @@ -71,10 +71,16 @@ class StockService: # 确保适配器已连接 adapter = adapter_service.get_active_adapter("stock") if not adapter: - # 尝试连接 amazingdata + # 从配置获取当前激活的适配器名称 + from app.core.config import get_config + config = get_config() + active_source = config.sources.stock.active + + # 尝试连接配置的适配器 + info(f"Connecting to configured adapter: {active_source}") loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - loop.run_until_complete(adapter_service._connect_adapter("amazingdata")) + loop.run_until_complete(adapter_service._connect_adapter(active_source)) loop.close() adapter = adapter_service.get_active_adapter("stock") @@ -151,13 +157,145 @@ class StockService: def _apply_adjust( self, symbol: str, - items: List, + items: List[KLineItem], adjust_type: AdjustType - ) -> List: - """应用复权计算(TODO: 实现复权逻辑)""" - # 复权计算需要从数据库获取复权系数 - # 这里简化处理,直接返回原始数据 - return items + ) -> List[KLineItem]: + """应用复权计算 + + 复权原理: + - 前复权(qfq): 以最新价格为基准,将历史价格按比例缩小 + - 后复权(hfq): 以历史最早价格为基准,将后续价格按比例放大 + """ + if not items or adjust_type == AdjustType.NONE: + return items + + try: + # 获取日期范围 + start_date = items[0].time.strftime("%Y%m%d") + end_date = items[-1].time.strftime("%Y%m%d") + + # 从数据库获取复权系数 + factors = self.repository.get_adjust_factors(symbol, start_date, end_date) + + # 如果没有复权系数,尝试从适配器获取 + if not factors: + factors = self._fetch_adjust_factors_from_adapter(symbol, start_date, end_date) + if factors: + self.repository.save_adjust_factors(symbol, factors) + + # 将复权系数转换为字典,方便查找 + factor_map = {f["trade_date"]: f for f in factors} + + # 应用复权 + adjusted_items = [] + for item in items: + # 获取交易日期 + trade_date = getattr(item, 'trade_date', None) + if not trade_date and hasattr(item, 'time'): + trade_date = item.time.strftime("%Y-%m-%d") + + factor = factor_map.get(trade_date, {"qfq_factor": 1.0, "hfq_factor": 1.0}) + + # 根据复权类型选择系数 + if adjust_type == AdjustType.QFQ: + adj_factor = factor.get("qfq_factor", 1.0) + else: # HFQ + adj_factor = factor.get("hfq_factor", 1.0) + + # 应用复权系数到价格字段 + adjusted_item = KLineItem( + symbol=item.symbol, + time=item.time, + open=round(item.open * adj_factor, 4), + high=round(item.high * adj_factor, 4), + low=round(item.low * adj_factor, 4), + close=round(item.close * adj_factor, 4), + volume=item.volume, + amount=round(item.amount * adj_factor, 4) if item.amount else item.amount, + trade_date=getattr(item, 'trade_date', None), + is_limit_up=getattr(item, 'is_limit_up', None), + is_limit_down=getattr(item, 'is_limit_down', None), + total_market_cap=getattr(item, 'total_market_cap', None), + float_market_cap=getattr(item, 'float_market_cap', None), + inst_holding_ratio=getattr(item, 'inst_holding_ratio', None), + trading_days=getattr(item, 'trading_days', None), + adj_factor=adj_factor + ) + adjusted_items.append(adjusted_item) + + return adjusted_items + + except Exception as e: + error(f"Failed to apply adjust factor for {symbol}: {e}") + # 出错时返回原始数据 + return items + + def _fetch_adjust_factors_from_adapter( + self, + symbol: str, + start_date: str, + end_date: str + ) -> List[dict]: + """从适配器获取复权系数""" + try: + adapter_service = AdapterService() + adapter = adapter_service.get_active_adapter("stock") + + if not adapter: + error("No active adapter available for fetching adjust factors") + return [] + + # 检查适配器是否支持获取复权因子 + if not hasattr(adapter, 'get_adj_factor'): + return [] + + # 异步获取前复权因子 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + qfq_df = loop.run_until_complete( + adapter.get_adj_factor([symbol]) + ) + hfq_df = loop.run_until_complete( + adapter.get_backward_factor([symbol]) + ) + finally: + loop.close() + + # 转换DataFrame为列表 + factors = [] + + # 处理日期格式 + for idx in qfq_df.index: + date_obj = idx if hasattr(idx, 'strftime') else datetime.strptime(str(idx), "%Y%m%d") + date_str = date_obj.strftime("%Y-%m-%d") + date_key = date_obj.strftime("%Y%m%d") + + # 只保留指定范围内的数据 + if not (start_date <= date_key <= end_date): + continue + + qfq_factor = float(qfq_df.loc[idx, symbol]) if symbol in qfq_df.columns else 1.0 + hfq_factor = float(hfq_df.loc[idx, symbol]) if symbol in hfq_df.columns else 1.0 + + # 确保复权系数有效 + if qfq_factor <= 0 or qfq_factor != qfq_factor: # 检查NaN + qfq_factor = 1.0 + if hfq_factor <= 0 or hfq_factor != hfq_factor: + hfq_factor = 1.0 + + factors.append({ + "trade_date": date_str, + "qfq_factor": qfq_factor, + "hfq_factor": hfq_factor + }) + + info(f"Fetched {len(factors)} adjust factors from adapter for {symbol}") + return factors + + except Exception as e: + error(f"Failed to fetch adjust factors from adapter: {e}") + return [] def list_symbols(self, req: SymbolListRequest) -> SymbolListData: """查询标的列表""" @@ -196,7 +334,13 @@ class StockService: # 确保适配器已连接 adapter = adapter_service.get_active_adapter("stock") if not adapter: - asyncio.run(adapter_service._connect_adapter("amazingdata")) + # 从配置获取当前激活的适配器名称 + from app.core.config import get_config + config = get_config() + active_source = config.sources.stock.active + + info(f"Connecting to configured adapter: {active_source}") + asyncio.run(adapter_service._connect_adapter(active_source)) adapter = adapter_service.get_active_adapter("stock") if not adapter: @@ -301,7 +445,13 @@ class StockService: # 确保适配器已连接 adapter = adapter_service.get_active_adapter("stock") if not adapter: - asyncio.run(adapter_service._connect_adapter("amazingdata")) + # 从配置获取当前激活的适配器名称 + from app.core.config import get_config + config = get_config() + active_source = config.sources.stock.active + + info(f"Connecting to configured adapter: {active_source}") + asyncio.run(adapter_service._connect_adapter(active_source)) adapter = adapter_service.get_active_adapter("stock") if not adapter: diff --git a/config.json b/config.json index 64e54ea..cf8ee62 100644 --- a/config.json +++ b/config.json @@ -20,32 +20,32 @@ }, "sources": { "stock": { - "active": "custom", + "active": "amazingdata", "list": { - "custom": { + "amazingdata": { "type": "sdk", "config": { - "username": "", - "password": "", - "host": "", - "port": "", - "local_path": "./custom_data_cache/", + "username": "11200008169", + "password": "11200008169@2026", + "host": "140.206.44.234", + "port": "8600", + "local_path": "./amazing_data_cache/", "use_local_cache": "true" } } } }, "futures": { - "active": "custom", + "active": "amazingdata", "list": { - "custom": { + "amazingdata": { "type": "sdk", "config": { "username": "", "password": "", "host": "", - "port": "", - "local_path": "./custom_data_cache/", + "port": "8600", + "local_path": "./amazing_data_cache/", "use_local_cache": "true" } } diff --git a/market_data_service.egg-info/PKG-INFO b/market_data_service.egg-info/PKG-INFO new file mode 100644 index 0000000..bb0d8bd --- /dev/null +++ b/market_data_service.egg-info/PKG-INFO @@ -0,0 +1,269 @@ +Metadata-Version: 2.4 +Name: market-data-service +Version: 1.0.0 +Summary: 统一行情数据服务 - Python实现 +Classifier: Development Status :: 4 - Beta +Classifier: Intended Audience :: Developers +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Requires-Python: >=3.10 +Description-Content-Type: text/markdown +Requires-Dist: fastapi>=0.115.0 +Requires-Dist: uvicorn[standard]>=0.32.0 +Requires-Dist: python-socketio>=5.12.1 +Requires-Dist: websockets>=14.1 +Requires-Dist: sqlalchemy>=2.0.36 +Requires-Dist: psycopg2-binary>=2.9.10 +Requires-Dist: pandas>=2.2.3 +Requires-Dist: numpy>=2.1.3 +Requires-Dist: numba>=0.61.0 +Requires-Dist: scipy>=1.15.0 +Requires-Dist: pydantic>=2.10.0 +Requires-Dist: pydantic-settings>=2.6.1 +Requires-Dist: python-dotenv>=1.0.1 +Requires-Dist: PyYAML>=6.0.2 +Requires-Dist: httpx>=0.28.0 +Requires-Dist: apscheduler>=3.11.0 +Provides-Extra: dev +Requires-Dist: pytest>=8.3.4; extra == "dev" +Requires-Dist: pytest-asyncio>=0.24.0; extra == "dev" + +# 统一行情数据服务 - Python实现 + +Python版本的统一行情数据服务,所有接口和功能与Go版本保持一致。 + +## 特性 + +- **多周期K线支持**:1m/5m/15m/30m/60m/1d/1w/1month +- **股票复权支持**:前复权(qfq)/后复权(hfq) +- **数据源热切换**:支持Wind、Tushare等多个数据源动态切换 +- **双轨设计**:股票和期货接口独立,数据存储隔离 +- **WebSocket实时订阅**:支持实时行情推送 +- **数据质量监控**:自动检测数据缺失并告警 +- **交易日历**:支持查询股票和期货的交易日历 +- **期货合约查询**:根据品种获取可交易合约列表 + +## 技术栈 + +- **语言**: Python 3.10+ +- **Web框架**: FastAPI +- **WebSocket**: FastAPI原生WebSocket + python-socketio +- **数据库**: PostgreSQL 15+ (SQLAlchemy ORM) +- **数据源**: Tushare (首期支持) + +## 项目结构 + +``` +python_market_data_service/ +├── app/ +│ ├── __init__.py +│ ├── main.py # 主程序入口 +│ ├── api/ # API路由 +│ │ ├── __init__.py +│ │ ├── routes.py # 主要API路由 +│ │ └── admin_routes.py # 管理后台路由 +│ ├── core/ # 核心模块 +│ │ ├── __init__.py +│ │ ├── config.py # 配置管理 +│ │ ├── errors.py # 错误定义 +│ │ └── logger.py # 日志工具 +│ ├── models/ # 数据模型 +│ │ ├── __init__.py +│ │ ├── types.py # 基础类型 +│ │ └── admin_types.py # 管理后台类型 +│ ├── repositories/ # 数据访问层 +│ │ ├── __init__.py +│ │ ├── database.py # 数据库连接 +│ │ ├── models.py # 数据库模型 +│ │ ├── stock_repository.py +│ │ └── futures_repository.py +│ ├── services/ # 业务逻辑层 +│ │ ├── __init__.py +│ │ ├── stock_service.py +│ │ ├── futures_service.py +│ │ ├── admin_service.py +│ │ ├── config_service.py +│ │ ├── adapter_service.py +│ │ └── test_service.py +│ ├── adapters/ # 数据源适配器 +│ │ ├── __init__.py +│ │ ├── base.py # 适配器基类 +│ │ └── tushare_adapter.py +│ └── websocket/ # WebSocket服务 +│ ├── __init__.py +│ └── server.py +├── scripts/ +│ └── sync_data.py # 数据同步工具 +├── tests/ # 测试文件 +├── requirements.txt # 依赖列表 +├── pyproject.toml # 项目配置 +└── README.md # 本文件 +``` + +## 快速开始 + +### 1. 环境准备 + +- Python 3.10+ +- PostgreSQL 15+ +- Tushare Token (从 [Tushare官网](https://tushare.pro) 获取) + +### 2. 安装依赖 + +```bash +# 创建虚拟环境 +python -m venv venv + +# 激活虚拟环境 +# Windows: +venv\Scripts\activate +# Linux/Mac: +source venv/bin/activate + +# 安装依赖 +pip install -r requirements.txt + +# 安装Tushare(需单独安装) +pip install tushare +``` + +### 3. 配置环境变量 + +```bash +# Windows PowerShell +$env:TUSHARE_TOKEN="your_tushare_token" +$env:DATABASE_URL="postgresql://user:password@localhost:5432/marketdata" + +# Linux/Mac +export TUSHARE_TOKEN="your_tushare_token" +export DATABASE_URL="postgresql://user:password@localhost:5432/marketdata" +``` + +### 4. 初始化数据库 + +```bash +# 创建数据库(使用psql或pgAdmin) +createdb marketdata + +# 启动服务时会自动创建表结构 +``` + +### 5. 启动服务 + +```bash +# 开发模式 +python -m app.main + +# 或使用uvicorn +uvicorn app.main:app --reload --port 8080 +``` + +服务将启动在 `http://localhost:8080` + +- API文档: `http://localhost:8080/docs` +- 管理后台: `http://localhost:8080/admin` + +### 6. 同步基础数据 + +```bash +# 同步股票列表 +python scripts/sync_data.py --type stocks + +# 同步期货列表 +python scripts/sync_data.py --type futures + +# 同步交易日历 +python scripts/sync_data.py --type calendar --start 20240101 --end 20241231 + +# 同步K线数据 +python scripts/sync_data.py --type klines --symbol 000001.SZ --start 20240301 --end 20240307 --freq 1d +``` + +## API接口 + +### 股票接口 + +| 接口 | 方法 | 说明 | +|------|------|------| +| `/v1/stock/klines/:symbol` | GET | 查询K线数据 | +| `/v1/stock/symbols` | GET | 查询标的列表 | +| `/v1/stock/klines/batch` | POST | 批量查询K线 | +| `/v1/stock/trading-dates` | GET | 获取交易日历 | + +### 期货接口 + +| 接口 | 方法 | 说明 | +|------|------|------| +| `/v1/futures/klines/:symbol` | GET | 查询K线数据 | +| `/v1/futures/symbols` | GET | 查询标的列表 | +| `/v1/futures/klines/batch` | POST | 批量查询K线 | +| `/v1/futures/continuous/:underlying` | GET | 查询主力连续合约(预留) | +| `/v1/futures/trading-dates` | GET | 获取交易日历 | +| `/v1/futures/contracts` | GET | 获取品种合约列表 | + +### 管理接口 + +| 接口 | 方法 | 说明 | +|------|------|------| +| `/v1/admin/source/status` | GET | 获取数据源状态 | +| `/v1/admin/source/switch` | POST | 切换数据源 | +| `/v1/admin/backfill` | POST | 历史数据补录 | +| `/v1/admin/health` | GET | 健康检查 | + +### 管理后台 + +服务启动后,访问 `http://localhost:8080/admin` 进入管理后台。 + +### WebSocket实时订阅 + +**连接地址**: `ws://localhost:8080/v1/stream` + +**认证**: 连接时在Header中传递 `X-API-Key` + +**客户端消息**: +```json +// 订阅 +{ + "action": "subscribe", + "symbols": ["000001.SZ", "CU2504.SHFE"] +} + +// 取消订阅 +{ + "action": "unsubscribe", + "symbols": ["000001.SZ"] +} +``` + +**服务器消息**: +```json +// 订阅确认 +{ + "type": "ack", + "action": "subscribe", + "symbols": ["000001.SZ", "CU2504.SHFE"], + "ts": "2025-03-07T12:30:00Z" +} + +// 心跳 +{ + "type": "heartbeat", + "ts": "2025-03-07T12:30:30Z" +} +``` + +**限制**: 单连接最大订阅100个标的 + +## 与Go版本的主要区别 + +1. **Web框架**: Gin -> FastAPI +2. **ORM**: 原生SQL -> SQLAlchemy +3. **WebSocket**: Gorilla -> FastAPI原生 +4. **配置**: 文件+环境变量 -> Pydantic Settings +5. **API文档**: 自动生成Swagger/ReDoc + +## License + +MIT diff --git a/market_data_service.egg-info/SOURCES.txt b/market_data_service.egg-info/SOURCES.txt new file mode 100644 index 0000000..ef08a44 --- /dev/null +++ b/market_data_service.egg-info/SOURCES.txt @@ -0,0 +1,43 @@ +README.md +pyproject.toml +app/__init__.py +app/main.py +app/adapters/__init__.py +app/adapters/amazingdata_adapter.py +app/adapters/base.py +app/api/__init__.py +app/api/admin_routes.py +app/api/routes.py +app/core/__init__.py +app/core/config.py +app/core/errors.py +app/core/logger.py +app/core/metrics.py +app/core/rate_limiter.py +app/models/__init__.py +app/models/admin_types.py +app/models/types.py +app/monitor/__init__.py +app/monitor/alert_channels.py +app/monitor/monitor.py +app/repositories/__init__.py +app/repositories/database.py +app/repositories/futures_repository.py +app/repositories/models.py +app/repositories/stock_repository.py +app/services/__init__.py +app/services/adapter_service.py +app/services/admin_service.py +app/services/config_service.py +app/services/futures_service.py +app/services/stock_service.py +app/services/test_service.py +app/websocket/__init__.py +app/websocket/server.py +market_data_service.egg-info/PKG-INFO +market_data_service.egg-info/SOURCES.txt +market_data_service.egg-info/dependency_links.txt +market_data_service.egg-info/requires.txt +market_data_service.egg-info/top_level.txt +tests/test_xysz_adapter.py +tests/test_xysz_integration.py \ No newline at end of file diff --git a/market_data_service.egg-info/dependency_links.txt b/market_data_service.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/market_data_service.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/market_data_service.egg-info/requires.txt b/market_data_service.egg-info/requires.txt new file mode 100644 index 0000000..b5b0438 --- /dev/null +++ b/market_data_service.egg-info/requires.txt @@ -0,0 +1,20 @@ +fastapi>=0.115.0 +uvicorn[standard]>=0.32.0 +python-socketio>=5.12.1 +websockets>=14.1 +sqlalchemy>=2.0.36 +psycopg2-binary>=2.9.10 +pandas>=2.2.3 +numpy>=2.1.3 +numba>=0.61.0 +scipy>=1.15.0 +pydantic>=2.10.0 +pydantic-settings>=2.6.1 +python-dotenv>=1.0.1 +PyYAML>=6.0.2 +httpx>=0.28.0 +apscheduler>=3.11.0 + +[dev] +pytest>=8.3.4 +pytest-asyncio>=0.24.0 diff --git a/market_data_service.egg-info/top_level.txt b/market_data_service.egg-info/top_level.txt new file mode 100644 index 0000000..b80f0bd --- /dev/null +++ b/market_data_service.egg-info/top_level.txt @@ -0,0 +1 @@ +app diff --git a/marketdata.db b/marketdata.db index 457389d..23e36b2 100644 Binary files a/marketdata.db and b/marketdata.db differ diff --git a/marketdata.db.backup b/marketdata.db.backup new file mode 100644 index 0000000..9545876 Binary files /dev/null and b/marketdata.db.backup differ diff --git a/requirements.txt b/requirements.txt index 0a604a2..ac98a7c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,6 +29,7 @@ aioredis==2.0.1 # Monitoring apscheduler==3.11.0 +prometheus-client==0.21.0 # Testing pytest==8.3.4 diff --git a/scripts/init_mysql_db.py b/scripts/init_mysql_db.py new file mode 100644 index 0000000..92bab95 --- /dev/null +++ b/scripts/init_mysql_db.py @@ -0,0 +1,53 @@ +"""MySQL数据库初始化脚本 + +创建数据库和表结构 +""" +import sys +sys.path.insert(0, '.') + +from sqlalchemy import create_engine, text +from app.core.config import get_config +from app.repositories.database import Base + +def init_mysql(): + """初始化MySQL数据库""" + config = get_config() + db_config = config.database + + # 连接MySQL服务器(不指定数据库) + server_url = f"mysql+pymysql://{db_config.user}:{db_config.password}@{db_config.host}:{db_config.port}" + + print(f"Connecting to MySQL server: {db_config.host}:{db_config.port}") + engine = create_engine(server_url) + + # 创建数据库 + with engine.connect() as conn: + conn.execute(text(f"CREATE DATABASE IF NOT EXISTS {db_config.database} CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci")) + print(f"Database '{db_config.database}' created or exists") + + # 连接到新创建的数据库 + db_url = f"{server_url}/{db_config.database}" + db_engine = create_engine(db_url) + + # 创建所有表 + print("Creating tables...") + Base.metadata.create_all(bind=db_engine) + print("Tables created successfully!") + + # 显示创建的表 + with db_engine.connect() as conn: + result = conn.execute(text("SHOW TABLES")) + tables = [row[0] for row in result] + print(f"\nTables in database '{db_config.database}':") + for table in tables: + print(f" - {table}") + +if __name__ == "__main__": + try: + init_mysql() + print("\nMySQL database initialization completed!") + except Exception as e: + print(f"Error: {e}") + import traceback + traceback.print_exc() + sys.exit(1)