You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

202 lines
6.1 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import tushare as ts
import pandas as pd
from typing import List, Optional
from datetime import datetime
from app.services.datasource.base import DataSourceBase
from app.config import settings
class TushareSource(DataSourceBase):
"""Tushare 数据源适配器"""
def __init__(self, config: dict):
super().__init__(config)
self.name = "tushare"
self.pro = None
self._token = config.get("token", settings.TUSHARE_TOKEN)
def initialize(self) -> bool:
"""初始化 Tushare 连接"""
try:
ts.set_token(self._token)
self.pro = ts.pro_api()
# 简单测试连接
self.pro.trade_cal(exchange="DCE", start_date="20240101", end_date="20240105")
self._initialized = True
return True
except Exception as e:
self._initialized = False
raise e
def _format_date(self, date_str: str) -> str:
"""将 YYYY-MM-DD 转换为 YYYYMMDD"""
return date_str.replace("-", "")
def get_contract_list(self, exchange: Optional[str] = None) -> List[dict]:
"""获取期货合约列表"""
if not self._initialized:
self.initialize()
# Tushare 获取合约信息
df = self.pro.fut_basic(
exchange=exchange or "CFFEX", # 需要分别查询每个交易所
fut_type="1", # 1=标准合约
fut_series=""
)
results = []
if df is not None and not df.empty:
for _, row in df.iterrows():
results.append({
"symbol": row.get("symbol", ""),
"exchange": self._map_exchange(row.get("exchange", "")),
"name": row.get("name", ""),
"product": row.get("underlying_symbol", ""),
"multiplier": int(row.get("contract_multiplier", 10)) if row.get("contract_multiplier") else 10,
"price_tick": float(row.get("price_tick", 0)) if row.get("price_tick") else None,
"expire_date": self._parse_date(row.get("delist_date", "")),
"is_active": row.get("list_status", "") == "L",
})
return results
def _map_exchange(self, exchange: str) -> str:
"""交易所代码映射"""
mapping = {
"CFFEX": "CFFEX",
"SHFE": "SHFE",
"DCE": "DCE",
"CZCE": "ZCE", # Tushare 用 CZCE我们统一用 ZCE
"INE": "INE",
"GFEX": "GFEX",
}
return mapping.get(exchange, exchange)
def _parse_date(self, date_str: str) -> Optional[datetime]:
"""解析日期字符串"""
if not date_str:
return None
try:
return datetime.strptime(str(date_str), "%Y%m%d")
except Exception:
return None
def get_kline_daily(
self,
symbol: str,
start_date: str,
end_date: str
) -> pd.DataFrame:
"""获取日K线数据"""
if not self._initialized:
self.initialize()
start = self._format_date(start_date)
end = self._format_date(end_date)
df = self.pro.fut_daily(
ts_code=symbol,
start_date=start,
end_date=end
)
if df is None or df.empty:
return pd.DataFrame()
# 统一列名
df = df.rename(columns={
"trade_date": "trade_date",
"open": "open",
"high": "high",
"low": "low",
"close": "close",
"vol": "volume",
"amount": "turnover",
"oi": "open_interest",
"settle": "settle",
"pre_settle": "pre_settle",
})
df["trade_date"] = pd.to_datetime(df["trade_date"])
return df[["trade_date", "open", "high", "low", "close", "volume", "turnover", "open_interest", "settle", "pre_settle"]]
def get_kline_weekly(
self,
symbol: str,
start_date: str,
end_date: str
) -> pd.DataFrame:
"""
获取周K线数据
Tushare 没有直接的周K接口通过日K聚合
"""
daily_df = self.get_kline_daily(symbol, start_date, end_date)
if daily_df.empty:
return pd.DataFrame()
# 按周聚合
daily_df = daily_df.set_index("trade_date")
weekly = daily_df.resample("W-FRI").agg({
"open": "first",
"high": "max",
"low": "min",
"close": "last",
"volume": "sum",
"turnover": "sum",
"open_interest": "last",
}).dropna()
weekly = weekly.reset_index()
weekly = weekly.rename(columns={"trade_date": "trade_date"})
return weekly
def get_kline_intraday(
self,
symbol: str,
period: str,
start_date: str,
end_date: str
) -> pd.DataFrame:
"""
获取分钟级K线数据
Tushare 的 fut_mins 接口
"""
if not self._initialized:
self.initialize()
# Tushare fut_mins 接口
start = self._format_date(start_date)
end = self._format_date(end_date)
# 分钟数映射
freq_map = {"5m": "5", "15m": "15", "30m": "30", "60m": "60"}
freq = freq_map.get(period, "5")
try:
df = self.pro.fut_mins(
ts_code=symbol,
freq=freq,
start_date=start,
end_date=end
)
except Exception:
# 部分交易所可能不支持分钟数据
return pd.DataFrame()
if df is None or df.empty:
return pd.DataFrame()
df = df.rename(columns={
"trade_time": "trade_time",
"open": "open",
"high": "high",
"low": "low",
"close": "close",
"vol": "volume",
"amount": "turnover",
"oi": "open_interest",
})
df["trade_time"] = pd.to_datetime(df["trade_time"])
return df[["trade_time", "open", "high", "low", "close", "volume", "turnover", "open_interest"]]