You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

495 lines
15 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
后端服务单元测试套件
测试覆盖:
- 认证模块(登录/令牌/API Key
- K 线数据服务
- 实时行情服务
- 告警服务
- 订阅服务
- 用户管理
"""
import pytest
from datetime import datetime, timedelta
from fastapi.testclient import TestClient
from unittest.mock import Mock, patch, MagicMock
from app.main import app
from app.config import settings
from app.services.auth_service import (
verify_password,
get_password_hash,
create_access_token,
create_refresh_token,
decode_token,
generate_api_key,
hash_api_key
)
from app.services.kline_service import KlineService
from app.services.alert_service import AlertService
from app.services.subscription_service import SubscriptionService
client = TestClient(app)
# ==================== 认证服务测试 ====================
class TestAuthService:
"""认证服务测试"""
def test_password_hashing(self):
"""测试密码哈希"""
password = "test_password_123"
hashed = get_password_hash(password)
assert hashed is not None
assert hashed != password
assert len(hashed) > 50
def test_password_verification_success(self):
"""测试密码验证成功"""
password = "test_password_123"
hashed = get_password_hash(password)
assert verify_password(password, hashed) is True
def test_password_verification_failure(self):
"""测试密码验证失败"""
password = "test_password_123"
wrong_password = "wrong_password"
hashed = get_password_hash(password)
assert verify_password(wrong_password, hashed) is False
def test_create_access_token(self):
"""测试创建访问令牌"""
data = {"sub": "testuser", "user_id": 1}
token = create_access_token(data)
assert token is not None
assert len(token) > 50
# 解码验证
payload = decode_token(token)
assert payload["sub"] == "testuser"
assert payload["user_id"] == 1
assert "exp" in payload
def test_create_access_token_with_expiry(self):
"""测试创建带过期时间的令牌"""
data = {"sub": "testuser"}
expires_delta = timedelta(hours=2)
token = create_access_token(data, expires_delta=expires_delta)
payload = decode_token(token)
assert payload["exp"] is not None
def test_create_refresh_token(self):
"""测试创建刷新令牌"""
data = {"sub": "testuser", "user_id": 1}
token = create_refresh_token(data)
assert token is not None
payload = decode_token(token)
assert payload["sub"] == "testuser"
assert payload["type"] == "refresh"
def test_decode_invalid_token(self):
"""测试解码无效令牌"""
invalid_token = "invalid.token.here"
with pytest.raises(Exception):
decode_token(invalid_token)
def test_decode_expired_token(self):
"""测试解码过期令牌"""
data = {"sub": "testuser"}
expires_delta = timedelta(seconds=-1) # 已过期
token = create_access_token(data, expires_delta=expires_delta)
with pytest.raises(Exception):
decode_token(token)
def test_generate_api_key(self):
"""测试生成 API Key"""
api_key = generate_api_key()
assert api_key is not None
assert len(api_key) == 64 # SHA256 哈希长度
def test_hash_api_key(self):
"""测试哈希 API Key"""
api_key = "test_api_key_123456"
hashed = hash_api_key(api_key)
assert hashed is not None
assert hashed != api_key
# ==================== K 线数据服务测试 ====================
class TestKlineService:
"""K 线数据服务测试"""
@patch('app.services.kline_service.TimescaleSessionLocal')
def test_get_kline_data(self, mock_session):
"""测试获取 K 线数据"""
# 模拟数据库查询结果
mock_db = MagicMock()
mock_session.return_value.__enter__.return_value = mock_db
mock_result = [
(datetime(2024, 1, 1, 10, 0), 4000.0, 4050.0, 3980.0, 4020.0, 1000, 4000000.0, 500),
(datetime(2024, 1, 1, 10, 5), 4020.0, 4080.0, 4010.0, 4060.0, 1200, 4800000.0, 520),
]
mock_db.execute.return_value.fetchall.return_value = mock_result
start = datetime(2024, 1, 1, 10, 0)
end = datetime(2024, 1, 1, 12, 0)
result = KlineService.get_kline_data("IF2406", "5m", start, end)
assert result is not None
assert len(result) == 2
assert result[0]["symbol"] == "IF2406"
assert result[0]["open"] == 4000.0
@patch('app.services.kline_service.TimescaleSessionLocal')
def test_get_latest_kline(self, mock_session):
"""测试获取最新 K 线"""
mock_db = MagicMock()
mock_session.return_value.__enter__.return_value = mock_db
mock_result = [(datetime(2024, 1, 1, 12, 0), 4100.0, 4150.0, 4080.0, 4120.0, 1500, 6000000.0, 600)]
mock_db.execute.return_value.fetchone.return_value = mock_result[0]
result = KlineService.get_latest_kline("IF2406", "5m")
assert result is not None
assert result["close"] == 4120.0
@patch('app.services.kline_service.TimescaleSessionLocal')
def test_get_symbols(self, mock_session):
"""测试获取品种列表"""
mock_db = MagicMock()
mock_session.return_value.__enter__.return_value = mock_db
mock_result = [("IF2406",), ("IC2406",), ("IH2406",)]
mock_db.execute.return_value.fetchall.return_value = mock_result
result = KlineService.get_symbols()
assert len(result) == 3
assert "IF2406" in result
@patch('app.services.kline_service.TimescaleSessionLocal')
def test_get_periods(self, mock_session):
"""测试获取周期列表"""
mock_db = MagicMock()
mock_session.return_value.__enter__.return_value = mock_db
mock_result = [("1m",), ("5m",), ("1h",), ("1d",)]
mock_db.execute.return_value.fetchall.return_value = mock_result
result = KlineService.get_periods()
assert len(result) == 4
assert "5m" in result
# ==================== 告警服务测试 ====================
class TestAlertService:
"""告警服务测试"""
@patch('app.services.alert_service.SQLiteSessionLocal')
def test_create_alert(self, mock_session):
"""测试创建告警"""
mock_db = MagicMock()
mock_session.return_value.__enter__.return_value = mock_db
mock_alert = Mock()
mock_alert.id = 1
mock_alert.user_id = 1
mock_alert.symbol = "IF2406"
mock_alert.condition_type = "greater_than"
mock_alert.condition_value = 4000.0
mock_alert.status = "active"
mock_db.add = Mock()
mock_db.commit = Mock()
mock_db.refresh = Mock()
# 模拟 add 后可以通过 query 获取
mock_db.query.return_value.filter.return_value.first.return_value = mock_alert
result = AlertService.create_alert(
user_id=1,
symbol="IF2406",
condition_type="greater_than",
condition_value=4000.0
)
assert mock_db.add.called
assert mock_db.commit.called
# ==================== 订阅服务测试 ====================
class TestSubscriptionService:
"""订阅服务测试"""
@patch('app.services.subscription_service.SQLiteSessionLocal')
def test_create_subscription(self, mock_session):
"""测试创建订阅"""
mock_db = MagicMock()
mock_session.return_value.__enter__.return_value = mock_db
# 模拟不存在已有订阅
mock_db.query.return_value.filter.return_value.first.return_value = None
mock_subscription = Mock()
mock_subscription.id = 1
mock_subscription.user_id = 1
mock_subscription.symbol = "IF2406"
mock_subscription.period = "5m"
mock_subscription.is_active = True
mock_db.add = Mock()
mock_db.commit = Mock()
mock_db.refresh = Mock()
mock_db.query.return_value.filter.return_value.first.return_value = mock_subscription
result = SubscriptionService.create_subscription(
user_id=1,
symbol="IF2406",
period="5m",
subscription_type="kline"
)
assert mock_db.add.called
assert mock_db.commit.called
@patch('app.services.subscription_service.SQLiteSessionLocal')
def test_create_duplicate_subscription(self, mock_session):
"""测试创建重复订阅"""
mock_db = MagicMock()
mock_session.return_value.__enter__.return_value = mock_db
# 模拟已存在订阅
existing_subscription = Mock()
existing_subscription.is_active = True
mock_db.query.return_value.filter.return_value.first.return_value = existing_subscription
result = SubscriptionService.create_subscription(
user_id=1,
symbol="IF2406",
period="5m",
subscription_type="kline"
)
# 已存在时不应调用 add
assert not mock_db.add.called
# ==================== API 端点测试 ====================
class TestHealthCheck:
"""健康检查测试"""
def test_health_check(self):
"""测试健康检查端点"""
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert "timestamp" in data
def test_root_endpoint(self):
"""测试根路径端点"""
response = client.get("/")
assert response.status_code == 200
data = response.json()
assert data["name"] == settings.APP_NAME
assert data["version"] == settings.APP_VERSION
class TestAuthAPI:
"""认证 API 测试"""
def test_login_missing_credentials(self):
"""测试登录缺少凭证"""
response = client.post("/api/v1/auth/login", data={})
assert response.status_code in [401, 422] # 401 或 422 都是合理的
def test_login_wrong_credentials(self):
"""测试登录错误凭证"""
response = client.post(
"/api/v1/auth/login",
data={"username": "nonexistent", "password": "wrongpassword"}
)
# 可能返回 401认证失败或 200如果数据库中有测试用户
assert response.status_code in [200, 401]
def test_docs_accessible(self):
"""测试 API 文档可访问"""
response = client.get("/docs")
assert response.status_code == 200
def test_openapi_schema(self):
"""测试 OpenAPI 模式"""
response = client.get("/openapi.json")
assert response.status_code == 200
data = response.json()
assert "paths" in data
assert "/api/v1/auth/login" in data["paths"]
class TestKlineAPI:
"""K 线数据 API 测试"""
def test_get_symbols_unauthorized(self):
"""测试未授权获取品种列表"""
response = client.get("/api/v1/kline/symbols")
# 应该需要认证
assert response.status_code in [200, 401]
def test_get_periods_unauthorized(self):
"""测试未授权获取周期列表"""
response = client.get("/api/v1/kline/periods")
assert response.status_code in [200, 401]
def test_get_kline_data_missing_params(self):
"""测试获取 K 线数据缺少参数"""
response = client.get("/api/v1/kline/data")
# 缺少必要参数应该返回 422
assert response.status_code in [401, 422]
class TestUserAPI:
"""用户管理 API 测试"""
def test_create_user(self):
"""测试创建用户"""
response = client.post(
"/api/v1/user",
json={
"username": "testuser_" + str(datetime.now().timestamp()),
"password": "testpass123",
"email": "test@example.com"
}
)
# 可能成功或因为用户名已存在而失败
assert response.status_code in [200, 400]
class TestAlertAPI:
"""告警 API 测试"""
def test_create_alert_unauthorized(self):
"""测试未授权创建告警"""
response = client.post(
"/api/v1/alert",
json={
"symbol": "IF2406",
"condition_type": "greater_than",
"condition_value": 4000.0
}
)
# 需要认证
assert response.status_code == 401
class TestSubscriptionAPI:
"""订阅 API 测试"""
def test_create_subscription_unauthorized(self):
"""测试未授权创建订阅"""
response = client.post(
"/api/v1/subscription",
json={
"symbol": "IF2406",
"subscription_type": "kline"
}
)
# 需要认证
assert response.status_code == 401
# ==================== 中间件测试 ====================
class TestMiddleware:
"""中间件测试"""
def test_cors_headers(self):
"""测试 CORS 头"""
response = client.options(
"/api/v1/kline/symbols",
headers={
"Origin": "http://localhost:3000",
"Access-Control-Request-Method": "GET"
}
)
# CORS 应该允许跨域
assert response.status_code in [200, 401]
# ==================== 集成测试 ====================
class TestIntegration:
"""集成测试"""
def test_full_auth_flow(self):
"""测试完整认证流程"""
# 1. 尝试登录(可能失败,因为数据库可能没有测试用户)
login_response = client.post(
"/api/v1/auth/login",
data={"username": "admin", "password": "admin123"}
)
# 如果登录成功,测试令牌使用
if login_response.status_code == 200:
token = login_response.json()["data"]["access_token"]
# 2. 使用令牌访问受保护端点
headers = {"Authorization": f"Bearer {token}"}
kline_response = client.get("/api/v1/kline/symbols", headers=headers)
assert kline_response.status_code == 200
# ==================== 性能测试 ====================
class TestPerformance:
"""性能测试"""
def test_health_check_response_time(self):
"""测试健康检查响应时间"""
import time
start = time.time()
response = client.get("/health")
elapsed = time.time() - start
assert response.status_code == 200
assert elapsed < 1.0 # 响应时间应小于 1 秒
if __name__ == "__main__":
pytest.main([__file__, "-v", "--cov=app", "--cov-report=html"])