You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

618 lines
21 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""星耀数智(AmazingData)数据源适配器
基于银河证券星耀数智量化平台 SDK 的封装
提供统一、简洁的金融数据获取接口
"""
import asyncio
from datetime import datetime, date
from typing import List, Optional, Dict, Any, Union
from dataclasses import dataclass
from enum import Enum
from httpx import codes
import pandas as pd
from app.adapters.base import (
DataSourceAdapter, TickData, KLineData, SymbolInfo,
TradeCalData, TickCallback
)
from app.core.logger import info, error, warning
class SecurityType(Enum):
"""证券类型枚举"""
STOCK_A = "EXTRA_STOCK_A" # 沪深A股
STOCK_A_SH_SZ = "EXTRA_STOCK_A_SH_SZ" # 沪深A股沪深
INDEX_A = "EXTRA_INDEX_A" # 沪深指数
ETF = "EXTRA_ETF" # ETF
FUTURE = "EXTRA_FUTURE" # 期货
KZZ = "EXTRA_KZZ" # 可转债
GLRA = "EXTRA_GLRA" # 逆回购
HKT = "EXTRA_HKT" # 港股通
ETF_OP = "EXTRA_ETF_OP" # ETF期权
class Market(Enum):
"""市场枚举"""
SH = "SH" # 上海
SZ = "SZ" # 深圳
BJ = "BJ" # 北京
@dataclass
class AmazingDataConfig:
"""星耀数智数据源配置"""
username: str
password: str
host: str
port: int
local_path: str = "./amazing_data_cache/"
use_local_cache: bool = True
class AmazingDataAdapter(DataSourceAdapter):
"""星耀数智(AmazingData)数据源适配器
封装银河证券星耀数智 SDK提供统一的数据获取接口
"""
def __init__(self):
self.config: Optional[AmazingDataConfig] = None
self._ad = None
self._base_data = None
self._market_data = None
self._info_data = None
self._calendar = None
self._is_logged_in = False
self._connected = False
def _check_login(self):
"""检查是否已登录"""
if not self._is_logged_in:
raise RuntimeError("未连接到数据源,请先调用 connect()")
def _format_date(self, d: Union[str, int, date]) -> int:
"""统一日期格式为 YYYYMMDD"""
if isinstance(d, int):
return d
elif isinstance(d, str):
return int(d.replace("-", "").replace("/", ""))
elif isinstance(d, date):
return int(d.strftime("%Y%m%d"))
else:
raise ValueError(f"不支持的日期格式: {d}")
async def connect(self, config: dict) -> None:
"""建立连接"""
try:
import AmazingData as ad
self._ad = ad
# 解析配置
# 处理 port支持字符串或整数空字符串时使用默认值
port_val = config.get("port", 8080)
if isinstance(port_val, str):
port_val = int(port_val) if port_val.strip() else 8080
self.config = AmazingDataConfig(
username=config.get("username", ""),
password=config.get("password", ""),
host=config.get("host", ""),
port=port_val,
local_path=config.get("local_path", "./amazing_data_cache/"),
use_local_cache=config.get("use_local_cache", True)
)
# 在executor中执行同步的登录操作
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, self._do_login)
self._connected = True
info("成功连接到 AmazingData 星耀数智数据源")
except ImportError:
error("AmazingData SDK 未安装,请先安装 tgw 和 AmazingData 包")
raise RuntimeError("AmazingData SDK not installed")
except Exception as e:
error(f"连接 AmazingData 失败: {e}")
raise
def _do_login(self):
"""执行登录(同步方法)"""
print("[amazingdata_adapter]正在登录 AmazingData...")
print(f"[amazingdata_adapter]登录用户: {self.config.username}")
print(f"[amazingdata_adapter]登录地址: {self.config.host}:{self.config.port}")
# 登录
self._ad.login(
username=self.config.username,
password=self.config.password,
host=self.config.host,
port=self.config.port
)
# 初始化数据类
self._base_data = self._ad.BaseData()
self._info_data = self._ad.InfoData()
self._calendar = self._base_data.get_calendar()
self._market_data = self._ad.MarketData(self._calendar)
self._is_logged_in = True
print("[amazingdata_adapter]登录成功")
async def close(self) -> None:
"""关闭连接"""
if self._is_logged_in and self._ad:
try:
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
lambda: self._ad.logout(self.config.username)
)
info("已断开与 AmazingData 的连接")
except Exception as e:
warning(f"断开连接时出错: {e}")
self._is_logged_in = False
self._connected = False
async def subscribe_ticks(self, symbols: List[str], callback: TickCallback) -> None:
"""订阅实时TickAmazingData暂不支持实时推送模式"""
raise NotImplementedError("AmazingData does not support real-time tick subscription via callback")
async def fetch_klines(
self,
symbol: str,
start: str,
end: str,
freq: str
) -> List[KLineData]:
print(f"[amazingdata_adapter fetch_klines]正在拉取 {symbol}{freq} 周期数据...")
"""拉取历史K线"""
self._check_login()
# 转换周期格式为 AmazingData SDK 的周期值
period_map = {
"1m": self._ad.constant.Period.min1,
"5m": self._ad.constant.Period.min5,
"15m": self._ad.constant.Period.min15,
"30m": self._ad.constant.Period.min30,
"60m": self._ad.constant.Period.min60,
"1d": self._ad.constant.Period.day,
"day": self._ad.constant.Period.day,
"D": self._ad.constant.Period.day,
"1w": self._ad.constant.Period.week,
"week": self._ad.constant.Period.week,
"W": self._ad.constant.Period.week,
"1M": self._ad.constant.Period.month,
"month": self._ad.constant.Period.month,
"M": self._ad.constant.Period.month,
}
period_value = period_map.get(freq, self._ad.constant.Period.day).value
# 在executor中执行同步查询
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
lambda: self._fetch_klines_sync(symbol, start, end, period_value)
)
def _fetch_klines_sync(
self,
symbol: str,
start_date: str,
end_date: str,
period_value: str
) -> List[KLineData]:
print(f"[amazingdata_adapter _fetch_klines_sync]正在同步拉取 {symbol}{period_value} 周期数据...")
"""同步获取K线数据"""
codes = [symbol]
start_int = self._format_date(start_date)
end_int = self._format_date(end_date)
print(f"[amazingdata_adapter _fetch_klines_sync]正在获取K线数据: 代码={codes}, 日期范围={start_date}~{end_date}, 周期={period_value}")
# 获取K线数据 - 将周期值转换为 SDK 的常量
print(f"[amazingdata_adapter _fetch_klines_sync]SDK 周期值: {period_value}, type: {type(period_value)}")
kline_dict = self._market_data.query_kline(
code_list=codes,
begin_date=start_int,
end_date=end_int,
period=period_value
)
print(f"[amazingdata_adapter _fetch_klines_sync]已同步获取 {symbol}{period_value} 周期数据")
print(f"[amazingdata_adapter _fetch_klines_sync]数据预览: {kline_dict.get(symbol).head() if symbol in kline_dict else '无数据'}")
results = []
if symbol in kline_dict:
df = kline_dict[symbol]
print(f"[amazingdata_adapter _fetch_klines_sync]DataFrame columns: {df.columns.tolist()}")
print(f"[amazingdata_adapter _fetch_klines_sync]DataFrame head:\n{df.head()}")
for _, row in df.iterrows():
# 从 kline_time 列获取日期AmazingData 返回的日期字段)
kline_time = row.get('kline_time')
if pd.isna(kline_time) or kline_time is None:
print(f"[amazingdata_adapter _fetch_klines_sync]跳过无效日期: kline_time 为空")
continue
try:
# kline_time 可能是 Timestamp 或整数 YYYYMMDD
if isinstance(kline_time, pd.Timestamp):
ts = int(kline_time.timestamp())
trade_date = kline_time.strftime('%Y-%m-%d')
else:
# 整数格式 YYYYMMDD
date_str = str(int(kline_time))
if len(date_str) != 8:
print(f"[amazingdata_adapter _fetch_klines_sync]跳过无效日期: {date_str}")
continue
dt = datetime.strptime(date_str, "%Y%m%d")
ts = int(dt.timestamp())
trade_date = dt.strftime('%Y-%m-%d')
except (ValueError, TypeError) as e:
print(f"[amazingdata_adapter _fetch_klines_sync]日期解析错误 '{kline_time}' (type: {type(kline_time)}): {e}")
continue
results.append(KLineData(
symbol=symbol,
time=ts,
open=float(row.get('open', 0)),
high=float(row.get('high', 0)),
low=float(row.get('low', 0)),
close=float(row.get('close', 0)),
volume=int(row.get('volume', 0)),
amount=float(row.get('amount', 0)),
trade_date=trade_date
))
info(f"Fetched {len(results)} klines from AmazingData for {symbol}")
return results
async def fetch_symbols(self, asset_type: str) -> List[SymbolInfo]:
"""获取标的列表"""
self._check_login()
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
lambda: self._fetch_symbols_sync(asset_type)
)
def _fetch_symbols_sync(self, asset_type: str) -> List[SymbolInfo]:
"""同步获取标的列表"""
results = []
if asset_type == "stock":
# 获取A股代码列表
codes = self._base_data.get_code_list(
security_type=SecurityType.STOCK_A.value
)
# 获取代码信息
info_df = self._base_data.get_code_info(
security_type=SecurityType.STOCK_A.value
)
# 构建代码到名称的映射
name_map = {}
if 'symbol' in info_df.columns:
for code in codes:
if code in info_df.index:
name_map[code] = info_df.loc[code, 'symbol']
for code in codes:
# 解析交易所
if ".SH" in code:
exchange = "SH"
elif ".SZ" in code:
exchange = "SZ"
elif ".BJ" in code:
exchange = "BJ"
else:
exchange = ""
results.append(SymbolInfo(
symbol_id=code,
name=name_map.get(code, code),
exchange=exchange
))
elif asset_type == "futures":
# 获取期货代码列表
codes = self._base_data.get_future_code_list(
security_type=SecurityType.FUTURE.value
)
for code in codes:
# 解析品种和合约月份
underlying = ''.join([c for c in code if c.isalpha()]).upper()
contract_month = ''.join([c for c in code if c.isdigit()])
exchange = self._get_futures_exchange(underlying)
ts_code = f"{code.upper()}.{exchange}"
results.append(SymbolInfo(
symbol_id=ts_code,
name=code,
exchange=exchange,
underlying=underlying,
contract_month=contract_month
))
info(f"Fetched {len(results)} {asset_type} symbols from AmazingData")
return results
def _get_futures_exchange(self, underlying: str) -> str:
"""根据品种代码判断交易所"""
# 上海期货交易所
if underlying in ['CU', 'AL', 'ZN', 'PB', 'NI', 'SN', 'AU', 'AG', 'RB', 'HC',
'BU', 'RU', 'FU', 'SP', 'WR', 'SS', 'LU', 'NR']:
return 'SHFE'
# 大连商品交易所
elif underlying in ['A', 'B', 'M', 'Y', 'P', 'C', 'CS', 'JD', 'LH', 'JM',
'J', 'I', 'FB', 'BB', 'RR', 'PG', 'EB', 'EG', 'V', 'PP', 'L']:
return 'DCE'
# 郑州商品交易所
elif underlying in ['WH', 'PM', 'CF', 'SR', 'TA', 'OI', 'RI', 'MA', 'FG', 'RS',
'RM', 'JR', 'LR', 'SM', 'SF', 'CY', 'AP', 'CJ', 'UR', 'SA', 'PF', 'PK']:
return 'CZCE'
# 中国金融期货交易所
elif underlying in ['IF', 'IC', 'IH', 'T', 'TF', 'TS', 'IM']:
return 'CFFEX'
# 上海国际能源交易中心
elif underlying in ['SC', 'BC', 'EC']:
return 'INE'
else:
return 'SHFE' # 默认上海
async def fetch_trading_calendar(
self,
exchange: str,
start: str,
end: str
) -> List[TradeCalData]:
"""获取交易日历"""
self._check_login()
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
lambda: self._fetch_calendar_sync(exchange, start, end)
)
def _fetch_calendar_sync(
self,
exchange: str,
start: str,
end: str
) -> List[TradeCalData]:
"""同步获取交易日历"""
# 获取交易日历
market = Market.SH if exchange in ["SH", "SSE"] else Market.SZ
calendar = self._base_data.get_calendar(market=market.value)
start_int = self._format_date(start)
end_int = self._format_date(end)
results = []
for date_int in calendar:
if start_int <= date_int <= end_int:
dt = datetime.strptime(str(date_int), "%Y%m%d")
results.append(TradeCalData(
date=dt,
is_trading_day=True
))
return results
async def health_check(self) -> bool:
"""健康检查"""
if not self._connected or not self._is_logged_in:
return False
try:
# 尝试获取代码列表作为健康检查
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
lambda: self._base_data.get_code_list(
security_type=SecurityType.STOCK_A.value
)
)
return True
except Exception as e:
error(f"AmazingData health check failed: {e}")
return False
# ==================== 星耀数智特有接口 ====================
async def get_adj_factor(
self,
codes: List[str],
is_local: Optional[bool] = None
) -> pd.DataFrame:
"""获取复权因子(单次复权)
Args:
codes: 股票代码列表
is_local: 是否使用本地缓存
Returns:
DataFrame (index: 日期, columns: 股票代码)
"""
self._check_login()
is_local = is_local if is_local is not None else self.config.use_local_cache
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
lambda: self._base_data.get_adj_factor(
code_list=codes,
local_path=self.config.local_path,
is_local=is_local
)
)
async def get_backward_factor(
self,
codes: List[str],
is_local: Optional[bool] = None
) -> pd.DataFrame:
"""获取后复权因子"""
self._check_login()
is_local = is_local if is_local is not None else self.config.use_local_cache
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
lambda: self._base_data.get_backward_factor(
code_list=codes,
local_path=self.config.local_path,
is_local=is_local
)
)
async def get_index_constituents(
self,
codes: List[str],
is_local: Optional[bool] = None
) -> Dict[str, pd.DataFrame]:
"""获取指数成分股
Args:
codes: 指数代码列表,如 ['000300.SH', '000905.SH']
Returns:
Dict[指数代码, DataFrame]
"""
self._check_login()
is_local = is_local if is_local is not None else self.config.use_local_cache
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
lambda: self._info_data.get_index_constituent(
code_list=codes,
local_path=self.config.local_path,
is_local=is_local
)
)
async def get_index_weights(
self,
codes: List[str],
start_date: Optional[str] = None,
end_date: Optional[str] = None,
is_local: Optional[bool] = None
) -> Dict[str, pd.DataFrame]:
"""获取指数成分股权重
支持指数:
- 000016.SH: 上证50
- 000300.SH: 沪深300
- 000905.SH: 中证500
- 000906.SH: 中证800
- 000852.SH: 中证1000
"""
self._check_login()
is_local = is_local if is_local is not None else self.config.use_local_cache
begin_date = self._format_date(start_date) if start_date else None
end_date_int = self._format_date(end_date) if end_date else None
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
lambda: self._info_data.get_index_weight(
code_list=codes,
local_path=self.config.local_path,
is_local=is_local,
begin_date=begin_date,
end_date=end_date_int
)
)
async def get_snapshot(
self,
codes: List[str],
start_date: str,
end_date: str
) -> Dict[str, pd.DataFrame]:
"""获取历史快照数据tick级别"""
self._check_login()
start_int = self._format_date(start_date)
end_int = self._format_date(end_date)
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
lambda: self._market_data.query_snapshot(
code_list=codes,
begin_date=start_int,
end_date=end_int
)
)
async def get_balance_sheet(
self,
codes: List[str],
start_date: Optional[str] = None,
end_date: Optional[str] = None,
is_local: Optional[bool] = None
) -> Dict[str, pd.DataFrame]:
"""获取资产负债表"""
return await self._get_financial_data(
'get_balance_sheet', codes, start_date, end_date, is_local
)
async def get_cash_flow(
self,
codes: List[str],
start_date: Optional[str] = None,
end_date: Optional[str] = None,
is_local: Optional[bool] = None
) -> Dict[str, pd.DataFrame]:
"""获取现金流量表"""
return await self._get_financial_data(
'get_cash_flow', codes, start_date, end_date, is_local
)
async def get_income_statement(
self,
codes: List[str],
start_date: Optional[str] = None,
end_date: Optional[str] = None,
is_local: Optional[bool] = None
) -> Dict[str, pd.DataFrame]:
"""获取利润表"""
return await self._get_financial_data(
'get_income', codes, start_date, end_date, is_local
)
async def _get_financial_data(
self,
method: str,
codes: List[str],
start_date: Optional[str] = None,
end_date: Optional[str] = None,
is_local: Optional[bool] = None
) -> Dict[str, pd.DataFrame]:
"""通用财务数据获取方法"""
self._check_login()
is_local = is_local if is_local is not None else self.config.use_local_cache
begin_date = self._format_date(start_date) if start_date else None
end_date_int = self._format_date(end_date) if end_date else None
loop = asyncio.get_event_loop()
def fetch():
fn = getattr(self._info_data, method)
return fn(
code_list=codes,
local_path=self.config.local_path,
is_local=is_local,
begin_date=begin_date,
end_date=end_date_int
)
return await loop.run_in_executor(None, fetch)