|
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
"""
|
|
|
|
|
|
===================================
|
|
|
|
|
|
PytdxFetcher - 通达信数据源 (Priority 2)
|
|
|
|
|
|
===================================
|
|
|
|
|
|
|
|
|
|
|
|
数据来源:通达信行情服务器(pytdx 库)
|
|
|
|
|
|
特点:免费、无需 Token、直连行情服务器
|
|
|
|
|
|
优点:实时数据、稳定、无配额限制
|
|
|
|
|
|
|
|
|
|
|
|
关键策略:
|
|
|
|
|
|
1. 多服务器自动切换
|
|
|
|
|
|
2. 连接超时自动重连
|
|
|
|
|
|
3. 失败后指数退避重试
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
|
import re
|
|
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
from typing import Optional, Generator, List, Tuple
|
|
|
|
|
|
|
|
|
|
|
|
import pandas as pd
|
|
|
|
|
|
from tenacity import (
|
|
|
|
|
|
retry,
|
|
|
|
|
|
stop_after_attempt,
|
|
|
|
|
|
wait_exponential,
|
|
|
|
|
|
retry_if_exception_type,
|
|
|
|
|
|
before_sleep_log,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
from .base import BaseFetcher, DataFetchError, STANDARD_COLUMNS
|
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_us_code(stock_code: str) -> bool:
|
|
|
|
|
|
"""
|
|
|
|
|
|
判断代码是否为美股
|
|
|
|
|
|
|
|
|
|
|
|
美股代码规则:
|
|
|
|
|
|
- 1-5个大写字母,如 'AAPL', 'TSLA'
|
|
|
|
|
|
- 可能包含 '.',如 'BRK.B'
|
|
|
|
|
|
"""
|
|
|
|
|
|
code = stock_code.strip().upper()
|
|
|
|
|
|
return bool(re.match(r'^[A-Z]{1,5}(\.[A-Z])?$', code))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PytdxFetcher(BaseFetcher):
|
|
|
|
|
|
"""
|
|
|
|
|
|
通达信数据源实现
|
|
|
|
|
|
|
|
|
|
|
|
优先级:2(与 Tushare 同级)
|
|
|
|
|
|
数据来源:通达信行情服务器
|
|
|
|
|
|
|
|
|
|
|
|
关键策略:
|
|
|
|
|
|
- 自动选择最优服务器
|
|
|
|
|
|
- 连接失败自动切换服务器
|
|
|
|
|
|
- 失败后指数退避重试
|
|
|
|
|
|
|
|
|
|
|
|
Pytdx 特点:
|
|
|
|
|
|
- 免费、无需注册
|
|
|
|
|
|
- 直连行情服务器
|
|
|
|
|
|
- 支持实时行情和历史数据
|
|
|
|
|
|
- 支持股票名称查询
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
name = "PytdxFetcher"
|
|
|
|
|
|
priority = int(os.getenv("PYTDX_PRIORITY", "2"))
|
|
|
|
|
|
|
|
|
|
|
|
# 默认通达信行情服务器列表
|
|
|
|
|
|
DEFAULT_HOSTS = [
|
|
|
|
|
|
("119.147.212.81", 7709), # 深圳
|
|
|
|
|
|
("112.74.214.43", 7727), # 深圳
|
|
|
|
|
|
("221.231.141.60", 7709), # 上海
|
|
|
|
|
|
("101.227.73.20", 7709), # 上海
|
|
|
|
|
|
("101.227.77.254", 7709), # 上海
|
|
|
|
|
|
("14.215.128.18", 7709), # 广州
|
|
|
|
|
|
("59.173.18.140", 7709), # 武汉
|
|
|
|
|
|
("180.153.39.51", 7709), # 杭州
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, hosts: Optional[List[Tuple[str, int]]] = None):
|
|
|
|
|
|
"""
|
|
|
|
|
|
初始化 PytdxFetcher
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
hosts: 服务器列表 [(host, port), ...],默认使用内置列表
|
|
|
|
|
|
"""
|
|
|
|
|
|
self._hosts = hosts or self.DEFAULT_HOSTS
|
|
|
|
|
|
self._api = None
|
|
|
|
|
|
self._connected = False
|
|
|
|
|
|
self._current_host_idx = 0
|
|
|
|
|
|
self._stock_list_cache = None # 股票列表缓存
|
|
|
|
|
|
self._stock_name_cache = {} # 股票名称缓存 {code: name}
|
|
|
|
|
|
|
|
|
|
|
|
def _get_pytdx(self):
|
|
|
|
|
|
"""
|
|
|
|
|
|
延迟加载 pytdx 模块
|
|
|
|
|
|
|
|
|
|
|
|
只在首次使用时导入,避免未安装时报错
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
from pytdx.hq import TdxHq_API
|
|
|
|
|
|
return TdxHq_API
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
logger.warning("pytdx 未安装,请运行: pip install pytdx")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
|
def _pytdx_session(self) -> Generator:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Pytdx 连接上下文管理器
|
|
|
|
|
|
|
|
|
|
|
|
确保:
|
|
|
|
|
|
1. 进入上下文时自动连接
|
|
|
|
|
|
2. 退出上下文时自动断开
|
|
|
|
|
|
3. 异常时也能正确断开
|
|
|
|
|
|
|
|
|
|
|
|
使用示例:
|
|
|
|
|
|
with self._pytdx_session() as api:
|
|
|
|
|
|
# 在这里执行数据查询
|
|
|
|
|
|
"""
|
|
|
|
|
|
TdxHq_API = self._get_pytdx()
|
|
|
|
|
|
if TdxHq_API is None:
|
|
|
|
|
|
raise DataFetchError("pytdx 库未安装")
|
|
|
|
|
|
|
|
|
|
|
|
api = TdxHq_API()
|
|
|
|
|
|
connected = False
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 尝试连接服务器(自动选择最优)
|
|
|
|
|
|
for i in range(len(self._hosts)):
|
|
|
|
|
|
host_idx = (self._current_host_idx + i) % len(self._hosts)
|
|
|
|
|
|
host, port = self._hosts[host_idx]
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
if api.connect(host, port, time_out=5):
|
|
|
|
|
|
connected = True
|
|
|
|
|
|
self._current_host_idx = host_idx
|
|
|
|
|
|
logger.debug(f"Pytdx 连接成功: {host}:{port}")
|
|
|
|
|
|
break
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.debug(f"Pytdx 连接 {host}:{port} 失败: {e}")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
if not connected:
|
|
|
|
|
|
raise DataFetchError("Pytdx 无法连接任何服务器")
|
|
|
|
|
|
|
|
|
|
|
|
yield api
|
|
|
|
|
|
|
|
|
|
|
|
finally:
|
|
|
|
|
|
# 确保断开连接
|
|
|
|
|
|
try:
|
|
|
|
|
|
api.disconnect()
|
|
|
|
|
|
logger.debug("Pytdx 连接已断开")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.warning(f"Pytdx 断开连接时出错: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
def _get_market_code(self, stock_code: str) -> Tuple[int, str]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
根据股票代码判断市场
|
|
|
|
|
|
|
|
|
|
|
|
Pytdx 市场代码:
|
|
|
|
|
|
- 0: 深圳
|
|
|
|
|
|
- 1: 上海
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
stock_code: 股票代码
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
(market, code) 元组
|
|
|
|
|
|
"""
|
|
|
|
|
|
code = stock_code.strip()
|
|
|
|
|
|
|
|
|
|
|
|
# 去除可能的前缀后缀
|
|
|
|
|
|
code = code.replace('.SH', '').replace('.SZ', '')
|
|
|
|
|
|
code = code.replace('.sh', '').replace('.sz', '')
|
|
|
|
|
|
code = code.replace('sh', '').replace('sz', '')
|
|
|
|
|
|
|
|
|
|
|
|
# 根据代码前缀判断市场
|
|
|
|
|
|
# 上海:60xxxx, 68xxxx(科创板)
|
|
|
|
|
|
# 深圳:00xxxx, 30xxxx(创业板), 002xxx(中小板)
|
|
|
|
|
|
if code.startswith(('60', '68')):
|
|
|
|
|
|
return 1, code # 上海
|
|
|
|
|
|
else:
|
|
|
|
|
|
return 0, code # 深圳
|
|
|
|
|
|
|
|
|
|
|
|
@retry(
|
|
|
|
|
|
stop=stop_after_attempt(3),
|
|
|
|
|
|
wait=wait_exponential(multiplier=1, min=2, max=30),
|
|
|
|
|
|
retry=retry_if_exception_type((ConnectionError, TimeoutError)),
|
|
|
|
|
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
|
|
|
|
)
|
|
|
|
|
|
def _fetch_raw_data(self, stock_code: str, start_date: str, end_date: str) -> pd.DataFrame:
|
|
|
|
|
|
"""
|
|
|
|
|
|
从通达信获取原始数据
|
|
|
|
|
|
|
|
|
|
|
|
使用 get_security_bars() 获取日线数据
|
|
|
|
|
|
|
|
|
|
|
|
流程:
|
|
|
|
|
|
1. 检查是否为美股(不支持)
|
|
|
|
|
|
2. 使用上下文管理器管理连接
|
|
|
|
|
|
3. 判断市场代码
|
|
|
|
|
|
4. 调用 API 获取 K 线数据
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 美股不支持,抛出异常让 DataFetcherManager 切换到其他数据源
|
|
|
|
|
|
if _is_us_code(stock_code):
|
|
|
|
|
|
raise DataFetchError(f"PytdxFetcher 不支持美股 {stock_code},请使用 AkshareFetcher 或 YfinanceFetcher")
|
|
|
|
|
|
|
|
|
|
|
|
market, code = self._get_market_code(stock_code)
|
|
|
|
|
|
|
|
|
|
|
|
# 计算需要获取的交易日数量(估算)
|
|
|
|
|
|
from datetime import datetime as dt
|
|
|
|
|
|
start_dt = dt.strptime(start_date, '%Y-%m-%d')
|
|
|
|
|
|
end_dt = dt.strptime(end_date, '%Y-%m-%d')
|
|
|
|
|
|
days = (end_dt - start_dt).days
|
|
|
|
|
|
count = min(max(days * 5 // 7 + 10, 30), 800) # 估算交易日,最大 800 条
|
|
|
|
|
|
|
|
|
|
|
|
logger.debug(f"调用 Pytdx get_security_bars(market={market}, code={code}, count={count})")
|
|
|
|
|
|
|
|
|
|
|
|
with self._pytdx_session() as api:
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 获取日 K 线数据
|
|
|
|
|
|
# category: 9-日线, 0-5分钟, 1-15分钟, 2-30分钟, 3-1小时
|
|
|
|
|
|
data = api.get_security_bars(
|
|
|
|
|
|
category=9, # 日线
|
|
|
|
|
|
market=market,
|
|
|
|
|
|
code=code,
|
|
|
|
|
|
start=0, # 从最新开始
|
|
|
|
|
|
count=count
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if data is None or len(data) == 0:
|
|
|
|
|
|
raise DataFetchError(f"Pytdx 未查询到 {stock_code} 的数据")
|
|
|
|
|
|
|
|
|
|
|
|
# 转换为 DataFrame
|
|
|
|
|
|
df = api.to_df(data)
|
|
|
|
|
|
|
|
|
|
|
|
# 过滤日期范围
|
|
|
|
|
|
df['datetime'] = pd.to_datetime(df['datetime'])
|
|
|
|
|
|
df = df[(df['datetime'] >= start_date) & (df['datetime'] <= end_date)]
|
|
|
|
|
|
|
|
|
|
|
|
return df
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
if isinstance(e, DataFetchError):
|
|
|
|
|
|
raise
|
|
|
|
|
|
raise DataFetchError(f"Pytdx 获取数据失败: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
def _normalize_data(self, df: pd.DataFrame, stock_code: str) -> pd.DataFrame:
|
|
|
|
|
|
"""
|
|
|
|
|
|
标准化 Pytdx 数据
|
|
|
|
|
|
|
|
|
|
|
|
Pytdx 返回的列名:
|
|
|
|
|
|
datetime, open, high, low, close, vol, amount
|
|
|
|
|
|
|
|
|
|
|
|
需要映射到标准列名:
|
|
|
|
|
|
date, open, high, low, close, volume, amount, pct_chg
|
|
|
|
|
|
"""
|
|
|
|
|
|
df = df.copy()
|
|
|
|
|
|
|
|
|
|
|
|
# 列名映射
|
|
|
|
|
|
column_mapping = {
|
|
|
|
|
|
'datetime': 'date',
|
|
|
|
|
|
'vol': 'volume',
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
df = df.rename(columns=column_mapping)
|
|
|
|
|
|
|
|
|
|
|
|
# 计算涨跌幅(pytdx 不返回涨跌幅,需要自己计算)
|
|
|
|
|
|
if 'pct_chg' not in df.columns and 'close' in df.columns:
|
|
|
|
|
|
df['pct_chg'] = df['close'].pct_change() * 100
|
|
|
|
|
|
df['pct_chg'] = df['pct_chg'].fillna(0).round(2)
|
|
|
|
|
|
|
|
|
|
|
|
# 添加股票代码列
|
|
|
|
|
|
df['code'] = stock_code
|
|
|
|
|
|
|
|
|
|
|
|
# 只保留需要的列
|
|
|
|
|
|
keep_cols = ['code'] + STANDARD_COLUMNS
|
|
|
|
|
|
existing_cols = [col for col in keep_cols if col in df.columns]
|
|
|
|
|
|
df = df[existing_cols]
|
|
|
|
|
|
|
|
|
|
|
|
return df
|
|
|
|
|
|
|
|
|
|
|
|
def get_stock_name(self, stock_code: str) -> Optional[str]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
获取股票名称
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
stock_code: 股票代码
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
股票名称,失败返回 None
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 先检查缓存
|
|
|
|
|
|
if stock_code in self._stock_name_cache:
|
|
|
|
|
|
return self._stock_name_cache[stock_code]
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
market, code = self._get_market_code(stock_code)
|
|
|
|
|
|
|
|
|
|
|
|
with self._pytdx_session() as api:
|
|
|
|
|
|
# 获取股票列表(缓存)
|
|
|
|
|
|
if self._stock_list_cache is None:
|
|
|
|
|
|
# 获取深圳和上海股票列表
|
|
|
|
|
|
sz_stocks = api.get_security_list(0, 0) # 深圳
|
|
|
|
|
|
sh_stocks = api.get_security_list(1, 0) # 上海
|
|
|
|
|
|
|
|
|
|
|
|
self._stock_list_cache = {}
|
|
|
|
|
|
for stock in (sz_stocks or []) + (sh_stocks or []):
|
|
|
|
|
|
self._stock_list_cache[stock['code']] = stock['name']
|
|
|
|
|
|
|
|
|
|
|
|
# 查找股票名称
|
|
|
|
|
|
name = self._stock_list_cache.get(code)
|
|
|
|
|
|
if name:
|
|
|
|
|
|
self._stock_name_cache[stock_code] = name
|
|
|
|
|
|
return name
|
|
|
|
|
|
|
|
|
|
|
|
# 尝试使用 get_finance_info
|
|
|
|
|
|
finance_info = api.get_finance_info(market, code)
|
|
|
|
|
|
if finance_info and 'name' in finance_info:
|
|
|
|
|
|
name = finance_info['name']
|
|
|
|
|
|
self._stock_name_cache[stock_code] = name
|
|
|
|
|
|
return name
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.warning(f"Pytdx 获取股票名称失败 {stock_code}: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
def get_realtime_quote(self, stock_code: str) -> Optional[dict]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
获取实时行情
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
stock_code: 股票代码
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
实时行情数据字典,失败返回 None
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
market, code = self._get_market_code(stock_code)
|
|
|
|
|
|
|
|
|
|
|
|
with self._pytdx_session() as api:
|
|
|
|
|
|
data = api.get_security_quotes([(market, code)])
|
|
|
|
|
|
|
|
|
|
|
|
if data and len(data) > 0:
|
|
|
|
|
|
quote = data[0]
|
|
|
|
|
|
return {
|
|
|
|
|
|
'code': stock_code,
|
|
|
|
|
|
'name': quote.get('name', ''),
|
|
|
|
|
|
'price': quote.get('price', 0),
|
|
|
|
|
|
'open': quote.get('open', 0),
|
|
|
|
|
|
'high': quote.get('high', 0),
|
|
|
|
|
|
'low': quote.get('low', 0),
|
|
|
|
|
|
'pre_close': quote.get('last_close', 0),
|
|
|
|
|
|
'volume': quote.get('vol', 0),
|
|
|
|
|
|
'amount': quote.get('amount', 0),
|
|
|
|
|
|
'bid_prices': [quote.get(f'bid{i}', 0) for i in range(1, 6)],
|
|
|
|
|
|
'ask_prices': [quote.get(f'ask{i}', 0) for i in range(1, 6)],
|
|
|
|
|
|
}
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.warning(f"Pytdx 获取实时行情失败 {stock_code}: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
# 测试代码
|
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
|
|
|
|
|
|
|
|
|
|
fetcher = PytdxFetcher()
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 测试历史数据
|
|
|
|
|
|
df = fetcher.get_daily_data('600519') # 茅台
|
|
|
|
|
|
print(f"获取成功,共 {len(df)} 条数据")
|
|
|
|
|
|
print(df.tail())
|
|
|
|
|
|
|
|
|
|
|
|
# 测试股票名称
|
|
|
|
|
|
name = fetcher.get_stock_name('600519')
|
|
|
|
|
|
print(f"股票名称: {name}")
|
|
|
|
|
|
|
|
|
|
|
|
# 测试实时行情
|
|
|
|
|
|
quote = fetcher.get_realtime_quote('600519')
|
|
|
|
|
|
print(f"实时行情: {quote}")
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"获取失败: {e}")
|