You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

183 lines
6.9 KiB

# Service API tests
import unittest
import json
import sys
import os
from unittest.mock import patch, MagicMock
# 添加项目根目录到 Python 路径
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# 直接导入 app 模块
from service.app import app
class ServiceAPITest(unittest.TestCase):
def setUp(self):
# 创建测试客户端
self.client = app.test_client()
self.client.testing = True
@patch('service.app.DataFetcher')
def test_health_check(self, mock_data_fetcher):
"""测试健康检查接口"""
response = self.client.get('/health')
self.assertEqual(response.status_code, 200)
data = json.loads(response.data)
self.assertEqual(data['status'], 'ok')
self.assertEqual(data['message'], 'Service is running')
@patch('service.app.DataFetcher')
def test_get_contracts(self, mock_data_fetcher):
"""测试合约数据获取接口"""
# 配置 mock
mock_fetcher_instance = MagicMock()
mock_data_fetcher.return_value = mock_fetcher_instance
# 模拟 get_contracts 方法
mock_fetcher_instance.get_contracts.return_value = [
{'symbol': 'CU2603', 'product': 'CU', 'product_name': '', 'exchange': 'SHFE', 'month': '2603'},
{'symbol': 'AL2603', 'product': 'AL', 'product_name': '', 'exchange': 'SHFE', 'month': '2603'}
]
# 测试获取所有合约
response = self.client.get('/api/contracts')
self.assertEqual(response.status_code, 200)
data = json.loads(response.data)
self.assertEqual(data['status'], 'success')
self.assertIsInstance(data['data'], list)
self.assertGreater(len(data['data']), 0)
@patch('service.app.DataStorage')
@patch('service.app.DataFetcher')
def test_get_kline(self, mock_data_fetcher, mock_data_storage):
"""测试K线数据获取接口"""
# 配置 mock
mock_fetcher_instance = MagicMock()
mock_data_fetcher.return_value = mock_fetcher_instance
mock_storage_instance = MagicMock()
mock_data_storage.return_value = mock_storage_instance
# 模拟数据
mock_df = MagicMock()
mock_df.empty = False
mock_df.iterrows.return_value = [(MagicMock(isoformat=lambda: '2026-02-22T00:00:00'), \
{'open': 35000, 'high': 35100, 'low': 34900, 'close': 35050, 'volume': 1000, 'open_interest': 10000})]
mock_storage_instance.get_kline_data.return_value = mock_df
mock_fetcher_instance.get_kline_data.return_value = mock_df
mock_storage_instance.save_kline_data.return_value = True
# 测试获取K线数据
response = self.client.get('/api/kline?symbol=CU2603&duration=1m&limit=10')
self.assertEqual(response.status_code, 200)
data = json.loads(response.data)
self.assertEqual(data['status'], 'success')
self.assertIsInstance(data['data'], list)
self.assertGreater(len(data['data']), 0)
@patch('service.app.DataStorage')
@patch('service.app.DeepseekAgent')
@patch('service.app.DataFetcher')
def test_analyze(self, mock_data_fetcher, mock_deepseek_agent, mock_data_storage):
"""测试DeepSeek分析接口"""
# 配置 mock
mock_fetcher_instance = MagicMock()
mock_data_fetcher.return_value = mock_fetcher_instance
mock_agent_instance = MagicMock()
mock_deepseek_agent.return_value = mock_agent_instance
mock_storage_instance = MagicMock()
mock_data_storage.return_value = mock_storage_instance
# 模拟数据
mock_df = MagicMock()
mock_df.empty = False
mock_fetcher_instance.get_kline_data.return_value = mock_df
# 模拟分析结果
mock_agent_instance.analyze_market.return_value = {
'symbol': 'CU2603',
'timestamp': '2026-02-22T00:00:00',
'trend': 'up',
'probability': 0.8,
'direction': 'buy'
}
mock_storage_instance.save_kline_data.return_value = True
mock_storage_instance.save_analysis_result.return_value = True
# 测试分析接口
test_data = {
'symbol': 'CU2603',
'duration': '1m',
'analysis_type': 'technical'
}
response = self.client.post('/api/analyze', json=test_data)
self.assertEqual(response.status_code, 200)
data = json.loads(response.data)
self.assertEqual(data['status'], 'success')
self.assertIn('data', data)
@patch('service.app.DataStorage')
def test_get_recommendations(self, mock_data_storage):
"""测试交易建议接口"""
# 配置 mock
mock_storage_instance = MagicMock()
mock_data_storage.return_value = mock_storage_instance
# 模拟数据
mock_df = MagicMock()
mock_storage_instance.get_trade_recommendations.return_value = mock_df
# 测试获取交易建议
response = self.client.get('/api/recommendations?symbol=CU2603')
self.assertEqual(response.status_code, 200)
data = json.loads(response.data)
self.assertEqual(data['status'], 'success')
self.assertIsInstance(data['data'], list)
@patch('service.app.DataStorage')
def test_monitor_risk(self, mock_data_storage):
"""测试风险监控接口"""
# 配置 mock
mock_storage_instance = MagicMock()
mock_data_storage.return_value = mock_storage_instance
mock_storage_instance.save_risk_monitoring.return_value = True
# 测试风险监控
test_data = {
'symbol': 'CU2603',
'current_price': 36000,
'entry_price': 35000,
'stop_loss': 34500,
'target_price': 37000
}
response = self.client.post('/api/risk', json=test_data)
self.assertEqual(response.status_code, 200)
data = json.loads(response.data)
self.assertEqual(data['status'], 'success')
self.assertIn('data', data)
self.assertEqual(data['data']['symbol'], 'CU2603')
@patch('service.app.DataStorage')
def test_get_analysis_history(self, mock_data_storage):
"""测试分析历史接口"""
# 配置 mock
mock_storage_instance = MagicMock()
mock_data_storage.return_value = mock_storage_instance
# 模拟数据
mock_df = MagicMock()
mock_storage_instance.get_analysis_results.return_value = mock_df
# 测试获取分析历史
response = self.client.get('/api/analysis/history?symbol=CU2603&limit=10')
self.assertEqual(response.status_code, 200)
data = json.loads(response.data)
self.assertEqual(data['status'], 'success')
self.assertIsInstance(data['data'], list)
if __name__ == '__main__':
unittest.main()