You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

495 lines
15 KiB

"""
后端服务单元测试套件
测试覆盖
- 认证模块登录/令牌/API Key
- K 线数据服务
- 实时行情服务
- 告警服务
- 订阅服务
- 用户管理
"""
import pytest
from datetime import datetime, timedelta
from fastapi.testclient import TestClient
from unittest.mock import Mock, patch, MagicMock
from app.main import app
from app.config import settings
from app.services.auth_service import (
verify_password,
get_password_hash,
create_access_token,
create_refresh_token,
decode_token,
generate_api_key,
hash_api_key
)
from app.services.kline_service import KlineService
from app.services.alert_service import AlertService
from app.services.subscription_service import SubscriptionService
client = TestClient(app)
# ==================== 认证服务测试 ====================
class TestAuthService:
"""认证服务测试"""
def test_password_hashing(self):
"""测试密码哈希"""
password = "test_password_123"
hashed = get_password_hash(password)
assert hashed is not None
assert hashed != password
assert len(hashed) > 50
def test_password_verification_success(self):
"""测试密码验证成功"""
password = "test_password_123"
hashed = get_password_hash(password)
assert verify_password(password, hashed) is True
def test_password_verification_failure(self):
"""测试密码验证失败"""
password = "test_password_123"
wrong_password = "wrong_password"
hashed = get_password_hash(password)
assert verify_password(wrong_password, hashed) is False
def test_create_access_token(self):
"""测试创建访问令牌"""
data = {"sub": "testuser", "user_id": 1}
token = create_access_token(data)
assert token is not None
assert len(token) > 50
# 解码验证
payload = decode_token(token)
assert payload["sub"] == "testuser"
assert payload["user_id"] == 1
assert "exp" in payload
def test_create_access_token_with_expiry(self):
"""测试创建带过期时间的令牌"""
data = {"sub": "testuser"}
expires_delta = timedelta(hours=2)
token = create_access_token(data, expires_delta=expires_delta)
payload = decode_token(token)
assert payload["exp"] is not None
def test_create_refresh_token(self):
"""测试创建刷新令牌"""
data = {"sub": "testuser", "user_id": 1}
token = create_refresh_token(data)
assert token is not None
payload = decode_token(token)
assert payload["sub"] == "testuser"
assert payload["type"] == "refresh"
def test_decode_invalid_token(self):
"""测试解码无效令牌"""
invalid_token = "invalid.token.here"
with pytest.raises(Exception):
decode_token(invalid_token)
def test_decode_expired_token(self):
"""测试解码过期令牌"""
data = {"sub": "testuser"}
expires_delta = timedelta(seconds=-1) # 已过期
token = create_access_token(data, expires_delta=expires_delta)
with pytest.raises(Exception):
decode_token(token)
def test_generate_api_key(self):
"""测试生成 API Key"""
api_key = generate_api_key()
assert api_key is not None
assert len(api_key) == 64 # SHA256 哈希长度
def test_hash_api_key(self):
"""测试哈希 API Key"""
api_key = "test_api_key_123456"
hashed = hash_api_key(api_key)
assert hashed is not None
assert hashed != api_key
# ==================== K 线数据服务测试 ====================
class TestKlineService:
"""K 线数据服务测试"""
@patch('app.services.kline_service.TimescaleSessionLocal')
def test_get_kline_data(self, mock_session):
"""测试获取 K 线数据"""
# 模拟数据库查询结果
mock_db = MagicMock()
mock_session.return_value.__enter__.return_value = mock_db
mock_result = [
(datetime(2024, 1, 1, 10, 0), 4000.0, 4050.0, 3980.0, 4020.0, 1000, 4000000.0, 500),
(datetime(2024, 1, 1, 10, 5), 4020.0, 4080.0, 4010.0, 4060.0, 1200, 4800000.0, 520),
]
mock_db.execute.return_value.fetchall.return_value = mock_result
start = datetime(2024, 1, 1, 10, 0)
end = datetime(2024, 1, 1, 12, 0)
result = KlineService.get_kline_data("IF2406", "5m", start, end)
assert result is not None
assert len(result) == 2
assert result[0]["symbol"] == "IF2406"
assert result[0]["open"] == 4000.0
@patch('app.services.kline_service.TimescaleSessionLocal')
def test_get_latest_kline(self, mock_session):
"""测试获取最新 K 线"""
mock_db = MagicMock()
mock_session.return_value.__enter__.return_value = mock_db
mock_result = [(datetime(2024, 1, 1, 12, 0), 4100.0, 4150.0, 4080.0, 4120.0, 1500, 6000000.0, 600)]
mock_db.execute.return_value.fetchone.return_value = mock_result[0]
result = KlineService.get_latest_kline("IF2406", "5m")
assert result is not None
assert result["close"] == 4120.0
@patch('app.services.kline_service.TimescaleSessionLocal')
def test_get_symbols(self, mock_session):
"""测试获取品种列表"""
mock_db = MagicMock()
mock_session.return_value.__enter__.return_value = mock_db
mock_result = [("IF2406",), ("IC2406",), ("IH2406",)]
mock_db.execute.return_value.fetchall.return_value = mock_result
result = KlineService.get_symbols()
assert len(result) == 3
assert "IF2406" in result
@patch('app.services.kline_service.TimescaleSessionLocal')
def test_get_periods(self, mock_session):
"""测试获取周期列表"""
mock_db = MagicMock()
mock_session.return_value.__enter__.return_value = mock_db
mock_result = [("1m",), ("5m",), ("1h",), ("1d",)]
mock_db.execute.return_value.fetchall.return_value = mock_result
result = KlineService.get_periods()
assert len(result) == 4
assert "5m" in result
# ==================== 告警服务测试 ====================
class TestAlertService:
"""告警服务测试"""
@patch('app.services.alert_service.SQLiteSessionLocal')
def test_create_alert(self, mock_session):
"""测试创建告警"""
mock_db = MagicMock()
mock_session.return_value.__enter__.return_value = mock_db
mock_alert = Mock()
mock_alert.id = 1
mock_alert.user_id = 1
mock_alert.symbol = "IF2406"
mock_alert.condition_type = "greater_than"
mock_alert.condition_value = 4000.0
mock_alert.status = "active"
mock_db.add = Mock()
mock_db.commit = Mock()
mock_db.refresh = Mock()
# 模拟 add 后可以通过 query 获取
mock_db.query.return_value.filter.return_value.first.return_value = mock_alert
result = AlertService.create_alert(
user_id=1,
symbol="IF2406",
condition_type="greater_than",
condition_value=4000.0
)
assert mock_db.add.called
assert mock_db.commit.called
# ==================== 订阅服务测试 ====================
class TestSubscriptionService:
"""订阅服务测试"""
@patch('app.services.subscription_service.SQLiteSessionLocal')
def test_create_subscription(self, mock_session):
"""测试创建订阅"""
mock_db = MagicMock()
mock_session.return_value.__enter__.return_value = mock_db
# 模拟不存在已有订阅
mock_db.query.return_value.filter.return_value.first.return_value = None
mock_subscription = Mock()
mock_subscription.id = 1
mock_subscription.user_id = 1
mock_subscription.symbol = "IF2406"
mock_subscription.period = "5m"
mock_subscription.is_active = True
mock_db.add = Mock()
mock_db.commit = Mock()
mock_db.refresh = Mock()
mock_db.query.return_value.filter.return_value.first.return_value = mock_subscription
result = SubscriptionService.create_subscription(
user_id=1,
symbol="IF2406",
period="5m",
subscription_type="kline"
)
assert mock_db.add.called
assert mock_db.commit.called
@patch('app.services.subscription_service.SQLiteSessionLocal')
def test_create_duplicate_subscription(self, mock_session):
"""测试创建重复订阅"""
mock_db = MagicMock()
mock_session.return_value.__enter__.return_value = mock_db
# 模拟已存在订阅
existing_subscription = Mock()
existing_subscription.is_active = True
mock_db.query.return_value.filter.return_value.first.return_value = existing_subscription
result = SubscriptionService.create_subscription(
user_id=1,
symbol="IF2406",
period="5m",
subscription_type="kline"
)
# 已存在时不应调用 add
assert not mock_db.add.called
# ==================== API 端点测试 ====================
class TestHealthCheck:
"""健康检查测试"""
def test_health_check(self):
"""测试健康检查端点"""
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert "timestamp" in data
def test_root_endpoint(self):
"""测试根路径端点"""
response = client.get("/")
assert response.status_code == 200
data = response.json()
assert data["name"] == settings.APP_NAME
assert data["version"] == settings.APP_VERSION
class TestAuthAPI:
"""认证 API 测试"""
def test_login_missing_credentials(self):
"""测试登录缺少凭证"""
response = client.post("/api/v1/auth/login", data={})
assert response.status_code in [401, 422] # 401 或 422 都是合理的
def test_login_wrong_credentials(self):
"""测试登录错误凭证"""
response = client.post(
"/api/v1/auth/login",
data={"username": "nonexistent", "password": "wrongpassword"}
)
# 可能返回 401认证失败或 200如果数据库中有测试用户
assert response.status_code in [200, 401]
def test_docs_accessible(self):
"""测试 API 文档可访问"""
response = client.get("/docs")
assert response.status_code == 200
def test_openapi_schema(self):
"""测试 OpenAPI 模式"""
response = client.get("/openapi.json")
assert response.status_code == 200
data = response.json()
assert "paths" in data
assert "/api/v1/auth/login" in data["paths"]
class TestKlineAPI:
"""K 线数据 API 测试"""
def test_get_symbols_unauthorized(self):
"""测试未授权获取品种列表"""
response = client.get("/api/v1/kline/symbols")
# 应该需要认证
assert response.status_code in [200, 401]
def test_get_periods_unauthorized(self):
"""测试未授权获取周期列表"""
response = client.get("/api/v1/kline/periods")
assert response.status_code in [200, 401]
def test_get_kline_data_missing_params(self):
"""测试获取 K 线数据缺少参数"""
response = client.get("/api/v1/kline/data")
# 缺少必要参数应该返回 422
assert response.status_code in [401, 422]
class TestUserAPI:
"""用户管理 API 测试"""
def test_create_user(self):
"""测试创建用户"""
response = client.post(
"/api/v1/user",
json={
"username": "testuser_" + str(datetime.now().timestamp()),
"password": "testpass123",
"email": "test@example.com"
}
)
# 可能成功或因为用户名已存在而失败
assert response.status_code in [200, 400]
class TestAlertAPI:
"""告警 API 测试"""
def test_create_alert_unauthorized(self):
"""测试未授权创建告警"""
response = client.post(
"/api/v1/alert",
json={
"symbol": "IF2406",
"condition_type": "greater_than",
"condition_value": 4000.0
}
)
# 需要认证
assert response.status_code == 401
class TestSubscriptionAPI:
"""订阅 API 测试"""
def test_create_subscription_unauthorized(self):
"""测试未授权创建订阅"""
response = client.post(
"/api/v1/subscription",
json={
"symbol": "IF2406",
"subscription_type": "kline"
}
)
# 需要认证
assert response.status_code == 401
# ==================== 中间件测试 ====================
class TestMiddleware:
"""中间件测试"""
def test_cors_headers(self):
"""测试 CORS 头"""
response = client.options(
"/api/v1/kline/symbols",
headers={
"Origin": "http://localhost:3000",
"Access-Control-Request-Method": "GET"
}
)
# CORS 应该允许跨域
assert response.status_code in [200, 401]
# ==================== 集成测试 ====================
class TestIntegration:
"""集成测试"""
def test_full_auth_flow(self):
"""测试完整认证流程"""
# 1. 尝试登录(可能失败,因为数据库可能没有测试用户)
login_response = client.post(
"/api/v1/auth/login",
data={"username": "admin", "password": "admin123"}
)
# 如果登录成功,测试令牌使用
if login_response.status_code == 200:
token = login_response.json()["data"]["access_token"]
# 2. 使用令牌访问受保护端点
headers = {"Authorization": f"Bearer {token}"}
kline_response = client.get("/api/v1/kline/symbols", headers=headers)
assert kline_response.status_code == 200
# ==================== 性能测试 ====================
class TestPerformance:
"""性能测试"""
def test_health_check_response_time(self):
"""测试健康检查响应时间"""
import time
start = time.time()
response = client.get("/health")
elapsed = time.time() - start
assert response.status_code == 200
assert elapsed < 1.0 # 响应时间应小于 1 秒
if __name__ == "__main__":
pytest.main([__file__, "-v", "--cov=app", "--cov-report=html"])