diff --git a/app/services/collector.py b/app/services/collector.py index 469f786..55bbacf 100644 --- a/app/services/collector.py +++ b/app/services/collector.py @@ -10,11 +10,11 @@ from typing import Dict, List, Optional logger = logging.getLogger(__name__) -# 获取原始采集脚本路径 (buffer_platform/app/services -> buffer_platform -> parent = market_data_colector_platform) -SCRIPT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')) -if SCRIPT_DIR not in sys.path: - sys.path.insert(0, SCRIPT_DIR) - logger.info(f"已添加采集脚本路径到sys.path: {SCRIPT_DIR}") +# 获取项目根目录 (buffer_platform/app/services -> buffer_platform) +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) +if PROJECT_ROOT not in sys.path: + sys.path.insert(0, PROJECT_ROOT) + logger.info(f"已添加项目根目录到sys.path: {PROJECT_ROOT}") def fetch_symbol_data( @@ -25,18 +25,6 @@ def fetch_symbol_data( ) -> Dict: """ 获取单个品种的多周期数据。 - - 返回格式: - { - "symbol": "SN2504", - "type": "futures", - "current_price": 12345.0, - "timestamp": "2025-01-15T10:30:00+08:00", - "timeframes": { - "5min": [{"datetime": ..., "open": ..., ...}, ...], - ... - } - } """ try: from futures_data_collector import collect_futures_data, collect_stock_data diff --git a/futures_data_collector.py b/futures_data_collector.py new file mode 100644 index 0000000..95f1ebe --- /dev/null +++ b/futures_data_collector.py @@ -0,0 +1,427 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +期货/股票多周期数据获取与技术指标计算脚本 +""" + +import akshare as ak +import pandas as pd +import json +import argparse +import os +from datetime import datetime, timedelta +from typing import Dict, List +import warnings +warnings.filterwarnings('ignore') +ak.cache = {} + +DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data') +os.makedirs(DATA_DIR, exist_ok=True) + +def calculate_ma(df: pd.DataFrame, periods: List[int] = [10, 20]) -> pd.DataFrame: + """计算移动平均线""" + for period in periods: + df[f'MA{period}'] = df['close'].rolling(window=period, min_periods=1).mean() + return df + + +def calculate_macd(df: pd.DataFrame, fast: int = 12, slow: int = 26, signal: int = 9) -> pd.DataFrame: + """计算MACD指标""" + ema_fast = df['close'].ewm(span=fast, adjust=False).mean() + ema_slow = df['close'].ewm(span=slow, adjust=False).mean() + df['macd_dif'] = ema_fast - ema_slow + df['macd_dea'] = df['macd_dif'].ewm(span=signal, adjust=False).mean() + df['macd_histogram'] = (df['macd_dif'] - df['macd_dea']) * 2 + + df['macd_signal'] = df.apply(lambda row: + 'bullish' if row['macd_dif'] > row['macd_dea'] and row['macd_histogram'] > 0 + else 'bearish' if row['macd_dif'] < row['macd_dea'] and row['macd_histogram'] < 0 + else 'neutral', axis=1) + + return df + + +def get_current_time() -> datetime: + """获取当前北京时间(去除微秒)""" + return datetime.now().replace(microsecond=0) + + +def filter_future_data(df: pd.DataFrame, current_time: datetime = None) -> pd.DataFrame: + """过滤掉未来数据""" + if current_time is None: + current_time = get_current_time() + + if 'datetime' not in df.columns: + return df + + df['datetime'] = pd.to_datetime(df['datetime']) + original_count = len(df) + df = df[df['datetime'] <= current_time].copy() + filtered_count = original_count - len(df) + + if filtered_count > 0: + print(f" 过滤了 {filtered_count} 条未来数据") + + return df + + +def extend_night_session_data(df: pd.DataFrame, symbol: str, period: str) -> pd.DataFrame: + """尝试获取完整的夜盘数据""" + if df.empty or 'datetime' not in df.columns: + return df + + df['datetime'] = pd.to_datetime(df['datetime']) + df = df.sort_values('datetime').reset_index(drop=True) + + last_time = df['datetime'].iloc[-1] + last_hour = last_time.hour + last_minute = last_time.minute + + is_night_session = ( + (last_hour >= 21) or + (last_hour < 2) or + (last_hour == 2 and last_minute <= 30) + ) + + if not is_night_session: + return df + + has_0230 = False + for dt in df['datetime']: + if dt.hour == 2 and dt.minute == 30: + has_0230 = True + break + + if has_0230: + return df + + print(f" 注意: 夜盘数据可能不完整(缺少02:30及之前的数据)") + + return df + + +def get_minute_data(symbol: str, period: str) -> pd.DataFrame: + """获取期货分钟K线数据""" + try: + current_time = get_current_time() + df = ak.futures_zh_minute_sina(symbol=symbol, period=period) + + df = df.rename(columns={ + 'day': 'datetime', + 'open': 'open', + 'high': 'high', + 'low': 'low', + 'close': 'close', + 'volume': 'volume' + }) + + for col in ['open', 'high', 'low', 'close', 'volume']: + df[col] = pd.to_numeric(df[col], errors='coerce') + + df['datetime'] = pd.to_datetime(df['datetime']) + df = filter_future_data(df, current_time) + df = extend_night_session_data(df, symbol, period) + + if len(df) < 50: + print(f" 警告: {period}分钟只获取到{len(df)}根K线,建议检查数据源") + + return df + + except Exception as e: + print(f" 获取{period}分钟数据失败: {e}") + return pd.DataFrame() + + +def get_daily_data(symbol: str, days: int = 60) -> pd.DataFrame: + """获取期货日K线数据""" + try: + current_time = get_current_time() + df = ak.futures_zh_daily_sina(symbol=symbol) + + df = df.rename(columns={ + 'date': 'datetime', + 'open': 'open', + 'high': 'high', + 'low': 'low', + 'close': 'close', + 'volume': 'volume' + }) + + for col in ['open', 'high', 'low', 'close', 'volume']: + df[col] = pd.to_numeric(df[col], errors='coerce') + + df['datetime'] = pd.to_datetime(df['datetime']) + df = df.sort_values('datetime').reset_index(drop=True) + df = filter_future_data(df, current_time) + df = df.tail(days).reset_index(drop=True) + + return df + + except Exception as e: + print(f" 获取日K数据失败: {e}") + return pd.DataFrame() + + +def get_stock_minute_data(symbol: str, period: str) -> pd.DataFrame: + """获取股票分钟K线数据""" + try: + current_time = get_current_time() + + if symbol.startswith('6'): + full_symbol = f"sh{symbol}" + else: + full_symbol = f"sz{symbol}" + + df = ak.stock_zh_a_minute(symbol=full_symbol, period=period) + + df = df.rename(columns={ + 'day': 'datetime', + 'open': 'open', + 'high': 'high', + 'low': 'low', + 'close': 'close', + 'volume': 'volume' + }) + + for col in ['open', 'high', 'low', 'close', 'volume']: + df[col] = pd.to_numeric(df[col], errors='coerce') + + df['datetime'] = pd.to_datetime(df['datetime']) + df = filter_future_data(df, current_time) + + if len(df) < 50: + print(f" 警告: {period}分钟只获取到{len(df)}根K线,建议检查数据源") + + return df + + except Exception as e: + print(f" 获取{period}分钟数据失败: {e}") + return pd.DataFrame() + + +def get_stock_daily_data(symbol: str, days: int = 60) -> pd.DataFrame: + """获取股票日K线数据""" + try: + current_time = get_current_time() + end_date = current_time.strftime('%Y%m%d') + start_date = (current_time - timedelta(days=days*2)).strftime('%Y%m%d') + + df = ak.stock_zh_a_hist(symbol=symbol, period="daily", start_date=start_date, end_date=end_date) + + df = df.rename(columns={ + '日期': 'datetime', + '开盘': 'open', + '最高': 'high', + '最低': 'low', + '收盘': 'close', + '成交量': 'volume' + }) + + for col in ['open', 'high', 'low', 'close', 'volume']: + df[col] = pd.to_numeric(df[col], errors='coerce') + + df['datetime'] = pd.to_datetime(df['datetime']) + df = df.sort_values('datetime').reset_index(drop=True) + df = filter_future_data(df, current_time) + df = df.tail(days).reset_index(drop=True) + + return df + + except Exception as e: + print(f" 获取日K数据失败: {e}") + return pd.DataFrame() + + +def process_data(df: pd.DataFrame, timeframe: str) -> List[Dict]: + """处理数据,计算指标并格式化输出""" + if df.empty or len(df) < 10: + return [] + + df = calculate_ma(df) + df = calculate_macd(df) + + candles = [] + df_tail = df.tail(50) if len(df) > 50 else df + + for _, row in df_tail.iterrows(): + candle = { + "time": str(row['datetime']), + "open": round(float(row['open']), 2), + "high": round(float(row['high']), 2), + "low": round(float(row['low']), 2), + "close": round(float(row['close']), 2), + "volume": int(row['volume']) if not pd.isna(row['volume']) else 0, + "ma10": round(float(row['MA10']), 2) if not pd.isna(row.get('MA10')) else None, + "ma20": round(float(row['MA20']), 2) if not pd.isna(row.get('MA20')) else None, + "macd_dif": round(float(row['macd_dif']), 4) if not pd.isna(row.get('macd_dif')) else 0, + "macd_dea": round(float(row['macd_dea']), 4) if not pd.isna(row.get('macd_dea')) else 0, + "macd_histogram": round(float(row['macd_histogram']), 4) if not pd.isna(row.get('macd_histogram')) else 0 + } + candles.append(candle) + + return candles + + +def collect_futures_data(symbol: str) -> Dict: + """收集期货多周期完整数据""" + print(f"\n正在获取期货 {symbol} 的多周期数据...") + print(f"当前时间: {get_current_time().strftime('%Y-%m-%d %H:%M:%S')}") + print("-" * 50) + + result = { + "symbol": symbol, + "type": "futures", + "current_price": None, + "timestamp": datetime.now().strftime("%Y-%m-%dT%H:%M:%S+08:00"), + "timeframes": {} + } + + periods = [ + ("60min", "60"), + ("30min", "30"), + ("15min", "15"), + ("5min", "5") + ] + + for tf_name, tf_period in periods: + print(f"获取 {tf_name} 数据...") + try: + df = get_minute_data(symbol, tf_period) + if not df.empty and len(df) >= 50: + candles = process_data(df, tf_name) + if candles: + result["timeframes"][tf_name] = candles + if result["current_price"] is None: + result["current_price"] = candles[-1]["close"] + print(f" [OK] 成功获取 {len(candles)} 根K线") + else: + print(f" [FAIL] 数据不足或获取失败 (获取到{len(df)}根)") + except Exception as e: + print(f" [ERROR] 错误: {e}") + + print("获取 daily 数据...") + try: + df_daily = get_daily_data(symbol, days=60) + if not df_daily.empty and len(df_daily) >= 50: + candles = process_data(df_daily, "daily") + if candles: + result["timeframes"]["daily"] = candles + print(f" [OK] 成功获取 {len(candles)} 根K线") + else: + print(f" [FAIL] 数据不足或获取失败 (获取到{len(df_daily)}根)") + except Exception as e: + print(f" [ERROR] 错误: {e}") + + print("-" * 50) + return result + + +def collect_stock_data(symbol: str) -> Dict: + """收集股票多周期完整数据""" + print(f"\n正在获取股票 {symbol} 的多周期数据...") + print(f"当前时间: {get_current_time().strftime('%Y-%m-%d %H:%M:%S')}") + print("-" * 50) + + result = { + "symbol": symbol, + "type": "stock", + "current_price": None, + "timestamp": datetime.now().strftime("%Y-%m-%dT%H:%M:%S+08:00"), + "timeframes": {} + } + + periods = [ + ("60min", "60"), + ("30min", "30"), + ("15min", "15"), + ("5min", "5") + ] + + for tf_name, tf_period in periods: + print(f"获取 {tf_name} 数据...") + try: + df = get_stock_minute_data(symbol, tf_period) + if not df.empty and len(df) >= 50: + candles = process_data(df, tf_name) + if candles: + result["timeframes"][tf_name] = candles + if result["current_price"] is None: + result["current_price"] = candles[-1]["close"] + print(f" [OK] 成功获取 {len(candles)} 根K线") + else: + print(f" [FAIL] 数据不足或获取失败 (获取到{len(df)}根)") + except Exception as e: + print(f" [ERROR] 错误: {e}") + + print("获取 daily 数据...") + try: + df_daily = get_stock_daily_data(symbol, days=60) + if not df_daily.empty and len(df_daily) >= 50: + candles = process_data(df_daily, "daily") + if candles: + result["timeframes"]["daily"] = candles + print(f" [OK] 成功获取 {len(candles)} 根K线") + else: + print(f" [FAIL] 数据不足或获取失败 (获取到{len(df_daily)}根)") + except Exception as e: + print(f" [ERROR] 错误: {e}") + + print("-" * 50) + return result + + +def main(): + parser = argparse.ArgumentParser(description='期货/股票多周期数据获取与技术指标计算') + parser.add_argument('--symbol', type=str, required=True, + help='代码,期货如 SN2504(沪锡), 股票如 000001(平安银行)') + parser.add_argument('--type', type=str, default='futures', choices=['futures', 'stock'], + help='数据类型:futures(期货)、stock(股票),默认为 futures') + parser.add_argument('--output', type=str, default=None, + help='输出JSON文件名,默认为 代码_时间戳.json') + + args = parser.parse_args() + + if args.type == 'stock': + data = collect_stock_data(args.symbol) + else: + data = collect_futures_data(args.symbol) + + if not data["timeframes"]: + print("\n错误: 未能获取到任何数据,请检查代码是否正确") + if args.type == 'stock': + print("常见股票代码示例:") + print(" 000001 - 平安银行") + print(" 600000 - 浦发银行") + print(" 000858 - 五粮液") + print(" 600519 - 贵州茅台") + else: + print("常见期货合约代码示例:") + print(" SN2504 - 沪锡2504") + print(" AG2506 - 沪银2506") + print(" LC2505 - 碳酸锂2505") + print(" NI2505 - 沪镍2505") + return + + print("\n" + "="*60) + print("JSON 输出:") + print("="*60) + json_output = json.dumps(data, ensure_ascii=False, indent=2) + print(json_output) + + if args.output: + filename = os.path.join(DATA_DIR, args.output) + else: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = os.path.join(DATA_DIR, f"{data['symbol']}_{timestamp}.json") + + with open(filename, 'w', encoding='utf-8') as f: + f.write(json_output) + + print("\n" + "="*60) + print(f"[OK] 数据已保存到: {filename}") + print(f"[OK] 共获取 {len(data['timeframes'])} 个周期数据") + print("="*60) + + +if __name__ == "__main__": + main()