You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
183 lines
6.9 KiB
183 lines
6.9 KiB
# Service API tests
|
|
import unittest
|
|
import json
|
|
import sys
|
|
import os
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
# 添加项目根目录到 Python 路径
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
# 直接导入 app 模块
|
|
from service.app import app
|
|
|
|
class ServiceAPITest(unittest.TestCase):
|
|
def setUp(self):
|
|
# 创建测试客户端
|
|
self.client = app.test_client()
|
|
self.client.testing = True
|
|
|
|
@patch('service.app.DataFetcher')
|
|
def test_health_check(self, mock_data_fetcher):
|
|
"""测试健康检查接口"""
|
|
response = self.client.get('/health')
|
|
self.assertEqual(response.status_code, 200)
|
|
data = json.loads(response.data)
|
|
self.assertEqual(data['status'], 'ok')
|
|
self.assertEqual(data['message'], 'Service is running')
|
|
|
|
@patch('service.app.DataFetcher')
|
|
def test_get_contracts(self, mock_data_fetcher):
|
|
"""测试合约数据获取接口"""
|
|
# 配置 mock
|
|
mock_fetcher_instance = MagicMock()
|
|
mock_data_fetcher.return_value = mock_fetcher_instance
|
|
|
|
# 模拟 get_contracts 方法
|
|
mock_fetcher_instance.get_contracts.return_value = [
|
|
{'symbol': 'CU2603', 'product': 'CU', 'product_name': '铜', 'exchange': 'SHFE', 'month': '2603'},
|
|
{'symbol': 'AL2603', 'product': 'AL', 'product_name': '铝', 'exchange': 'SHFE', 'month': '2603'}
|
|
]
|
|
|
|
# 测试获取所有合约
|
|
response = self.client.get('/api/contracts')
|
|
self.assertEqual(response.status_code, 200)
|
|
data = json.loads(response.data)
|
|
self.assertEqual(data['status'], 'success')
|
|
self.assertIsInstance(data['data'], list)
|
|
self.assertGreater(len(data['data']), 0)
|
|
|
|
@patch('service.app.DataStorage')
|
|
@patch('service.app.DataFetcher')
|
|
def test_get_kline(self, mock_data_fetcher, mock_data_storage):
|
|
"""测试K线数据获取接口"""
|
|
# 配置 mock
|
|
mock_fetcher_instance = MagicMock()
|
|
mock_data_fetcher.return_value = mock_fetcher_instance
|
|
|
|
mock_storage_instance = MagicMock()
|
|
mock_data_storage.return_value = mock_storage_instance
|
|
|
|
# 模拟数据
|
|
mock_df = MagicMock()
|
|
mock_df.empty = False
|
|
mock_df.iterrows.return_value = [(MagicMock(isoformat=lambda: '2026-02-22T00:00:00'), \
|
|
{'open': 35000, 'high': 35100, 'low': 34900, 'close': 35050, 'volume': 1000, 'open_interest': 10000})]
|
|
|
|
mock_storage_instance.get_kline_data.return_value = mock_df
|
|
mock_fetcher_instance.get_kline_data.return_value = mock_df
|
|
mock_storage_instance.save_kline_data.return_value = True
|
|
|
|
# 测试获取K线数据
|
|
response = self.client.get('/api/kline?symbol=CU2603&duration=1m&limit=10')
|
|
self.assertEqual(response.status_code, 200)
|
|
data = json.loads(response.data)
|
|
self.assertEqual(data['status'], 'success')
|
|
self.assertIsInstance(data['data'], list)
|
|
self.assertGreater(len(data['data']), 0)
|
|
|
|
@patch('service.app.DataStorage')
|
|
@patch('service.app.DeepseekAgent')
|
|
@patch('service.app.DataFetcher')
|
|
def test_analyze(self, mock_data_fetcher, mock_deepseek_agent, mock_data_storage):
|
|
"""测试DeepSeek分析接口"""
|
|
# 配置 mock
|
|
mock_fetcher_instance = MagicMock()
|
|
mock_data_fetcher.return_value = mock_fetcher_instance
|
|
|
|
mock_agent_instance = MagicMock()
|
|
mock_deepseek_agent.return_value = mock_agent_instance
|
|
|
|
mock_storage_instance = MagicMock()
|
|
mock_data_storage.return_value = mock_storage_instance
|
|
|
|
# 模拟数据
|
|
mock_df = MagicMock()
|
|
mock_df.empty = False
|
|
mock_fetcher_instance.get_kline_data.return_value = mock_df
|
|
|
|
# 模拟分析结果
|
|
mock_agent_instance.analyze_market.return_value = {
|
|
'symbol': 'CU2603',
|
|
'timestamp': '2026-02-22T00:00:00',
|
|
'trend': 'up',
|
|
'probability': 0.8,
|
|
'direction': 'buy'
|
|
}
|
|
|
|
mock_storage_instance.save_kline_data.return_value = True
|
|
mock_storage_instance.save_analysis_result.return_value = True
|
|
|
|
# 测试分析接口
|
|
test_data = {
|
|
'symbol': 'CU2603',
|
|
'duration': '1m',
|
|
'analysis_type': 'technical'
|
|
}
|
|
response = self.client.post('/api/analyze', json=test_data)
|
|
self.assertEqual(response.status_code, 200)
|
|
data = json.loads(response.data)
|
|
self.assertEqual(data['status'], 'success')
|
|
self.assertIn('data', data)
|
|
|
|
@patch('service.app.DataStorage')
|
|
def test_get_recommendations(self, mock_data_storage):
|
|
"""测试交易建议接口"""
|
|
# 配置 mock
|
|
mock_storage_instance = MagicMock()
|
|
mock_data_storage.return_value = mock_storage_instance
|
|
|
|
# 模拟数据
|
|
mock_df = MagicMock()
|
|
mock_storage_instance.get_trade_recommendations.return_value = mock_df
|
|
|
|
# 测试获取交易建议
|
|
response = self.client.get('/api/recommendations?symbol=CU2603')
|
|
self.assertEqual(response.status_code, 200)
|
|
data = json.loads(response.data)
|
|
self.assertEqual(data['status'], 'success')
|
|
self.assertIsInstance(data['data'], list)
|
|
|
|
@patch('service.app.DataStorage')
|
|
def test_monitor_risk(self, mock_data_storage):
|
|
"""测试风险监控接口"""
|
|
# 配置 mock
|
|
mock_storage_instance = MagicMock()
|
|
mock_data_storage.return_value = mock_storage_instance
|
|
mock_storage_instance.save_risk_monitoring.return_value = True
|
|
|
|
# 测试风险监控
|
|
test_data = {
|
|
'symbol': 'CU2603',
|
|
'current_price': 36000,
|
|
'entry_price': 35000,
|
|
'stop_loss': 34500,
|
|
'target_price': 37000
|
|
}
|
|
response = self.client.post('/api/risk', json=test_data)
|
|
self.assertEqual(response.status_code, 200)
|
|
data = json.loads(response.data)
|
|
self.assertEqual(data['status'], 'success')
|
|
self.assertIn('data', data)
|
|
self.assertEqual(data['data']['symbol'], 'CU2603')
|
|
|
|
@patch('service.app.DataStorage')
|
|
def test_get_analysis_history(self, mock_data_storage):
|
|
"""测试分析历史接口"""
|
|
# 配置 mock
|
|
mock_storage_instance = MagicMock()
|
|
mock_data_storage.return_value = mock_storage_instance
|
|
|
|
# 模拟数据
|
|
mock_df = MagicMock()
|
|
mock_storage_instance.get_analysis_results.return_value = mock_df
|
|
|
|
# 测试获取分析历史
|
|
response = self.client.get('/api/analysis/history?symbol=CU2603&limit=10')
|
|
self.assertEqual(response.status_code, 200)
|
|
data = json.loads(response.data)
|
|
self.assertEqual(data['status'], 'success')
|
|
self.assertIsInstance(data['data'], list)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |