""" 星耀数智(AmazingData)适配器单元测试 测试覆盖: - 适配器初始化和配置 - 连接/断开连接 - 基础数据获取 - K线数据获取 - 财务数据获取 - 错误处理 """ import unittest import asyncio import sys import os from datetime import date, datetime from unittest.mock import Mock, patch, MagicMock import pandas as pd # 添加项目根目录到路径 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from app.adapters.amazingdata_adapter import ( AmazingDataAdapter, AmazingDataConfig, SecurityType, Market, Period ) from app.adapters.base import KLineData, SymbolInfo, TradeCalData class TestAmazingDataConfig(unittest.TestCase): """测试配置类""" def test_default_config(self): """测试默认配置""" config = AmazingDataConfig( username='test_user', password='test_pass', host='localhost', port=8080 ) self.assertEqual(config.username, 'test_user') self.assertEqual(config.local_path, './amazing_data_cache/') self.assertTrue(config.use_local_cache) def test_custom_config(self): """测试自定义配置""" config = AmazingDataConfig( username='user', password='pass', host='192.168.1.1', port=9090, local_path='/custom/path/', use_local_cache=False ) self.assertEqual(config.port, 9090) self.assertEqual(config.local_path, '/custom/path/') self.assertFalse(config.use_local_cache) class TestAmazingDataAdapter(unittest.TestCase): """测试适配器类""" def setUp(self): """测试前准备""" self.adapter = AmazingDataAdapter() self.test_config = { 'username': 'test_user', 'password': 'test_pass', 'host': 'localhost', 'port': 8080, 'local_path': './test_cache/', 'use_local_cache': False } def tearDown(self): """测试后清理""" if self.adapter._connected: asyncio.get_event_loop().run_until_complete(self.adapter.close()) def test_initial_state(self): """测试初始状态""" self.assertIsNone(self.adapter.config) self.assertIsNone(self.adapter._ad) self.assertFalse(self.adapter._is_logged_in) self.assertFalse(self.adapter._connected) def test_format_date(self): """测试日期格式化""" # 测试整数 self.assertEqual(self.adapter._format_date(20240101), 20240101) # 测试字符串 (带横线) self.assertEqual(self.adapter._format_date('2024-01-01'), 20240101) # 测试字符串 (带斜杠) self.assertEqual(self.adapter._format_date('2024/01/01'), 20240101) # 测试纯字符串数字 self.assertEqual(self.adapter._format_date('20240101'), 20240101) # 测试date对象 self.assertEqual(self.adapter._format_date(date(2024, 1, 1)), 20240101) # 测试datetime对象 self.assertEqual(self.adapter._format_date(datetime(2024, 1, 1)), 20240101) def test_format_date_invalid(self): """测试无效日期格式""" with self.assertRaises(ValueError): self.adapter._format_date(None) with self.assertRaises(ValueError): self.adapter._format_date([]) def test_check_login_not_logged_in(self): """测试未登录检查""" with self.assertRaises(RuntimeError) as context: self.adapter._check_login() self.assertIn('未连接到数据源', str(context.exception)) @patch('app.adapters.amazingdata_adapter.AmazingDataAdapter._check_login') def test_check_login_logged_in(self, mock_check): """测试已登录检查""" self.adapter._is_logged_in = True # 不应该抛出异常 self.adapter._check_login() class TestAmazingDataAdapterAsync(unittest.IsolatedAsyncioTestCase): """测试适配器异步方法""" async def asyncSetUp(self): """异步测试前准备""" self.adapter = AmazingDataAdapter() self.test_config = { 'username': 'test_user', 'password': 'test_pass', 'host': 'localhost', 'port': 8080 } async def asyncTearDown(self): """异步测试后清理""" if self.adapter._connected: await self.adapter.close() @patch('app.adapters.amazingdata_adapter.AmazingDataAdapter._do_login') async def test_connect_success(self, mock_login): """测试连接成功""" mock_login.return_value = None # Mock AmazingData模块 mock_ad = Mock() mock_ad.BaseData = Mock mock_ad.InfoData = Mock mock_ad.MarketData = Mock mock_ad.constant.Period = Mock() with patch.dict('sys.modules', {'AmazingData': mock_ad}): await self.adapter.connect(self.test_config) self.assertTrue(self.adapter._connected) self.assertIsNotNone(self.adapter.config) @patch.dict('sys.modules', {'AmazingData': None}) async def test_connect_import_error(self): """测试SDK未安装""" with self.assertRaises(RuntimeError) as context: await self.adapter.connect(self.test_config) self.assertIn('AmazingData SDK 未安装', str(context.exception)) async def test_close_not_connected(self): """测试关闭未连接状态""" # 不应该抛出异常 await self.adapter.close() self.assertFalse(self.adapter._is_logged_in) class TestFetchKlines(unittest.IsolatedAsyncioTestCase): """测试K线数据获取""" async def asyncSetUp(self): """异步测试前准备""" self.adapter = AmazingDataAdapter() self.adapter._is_logged_in = True self.adapter._ad = Mock() self.adapter._ad.constant.Period = Mock() self.adapter._ad.constant.Period.daily = Mock() self.adapter._ad.constant.Period.daily.value = 'daily' # Mock MarketData self.mock_market_data = Mock() self.adapter._market_data = self.mock_market_data async def test_fetch_klines_empty_result(self): """测试获取空K线数据""" self.mock_market_data.query_kline.return_value = {} result = await self.adapter.fetch_klines( symbol='000001.SZ', start='20240101', end='20241231', freq='1d' ) self.assertEqual(result, []) async def test_fetch_klines_with_data(self): """测试获取K线数据""" # 创建测试DataFrame df = pd.DataFrame({ 'open': [10.0, 11.0], 'high': [11.0, 12.0], 'low': [9.0, 10.0], 'close': [10.5, 11.5], 'volume': [10000, 20000], 'amount': [105000, 230000] }, index=pd.to_datetime(['2024-01-01', '2024-01-02'])) self.mock_market_data.query_kline.return_value = {'000001.SZ': df} result = await self.adapter.fetch_klines( symbol='000001.SZ', start='20240101', end='20240102', freq='1d' ) self.assertEqual(len(result), 2) self.assertIsInstance(result[0], KLineData) self.assertEqual(result[0].symbol, '000001.SZ') self.assertEqual(result[0].open, 10.0) class TestFetchSymbols(unittest.IsolatedAsyncioTestCase): """测试标的列表获取""" async def asyncSetUp(self): """异步测试前准备""" self.adapter = AmazingDataAdapter() self.adapter._is_logged_in = True self.adapter._ad = Mock() # Mock BaseData self.mock_base_data = Mock() self.adapter._base_data = self.mock_base_data async def test_fetch_stock_symbols(self): """测试获取股票列表""" self.mock_base_data.get_code_list.return_value = [ '000001.SZ', '600000.SH' ] self.mock_base_data.get_code_info.return_value = pd.DataFrame({ 'symbol': ['平安银行', '浦发银行'] }, index=['000001.SZ', '600000.SH']) result = await self.adapter.fetch_symbols('stock') self.assertEqual(len(result), 2) self.assertIsInstance(result[0], SymbolInfo) self.assertEqual(result[0].symbol_id, '000001.SZ') self.assertEqual(result[0].exchange, 'SZ') async def test_fetch_futures_symbols(self): """测试获取期货列表""" self.mock_base_data.get_future_code_list.return_value = [ 'cu2501', 'al2502' ] result = await self.adapter.fetch_symbols('futures') self.assertEqual(len(result), 2) self.assertEqual(result[0].underlying, 'CU') class TestTradingCalendar(unittest.IsolatedAsyncioTestCase): """测试交易日历获取""" async def asyncSetUp(self): """异步测试前准备""" self.adapter = AmazingDataAdapter() self.adapter._is_logged_in = True self.adapter._ad = Mock() # Mock BaseData self.mock_base_data = Mock() self.adapter._base_data = self.mock_base_data async def test_fetch_calendar(self): """测试获取交易日历""" self.mock_base_data.get_calendar.return_value = [ 20240101, 20240102, 20240103 ] result = await self.adapter.fetch_trading_calendar( exchange='SH', start='20240101', end='20240103' ) self.assertEqual(len(result), 3) self.assertIsInstance(result[0], TradeCalData) class TestHealthCheck(unittest.IsolatedAsyncioTestCase): """测试健康检查""" async def asyncSetUp(self): """异步测试前准备""" self.adapter = AmazingDataAdapter() async def test_health_check_not_connected(self): """测试未连接时健康检查""" result = await self.adapter.health_check() self.assertFalse(result) async def test_health_check_connected(self): """测试已连接时健康检查""" self.adapter._connected = True self.adapter._is_logged_in = True self.adapter._base_data = Mock() self.adapter._base_data.get_code_list.return_value = ['000001.SZ'] result = await self.adapter.health_check() self.assertTrue(result) class TestEnums(unittest.TestCase): """测试枚举类""" def test_security_type_values(self): """测试证券类型枚举值""" self.assertEqual(SecurityType.STOCK_A.value, 'EXTRA_STOCK_A') self.assertEqual(SecurityType.ETF.value, 'EXTRA_ETF') self.assertEqual(SecurityType.FUTURE.value, 'EXTRA_FUTURE') def test_market_values(self): """测试市场枚举值""" self.assertEqual(Market.SH.value, 'SH') self.assertEqual(Market.SZ.value, 'SZ') self.assertEqual(Market.BJ.value, 'BJ') def test_period_values(self): """测试周期枚举值""" self.assertEqual(Period.MIN1.value, 'min1') self.assertEqual(Period.DAILY.value, 'daily') self.assertEqual(Period.WEEKLY.value, 'weekly') if __name__ == '__main__': # 运行测试 unittest.main(verbosity=2)