""" 后端服务单元测试套件 测试覆盖: - 认证模块(登录/令牌/API Key) - K 线数据服务 - 实时行情服务 - 告警服务 - 订阅服务 - 用户管理 """ import pytest from datetime import datetime, timedelta from fastapi.testclient import TestClient from unittest.mock import Mock, patch, MagicMock from app.main import app from app.config import settings from app.services.auth_service import ( verify_password, get_password_hash, create_access_token, create_refresh_token, decode_token, generate_api_key, hash_api_key ) from app.services.kline_service import KlineService from app.services.alert_service import AlertService from app.services.subscription_service import SubscriptionService client = TestClient(app) # ==================== 认证服务测试 ==================== class TestAuthService: """认证服务测试""" def test_password_hashing(self): """测试密码哈希""" password = "test_password_123" hashed = get_password_hash(password) assert hashed is not None assert hashed != password assert len(hashed) > 50 def test_password_verification_success(self): """测试密码验证成功""" password = "test_password_123" hashed = get_password_hash(password) assert verify_password(password, hashed) is True def test_password_verification_failure(self): """测试密码验证失败""" password = "test_password_123" wrong_password = "wrong_password" hashed = get_password_hash(password) assert verify_password(wrong_password, hashed) is False def test_create_access_token(self): """测试创建访问令牌""" data = {"sub": "testuser", "user_id": 1} token = create_access_token(data) assert token is not None assert len(token) > 50 # 解码验证 payload = decode_token(token) assert payload["sub"] == "testuser" assert payload["user_id"] == 1 assert "exp" in payload def test_create_access_token_with_expiry(self): """测试创建带过期时间的令牌""" data = {"sub": "testuser"} expires_delta = timedelta(hours=2) token = create_access_token(data, expires_delta=expires_delta) payload = decode_token(token) assert payload["exp"] is not None def test_create_refresh_token(self): """测试创建刷新令牌""" data = {"sub": "testuser", "user_id": 1} token = create_refresh_token(data) assert token is not None payload = decode_token(token) assert payload["sub"] == "testuser" assert payload["type"] == "refresh" def test_decode_invalid_token(self): """测试解码无效令牌""" invalid_token = "invalid.token.here" with pytest.raises(Exception): decode_token(invalid_token) def test_decode_expired_token(self): """测试解码过期令牌""" data = {"sub": "testuser"} expires_delta = timedelta(seconds=-1) # 已过期 token = create_access_token(data, expires_delta=expires_delta) with pytest.raises(Exception): decode_token(token) def test_generate_api_key(self): """测试生成 API Key""" api_key = generate_api_key() assert api_key is not None assert len(api_key) == 64 # SHA256 哈希长度 def test_hash_api_key(self): """测试哈希 API Key""" api_key = "test_api_key_123456" hashed = hash_api_key(api_key) assert hashed is not None assert hashed != api_key # ==================== K 线数据服务测试 ==================== class TestKlineService: """K 线数据服务测试""" @patch('app.services.kline_service.TimescaleSessionLocal') def test_get_kline_data(self, mock_session): """测试获取 K 线数据""" # 模拟数据库查询结果 mock_db = MagicMock() mock_session.return_value.__enter__.return_value = mock_db mock_result = [ (datetime(2024, 1, 1, 10, 0), 4000.0, 4050.0, 3980.0, 4020.0, 1000, 4000000.0, 500), (datetime(2024, 1, 1, 10, 5), 4020.0, 4080.0, 4010.0, 4060.0, 1200, 4800000.0, 520), ] mock_db.execute.return_value.fetchall.return_value = mock_result start = datetime(2024, 1, 1, 10, 0) end = datetime(2024, 1, 1, 12, 0) result = KlineService.get_kline_data("IF2406", "5m", start, end) assert result is not None assert len(result) == 2 assert result[0]["symbol"] == "IF2406" assert result[0]["open"] == 4000.0 @patch('app.services.kline_service.TimescaleSessionLocal') def test_get_latest_kline(self, mock_session): """测试获取最新 K 线""" mock_db = MagicMock() mock_session.return_value.__enter__.return_value = mock_db mock_result = [(datetime(2024, 1, 1, 12, 0), 4100.0, 4150.0, 4080.0, 4120.0, 1500, 6000000.0, 600)] mock_db.execute.return_value.fetchone.return_value = mock_result[0] result = KlineService.get_latest_kline("IF2406", "5m") assert result is not None assert result["close"] == 4120.0 @patch('app.services.kline_service.TimescaleSessionLocal') def test_get_symbols(self, mock_session): """测试获取品种列表""" mock_db = MagicMock() mock_session.return_value.__enter__.return_value = mock_db mock_result = [("IF2406",), ("IC2406",), ("IH2406",)] mock_db.execute.return_value.fetchall.return_value = mock_result result = KlineService.get_symbols() assert len(result) == 3 assert "IF2406" in result @patch('app.services.kline_service.TimescaleSessionLocal') def test_get_periods(self, mock_session): """测试获取周期列表""" mock_db = MagicMock() mock_session.return_value.__enter__.return_value = mock_db mock_result = [("1m",), ("5m",), ("1h",), ("1d",)] mock_db.execute.return_value.fetchall.return_value = mock_result result = KlineService.get_periods() assert len(result) == 4 assert "5m" in result # ==================== 告警服务测试 ==================== class TestAlertService: """告警服务测试""" @patch('app.services.alert_service.SQLiteSessionLocal') def test_create_alert(self, mock_session): """测试创建告警""" mock_db = MagicMock() mock_session.return_value.__enter__.return_value = mock_db mock_alert = Mock() mock_alert.id = 1 mock_alert.user_id = 1 mock_alert.symbol = "IF2406" mock_alert.condition_type = "greater_than" mock_alert.condition_value = 4000.0 mock_alert.status = "active" mock_db.add = Mock() mock_db.commit = Mock() mock_db.refresh = Mock() # 模拟 add 后可以通过 query 获取 mock_db.query.return_value.filter.return_value.first.return_value = mock_alert result = AlertService.create_alert( user_id=1, symbol="IF2406", condition_type="greater_than", condition_value=4000.0 ) assert mock_db.add.called assert mock_db.commit.called # ==================== 订阅服务测试 ==================== class TestSubscriptionService: """订阅服务测试""" @patch('app.services.subscription_service.SQLiteSessionLocal') def test_create_subscription(self, mock_session): """测试创建订阅""" mock_db = MagicMock() mock_session.return_value.__enter__.return_value = mock_db # 模拟不存在已有订阅 mock_db.query.return_value.filter.return_value.first.return_value = None mock_subscription = Mock() mock_subscription.id = 1 mock_subscription.user_id = 1 mock_subscription.symbol = "IF2406" mock_subscription.period = "5m" mock_subscription.is_active = True mock_db.add = Mock() mock_db.commit = Mock() mock_db.refresh = Mock() mock_db.query.return_value.filter.return_value.first.return_value = mock_subscription result = SubscriptionService.create_subscription( user_id=1, symbol="IF2406", period="5m", subscription_type="kline" ) assert mock_db.add.called assert mock_db.commit.called @patch('app.services.subscription_service.SQLiteSessionLocal') def test_create_duplicate_subscription(self, mock_session): """测试创建重复订阅""" mock_db = MagicMock() mock_session.return_value.__enter__.return_value = mock_db # 模拟已存在订阅 existing_subscription = Mock() existing_subscription.is_active = True mock_db.query.return_value.filter.return_value.first.return_value = existing_subscription result = SubscriptionService.create_subscription( user_id=1, symbol="IF2406", period="5m", subscription_type="kline" ) # 已存在时不应调用 add assert not mock_db.add.called # ==================== API 端点测试 ==================== class TestHealthCheck: """健康检查测试""" def test_health_check(self): """测试健康检查端点""" response = client.get("/health") assert response.status_code == 200 data = response.json() assert data["status"] == "healthy" assert "timestamp" in data def test_root_endpoint(self): """测试根路径端点""" response = client.get("/") assert response.status_code == 200 data = response.json() assert data["name"] == settings.APP_NAME assert data["version"] == settings.APP_VERSION class TestAuthAPI: """认证 API 测试""" def test_login_missing_credentials(self): """测试登录缺少凭证""" response = client.post("/api/v1/auth/login", data={}) assert response.status_code in [401, 422] # 401 或 422 都是合理的 def test_login_wrong_credentials(self): """测试登录错误凭证""" response = client.post( "/api/v1/auth/login", data={"username": "nonexistent", "password": "wrongpassword"} ) # 可能返回 401(认证失败)或 200(如果数据库中有测试用户) assert response.status_code in [200, 401] def test_docs_accessible(self): """测试 API 文档可访问""" response = client.get("/docs") assert response.status_code == 200 def test_openapi_schema(self): """测试 OpenAPI 模式""" response = client.get("/openapi.json") assert response.status_code == 200 data = response.json() assert "paths" in data assert "/api/v1/auth/login" in data["paths"] class TestKlineAPI: """K 线数据 API 测试""" def test_get_symbols_unauthorized(self): """测试未授权获取品种列表""" response = client.get("/api/v1/kline/symbols") # 应该需要认证 assert response.status_code in [200, 401] def test_get_periods_unauthorized(self): """测试未授权获取周期列表""" response = client.get("/api/v1/kline/periods") assert response.status_code in [200, 401] def test_get_kline_data_missing_params(self): """测试获取 K 线数据缺少参数""" response = client.get("/api/v1/kline/data") # 缺少必要参数应该返回 422 assert response.status_code in [401, 422] class TestUserAPI: """用户管理 API 测试""" def test_create_user(self): """测试创建用户""" response = client.post( "/api/v1/user", json={ "username": "testuser_" + str(datetime.now().timestamp()), "password": "testpass123", "email": "test@example.com" } ) # 可能成功或因为用户名已存在而失败 assert response.status_code in [200, 400] class TestAlertAPI: """告警 API 测试""" def test_create_alert_unauthorized(self): """测试未授权创建告警""" response = client.post( "/api/v1/alert", json={ "symbol": "IF2406", "condition_type": "greater_than", "condition_value": 4000.0 } ) # 需要认证 assert response.status_code == 401 class TestSubscriptionAPI: """订阅 API 测试""" def test_create_subscription_unauthorized(self): """测试未授权创建订阅""" response = client.post( "/api/v1/subscription", json={ "symbol": "IF2406", "subscription_type": "kline" } ) # 需要认证 assert response.status_code == 401 # ==================== 中间件测试 ==================== class TestMiddleware: """中间件测试""" def test_cors_headers(self): """测试 CORS 头""" response = client.options( "/api/v1/kline/symbols", headers={ "Origin": "http://localhost:3000", "Access-Control-Request-Method": "GET" } ) # CORS 应该允许跨域 assert response.status_code in [200, 401] # ==================== 集成测试 ==================== class TestIntegration: """集成测试""" def test_full_auth_flow(self): """测试完整认证流程""" # 1. 尝试登录(可能失败,因为数据库可能没有测试用户) login_response = client.post( "/api/v1/auth/login", data={"username": "admin", "password": "admin123"} ) # 如果登录成功,测试令牌使用 if login_response.status_code == 200: token = login_response.json()["data"]["access_token"] # 2. 使用令牌访问受保护端点 headers = {"Authorization": f"Bearer {token}"} kline_response = client.get("/api/v1/kline/symbols", headers=headers) assert kline_response.status_code == 200 # ==================== 性能测试 ==================== class TestPerformance: """性能测试""" def test_health_check_response_time(self): """测试健康检查响应时间""" import time start = time.time() response = client.get("/health") elapsed = time.time() - start assert response.status_code == 200 assert elapsed < 1.0 # 响应时间应小于 1 秒 if __name__ == "__main__": pytest.main([__file__, "-v", "--cov=app", "--cov-report=html"])