You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

462 lines
17 KiB

# 数据获取模块
import os
import time
import pandas as pd
from typing import Dict, Optional, List
from qihuo_analyzer.utils.config_manager import config_manager
# 尝试导入tqsdk如果失败则使用模拟数据
try:
from tqsdk import TqApi, TqAuth
TQSDK_AVAILABLE = True
except Exception as e:
print(f"tqsdk导入失败{e},将使用模拟数据")
TQSDK_AVAILABLE = False
class DataFetcher:
"""数据获取器"""
def __init__(self):
self.api = None
def connect(self) -> bool:
"""连接API"""
try:
if TQSDK_AVAILABLE:
# 使用天勤TQSDK连接
from qihuo_analyzer.utils.config_manager import config_manager
username = os.getenv('TQSDK_USERNAME', '')
password = os.getenv('TQSDK_PASSWORD', '')
if username and password:
self.api = TqApi(auth=TqAuth(username, password))
print("API连接成功")
return True
else:
print("TQSDK账号密码未配置将使用模拟数据")
self.api = None
return False
else:
# 模拟API用于测试
print("使用模拟API")
self.api = None
return False
except Exception as e:
print(f"API连接失败{e}")
# 模拟API用于测试
self.api = None
return False
def disconnect(self):
"""断开连接"""
if self.api:
try:
self.api.close()
print("API连接已断开")
except:
pass
def _convert_duration(self, duration: str) -> int:
"""将时间周期字符串转换为分钟数
Args:
duration: 时间周期 '1m', '5m', '15m', '1h', '1d'
Returns:
分钟数
"""
duration_map = {
'1m': 1,
'5m': 5,
'15m': 15,
'30m': 30,
'1h': 60,
'2h': 120,
'4h': 240,
'6h': 360,
'12h': 720,
'1d': 1440,
'1w': 10080
}
return duration_map.get(duration, 60) # 默认60分钟
def _convert_symbol(self, symbol: str) -> str:
"""将合约代码转换为TQSDK格式
Args:
symbol: 合约代码 'CU2603'
Returns:
TQSDK格式的合约代码 'SHFE.cu2603'
"""
# 交易所映射
exchange_map = {
'CU': 'SHFE', # 铜 - 上海期货交易所
'AL': 'SHFE', # 铝 - 上海期货交易所
'ZN': 'SHFE', # 锌 - 上海期货交易所
'PB': 'SHFE', # 铅 - 上海期货交易所
'NI': 'SHFE', # 镍 - 上海期货交易所
'SN': 'SHFE', # 锡 - 上海期货交易所
'AU': 'SHFE', # 黄金 - 上海期货交易所
'AG': 'SHFE', # 白银 - 上海期货交易所
'RB': 'SHFE', # 螺纹钢 - 上海期货交易所
'HC': 'SHFE', # 热轧卷板 - 上海期货交易所
'BU': 'SHFE', # 沥青 - 上海期货交易所
'RU': 'SHFE', # 橡胶 - 上海期货交易所
'FU': 'SHFE', # 燃油 - 上海期货交易所
'SC': 'INE', # 原油 - 上海国际能源交易中心
'I': 'DCE', # 铁矿石 - 大连商品交易所
'J': 'DCE', # 焦炭 - 大连商品交易所
'JM': 'DCE', # 焦煤 - 大连商品交易所
'A': 'DCE', # 大豆 - 大连商品交易所
'B': 'DCE', # 豆粕 - 大连商品交易所
'M': 'DCE', # 豆粕 - 大连商品交易所
'Y': 'DCE', # 豆油 - 大连商品交易所
'P': 'DCE', # 棕榈油 - 大连商品交易所
'C': 'DCE', # 玉米 - 大连商品交易所
'CS': 'DCE', # 玉米淀粉 - 大连商品交易所
'L': 'DCE', # 聚乙烯 - 大连商品交易所
'V': 'DCE', # 聚氯乙烯 - 大连商品交易所
'PP': 'DCE', # 聚丙烯 - 大连商品交易所
'TA': 'CZCE', # PTA - 郑州商品交易所
'CF': 'CZCE', # 棉花 - 郑州商品交易所
'SR': 'CZCE', # 白糖 - 郑州商品交易所
'MA': 'CZCE', # 甲醇 - 郑州商品交易所
'ZC': 'CZCE', # 动力煤 - 郑州商品交易所
'FG': 'CZCE', # 玻璃 - 郑州商品交易所
'RM': 'CZCE', # 菜籽粕 - 郑州商品交易所
'OI': 'CZCE', # 菜籽油 - 郑州商品交易所
'RS': 'CZCE', # 菜籽 - 郑州商品交易所
'WH': 'CZCE', # 强麦 - 郑州商品交易所
'JR': 'CZCE', # 粳稻 - 郑州商品交易所
'LR': 'CZCE', # 晚籼稻 - 郑州商品交易所
}
# 提取品种代码和合约月份
if len(symbol) >= 4:
product_code = symbol[:2].upper()
contract_month = symbol[2:].lower()
# 获取交易所代码
exchange = exchange_map.get(product_code, 'SHFE')
# 构建TQSDK格式的合约代码
tq_symbol = f"{exchange}.{product_code.lower()}{contract_month}"
return tq_symbol
else:
return symbol
def get_product_name_cn(self, symbol: str) -> str:
"""获取合约的中文名称
Args:
symbol: 合约代码 'CU2603'
Returns:
合约的中文名称 ''
"""
# 品种中文名称映射
product_name_map = {
'CU': '',
'AL': '',
'ZN': '',
'PB': '',
'NI': '',
'SN': '',
'AU': '黄金',
'AG': '白银',
'RB': '螺纹钢',
'HC': '热轧卷板',
'BU': '沥青',
'RU': '橡胶',
'FU': '燃油',
'SC': '原油',
'I': '铁矿石',
'J': '焦炭',
'JM': '焦煤',
'A': '大豆',
'B': '豆粕',
'M': '豆粕',
'Y': '豆油',
'P': '棕榈油',
'C': '玉米',
'CS': '玉米淀粉',
'L': '聚乙烯',
'V': '聚氯乙烯',
'PP': '聚丙烯',
'TA': 'PTA',
'CF': '棉花',
'SR': '白糖',
'MA': '甲醇',
'ZC': '动力煤',
'FG': '玻璃',
'RM': '菜籽粕',
'OI': '菜籽油',
'RS': '菜籽',
'WH': '强麦',
'JR': '粳稻',
'LR': '晚籼稻',
}
if len(symbol) >= 2:
product_code = symbol[:2].upper()
return product_name_map.get(product_code, product_code)
else:
return symbol
def get_kline_data(self, symbol: str, duration: str, count: int = 200) -> Optional[pd.DataFrame]:
"""获取K线数据
Args:
symbol: 合约代码
duration: 时间周期 '1m', '5m', '15m', '1h', '1d'
count: 数据数量
Returns:
K线数据DataFrame如果无法获取真实数据则返回None
"""
try:
if TQSDK_AVAILABLE and self.api:
# 转换合约代码为TQSDK格式
tq_symbol = self._convert_symbol(symbol)
print(f"使用TQSDK格式合约代码: {tq_symbol}")
# 转换时间周期为分钟数
duration_minutes = self._convert_duration(duration)
# 使用真实API获取数据
klines = self.api.get_kline_serial(tq_symbol, duration_minutes, data_length=count)
# 等待数据准备就绪
import time
start_time = time.time()
timeout = 5 # 5秒超时
while True:
if hasattr(klines, 'datetime') and len(klines.datetime) > 0:
break
if time.time() - start_time > timeout:
print("获取K线数据超时")
return None
time.sleep(0.1)
# 转换为DataFrame
data = {
'datetime': klines.datetime,
'open': klines.open,
'high': klines.high,
'low': klines.low,
'close': klines.close,
'volume': klines.volume,
'open_interest': klines.open_oi
}
df = pd.DataFrame(data)
df['datetime'] = pd.to_datetime(df['datetime'], unit='ns')
df.set_index('datetime', inplace=True)
print(f"成功获取K线数据数据长度: {len(df)}")
return df
else:
# 不再自动返回模拟数据返回None
print(f"无法获取真实数据:{'API未连接' if not self.api else 'TQSDK不可用'}")
return None
except Exception as e:
print(f"获取K线数据失败{e}")
# 不再自动返回模拟数据返回None
return None
def get_tick_data(self, symbol: str, count: int = 1000) -> Optional[pd.DataFrame]:
"""获取Tick数据"""
try:
if TQSDK_AVAILABLE and self.api:
# 使用真实API获取数据
ticks = self.api.get_tick_serial(symbol, data_length=count)
self.api.wait_update()
# 转换为DataFrame
data = {
'datetime': ticks.datetime,
'last_price': ticks.last_price,
'volume': ticks.volume,
'open_interest': ticks.open_interest,
'bid_price1': ticks.bid_price1,
'bid_volume1': ticks.bid_volume1,
'ask_price1': ticks.ask_price1,
'ask_volume1': ticks.ask_volume1
}
df = pd.DataFrame(data)
df['datetime'] = pd.to_datetime(df['datetime'], unit='ns')
df.set_index('datetime', inplace=True)
return df
else:
# 返回模拟数据
return self._get_mock_tick_data(symbol, count)
except Exception as e:
print(f"获取Tick数据失败{e}")
return self._get_mock_tick_data(symbol, count)
def get_contract_info(self, symbol: str) -> Optional[Dict]:
"""获取合约信息"""
try:
if TQSDK_AVAILABLE and self.api:
# 使用真实API获取数据
quote = self.api.get_quote(symbol)
self.api.wait_update()
return {
'symbol': symbol,
'name': quote.instrument_name,
'exchange': quote.exchange_id,
'product': quote.product_id,
'price_tick': quote.price_tick,
'volume_multiple': quote.volume_multiple,
'margin_rate': quote.margin_rate,
'expire_datetime': quote.expire_datetime,
'create_datetime': quote.create_datetime
}
else:
# 返回模拟数据
return self._get_mock_contract_info(symbol)
except Exception as e:
print(f"获取合约信息失败:{e}")
return self._get_mock_contract_info(symbol)
def get_market_data(self, symbols: List[str]) -> Dict[str, Dict]:
"""批量获取市场数据"""
market_data = {}
for symbol in symbols:
try:
if TQSDK_AVAILABLE and self.api:
quote = self.api.get_quote(symbol)
self.api.wait_update()
market_data[symbol] = {
'latest_price': quote.last_price,
'open': quote.open,
'high': quote.high,
'low': quote.low,
'pre_close': quote.pre_close,
'volume': quote.volume,
'open_interest': quote.open_interest,
'bid_price1': quote.bid_price1,
'ask_price1': quote.ask_price1
}
else:
# 模拟数据
market_data[symbol] = self._get_mock_market_data(symbol)
except Exception as e:
print(f"获取{symbol}市场数据失败:{e}")
market_data[symbol] = self._get_mock_market_data(symbol)
return market_data
def _get_mock_kline_data(self, symbol: str, duration: str, count: int) -> pd.DataFrame:
"""获取模拟K线数据"""
# 生成时间序列
end_time = pd.Timestamp.now()
if duration == '1m':
freq = '1T'
elif duration == '5m':
freq = '5T'
elif duration == '15m':
freq = '15T'
elif duration == '1h':
freq = '1H'
elif duration == '1d':
freq = '1D'
else:
freq = '1H'
datetime_index = pd.date_range(end=end_time, periods=count, freq=freq)
# 生成随机价格数据
base_price = 3500
price_changes = np.random.normal(0, 5, count)
prices = base_price + np.cumsum(price_changes)
# 生成其他数据
opens = prices * (1 + np.random.normal(0, 0.001, count))
highs = np.maximum(prices, opens) * (1 + np.random.normal(0, 0.002, count))
lows = np.minimum(prices, opens) * (1 - np.random.normal(0, 0.002, count))
volumes = np.random.randint(1000, 10000, count)
open_interests = np.random.randint(10000, 100000, count)
# 创建DataFrame
df = pd.DataFrame({
'open': opens,
'high': highs,
'low': lows,
'close': prices,
'volume': volumes,
'open_interest': open_interests
}, index=datetime_index)
return df
def _get_mock_tick_data(self, symbol: str, count: int) -> pd.DataFrame:
"""获取模拟Tick数据"""
# 生成时间序列
end_time = pd.Timestamp.now()
datetime_index = pd.date_range(end=end_time, periods=count, freq='1S')
# 生成随机价格数据
base_price = 3500
price_changes = np.random.normal(0, 0.5, count)
last_prices = base_price + np.cumsum(price_changes)
# 生成其他数据
volumes = np.random.randint(10, 100, count)
open_interests = np.random.randint(10000, 100000, count)
bid_prices = last_prices * (1 - np.random.normal(0, 0.0005, count))
ask_prices = last_prices * (1 + np.random.normal(0, 0.0005, count))
bid_volumes = np.random.randint(10, 50, count)
ask_volumes = np.random.randint(10, 50, count)
# 创建DataFrame
df = pd.DataFrame({
'last_price': last_prices,
'volume': volumes,
'open_interest': open_interests,
'bid_price1': bid_prices,
'bid_volume1': bid_volumes,
'ask_price1': ask_prices,
'ask_volume1': ask_volumes
}, index=datetime_index)
return df
def _get_mock_contract_info(self, symbol: str) -> Dict:
"""获取模拟合约信息"""
return {
'symbol': symbol,
'name': symbol,
'exchange': 'SHFE',
'product': symbol[:2],
'price_tick': 1,
'volume_multiple': 10,
'margin_rate': 0.1,
'expire_datetime': int(time.time() * 1e9) + 90 * 24 * 3600 * 1e9,
'create_datetime': int(time.time() * 1e9) - 180 * 24 * 3600 * 1e9
}
def _get_mock_market_data(self, symbol: str) -> Dict:
"""获取模拟市场数据"""
base_price = 3500
return {
'latest_price': base_price + np.random.normal(0, 10),
'open': base_price,
'high': base_price + 20,
'low': base_price - 20,
'pre_close': base_price,
'volume': np.random.randint(10000, 100000),
'open_interest': np.random.randint(100000, 1000000),
'bid_price1': base_price - 1,
'ask_price1': base_price + 1
}
# 导入numpy
import numpy as np