You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
193 lines
7.1 KiB
193 lines
7.1 KiB
#!/usr/bin/env python3
|
|
# 系统测试脚本
|
|
|
|
import unittest
|
|
import pandas as pd
|
|
from qihuo_analyzer.data.data_fetcher import DataFetcher
|
|
from qihuo_analyzer.data.data_storage import DataStorage
|
|
from qihuo_analyzer.modules.trend_filter import TrendFilter
|
|
from qihuo_analyzer.modules.risk_manager import RiskManager
|
|
from qihuo_analyzer.modules.fund_flow_monitor import FundFlowMonitor
|
|
from qihuo_analyzer.modules.support_resistance import SupportResistance
|
|
from qihuo_analyzer.modules.rollover_detector import RolloverDetector
|
|
from qihuo_analyzer.modules.deepseek_agent import DeepseekAgent
|
|
|
|
|
|
class TestSystemComponents(unittest.TestCase):
|
|
"""测试系统组件"""
|
|
|
|
def setUp(self):
|
|
"""设置测试环境"""
|
|
self.symbol = "CU2309"
|
|
self.data_fetcher = DataFetcher()
|
|
self.data_storage = DataStorage()
|
|
|
|
def test_data_fetcher(self):
|
|
"""测试数据获取器"""
|
|
print("测试数据获取器...")
|
|
kline_data = self.data_fetcher.get_kline_data(self.symbol, "1d", 100)
|
|
self.assertFalse(kline_data.empty)
|
|
self.assertIn('close', kline_data.columns)
|
|
self.assertIn('volume', kline_data.columns)
|
|
print("✅ 数据获取器测试通过")
|
|
|
|
def test_data_storage(self):
|
|
"""测试数据存储"""
|
|
print("测试数据存储...")
|
|
kline_data = self.data_fetcher.get_kline_data(self.symbol, "1d", 20)
|
|
success = self.data_storage.save_kline_data(self.symbol, "1d", kline_data)
|
|
self.assertTrue(success)
|
|
|
|
# 测试读取数据
|
|
stored_data = self.data_storage.get_kline_data(self.symbol, "1d", 10)
|
|
self.assertFalse(stored_data.empty)
|
|
print("✅ 数据存储测试通过")
|
|
|
|
def test_trend_filter(self):
|
|
"""测试趋势过滤器"""
|
|
print("测试趋势过滤器...")
|
|
kline_data = self.data_fetcher.get_kline_data(self.symbol, "1d", 100)
|
|
trend_filter = TrendFilter()
|
|
|
|
# 测试趋势分析
|
|
trend_analysis = trend_filter.analyze_trend(kline_data)
|
|
self.assertIn('adx', trend_analysis)
|
|
self.assertIn('trend_strength', trend_analysis)
|
|
|
|
# 测试胜率计算
|
|
win_rate = trend_filter.calculate_win_rate(kline_data)
|
|
self.assertGreaterEqual(win_rate, 0)
|
|
self.assertLessEqual(win_rate, 100)
|
|
|
|
# 测试周期判断
|
|
cycle = trend_filter.judge_cycle(kline_data)
|
|
self.assertIn(cycle, ['short', 'medium', 'long'])
|
|
print("✅ 趋势过滤器测试通过")
|
|
|
|
def test_risk_manager(self):
|
|
"""测试风险管理器"""
|
|
print("测试风险管理器...")
|
|
kline_data = self.data_fetcher.get_kline_data(self.symbol, "1d", 100)
|
|
risk_manager = RiskManager()
|
|
|
|
# 测试止损计算
|
|
entry_price = kline_data['close'].iloc[-1]
|
|
stop_loss = risk_manager.calculate_stop_loss(kline_data, entry_price, "long")
|
|
self.assertLess(stop_loss, entry_price)
|
|
|
|
# 测试仓位计算
|
|
account_balance = 1000000
|
|
position_info = risk_manager.calculate_position_size(account_balance, kline_data, "long", entry_price)
|
|
self.assertIn('suggested_units', position_info)
|
|
self.assertGreater(position_info['suggested_units'], 0)
|
|
print("✅ 风险管理器测试通过")
|
|
|
|
def test_fund_flow_monitor(self):
|
|
"""测试资金流向监控器"""
|
|
print("测试资金流向监控器...")
|
|
kline_data = self.data_fetcher.get_kline_data(self.symbol, "1d", 100)
|
|
fund_flow_monitor = FundFlowMonitor()
|
|
|
|
# 测试资金流向分析
|
|
fund_flow_analysis = fund_flow_monitor.analyze_fund_flow(kline_data)
|
|
self.assertIn('fund_flow_strength', fund_flow_analysis)
|
|
self.assertIn('fund_signal', fund_flow_analysis)
|
|
print("✅ 资金流向监控器测试通过")
|
|
|
|
def test_support_resistance(self):
|
|
"""测试压力支撑分析器"""
|
|
print("测试压力支撑分析器...")
|
|
kline_data = self.data_fetcher.get_kline_data(self.symbol, "1d", 100)
|
|
support_resistance = SupportResistance()
|
|
|
|
# 测试压力支撑分析
|
|
sr_analysis = support_resistance.analyze_support_resistance(kline_data)
|
|
self.assertIn('support_resistance_levels', sr_analysis)
|
|
support_levels = sr_analysis['support_resistance_levels']['support_levels']
|
|
resistance_levels = sr_analysis['support_resistance_levels']['resistance_levels']
|
|
self.assertIsInstance(support_levels, list)
|
|
self.assertIsInstance(resistance_levels, list)
|
|
print("✅ 压力支撑分析器测试通过")
|
|
|
|
def test_rollover_detector(self):
|
|
"""测试换月检测器"""
|
|
print("测试换月检测器...")
|
|
kline_data = self.data_fetcher.get_kline_data(self.symbol, "1d", 100)
|
|
rollover_detector = RolloverDetector()
|
|
|
|
# 测试换月分析
|
|
rollover_analysis = rollover_detector.analyze_rollover(self.symbol, kline_data)
|
|
self.assertIn('expire_date', rollover_analysis)
|
|
self.assertIn('days_to_delivery', rollover_analysis)
|
|
self.assertIn('warning_level', rollover_analysis)
|
|
print("✅ 换月检测器测试通过")
|
|
|
|
def test_deepseek_agent(self):
|
|
"""测试DeepSeek代理"""
|
|
print("测试DeepSeek代理...")
|
|
deepseek_agent = DeepseekAgent()
|
|
|
|
# 测试市场分析
|
|
market_data = {
|
|
'symbol': self.symbol,
|
|
'latest_price': 35000,
|
|
'volume': 10000,
|
|
'open_interest': 50000,
|
|
'timeframe': '1d'
|
|
}
|
|
|
|
technical_indicators = {
|
|
'macd': {'signal': '金叉'},
|
|
'rsi': 55,
|
|
'bollinger': {'position': '中轨附近'},
|
|
'kdj': {'signal': '金叉'},
|
|
'atr': 200
|
|
}
|
|
|
|
trend_data = {
|
|
'adx': 25,
|
|
'trend_strength': 'medium',
|
|
'trend_direction': 'up',
|
|
'ma_relationship': 'bullish',
|
|
'overall_trend': 'strong_bullish',
|
|
'win_rate': 65
|
|
}
|
|
|
|
risk_metrics = {
|
|
'stop_loss': 34500,
|
|
'target_price': 36000,
|
|
'profit_loss_ratio': 1.8,
|
|
'position_size': 2,
|
|
'risk_ratio': 2.5
|
|
}
|
|
|
|
analysis_result = deepseek_agent.analyze_market(market_data, technical_indicators, trend_data, risk_metrics)
|
|
self.assertIn('trend_judgment', analysis_result)
|
|
self.assertIn('win_rate_assessment', analysis_result)
|
|
print("✅ DeepSeek代理测试通过")
|
|
|
|
|
|
def run_tests():
|
|
"""运行测试"""
|
|
print("="*60)
|
|
print("AI 期货分析系统 - 组件测试")
|
|
print("="*60)
|
|
|
|
# 创建测试套件
|
|
suite = unittest.TestLoader().loadTestsFromTestCase(TestSystemComponents)
|
|
|
|
# 运行测试
|
|
runner = unittest.TextTestRunner(verbosity=2)
|
|
result = runner.run(suite)
|
|
|
|
print("\n" + "="*60)
|
|
if result.wasSuccessful():
|
|
print("✅ 所有测试通过!系统组件运行正常")
|
|
else:
|
|
print("❌ 测试失败,请检查系统组件")
|
|
print("="*60)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|