|
|
"""星耀数智(AmazingData)数据源适配器
|
|
|
|
|
|
基于银河证券星耀数智量化平台 SDK 的封装
|
|
|
提供统一、简洁的金融数据获取接口
|
|
|
"""
|
|
|
import asyncio
|
|
|
from datetime import datetime, date
|
|
|
from typing import List, Optional, Dict, Any, Union
|
|
|
from dataclasses import dataclass
|
|
|
from enum import Enum
|
|
|
|
|
|
from httpx import codes
|
|
|
import pandas as pd
|
|
|
|
|
|
from app.adapters.base import (
|
|
|
DataSourceAdapter, TickData, KLineData, SymbolInfo,
|
|
|
TradeCalData, TickCallback
|
|
|
)
|
|
|
from app.core.logger import info, error, warning
|
|
|
|
|
|
class SecurityType(Enum):
|
|
|
"""证券类型枚举"""
|
|
|
STOCK_A = "EXTRA_STOCK_A" # 沪深A股
|
|
|
STOCK_A_SH_SZ = "EXTRA_STOCK_A_SH_SZ" # 沪深A股(沪深)
|
|
|
INDEX_A = "EXTRA_INDEX_A" # 沪深指数
|
|
|
ETF = "EXTRA_ETF" # ETF
|
|
|
FUTURE = "EXTRA_FUTURE" # 期货
|
|
|
KZZ = "EXTRA_KZZ" # 可转债
|
|
|
GLRA = "EXTRA_GLRA" # 逆回购
|
|
|
HKT = "EXTRA_HKT" # 港股通
|
|
|
ETF_OP = "EXTRA_ETF_OP" # ETF期权
|
|
|
|
|
|
|
|
|
class Market(Enum):
|
|
|
"""市场枚举"""
|
|
|
SH = "SH" # 上海
|
|
|
SZ = "SZ" # 深圳
|
|
|
BJ = "BJ" # 北京
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class AmazingDataConfig:
|
|
|
"""星耀数智数据源配置"""
|
|
|
username: str
|
|
|
password: str
|
|
|
host: str
|
|
|
port: int
|
|
|
local_path: str = "./amazing_data_cache/"
|
|
|
use_local_cache: bool = True
|
|
|
|
|
|
|
|
|
class AmazingDataAdapter(DataSourceAdapter):
|
|
|
"""星耀数智(AmazingData)数据源适配器
|
|
|
|
|
|
封装银河证券星耀数智 SDK,提供统一的数据获取接口
|
|
|
"""
|
|
|
|
|
|
def __init__(self):
|
|
|
self.config: Optional[AmazingDataConfig] = None
|
|
|
self._ad = None
|
|
|
self._base_data = None
|
|
|
self._market_data = None
|
|
|
self._info_data = None
|
|
|
self._calendar = None
|
|
|
self._is_logged_in = False
|
|
|
self._connected = False
|
|
|
|
|
|
def _check_login(self):
|
|
|
"""检查是否已登录"""
|
|
|
if not self._is_logged_in:
|
|
|
raise RuntimeError("未连接到数据源,请先调用 connect()")
|
|
|
|
|
|
def _format_date(self, d: Union[str, int, date]) -> int:
|
|
|
"""统一日期格式为 YYYYMMDD"""
|
|
|
if isinstance(d, int):
|
|
|
return d
|
|
|
elif isinstance(d, str):
|
|
|
return int(d.replace("-", "").replace("/", ""))
|
|
|
elif isinstance(d, date):
|
|
|
return int(d.strftime("%Y%m%d"))
|
|
|
else:
|
|
|
raise ValueError(f"不支持的日期格式: {d}")
|
|
|
|
|
|
async def connect(self, config: dict) -> None:
|
|
|
"""建立连接"""
|
|
|
try:
|
|
|
import AmazingData as ad
|
|
|
self._ad = ad
|
|
|
|
|
|
# 解析配置
|
|
|
# 处理 port:支持字符串或整数,空字符串时使用默认值
|
|
|
port_val = config.get("port", 8080)
|
|
|
if isinstance(port_val, str):
|
|
|
port_val = int(port_val) if port_val.strip() else 8080
|
|
|
|
|
|
self.config = AmazingDataConfig(
|
|
|
username=config.get("username", ""),
|
|
|
password=config.get("password", ""),
|
|
|
host=config.get("host", ""),
|
|
|
port=port_val,
|
|
|
local_path=config.get("local_path", "./amazing_data_cache/"),
|
|
|
use_local_cache=config.get("use_local_cache", True)
|
|
|
)
|
|
|
|
|
|
# 在executor中执行同步的登录操作
|
|
|
loop = asyncio.get_event_loop()
|
|
|
await loop.run_in_executor(None, self._do_login)
|
|
|
|
|
|
self._connected = True
|
|
|
info("成功连接到 AmazingData 星耀数智数据源")
|
|
|
|
|
|
except ImportError:
|
|
|
error("AmazingData SDK 未安装,请先安装 tgw 和 AmazingData 包")
|
|
|
raise RuntimeError("AmazingData SDK not installed")
|
|
|
except Exception as e:
|
|
|
error(f"连接 AmazingData 失败: {e}")
|
|
|
raise
|
|
|
|
|
|
def _do_login(self):
|
|
|
"""执行登录(同步方法)"""
|
|
|
print("[amazingdata_adapter]正在登录 AmazingData...")
|
|
|
print(f"[amazingdata_adapter]登录用户: {self.config.username}")
|
|
|
print(f"[amazingdata_adapter]登录地址: {self.config.host}:{self.config.port}")
|
|
|
# 登录
|
|
|
self._ad.login(
|
|
|
username=self.config.username,
|
|
|
password=self.config.password,
|
|
|
host=self.config.host,
|
|
|
port=self.config.port
|
|
|
)
|
|
|
|
|
|
# 初始化数据类
|
|
|
self._base_data = self._ad.BaseData()
|
|
|
self._info_data = self._ad.InfoData()
|
|
|
self._calendar = self._base_data.get_calendar()
|
|
|
self._market_data = self._ad.MarketData(self._calendar)
|
|
|
|
|
|
self._is_logged_in = True
|
|
|
print("[amazingdata_adapter]登录成功")
|
|
|
|
|
|
async def close(self) -> None:
|
|
|
"""关闭连接"""
|
|
|
if self._is_logged_in and self._ad:
|
|
|
try:
|
|
|
loop = asyncio.get_event_loop()
|
|
|
await loop.run_in_executor(
|
|
|
None,
|
|
|
lambda: self._ad.logout(self.config.username)
|
|
|
)
|
|
|
info("已断开与 AmazingData 的连接")
|
|
|
except Exception as e:
|
|
|
warning(f"断开连接时出错: {e}")
|
|
|
self._is_logged_in = False
|
|
|
self._connected = False
|
|
|
|
|
|
async def subscribe_ticks(self, symbols: List[str], callback: TickCallback) -> None:
|
|
|
"""订阅实时Tick(AmazingData暂不支持实时推送模式)"""
|
|
|
raise NotImplementedError("AmazingData does not support real-time tick subscription via callback")
|
|
|
|
|
|
async def fetch_klines(
|
|
|
self,
|
|
|
symbol: str,
|
|
|
start: str,
|
|
|
end: str,
|
|
|
freq: str
|
|
|
) -> List[KLineData]:
|
|
|
print(f"[amazingdata_adapter fetch_klines]正在拉取 {symbol} 的 {freq} 周期数据...")
|
|
|
"""拉取历史K线"""
|
|
|
self._check_login()
|
|
|
|
|
|
# 转换周期格式为 AmazingData SDK 的周期值
|
|
|
period_map = {
|
|
|
"1m": self._ad.constant.Period.min1,
|
|
|
"5m": self._ad.constant.Period.min5,
|
|
|
"15m": self._ad.constant.Period.min15,
|
|
|
"30m": self._ad.constant.Period.min30,
|
|
|
"60m": self._ad.constant.Period.min60,
|
|
|
"1d": self._ad.constant.Period.day,
|
|
|
"day": self._ad.constant.Period.day,
|
|
|
"D": self._ad.constant.Period.day,
|
|
|
"1w": self._ad.constant.Period.week,
|
|
|
"week": self._ad.constant.Period.week,
|
|
|
"W": self._ad.constant.Period.week,
|
|
|
"1M": self._ad.constant.Period.month,
|
|
|
"month": self._ad.constant.Period.month,
|
|
|
"M": self._ad.constant.Period.month,
|
|
|
}
|
|
|
period_value = period_map.get(freq, self._ad.constant.Period.day).value
|
|
|
|
|
|
# 在executor中执行同步查询
|
|
|
loop = asyncio.get_event_loop()
|
|
|
return await loop.run_in_executor(
|
|
|
None,
|
|
|
lambda: self._fetch_klines_sync(symbol, start, end, period_value)
|
|
|
)
|
|
|
|
|
|
def _fetch_klines_sync(
|
|
|
self,
|
|
|
symbol: str,
|
|
|
start_date: str,
|
|
|
end_date: str,
|
|
|
period_value: str
|
|
|
) -> List[KLineData]:
|
|
|
print(f"[amazingdata_adapter _fetch_klines_sync]正在同步拉取 {symbol} 的 {period_value} 周期数据...")
|
|
|
"""同步获取K线数据"""
|
|
|
codes = [symbol]
|
|
|
start_int = self._format_date(start_date)
|
|
|
end_int = self._format_date(end_date)
|
|
|
|
|
|
print(f"[amazingdata_adapter _fetch_klines_sync]正在获取K线数据: 代码={codes}, 日期范围={start_date}~{end_date}, 周期={period_value}")
|
|
|
# 获取K线数据 - 将周期值转换为 SDK 的常量
|
|
|
print(f"[amazingdata_adapter _fetch_klines_sync]SDK 周期值: {period_value}, type: {type(period_value)}")
|
|
|
kline_dict = self._market_data.query_kline(
|
|
|
code_list=codes,
|
|
|
begin_date=start_int,
|
|
|
end_date=end_int,
|
|
|
period=period_value
|
|
|
)
|
|
|
print(f"[amazingdata_adapter _fetch_klines_sync]已同步获取 {symbol} 的 {period_value} 周期数据")
|
|
|
print(f"[amazingdata_adapter _fetch_klines_sync]数据预览: {kline_dict.get(symbol).head() if symbol in kline_dict else '无数据'}")
|
|
|
|
|
|
results = []
|
|
|
if symbol in kline_dict:
|
|
|
df = kline_dict[symbol]
|
|
|
print(f"[amazingdata_adapter _fetch_klines_sync]DataFrame columns: {df.columns.tolist()}")
|
|
|
print(f"[amazingdata_adapter _fetch_klines_sync]DataFrame head:\n{df.head()}")
|
|
|
|
|
|
for _, row in df.iterrows():
|
|
|
# 从 kline_time 列获取日期(AmazingData 返回的日期字段)
|
|
|
kline_time = row.get('kline_time')
|
|
|
if pd.isna(kline_time) or kline_time is None:
|
|
|
print(f"[amazingdata_adapter _fetch_klines_sync]跳过无效日期: kline_time 为空")
|
|
|
continue
|
|
|
|
|
|
try:
|
|
|
# kline_time 可能是 Timestamp 或整数 YYYYMMDD
|
|
|
if isinstance(kline_time, pd.Timestamp):
|
|
|
ts = int(kline_time.timestamp())
|
|
|
trade_date = kline_time.strftime('%Y-%m-%d')
|
|
|
else:
|
|
|
# 整数格式 YYYYMMDD
|
|
|
date_str = str(int(kline_time))
|
|
|
if len(date_str) != 8:
|
|
|
print(f"[amazingdata_adapter _fetch_klines_sync]跳过无效日期: {date_str}")
|
|
|
continue
|
|
|
|
|
|
dt = datetime.strptime(date_str, "%Y%m%d")
|
|
|
ts = int(dt.timestamp())
|
|
|
trade_date = dt.strftime('%Y-%m-%d')
|
|
|
except (ValueError, TypeError) as e:
|
|
|
print(f"[amazingdata_adapter _fetch_klines_sync]日期解析错误 '{kline_time}' (type: {type(kline_time)}): {e}")
|
|
|
continue
|
|
|
|
|
|
results.append(KLineData(
|
|
|
symbol=symbol,
|
|
|
time=ts,
|
|
|
open=float(row.get('open', 0)),
|
|
|
high=float(row.get('high', 0)),
|
|
|
low=float(row.get('low', 0)),
|
|
|
close=float(row.get('close', 0)),
|
|
|
volume=int(row.get('volume', 0)),
|
|
|
amount=float(row.get('amount', 0)),
|
|
|
trade_date=trade_date
|
|
|
))
|
|
|
|
|
|
info(f"Fetched {len(results)} klines from AmazingData for {symbol}")
|
|
|
return results
|
|
|
|
|
|
async def fetch_symbols(self, asset_type: str) -> List[SymbolInfo]:
|
|
|
"""获取标的列表"""
|
|
|
self._check_login()
|
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
return await loop.run_in_executor(
|
|
|
None,
|
|
|
lambda: self._fetch_symbols_sync(asset_type)
|
|
|
)
|
|
|
|
|
|
def _fetch_symbols_sync(self, asset_type: str) -> List[SymbolInfo]:
|
|
|
"""同步获取标的列表"""
|
|
|
results = []
|
|
|
|
|
|
if asset_type == "stock":
|
|
|
# 获取A股代码列表
|
|
|
codes = self._base_data.get_code_list(
|
|
|
security_type=SecurityType.STOCK_A.value
|
|
|
)
|
|
|
# 获取代码信息
|
|
|
info_df = self._base_data.get_code_info(
|
|
|
security_type=SecurityType.STOCK_A.value
|
|
|
)
|
|
|
|
|
|
# 构建代码到名称的映射
|
|
|
name_map = {}
|
|
|
if 'symbol' in info_df.columns:
|
|
|
for code in codes:
|
|
|
if code in info_df.index:
|
|
|
name_map[code] = info_df.loc[code, 'symbol']
|
|
|
|
|
|
for code in codes:
|
|
|
# 解析交易所
|
|
|
if ".SH" in code:
|
|
|
exchange = "SH"
|
|
|
elif ".SZ" in code:
|
|
|
exchange = "SZ"
|
|
|
elif ".BJ" in code:
|
|
|
exchange = "BJ"
|
|
|
else:
|
|
|
exchange = ""
|
|
|
|
|
|
results.append(SymbolInfo(
|
|
|
symbol_id=code,
|
|
|
name=name_map.get(code, code),
|
|
|
exchange=exchange
|
|
|
))
|
|
|
|
|
|
elif asset_type == "futures":
|
|
|
# 获取期货代码列表
|
|
|
codes = self._base_data.get_future_code_list(
|
|
|
security_type=SecurityType.FUTURE.value
|
|
|
)
|
|
|
|
|
|
for code in codes:
|
|
|
# 解析品种和合约月份
|
|
|
underlying = ''.join([c for c in code if c.isalpha()]).upper()
|
|
|
contract_month = ''.join([c for c in code if c.isdigit()])
|
|
|
exchange = self._get_futures_exchange(underlying)
|
|
|
ts_code = f"{code.upper()}.{exchange}"
|
|
|
|
|
|
results.append(SymbolInfo(
|
|
|
symbol_id=ts_code,
|
|
|
name=code,
|
|
|
exchange=exchange,
|
|
|
underlying=underlying,
|
|
|
contract_month=contract_month
|
|
|
))
|
|
|
|
|
|
info(f"Fetched {len(results)} {asset_type} symbols from AmazingData")
|
|
|
return results
|
|
|
|
|
|
def _get_futures_exchange(self, underlying: str) -> str:
|
|
|
"""根据品种代码判断交易所"""
|
|
|
# 上海期货交易所
|
|
|
if underlying in ['CU', 'AL', 'ZN', 'PB', 'NI', 'SN', 'AU', 'AG', 'RB', 'HC',
|
|
|
'BU', 'RU', 'FU', 'SP', 'WR', 'SS', 'LU', 'NR']:
|
|
|
return 'SHFE'
|
|
|
# 大连商品交易所
|
|
|
elif underlying in ['A', 'B', 'M', 'Y', 'P', 'C', 'CS', 'JD', 'LH', 'JM',
|
|
|
'J', 'I', 'FB', 'BB', 'RR', 'PG', 'EB', 'EG', 'V', 'PP', 'L']:
|
|
|
return 'DCE'
|
|
|
# 郑州商品交易所
|
|
|
elif underlying in ['WH', 'PM', 'CF', 'SR', 'TA', 'OI', 'RI', 'MA', 'FG', 'RS',
|
|
|
'RM', 'JR', 'LR', 'SM', 'SF', 'CY', 'AP', 'CJ', 'UR', 'SA', 'PF', 'PK']:
|
|
|
return 'CZCE'
|
|
|
# 中国金融期货交易所
|
|
|
elif underlying in ['IF', 'IC', 'IH', 'T', 'TF', 'TS', 'IM']:
|
|
|
return 'CFFEX'
|
|
|
# 上海国际能源交易中心
|
|
|
elif underlying in ['SC', 'BC', 'EC']:
|
|
|
return 'INE'
|
|
|
else:
|
|
|
return 'SHFE' # 默认上海
|
|
|
|
|
|
async def fetch_trading_calendar(
|
|
|
self,
|
|
|
exchange: str,
|
|
|
start: str,
|
|
|
end: str
|
|
|
) -> List[TradeCalData]:
|
|
|
"""获取交易日历"""
|
|
|
self._check_login()
|
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
return await loop.run_in_executor(
|
|
|
None,
|
|
|
lambda: self._fetch_calendar_sync(exchange, start, end)
|
|
|
)
|
|
|
|
|
|
def _fetch_calendar_sync(
|
|
|
self,
|
|
|
exchange: str,
|
|
|
start: str,
|
|
|
end: str
|
|
|
) -> List[TradeCalData]:
|
|
|
"""同步获取交易日历"""
|
|
|
# 获取交易日历
|
|
|
market = Market.SH if exchange in ["SH", "SSE"] else Market.SZ
|
|
|
calendar = self._base_data.get_calendar(market=market.value)
|
|
|
|
|
|
start_int = self._format_date(start)
|
|
|
end_int = self._format_date(end)
|
|
|
|
|
|
results = []
|
|
|
for date_int in calendar:
|
|
|
if start_int <= date_int <= end_int:
|
|
|
dt = datetime.strptime(str(date_int), "%Y%m%d")
|
|
|
results.append(TradeCalData(
|
|
|
date=dt,
|
|
|
is_trading_day=True
|
|
|
))
|
|
|
|
|
|
return results
|
|
|
|
|
|
async def health_check(self) -> bool:
|
|
|
"""健康检查"""
|
|
|
if not self._connected or not self._is_logged_in:
|
|
|
return False
|
|
|
|
|
|
try:
|
|
|
# 尝试获取代码列表作为健康检查
|
|
|
loop = asyncio.get_event_loop()
|
|
|
await loop.run_in_executor(
|
|
|
None,
|
|
|
lambda: self._base_data.get_code_list(
|
|
|
security_type=SecurityType.STOCK_A.value
|
|
|
)
|
|
|
)
|
|
|
return True
|
|
|
except Exception as e:
|
|
|
error(f"AmazingData health check failed: {e}")
|
|
|
return False
|
|
|
|
|
|
# ==================== 星耀数智特有接口 ====================
|
|
|
|
|
|
async def get_adj_factor(
|
|
|
self,
|
|
|
codes: List[str],
|
|
|
is_local: Optional[bool] = None
|
|
|
) -> pd.DataFrame:
|
|
|
"""获取复权因子(单次复权)
|
|
|
|
|
|
Args:
|
|
|
codes: 股票代码列表
|
|
|
is_local: 是否使用本地缓存
|
|
|
|
|
|
Returns:
|
|
|
DataFrame (index: 日期, columns: 股票代码)
|
|
|
"""
|
|
|
self._check_login()
|
|
|
is_local = is_local if is_local is not None else self.config.use_local_cache
|
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
return await loop.run_in_executor(
|
|
|
None,
|
|
|
lambda: self._base_data.get_adj_factor(
|
|
|
code_list=codes,
|
|
|
local_path=self.config.local_path,
|
|
|
is_local=is_local
|
|
|
)
|
|
|
)
|
|
|
|
|
|
async def get_backward_factor(
|
|
|
self,
|
|
|
codes: List[str],
|
|
|
is_local: Optional[bool] = None
|
|
|
) -> pd.DataFrame:
|
|
|
"""获取后复权因子"""
|
|
|
self._check_login()
|
|
|
is_local = is_local if is_local is not None else self.config.use_local_cache
|
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
return await loop.run_in_executor(
|
|
|
None,
|
|
|
lambda: self._base_data.get_backward_factor(
|
|
|
code_list=codes,
|
|
|
local_path=self.config.local_path,
|
|
|
is_local=is_local
|
|
|
)
|
|
|
)
|
|
|
|
|
|
async def get_index_constituents(
|
|
|
self,
|
|
|
codes: List[str],
|
|
|
is_local: Optional[bool] = None
|
|
|
) -> Dict[str, pd.DataFrame]:
|
|
|
"""获取指数成分股
|
|
|
|
|
|
Args:
|
|
|
codes: 指数代码列表,如 ['000300.SH', '000905.SH']
|
|
|
|
|
|
Returns:
|
|
|
Dict[指数代码, DataFrame]
|
|
|
"""
|
|
|
self._check_login()
|
|
|
is_local = is_local if is_local is not None else self.config.use_local_cache
|
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
return await loop.run_in_executor(
|
|
|
None,
|
|
|
lambda: self._info_data.get_index_constituent(
|
|
|
code_list=codes,
|
|
|
local_path=self.config.local_path,
|
|
|
is_local=is_local
|
|
|
)
|
|
|
)
|
|
|
|
|
|
async def get_index_weights(
|
|
|
self,
|
|
|
codes: List[str],
|
|
|
start_date: Optional[str] = None,
|
|
|
end_date: Optional[str] = None,
|
|
|
is_local: Optional[bool] = None
|
|
|
) -> Dict[str, pd.DataFrame]:
|
|
|
"""获取指数成分股权重
|
|
|
|
|
|
支持指数:
|
|
|
- 000016.SH: 上证50
|
|
|
- 000300.SH: 沪深300
|
|
|
- 000905.SH: 中证500
|
|
|
- 000906.SH: 中证800
|
|
|
- 000852.SH: 中证1000
|
|
|
"""
|
|
|
self._check_login()
|
|
|
is_local = is_local if is_local is not None else self.config.use_local_cache
|
|
|
|
|
|
begin_date = self._format_date(start_date) if start_date else None
|
|
|
end_date_int = self._format_date(end_date) if end_date else None
|
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
return await loop.run_in_executor(
|
|
|
None,
|
|
|
lambda: self._info_data.get_index_weight(
|
|
|
code_list=codes,
|
|
|
local_path=self.config.local_path,
|
|
|
is_local=is_local,
|
|
|
begin_date=begin_date,
|
|
|
end_date=end_date_int
|
|
|
)
|
|
|
)
|
|
|
|
|
|
async def get_snapshot(
|
|
|
self,
|
|
|
codes: List[str],
|
|
|
start_date: str,
|
|
|
end_date: str
|
|
|
) -> Dict[str, pd.DataFrame]:
|
|
|
"""获取历史快照数据(tick级别)"""
|
|
|
self._check_login()
|
|
|
|
|
|
start_int = self._format_date(start_date)
|
|
|
end_int = self._format_date(end_date)
|
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
return await loop.run_in_executor(
|
|
|
None,
|
|
|
lambda: self._market_data.query_snapshot(
|
|
|
code_list=codes,
|
|
|
begin_date=start_int,
|
|
|
end_date=end_int
|
|
|
)
|
|
|
)
|
|
|
|
|
|
async def get_balance_sheet(
|
|
|
self,
|
|
|
codes: List[str],
|
|
|
start_date: Optional[str] = None,
|
|
|
end_date: Optional[str] = None,
|
|
|
is_local: Optional[bool] = None
|
|
|
) -> Dict[str, pd.DataFrame]:
|
|
|
"""获取资产负债表"""
|
|
|
return await self._get_financial_data(
|
|
|
'get_balance_sheet', codes, start_date, end_date, is_local
|
|
|
)
|
|
|
|
|
|
async def get_cash_flow(
|
|
|
self,
|
|
|
codes: List[str],
|
|
|
start_date: Optional[str] = None,
|
|
|
end_date: Optional[str] = None,
|
|
|
is_local: Optional[bool] = None
|
|
|
) -> Dict[str, pd.DataFrame]:
|
|
|
"""获取现金流量表"""
|
|
|
return await self._get_financial_data(
|
|
|
'get_cash_flow', codes, start_date, end_date, is_local
|
|
|
)
|
|
|
|
|
|
async def get_income_statement(
|
|
|
self,
|
|
|
codes: List[str],
|
|
|
start_date: Optional[str] = None,
|
|
|
end_date: Optional[str] = None,
|
|
|
is_local: Optional[bool] = None
|
|
|
) -> Dict[str, pd.DataFrame]:
|
|
|
"""获取利润表"""
|
|
|
return await self._get_financial_data(
|
|
|
'get_income', codes, start_date, end_date, is_local
|
|
|
)
|
|
|
|
|
|
async def _get_financial_data(
|
|
|
self,
|
|
|
method: str,
|
|
|
codes: List[str],
|
|
|
start_date: Optional[str] = None,
|
|
|
end_date: Optional[str] = None,
|
|
|
is_local: Optional[bool] = None
|
|
|
) -> Dict[str, pd.DataFrame]:
|
|
|
"""通用财务数据获取方法"""
|
|
|
self._check_login()
|
|
|
is_local = is_local if is_local is not None else self.config.use_local_cache
|
|
|
|
|
|
begin_date = self._format_date(start_date) if start_date else None
|
|
|
end_date_int = self._format_date(end_date) if end_date else None
|
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
|
def fetch():
|
|
|
fn = getattr(self._info_data, method)
|
|
|
return fn(
|
|
|
code_list=codes,
|
|
|
local_path=self.config.local_path,
|
|
|
is_local=is_local,
|
|
|
begin_date=begin_date,
|
|
|
end_date=end_date_int
|
|
|
)
|
|
|
|
|
|
return await loop.run_in_executor(None, fetch)
|