You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

535 lines
20 KiB

"""测试服务 - 对应Go的internal/service/test.go"""
import asyncio
import json
from datetime import datetime, timedelta
from typing import List, Optional
from threading import RLock
import httpx
import websockets
from app.models import (
APITestListData, APITestCategory, APITestCase,
APITestRequest, APITestResult,
WSTestListData, WSTestCase, WSTestRequest, WSTestResult, WSMessage,
TestHistoryRequest, TestHistoryData
)
from app.core.logger import info, error
class TestService:
"""测试服务"""
def __init__(self):
self.lock = RLock()
self.api_history: List[APITestResult] = []
self.ws_history: List[WSTestResult] = []
self.history_size = 100
def get_api_test_list(self) -> APITestListData:
"""获取API测试列表"""
# 固定交易时间2026年3月2日到2026年3月6日
test_start = datetime(2026, 3, 2)
test_end = datetime(2026, 3, 6)
categories = [
APITestCategory(
name="股票接口",
items=[
APITestCase(
id="stock_klines",
name="查询股票K线",
method="GET",
path="/v1/stock/klines/{symbol}",
description="查询指定股票的K线数据",
params={
"symbol": "000001.SZ",
"start": test_start.strftime("%Y%m%d"),
"end": test_end.strftime("%Y%m%d"),
"freq": "1d",
"adjust": "qfq"
}
),
APITestCase(
id="stock_symbols",
name="查询股票列表",
method="GET",
path="/v1/stock/symbols",
description="获取所有可用股票标的",
params={"page": "1", "size": "20"}
),
APITestCase(
id="stock_batch",
name="批量查询股票K线",
method="POST",
path="/v1/stock/klines/batch",
description="批量查询多只股票K线",
body={
"symbols": ["000001.SZ", "000002.SZ"],
"start": test_start.strftime("%Y%m%d"),
"end": test_end.strftime("%Y%m%d"),
"freq": "1d"
}
),
APITestCase(
id="stock_calendar",
name="查询交易日历",
method="GET",
path="/v1/stock/trading-dates",
description="查询股票交易日历",
params={
"start": test_start.strftime("%Y%m%d"),
"end": test_end.strftime("%Y%m%d")
}
),
]
),
APITestCategory(
name="期货接口",
items=[
APITestCase(
id="futures_klines",
name="查询期货K线",
method="GET",
path="/v1/futures/klines/{symbol}",
description="查询指定期货合约的K线数据",
params={
"symbol": "CU2504.SHFE",
"start": test_start.strftime("%Y%m%d"),
"end": test_end.strftime("%Y%m%d"),
"freq": "1d"
}
),
APITestCase(
id="futures_symbols",
name="查询期货列表",
method="GET",
path="/v1/futures/symbols",
description="获取所有可用期货标的",
params={"page": "1", "size": "20"}
),
APITestCase(
id="futures_batch",
name="批量查询期货K线",
method="POST",
path="/v1/futures/klines/batch",
description="批量查询多个期货合约K线",
body={
"symbols": ["CU2504.SHFE", "RB2505.SHFE"],
"start": test_start.strftime("%Y%m%d"),
"end": test_end.strftime("%Y%m%d"),
"freq": "1d"
}
),
APITestCase(
id="futures_contracts",
name="查询合约列表",
method="GET",
path="/v1/futures/contracts",
description="根据品种查询可交易合约",
params={"underlying": "CU", "exchange": "SHFE"}
),
APITestCase(
id="futures_calendar",
name="查询期货交易日历",
method="GET",
path="/v1/futures/trading-dates",
description="查询期货交易日历",
params={
"start": test_start.strftime("%Y%m%d"),
"end": test_end.strftime("%Y%m%d")
}
),
]
),
APITestCategory(
name="管理接口",
items=[
APITestCase(
id="admin_health",
name="健康检查",
method="GET",
path="/v1/admin/health",
description="检查服务健康状态",
params={}
),
APITestCase(
id="admin_source_status",
name="数据源状态",
method="GET",
path="/v1/admin/source/status",
description="获取当前数据源状态",
params={}
),
APITestCase(
id="admin_source_switch",
name="切换数据源",
method="POST",
path="/v1/admin/source/switch",
description="切换到指定数据源amazingdata",
body={
"asset_class": "all",
"source": "amazingdata",
"sync_backfill": False
}
),
APITestCase(
id="admin_system_status",
name="系统状态",
method="GET",
path="/v1/admin/system/status",
description="获取系统运行状态和资源使用情况",
params={}
),
APITestCase(
id="admin_config_list",
name="查询配置列表",
method="GET",
path="/v1/admin/config",
description="获取所有配置项列表",
params={}
),
APITestCase(
id="admin_config_update",
name="更新配置",
method="PUT",
path="/v1/admin/config",
description="更新系统配置",
body={
"key": "server.mode",
"value": "debug",
"description": "服务器运行模式"
}
),
APITestCase(
id="admin_reload_config",
name="热加载配置",
method="POST",
path="/v1/admin/system/reload",
description="重新加载配置文件",
body={}
),
]
),
APITestCategory(
name="适配器管理",
items=[
APITestCase(
id="admin_adapters_list",
name="适配器列表",
method="GET",
path="/v1/admin/adapters",
description="获取所有数据源适配器列表",
params={}
),
APITestCase(
id="admin_adapter_toggle",
name="切换适配器状态",
method="POST",
path="/v1/admin/adapters/toggle",
description="启用或禁用适配器",
body={
"name": "amazingdata",
"enable": True
}
),
APITestCase(
id="admin_adapter_config",
name="更新适配器配置",
method="PUT",
path="/v1/admin/adapters/config",
description="更新适配器配置参数",
body={
"name": "amazingdata",
"config": {
"timeout": "60"
}
}
),
]
),
APITestCategory(
name="测试管理",
items=[
APITestCase(
id="admin_test_history",
name="测试历史",
method="GET",
path="/v1/admin/tests/history",
description="获取测试执行历史记录",
params={"type": "api", "limit": "20"}
),
]
),
]
return APITestListData(categories=categories, base_url="")
async def run_api_test(self, base_url: str, req: APITestRequest) -> APITestResult:
"""执行API测试"""
# 获取测试用例
test_list = self.get_api_test_list()
test_case = None
for cat in test_list.categories:
for item in cat.items:
if item.id == req.id:
test_case = item
break
if test_case:
break
if not test_case:
raise ValueError(f"Test case not found: {req.id}")
# 合并参数
params = dict(test_case.params)
if req.params:
params.update(req.params)
# 构建URL
url = base_url + test_case.path
for k, v in params.items():
url = url.replace(f"{{{k}}}", str(v))
# 添加查询参数
if test_case.method == "GET" and params:
query_parts = []
for k, v in params.items():
if f"{{{k}}}" not in test_case.path:
query_parts.append(f"{k}={v}")
if query_parts:
url += "?" + "&".join(query_parts)
# 准备请求体
body = req.body if req.body is not None else test_case.body
# 执行请求
start_time = datetime.now()
async with httpx.AsyncClient() as client:
try:
headers = {"X-API-Key": "test-api-key"}
if test_case.method == "GET":
response = await client.get(url, headers=headers, timeout=30)
elif test_case.method == "POST":
response = await client.post(
url, json=body, headers=headers, timeout=30
)
else:
raise ValueError(f"Unsupported method: {test_case.method}")
latency = int((datetime.now() - start_time).total_seconds() * 1000)
result = APITestResult(
id=int(datetime.now().timestamp()),
case_id=req.id,
name=test_case.name,
success=200 <= response.status_code < 300,
status_code=response.status_code,
latency=latency,
request={
"method": test_case.method,
"url": url,
"body": body
},
response=response.json() if response.headers.get("content-type", "").startswith("application/json") else response.text,
timestamp=datetime.now()
)
self._add_api_history(result)
return result
except Exception as e:
latency = int((datetime.now() - start_time).total_seconds() * 1000)
result = APITestResult(
id=int(datetime.now().timestamp()),
case_id=req.id,
name=test_case.name,
success=False,
latency=latency,
request={
"method": test_case.method,
"url": url,
"body": body
},
error=str(e),
timestamp=datetime.now()
)
self._add_api_history(result)
return result
def get_ws_test_list(self) -> WSTestListData:
"""获取WebSocket测试列表"""
cases = [
WSTestCase(
id="ws_subscribe_stock",
name="订阅股票行情",
description="订阅单只股票实时行情",
action="subscribe",
symbols=["000001.SZ"]
),
WSTestCase(
id="ws_subscribe_futures",
name="订阅期货行情",
description="订阅单个期货合约实时行情",
action="subscribe",
symbols=["CU2504.SHFE"]
),
WSTestCase(
id="ws_subscribe_multi",
name="批量订阅",
description="同时订阅多个标的",
action="subscribe",
symbols=["000001.SZ", "000002.SZ", "CU2504.SHFE"]
),
WSTestCase(
id="ws_subscribe_many",
name="压力测试-大量订阅",
description="订阅大量标的测试性能",
action="subscribe",
symbols=[
"000001.SZ", "000002.SZ", "000063.SZ", "000333.SZ",
"000538.SZ", "000568.SZ", "000651.SZ", "000725.SZ",
"000768.SZ", "000858.SZ"
]
),
WSTestCase(
id="ws_unsubscribe",
name="取消订阅",
description="取消订阅标的",
action="unsubscribe",
symbols=["000001.SZ"]
),
WSTestCase(
id="ws_unsubscribe_all",
name="取消全部订阅",
description="取消所有已订阅标的",
action="unsubscribe",
symbols=["000001.SZ", "000002.SZ", "CU2504.SHFE"]
),
WSTestCase(
id="ws_heartbeat",
name="心跳检测",
description="测试WebSocket连接心跳",
action="subscribe",
symbols=["000001.SZ"]
),
WSTestCase(
id="ws_invalid_symbol",
name="无效标的测试",
description="测试订阅无效标的的错误处理",
action="subscribe",
symbols=["INVALID.CODE"]
),
WSTestCase(
id="ws_empty_symbols",
name="空订阅测试",
description="测试空标的列表的处理",
action="subscribe",
symbols=[]
),
WSTestCase(
id="ws_resubscribe",
name="重新订阅",
description="取消后重新订阅同一标的",
action="subscribe",
symbols=["000001.SZ"]
),
]
return WSTestListData(cases=cases, ws_url="")
async def run_ws_test(self, ws_url: str, req: WSTestRequest) -> WSTestResult:
"""执行WebSocket测试"""
# 获取测试用例
test_list = self.get_ws_test_list()
test_case = None
for item in test_list.cases:
if item.id == req.id:
test_case = item
break
if not test_case:
raise ValueError(f"Test case not found: {req.id}")
# 使用自定义标的
symbols = req.symbols if req.symbols else test_case.symbols
result = WSTestResult(
id=f"ws_{int(datetime.now().timestamp())}",
case_id=req.id,
timestamp=datetime.now(),
messages=[]
)
# 连接WebSocket
start_time = datetime.now()
try:
async with websockets.connect(
ws_url,
extra_headers={"X-API-Key": "test-api-key"}
) as ws:
result.latency = int((datetime.now() - start_time).total_seconds() * 1000)
result.success = True
# 发送订阅消息
msg = {
"action": test_case.action,
"symbols": symbols
}
await ws.send(json.dumps(msg))
# 等待响应最多3条消息
for _ in range(3):
try:
msg_data = await asyncio.wait_for(ws.recv(), timeout=5)
result.messages.append(WSMessage(
type="received",
data=json.loads(msg_data),
timestamp=datetime.now()
))
except asyncio.TimeoutError:
break
except Exception as e:
result.latency = int((datetime.now() - start_time).total_seconds() * 1000)
result.success = False
result.error = str(e)
self._add_ws_history(result)
return result
def get_test_history(self, req: TestHistoryRequest) -> TestHistoryData:
"""获取测试历史"""
with self.lock:
limit = req.limit or 20
api_tests = []
ws_tests = []
if not req.type or req.type == "api":
api_tests = self.api_history[-limit:]
if not req.type or req.type == "ws":
ws_tests = self.ws_history[-limit:]
return TestHistoryData(api_tests=api_tests, ws_tests=ws_tests)
def _add_api_history(self, result: APITestResult):
"""添加API测试历史"""
with self.lock:
self.api_history.append(result)
if len(self.api_history) > self.history_size:
self.api_history = self.api_history[-self.history_size:]
def _add_ws_history(self, result: WSTestResult):
"""添加WebSocket测试历史"""
with self.lock:
self.ws_history.append(result)
if len(self.ws_history) > self.history_size:
self.ws_history = self.ws_history[-self.history_size:]