You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

638 lines
25 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""AKShare数据源适配器 - 使用新浪接口"""
import asyncio
import time
from datetime import datetime
from typing import List, Optional, Dict, Any
import akshare as ak
import pandas as pd
from app.adapters.base import (
DataSourceAdapter, TickData, KLineData, SymbolInfo,
TradeCalData, TickCallback
)
from app.core.logger import info, error, warning
# 缓存字典,用于存储股票基本信息和交易日历
_stock_info_cache: Dict[str, Dict[str, Any]] = {}
_trade_calendar_cache: Optional[pd.DataFrame] = None
_inst_holding_cache: Dict[str, Dict[str, Any]] = {}
class AKShareAdapter(DataSourceAdapter):
"""AKShare数据源适配器 - 使用新浪接口"""
def __init__(self):
self.config = {}
self._connected = False
self._max_retries = 3
self._retry_delay = 2 # 秒
# 实例缓存
self._stock_info_cache: Dict[str, Dict[str, Any]] = {}
self._trade_calendar_cache: Optional[pd.DataFrame] = None
self._inst_holding_cache: Dict[str, Dict[str, Any]] = {}
def _get_stock_code_from_symbol(self, symbol: str) -> str:
"""从 symbol 中提取股票代码: 000001.SZ -> 000001"""
if "." in symbol:
return symbol.split(".")[0]
return symbol
async def _get_trade_calendar(self) -> pd.DataFrame:
"""获取交易日历(带缓存)"""
if self._trade_calendar_cache is None:
try:
df = await self._fetch_with_retry(ak.tool_trade_date_hist_sina)
if df is not None and not df.empty:
df['trade_date'] = pd.to_datetime(df['trade_date'])
self._trade_calendar_cache = df
info(f"Loaded trade calendar with {len(df)} trading days")
except Exception as e:
error(f"Failed to load trade calendar: {e}")
return pd.DataFrame()
return self._trade_calendar_cache if self._trade_calendar_cache is not None else pd.DataFrame()
async def _get_stock_info(self, stock_code: str) -> Dict[str, Any]:
"""获取股票基本信息(带缓存)"""
if stock_code not in self._stock_info_cache:
try:
# 使用东财接口获取个股信息
info_df = await self._fetch_with_retry(ak.stock_individual_info_em, symbol=stock_code)
if info_df is not None and not info_df.empty:
info_dict = dict(zip(info_df['item'], info_df['value']))
self._stock_info_cache[stock_code] = info_dict
info(f"Loaded stock info for {stock_code}")
except Exception as e:
error(f"Failed to get stock info for {stock_code}: {e}")
return {}
return self._stock_info_cache.get(stock_code, {})
async def _get_trading_days_count(self, stock_code: str, trade_date: datetime) -> int:
"""获取可交易日数(从上市至今)"""
try:
stock_info = await self._get_stock_info(stock_code)
listing_date_str = str(stock_info.get('上市时间', ''))
if not listing_date_str or listing_date_str == 'nan':
return 0
listing_date = datetime.strptime(listing_date_str, '%Y%m%d').date()
trade_calendar = await self._get_trade_calendar()
if trade_calendar.empty:
return 0
# 计算从上市到指定日期的交易日数
trade_calendar['trade_date'] = pd.to_datetime(trade_calendar['trade_date']).dt.date
trading_days = trade_calendar[
(trade_calendar['trade_date'] >= listing_date) &
(trade_calendar['trade_date'] <= trade_date.date())
]
return len(trading_days)
except Exception as e:
error(f"Failed to calculate trading days for {stock_code}: {e}")
return 0
async def _check_limit_up_down(self, stock_code: str, trade_date: str, close_price: float) -> tuple:
"""检查是否涨停或跌停"""
try:
# 获取涨停/跌停股票池
zt_df = await self._fetch_with_retry(ak.stock_zt_pool_em, date=trade_date)
dt_df = await self._fetch_with_retry(ak.stock_zt_pool_dtgc_em, date=trade_date)
is_limit_up = False
is_limit_down = False
if zt_df is not None and not zt_df.empty:
zt_list = zt_df['代码'].astype(str).tolist()
is_limit_up = stock_code in zt_list
if dt_df is not None and not dt_df.empty:
dt_list = dt_df['代码'].astype(str).tolist()
is_limit_down = stock_code in dt_list
return is_limit_up, is_limit_down
except Exception as e:
# 涨停跌停池接口可能不支持历史日期失败时返回False
info(f"Could not get limit up/down info for {stock_code} on {trade_date}: {e}")
return False, False
async def _get_market_cap(self, stock_code: str) -> tuple:
"""获取总市值和流通市值"""
try:
stock_info = await self._get_stock_info(stock_code)
total_cap = stock_info.get('总市值', 0)
float_cap = stock_info.get('流通市值', 0)
# 转换为浮点数
total_cap = float(total_cap) if total_cap and str(total_cap) != 'nan' else 0.0
float_cap = float(float_cap) if float_cap and str(float_cap) != 'nan' else 0.0
return total_cap, float_cap
except Exception as e:
error(f"Failed to get market cap for {stock_code}: {e}")
return 0.0, 0.0
async def _get_inst_holding_ratio(self, stock_code: str) -> float:
"""获取机构持仓占比"""
# 缓存键使用季度标识(简化处理,实际应按财报季度)
cache_key = f"{stock_code}_latest"
if cache_key not in self._inst_holding_cache:
try:
# 获取基金持仓数据
fund_holder_df = await self._fetch_with_retry(ak.stock_fund_stock_holder, symbol=stock_code)
if fund_holder_df is not None and not fund_holder_df.empty:
# 获取最新季度的数据
latest_quarter = fund_holder_df['季度'].iloc[0]
latest_df = fund_holder_df[fund_holder_df['季度'] == latest_quarter]
# 计算机构持仓占比合计
total_ratio = 0.0
if '占总股本比例' in latest_df.columns:
ratios = pd.to_numeric(latest_df['占总股本比例'], errors='coerce').fillna(0)
total_ratio = ratios.sum()
self._inst_holding_cache[cache_key] = {
'quarter': latest_quarter,
'ratio': total_ratio
}
info(f"Loaded inst holding for {stock_code}: {total_ratio:.4f}% in {latest_quarter}")
else:
self._inst_holding_cache[cache_key] = {'quarter': '', 'ratio': 0.0}
except Exception as e:
error(f"Failed to get inst holding for {stock_code}: {e}")
return 0.0
return self._inst_holding_cache.get(cache_key, {}).get('ratio', 0.0)
async def connect(self, config: dict) -> None:
"""建立连接AKShare无需认证"""
self.config = config
self._connected = True
info("AKShare adapter connected (Sina API)")
async def subscribe_ticks(self, symbols: List[str], callback: TickCallback) -> None:
"""订阅实时TickAKShare不支持实时推送"""
raise NotImplementedError("AKShare does not support real-time tick subscription")
async def fetch_klines(
self,
symbol: str,
start: str,
end: str,
freq: str
) -> List[KLineData]:
info(f"Fetching KLines from Sina for {symbol} [{freq}] from {start} to {end}")
"""拉取历史K线"""
# 判断是股票还是期货
if ".SH" in symbol or ".SZ" in symbol or ".BJ" in symbol:
return await self._fetch_stock_klines(symbol, start, end, freq)
elif "." in symbol: # 期货格式: CU2504.SHFE
return await self._fetch_futures_klines(symbol, start, end, freq)
else:
raise ValueError(f"Unknown symbol format: {symbol}")
async def _fetch_stock_klines(
self,
symbol: str,
start: str,
end: str,
freq: str
) -> List[KLineData]:
"""获取股票K线 - 使用新浪接口"""
# 转换symbol格式: 000001.SZ -> sz000001
ts_code = self._normalize_stock_symbol(symbol)
if freq in ["1d", "day", "D", ""]:
return await self._fetch_stock_daily_sina(ts_code, symbol, start, end)
elif freq in ["1m", "5m", "15m", "30m", "60m"]:
return await self._fetch_stock_minute_sina(ts_code, symbol, start, end, freq)
else:
raise ValueError(f"Unsupported frequency: {freq}")
def _normalize_stock_symbol(self, symbol: str) -> str:
"""转换股票代码格式: 000001.SZ -> sz000001"""
if "." in symbol:
code, exchange = symbol.split(".")
exchange_map = {
"SH": "sh",
"SZ": "sz",
"BJ": "bj"
}
return exchange_map.get(exchange, "sz") + code
return symbol
def _denormalize_stock_symbol(self, symbol: str) -> str:
"""还原股票代码格式: sz000001 -> 000001.SZ"""
if symbol.startswith("sh"):
return symbol[2:] + ".SH"
elif symbol.startswith("sz"):
return symbol[2:] + ".SZ"
elif symbol.startswith("bj"):
return symbol[2:] + ".BJ"
return symbol
async def _fetch_with_retry(self, func, *args, **kwargs):
"""带重试机制的调用"""
last_exception = None
for attempt in range(self._max_retries):
try:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, lambda: func(*args, **kwargs))
except Exception as e:
last_exception = e
error_msg = str(e).lower()
# 检查是否是可重试的错误
if any(x in error_msg for x in ['connection', 'timeout', 'remote', 'reset', 'closed']):
if attempt < self._max_retries - 1:
warning(f"Sina API request failed (attempt {attempt + 1}/{self._max_retries}): {e}")
await asyncio.sleep(self._retry_delay * (attempt + 1)) # 指数退避
continue
# 不可重试的错误,直接抛出
raise
raise last_exception
async def _fetch_stock_daily_sina(
self,
ts_code: str,
original_symbol: str,
start_date: str,
end_date: str
) -> List[KLineData]:
"""获取股票日线 - 使用新浪接口(包含扩展字段)"""
try:
# 新浪接口获取历史数据
# 使用 stock_zh_a_daily 接口(新浪)
df = await self._fetch_with_retry(
ak.stock_zh_a_daily,
symbol=ts_code,
start_date=start_date,
end_date=end_date,
adjust="qfq" # 前复权
)
if df is None or df.empty:
warning(f"No data returned from Sina for {original_symbol}")
return []
# 获取股票代码(不带交易所后缀)
stock_code = self._get_stock_code_from_symbol(original_symbol)
# 预获取市值和机构持仓数据(这些不随日期变化)
total_cap, float_cap = await self._get_market_cap(stock_code)
inst_ratio = await self._get_inst_holding_ratio(stock_code)
results = []
for _, row in df.iterrows():
trade_date = datetime.strptime(str(row['date']), "%Y-%m-%d")
trade_date_str = trade_date.strftime("%Y%m%d")
close_price = float(row['close'])
# 获取可交易日数(只计算一次,以当前日期为准)
trading_days = await self._get_trading_days_count(stock_code, trade_date)
# 检查涨停跌停(注意:历史数据可能无法准确判断)
is_limit_up, is_limit_down = await self._check_limit_up_down(
stock_code, trade_date_str, close_price
)
results.append(KLineData(
symbol=original_symbol,
time=int(trade_date.timestamp()),
open=float(row['open']),
high=float(row['high']),
low=float(row['low']),
close=close_price,
volume=int(row['volume']),
amount=float(row.get('amount', 0)),
trade_date=trade_date.strftime('%Y-%m-%d'),
is_limit_up=is_limit_up,
is_limit_down=is_limit_down,
total_market_cap=total_cap,
float_market_cap=float_cap,
inst_holding_ratio=inst_ratio,
trading_days=trading_days
))
info(f"Fetched {len(results)} daily klines with extended fields from Sina for {original_symbol}")
return results
except Exception as e:
error(f"Failed to fetch stock daily from Sina for {original_symbol}: {e}")
# 新浪接口失败时返回空列表
return []
async def _fetch_stock_minute_sina(
self,
ts_code: str,
original_symbol: str,
start_date: str,
end_date: str,
freq: str
) -> List[KLineData]:
"""获取股票分钟线 - 使用新浪接口"""
try:
# 新浪分钟线接口
# 使用 stock_zh_a_minute 接口
df = await self._fetch_with_retry(
ak.stock_zh_a_minute,
symbol=ts_code,
period=freq.replace("m", ""), # 1m -> 1
adjust="qfq"
)
if df is None or df.empty:
return []
# 过滤日期范围
df['date'] = pd.to_datetime(df['date'])
start_dt = datetime.strptime(start_date, "%Y%m%d")
end_dt = datetime.strptime(end_date, "%Y%m%d")
df = df[(df['date'] >= start_dt) & (df['date'] <= end_dt)]
results = []
for _, row in df.iterrows():
trade_time = datetime.strptime(str(row['date']), "%Y-%m-%d %H:%M:%S")
results.append(KLineData(
symbol=original_symbol,
time=int(trade_time.timestamp()),
open=float(row['open']),
high=float(row['high']),
low=float(row['low']),
close=float(row['close']),
volume=int(row['volume']),
amount=float(row.get('amount', 0))
))
return results
except Exception as e:
error(f"Failed to fetch stock minute from Sina for {original_symbol}: {e}")
return []
async def _fetch_futures_klines(
self,
symbol: str,
start: str,
end: str,
freq: str
) -> List[KLineData]:
"""获取期货K线 - 使用新浪接口"""
if freq in ["1d", "day", "D", ""]:
return await self._fetch_futures_daily_sina(symbol, start, end)
elif freq in ["1m", "5m", "15m", "30m", "60m"]:
return await self._fetch_futures_minute_sina(symbol, start, end, freq)
else:
raise ValueError(f"Unsupported frequency: {freq}")
async def _fetch_futures_daily_sina(
self,
symbol: str,
start_date: str,
end_date: str
) -> List[KLineData]:
"""获取期货日线 - 使用新浪接口"""
try:
# 解析合约代码: CU2504.SHFE -> cu2504
contract_code, exchange = symbol.split(".")
contract_code = contract_code.lower()
# 新浪期货历史行情接口
df = await self._fetch_with_retry(
ak.futures_zh_daily,
symbol=contract_code,
start_date=start_date,
end_date=end_date
)
if df is None or df.empty:
return []
results = []
for _, row in df.iterrows():
trade_date = datetime.strptime(str(row['date']), "%Y-%m-%d")
results.append(KLineData(
symbol=symbol,
time=int(trade_date.timestamp()),
open=float(row['open']),
high=float(row['high']),
low=float(row['low']),
close=float(row['close']),
volume=int(row['volume']),
amount=float(row.get('amount', 0)),
open_interest=int(row.get('hold', 0))
))
return results
except Exception as e:
error(f"Failed to fetch futures daily from Sina for {symbol}: {e}")
return []
async def _fetch_futures_minute_sina(
self,
symbol: str,
start_date: str,
end_date: str,
freq: str
) -> List[KLineData]:
"""获取期货分钟线 - 使用新浪接口"""
try:
# 解析合约代码
contract_code, exchange = symbol.split(".")
contract_code = contract_code.lower()
# 新浪期货分钟线接口
df = await self._fetch_with_retry(
ak.futures_zh_minute_sina,
symbol=contract_code,
period=freq.replace("m", "")
)
if df is None or df.empty:
return []
# 过滤日期范围
df['datetime'] = pd.to_datetime(df['datetime'])
start_dt = datetime.strptime(start_date, "%Y%m%d")
end_dt = datetime.strptime(end_date, "%Y%m%d")
df = df[(df['datetime'] >= start_dt) & (df['datetime'] <= end_dt)]
results = []
for _, row in df.iterrows():
trade_time = row['datetime']
results.append(KLineData(
symbol=symbol,
time=int(trade_time.timestamp()),
open=float(row['open']),
high=float(row['high']),
low=float(row['low']),
close=float(row['close']),
volume=int(row['volume']),
amount=0,
open_interest=0
))
return results
except Exception as e:
error(f"Failed to fetch futures minute from Sina for {symbol}: {e}")
return []
async def fetch_symbols(self, asset_type: str) -> List[SymbolInfo]:
"""获取标的列表"""
if asset_type == "stock":
return await self._fetch_stock_symbols_sina()
elif asset_type == "futures":
return await self._fetch_futures_symbols_sina()
else:
raise ValueError(f"Unsupported asset type: {asset_type}")
async def _fetch_stock_symbols_sina(self) -> List[SymbolInfo]:
"""获取A股股票列表 - 使用新浪接口"""
try:
# 新浪A股列表接口
df = await self._fetch_with_retry(ak.stock_zh_a_spot)
if df is None or df.empty:
return []
results = []
for _, row in df.iterrows():
# 新浪接口的代码格式
code = str(row['代码'])
if code.startswith('6') or code.startswith('5') or code.startswith('9'):
ts_code = f"{code}.SH"
exchange = "SH"
elif code.startswith('8') or code.startswith('4'):
ts_code = f"{code}.BJ"
exchange = "BJ"
else:
ts_code = f"{code}.SZ"
exchange = "SZ"
results.append(SymbolInfo(
symbol_id=ts_code,
name=str(row['名称']),
exchange=exchange
))
info(f"Fetched {len(results)} stock symbols from Sina")
return results
except Exception as e:
error(f"Failed to fetch stock symbols from Sina: {e}")
return []
async def _fetch_futures_symbols_sina(self) -> List[SymbolInfo]:
"""获取期货合约列表 - 使用新浪接口"""
try:
# 新浪期货列表接口
df = await self._fetch_with_retry(ak.futures_zh_realtime, subscribe_list=["0", "1", "2", "3"])
if df is None or df.empty:
return []
results = []
for _, row in df.iterrows():
symbol = str(row['symbol'])
underlying = ''.join([c for c in symbol if c.isalpha()]).upper()
contract_month = ''.join([c for c in symbol if c.isdigit()])
exchange = self._get_futures_exchange(underlying)
ts_code = f"{symbol.upper()}.{exchange}"
results.append(SymbolInfo(
symbol_id=ts_code,
name=str(row.get('name', symbol)),
exchange=exchange,
underlying=underlying,
contract_month=contract_month
))
info(f"Fetched {len(results)} futures symbols from Sina")
return results
except Exception as e:
error(f"Failed to fetch futures symbols from Sina: {e}")
return []
def _get_futures_exchange(self, underlying: str) -> str:
"""根据品种代码判断交易所"""
# 上海期货交易所
if underlying in ['CU', 'AL', 'ZN', 'PB', 'NI', 'SN', 'AU', 'AG', 'RB', 'HC',
'BU', 'RU', 'FU', 'SP', 'WR', 'SS', 'LU', 'NR']:
return 'SHFE'
# 大连商品交易所
elif underlying in ['A', 'B', 'M', 'Y', 'P', 'C', 'CS', 'JD', 'LH', 'JM',
'J', 'I', 'FB', 'BB', 'RR', 'PG', 'EB', 'EG', 'V', 'PP', 'L']:
return 'DCE'
# 郑州商品交易所
elif underlying in ['WH', 'PM', 'CF', 'SR', 'TA', 'OI', 'RI', 'MA', 'FG', 'RS',
'RM', 'JR', 'LR', 'SM', 'SF', 'CY', 'AP', 'CJ', 'UR', 'SA', 'PF', 'PK']:
return 'CZCE'
# 中国金融期货交易所
elif underlying in ['IF', 'IC', 'IH', 'T', 'TF', 'TS', 'IM']:
return 'CFFEX'
# 上海国际能源交易中心
elif underlying in ['SC', 'BC', 'EC']:
return 'INE'
else:
return 'SHFE' # 默认上海
async def fetch_trading_calendar(
self,
exchange: str,
start: str,
end: str
) -> List[TradeCalData]:
"""获取交易日历 - 使用新浪接口"""
try:
# 新浪交易日历接口
df = await self._fetch_with_retry(ak.tool_trade_date_hist_sina)
if df is None or df.empty:
return []
# 过滤日期范围
df['trade_date'] = pd.to_datetime(df['trade_date'])
start_dt = datetime.strptime(start, "%Y%m%d")
end_dt = datetime.strptime(end, "%Y%m%d")
df = df[(df['trade_date'] >= start_dt) & (df['trade_date'] <= end_dt)]
results = []
for _, row in df.iterrows():
cal_date = row['trade_date']
results.append(TradeCalData(
date=cal_date,
is_trading_day=True
))
return results
except Exception as e:
error(f"Failed to fetch trading calendar from Sina: {e}")
return []
async def health_check(self) -> bool:
"""健康检查"""
try:
if not self._connected:
return False
# 尝试获取股票列表作为健康检查
df = await self._fetch_with_retry(ak.stock_zh_a_spot)
return df is not None and not df.empty
except Exception as e:
error(f"Health check failed: {e}")
return False
async def close(self) -> None:
"""关闭连接"""
self._connected = False
info("AKShare adapter closed (Sina API)")