You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

347 lines
11 KiB

"""
星耀数智(AmazingData)适配器单元测试
测试覆盖:
- 适配器初始化和配置
- 连接/断开连接
- 基础数据获取
- K线数据获取
- 财务数据获取
- 错误处理
"""
import unittest
import asyncio
import sys
import os
from datetime import date, datetime
from unittest.mock import Mock, patch, MagicMock
import pandas as pd
# 添加项目根目录到路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.adapters.amazingdata_adapter import (
AmazingDataAdapter, AmazingDataConfig, SecurityType, Market, Period
)
from app.adapters.base import KLineData, SymbolInfo, TradeCalData
class TestAmazingDataConfig(unittest.TestCase):
"""测试配置类"""
def test_default_config(self):
"""测试默认配置"""
config = AmazingDataConfig(
username='test_user',
password='test_pass',
host='localhost',
port=8080
)
self.assertEqual(config.username, 'test_user')
self.assertEqual(config.local_path, './amazing_data_cache/')
self.assertTrue(config.use_local_cache)
def test_custom_config(self):
"""测试自定义配置"""
config = AmazingDataConfig(
username='user',
password='pass',
host='192.168.1.1',
port=9090,
local_path='/custom/path/',
use_local_cache=False
)
self.assertEqual(config.port, 9090)
self.assertEqual(config.local_path, '/custom/path/')
self.assertFalse(config.use_local_cache)
class TestAmazingDataAdapter(unittest.TestCase):
"""测试适配器类"""
def setUp(self):
"""测试前准备"""
self.adapter = AmazingDataAdapter()
self.test_config = {
'username': 'test_user',
'password': 'test_pass',
'host': 'localhost',
'port': 8080,
'local_path': './test_cache/',
'use_local_cache': False
}
def tearDown(self):
"""测试后清理"""
if self.adapter._connected:
asyncio.get_event_loop().run_until_complete(self.adapter.close())
def test_initial_state(self):
"""测试初始状态"""
self.assertIsNone(self.adapter.config)
self.assertIsNone(self.adapter._ad)
self.assertFalse(self.adapter._is_logged_in)
self.assertFalse(self.adapter._connected)
def test_format_date(self):
"""测试日期格式化"""
# 测试整数
self.assertEqual(self.adapter._format_date(20240101), 20240101)
# 测试字符串 (带横线)
self.assertEqual(self.adapter._format_date('2024-01-01'), 20240101)
# 测试字符串 (带斜杠)
self.assertEqual(self.adapter._format_date('2024/01/01'), 20240101)
# 测试纯字符串数字
self.assertEqual(self.adapter._format_date('20240101'), 20240101)
# 测试date对象
self.assertEqual(self.adapter._format_date(date(2024, 1, 1)), 20240101)
# 测试datetime对象
self.assertEqual(self.adapter._format_date(datetime(2024, 1, 1)), 20240101)
def test_format_date_invalid(self):
"""测试无效日期格式"""
with self.assertRaises(ValueError):
self.adapter._format_date(None)
with self.assertRaises(ValueError):
self.adapter._format_date([])
def test_check_login_not_logged_in(self):
"""测试未登录检查"""
with self.assertRaises(RuntimeError) as context:
self.adapter._check_login()
self.assertIn('未连接到数据源', str(context.exception))
@patch('app.adapters.amazingdata_adapter.AmazingDataAdapter._check_login')
def test_check_login_logged_in(self, mock_check):
"""测试已登录检查"""
self.adapter._is_logged_in = True
# 不应该抛出异常
self.adapter._check_login()
class TestAmazingDataAdapterAsync(unittest.IsolatedAsyncioTestCase):
"""测试适配器异步方法"""
async def asyncSetUp(self):
"""异步测试前准备"""
self.adapter = AmazingDataAdapter()
self.test_config = {
'username': 'test_user',
'password': 'test_pass',
'host': 'localhost',
'port': 8080
}
async def asyncTearDown(self):
"""异步测试后清理"""
if self.adapter._connected:
await self.adapter.close()
@patch('app.adapters.amazingdata_adapter.AmazingDataAdapter._do_login')
async def test_connect_success(self, mock_login):
"""测试连接成功"""
mock_login.return_value = None
# Mock AmazingData模块
mock_ad = Mock()
mock_ad.BaseData = Mock
mock_ad.InfoData = Mock
mock_ad.MarketData = Mock
mock_ad.constant.Period = Mock()
with patch.dict('sys.modules', {'AmazingData': mock_ad}):
await self.adapter.connect(self.test_config)
self.assertTrue(self.adapter._connected)
self.assertIsNotNone(self.adapter.config)
@patch.dict('sys.modules', {'AmazingData': None})
async def test_connect_import_error(self):
"""测试SDK未安装"""
with self.assertRaises(RuntimeError) as context:
await self.adapter.connect(self.test_config)
self.assertIn('AmazingData SDK 未安装', str(context.exception))
async def test_close_not_connected(self):
"""测试关闭未连接状态"""
# 不应该抛出异常
await self.adapter.close()
self.assertFalse(self.adapter._is_logged_in)
class TestFetchKlines(unittest.IsolatedAsyncioTestCase):
"""测试K线数据获取"""
async def asyncSetUp(self):
"""异步测试前准备"""
self.adapter = AmazingDataAdapter()
self.adapter._is_logged_in = True
self.adapter._ad = Mock()
self.adapter._ad.constant.Period = Mock()
self.adapter._ad.constant.Period.daily = Mock()
self.adapter._ad.constant.Period.daily.value = 'daily'
# Mock MarketData
self.mock_market_data = Mock()
self.adapter._market_data = self.mock_market_data
async def test_fetch_klines_empty_result(self):
"""测试获取空K线数据"""
self.mock_market_data.query_kline.return_value = {}
result = await self.adapter.fetch_klines(
symbol='000001.SZ',
start='20240101',
end='20241231',
freq='1d'
)
self.assertEqual(result, [])
async def test_fetch_klines_with_data(self):
"""测试获取K线数据"""
# 创建测试DataFrame
df = pd.DataFrame({
'open': [10.0, 11.0],
'high': [11.0, 12.0],
'low': [9.0, 10.0],
'close': [10.5, 11.5],
'volume': [10000, 20000],
'amount': [105000, 230000]
}, index=pd.to_datetime(['2024-01-01', '2024-01-02']))
self.mock_market_data.query_kline.return_value = {'000001.SZ': df}
result = await self.adapter.fetch_klines(
symbol='000001.SZ',
start='20240101',
end='20240102',
freq='1d'
)
self.assertEqual(len(result), 2)
self.assertIsInstance(result[0], KLineData)
self.assertEqual(result[0].symbol, '000001.SZ')
self.assertEqual(result[0].open, 10.0)
class TestFetchSymbols(unittest.IsolatedAsyncioTestCase):
"""测试标的列表获取"""
async def asyncSetUp(self):
"""异步测试前准备"""
self.adapter = AmazingDataAdapter()
self.adapter._is_logged_in = True
self.adapter._ad = Mock()
# Mock BaseData
self.mock_base_data = Mock()
self.adapter._base_data = self.mock_base_data
async def test_fetch_stock_symbols(self):
"""测试获取股票列表"""
self.mock_base_data.get_code_list.return_value = [
'000001.SZ', '600000.SH'
]
self.mock_base_data.get_code_info.return_value = pd.DataFrame({
'symbol': ['平安银行', '浦发银行']
}, index=['000001.SZ', '600000.SH'])
result = await self.adapter.fetch_symbols('stock')
self.assertEqual(len(result), 2)
self.assertIsInstance(result[0], SymbolInfo)
self.assertEqual(result[0].symbol_id, '000001.SZ')
self.assertEqual(result[0].exchange, 'SZ')
async def test_fetch_futures_symbols(self):
"""测试获取期货列表"""
self.mock_base_data.get_future_code_list.return_value = [
'cu2501', 'al2502'
]
result = await self.adapter.fetch_symbols('futures')
self.assertEqual(len(result), 2)
self.assertEqual(result[0].underlying, 'CU')
class TestTradingCalendar(unittest.IsolatedAsyncioTestCase):
"""测试交易日历获取"""
async def asyncSetUp(self):
"""异步测试前准备"""
self.adapter = AmazingDataAdapter()
self.adapter._is_logged_in = True
self.adapter._ad = Mock()
# Mock BaseData
self.mock_base_data = Mock()
self.adapter._base_data = self.mock_base_data
async def test_fetch_calendar(self):
"""测试获取交易日历"""
self.mock_base_data.get_calendar.return_value = [
20240101, 20240102, 20240103
]
result = await self.adapter.fetch_trading_calendar(
exchange='SH',
start='20240101',
end='20240103'
)
self.assertEqual(len(result), 3)
self.assertIsInstance(result[0], TradeCalData)
class TestHealthCheck(unittest.IsolatedAsyncioTestCase):
"""测试健康检查"""
async def asyncSetUp(self):
"""异步测试前准备"""
self.adapter = AmazingDataAdapter()
async def test_health_check_not_connected(self):
"""测试未连接时健康检查"""
result = await self.adapter.health_check()
self.assertFalse(result)
async def test_health_check_connected(self):
"""测试已连接时健康检查"""
self.adapter._connected = True
self.adapter._is_logged_in = True
self.adapter._base_data = Mock()
self.adapter._base_data.get_code_list.return_value = ['000001.SZ']
result = await self.adapter.health_check()
self.assertTrue(result)
class TestEnums(unittest.TestCase):
"""测试枚举类"""
def test_security_type_values(self):
"""测试证券类型枚举值"""
self.assertEqual(SecurityType.STOCK_A.value, 'EXTRA_STOCK_A')
self.assertEqual(SecurityType.ETF.value, 'EXTRA_ETF')
self.assertEqual(SecurityType.FUTURE.value, 'EXTRA_FUTURE')
def test_market_values(self):
"""测试市场枚举值"""
self.assertEqual(Market.SH.value, 'SH')
self.assertEqual(Market.SZ.value, 'SZ')
self.assertEqual(Market.BJ.value, 'BJ')
def test_period_values(self):
"""测试周期枚举值"""
self.assertEqual(Period.MIN1.value, 'min1')
self.assertEqual(Period.DAILY.value, 'daily')
self.assertEqual(Period.WEEKLY.value, 'weekly')
if __name__ == '__main__':
# 运行测试
unittest.main(verbosity=2)