You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

319 lines
13 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import akshare as ak
import pandas as pd
import random
import time
import logging
import requests
from typing import List, Optional, Dict
from datetime import datetime
from app.services.datasource.base import DataSourceBase
logger = logging.getLogger(__name__)
class SmartRequester:
"""
反爬综合管理器:集成 Headers 伪装、拟人化延时、重试机制
(IP 代理部分暂按文档要求空缺,后续扩展)
"""
def __init__(self, max_retries: int = 3):
self.max_retries = max_retries
self.session = requests.Session()
# User-Agent 池
self.user_agents = [
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/115.0.0.0 Safari/537.36",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.0.0 Safari/537.36 Edg/113.0.1774.42",
]
def get_random_headers(self, referer: str = "https://finance.sina.com.cn/") -> Dict[str, str]:
"""生成随机请求头,模拟真实浏览器"""
return {
"User-Agent": random.choice(self.user_agents),
"Referer": referer,
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8",
"Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8",
"Accept-Encoding": "gzip, deflate, br",
"Connection": "keep-alive",
"Upgrade-Insecure-Requests": "1",
}
def request(self, url, referer="https://finance.sina.com.cn/", method="GET", **kwargs) -> Optional[requests.Response]:
"""执行智能请求"""
last_error = None
for attempt in range(self.max_retries):
try:
# 1. 拟人化延时
if attempt == 0:
time.sleep(random.uniform(0.5, 1.5))
else:
delay = (2 ** attempt) + random.uniform(1, 3)
logging.warning(f"{attempt+1} 次重试,等待 {delay:.1f} 秒...")
time.sleep(delay)
# 2. 轮换 Headers
headers = kwargs.pop("headers", {})
random_headers = self.get_random_headers(referer=referer)
headers.update(random_headers)
kwargs["headers"] = headers
# 3. 发送请求
logging.debug(f"请求: {method} {url}")
response = self.session.request(method, url, timeout=10, **kwargs)
# 4. 检查状态
if response.status_code == 200:
logging.info("✅ 请求成功")
return response
elif response.status_code == 403:
logging.warning("⚠️ 收到 403 Forbidden将尝试重试")
raise requests.exceptions.HTTPError("403 Forbidden")
else:
response.raise_for_status()
except Exception as e:
last_error = e
logging.warning(f"❌ 请求失败: {str(e)}")
logging.error(f"🚫 所有 {self.max_retries} 次尝试均失败")
return None
class AkshareSource(DataSourceBase):
"""AKShare 数据源适配器"""
def __init__(self, config: dict):
super().__init__(config)
self.name = "akshare"
self.requester = SmartRequester(max_retries=config.get("max_retries", 3))
def initialize(self) -> bool:
"""初始化检查"""
try:
ak.__version__
self._initialized = True
return True
except Exception as e:
self._initialized = False
logging.error(f"AkshareSource 初始化失败: {e}")
return False
def get_contract_list(self, exchange: Optional[str] = None) -> List[dict]:
"""获取期货合约列表"""
if not self._initialized:
self.initialize()
results = []
try:
exchanges_to_fetch = ["CZCE", "DCE", "SHFE", "INE", "CFFEX", "GFEX"]
if exchange:
exchanges_to_fetch = [exchange]
for ex in exchanges_to_fetch:
try:
func_name = f"futures_contract_info_{ex.lower()}"
if hasattr(ak, func_name):
df = getattr(ak, func_name)()
else:
continue
if df is not None and not df.empty:
# 统一列名映射
col_map = {}
if '合约代码' in df.columns:
col_map['合约代码'] = 'symbol'
if '产品代码' in df.columns:
col_map['产品代码'] = 'product'
if '品种' in df.columns:
col_map['品种'] = 'product'
if '交易单位' in df.columns:
col_map['交易单位'] = 'multiplier'
if '最小变动价位' in df.columns:
col_map['最小变动价位'] = 'price_tick'
if '上市日' in df.columns:
col_map['上市日'] = 'list_date'
if '到期日' in df.columns:
col_map['到期日'] = 'expire_date'
df = df.rename(columns=col_map)
for _, row in df.iterrows():
symbol = row.get("symbol")
if not symbol:
continue
# 提取品种代码 (通常 symbol 最后两位是年月,前面是品种,如 rb2401 -> rb)
product = row.get("product", "")
if not product and len(symbol) > 2:
# 简单提取字母部分
import re
match = re.match(r"([a-zA-Z]+)", symbol)
if match:
product = match.group(1).lower()
multiplier = row.get("multiplier")
if multiplier:
try:
multiplier = int(float(str(multiplier).replace(',', '')))
except:
multiplier = 10
else:
multiplier = 10 # 默认值
price_tick = row.get("price_tick")
if price_tick:
try:
price_tick = float(str(price_tick).replace(',', ''))
except:
price_tick = None
results.append({
"symbol": symbol,
"exchange": ex,
"name": symbol, # AKShare 通常不返回中文名称,用代码代替
"product": product,
"multiplier": multiplier,
"price_tick": price_tick,
"expire_date": self._parse_date(row.get("expire_date")),
"is_active": True,
})
except Exception as e:
logging.warning(f"获取 {ex} 合约列表失败: {e}")
continue
except Exception as e:
logging.error(f"获取合约列表异常: {e}")
return results
def _parse_date(self, date_str) -> Optional[datetime]:
"""解析日期"""
if not date_str:
return None
try:
if isinstance(date_str, str):
return datetime.strptime(date_str, "%Y-%m-%d")
elif isinstance(date_str, pd.Timestamp):
return date_str.to_pydatetime()
return None
except:
return None
def get_kline_daily(self, symbol: str, start_date: str, end_date: str) -> pd.DataFrame:
"""获取日 K 线数据"""
logger.info(f"[AKShare-日K线] 开始获取 symbol={symbol}, start_date={start_date}, end_date={end_date}")
if not self._initialized:
logger.info(f"[AKShare-日K线] 初始化 AKShare")
self.initialize()
try:
# AKShare 期货日 K 线接口futures_zh_daily_sina
logger.info(f"[AKShare-日K线] 调用 ak.futures_zh_daily_sina(symbol='{symbol}')")
df = ak.futures_zh_daily_sina(symbol=symbol)
if df is None or df.empty:
logger.warning(f"[AKShare-日K线] AKShare 返回空数据symbol={symbol}")
return pd.DataFrame()
logger.info(f"[AKShare-日K线] AKShare 返回 {len(df)} 条原始数据")
logger.debug(f"[AKShare-日K线] 原始数据列: {df.columns.tolist()}")
logger.debug(f"[AKShare-日K线] 原始数据样例:\n{df.head()}")
# 过滤日期范围
df['date'] = pd.to_datetime(df['date'])
logger.debug(f"[AKShare-日K线] 日期范围: {df['date'].min()} ~ {df['date'].max()}")
mask = (df['date'] >= start_date) & (df['date'] <= end_date)
df_filtered = df.loc[mask].copy()
logger.info(f"[AKShare-日K线] 日期过滤后剩余 {len(df_filtered)} 条记录")
# 统一列名
df_filtered = df_filtered.rename(columns={
"date": "trade_date",
"open": "open",
"high": "high",
"low": "low",
"close": "close",
"volume": "volume",
"hold": "open_interest",
"settle": "settle",
})
# AKShare 日线通常没有 turnover 和 pre_settle置空
df_filtered["turnover"] = None
df_filtered["pre_settle"] = None
df_filtered["trade_date"] = pd.to_datetime(df_filtered["trade_date"])
logger.info(f"[AKShare-日K线] 最终返回 {len(df_filtered)} 条记录")
return df_filtered[["trade_date", "open", "high", "low", "close", "volume", "turnover", "open_interest", "settle", "pre_settle"]]
except Exception as e:
logger.error(f"[AKShare-日K线] 获取 {symbol} 日 K 线失败: {e}", exc_info=True)
return pd.DataFrame()
def get_kline_weekly(self, symbol: str, start_date: str, end_date: str) -> pd.DataFrame:
"""获取周 K 线数据 (通过日 K 聚合)"""
daily_df = self.get_kline_daily(symbol, start_date, end_date)
if daily_df.empty:
return pd.DataFrame()
daily_df = daily_df.set_index("trade_date")
weekly = daily_df.resample("W-FRI").agg({
"open": "first",
"high": "max",
"low": "min",
"close": "last",
"volume": "sum",
"turnover": "sum",
"open_interest": "last",
}).dropna()
weekly = weekly.reset_index()
weekly = weekly.rename(columns={"trade_date": "trade_date"})
return weekly
def get_kline_intraday(self, symbol: str, period: str, start_date: str, end_date: str) -> pd.DataFrame:
"""获取分钟级 K 线数据"""
if not self._initialized:
self.initialize()
try:
# AKShare 期货分钟 K 线接口futures_zh_minute_sina
period_map = {"5m": "5", "15m": "15", "30m": "30", "60m": "60"}
freq = period_map.get(period, "5")
df = ak.futures_zh_minute_sina(symbol=symbol, period=freq)
if df is None or df.empty:
return pd.DataFrame()
# 过滤日期
df['datetime'] = pd.to_datetime(df['datetime'])
mask = (df['datetime'] >= start_date) & (df['datetime'] <= end_date)
df = df.loc[mask].copy()
# 统一列名
df = df.rename(columns={
"datetime": "trade_time",
"open": "open",
"high": "high",
"low": "low",
"close": "close",
"volume": "volume",
"hold": "open_interest",
})
df["turnover"] = None
return df[["trade_time", "open", "high", "low", "close", "volume", "turnover", "open_interest"]]
except Exception as e:
logging.error(f"获取 {symbol} 分钟 K 线失败: {e}")
return pd.DataFrame()