""" 数据同步模块 从 AKShare 获取数据并同步到 MySQL 数据库 """ import os import json import time from datetime import datetime, timedelta from typing import List, Dict, Optional from concurrent.futures import ThreadPoolExecutor, as_completed import akshare as ak import pandas as pd from sqlalchemy import func from database import ( get_db, init_db, check_connection, Stock, StockQuote, StockKLine, Sector, SectorQuote, MarketIndex, HighLowStock, MomentumStock, upsert_stock, upsert_stock_quote, upsert_stock_kline, upsert_sector, upsert_market_index ) # 配置 BATCH_SIZE = 100 MAX_WORKERS = 5 def dataframe_to_records(df: pd.DataFrame) -> List[Dict]: """将 DataFrame 转换为可 JSON 序列化的记录列表""" if df is None or df.empty: return [] df = df.replace({pd.NaT: None}) df = df.where(pd.notnull(df), None) return df.to_dict('records') def safe_float(value, default=0.0): """安全转换为浮点数""" try: if pd.isna(value) or value is None: return default return float(value) except: return default def safe_int(value, default=0): """安全转换为整数""" try: if pd.isna(value) or value is None: return default return int(value) except: return default # ==================== 股票数据同步 ==================== def sync_all_stocks(): """同步所有股票基础信息""" print("开始同步股票列表...") try: df = ak.stock_zh_a_spot_em() records = dataframe_to_records(df) with get_db() as db: count = 0 for record in records: try: code = record.get("代码") name = record.get("名称") industry = record.get("行业", "") if not code or not name: continue upsert_stock( db, code=code, name=name, pe=safe_float(record.get("市盈率-动态")), pb=safe_float(record.get("市净率")), ) count += 1 if count % BATCH_SIZE == 0: print(f"已处理 {count} 只股票") except Exception as e: print(f"处理股票 {record.get('代码')} 失败: {e}") continue print(f"股票列表同步完成,共 {count} 只") return count except Exception as e: print(f"同步股票列表失败: {e}") return 0 def sync_realtime_quotes(): """同步实时行情""" print("开始同步实时行情...") try: df = ak.stock_zh_a_spot_em() records = dataframe_to_records(df) now = datetime.now() with get_db() as db: count = 0 for record in records: try: code = record.get("代码") if not code: continue upsert_stock_quote( db, stock_code=code, price=safe_float(record.get("最新价")), open=safe_float(record.get("开盘价")), high=safe_float(record.get("最高价")), low=safe_float(record.get("最低价")), preClose=safe_float(record.get("昨收")), volume=safe_int(record.get("成交量")), turnover=safe_int(record.get("成交额")), changePercent=safe_float(record.get("涨跌幅")), turnoverRate=safe_float(record.get("换手率")), amplitude=safe_float(record.get("振幅")), quoteTime=now ) count += 1 except Exception as e: print(f"处理行情 {record.get('代码')} 失败: {e}") continue print(f"实时行情同步完成,共 {count} 条") return count except Exception as e: print(f"同步实时行情失败: {e}") return 0 def sync_stock_kline(symbol: str, period: str = "daily", days: int = 365): """同步单只股票K线数据""" try: end_date = datetime.now() start_date = end_date - timedelta(days=days) df = ak.stock_zh_a_hist( symbol=symbol, period=period, start_date=start_date.strftime("%Y%m%d"), end_date=end_date.strftime("%Y%m%d"), adjust="qfq" ) if df is None or df.empty: return 0 records = dataframe_to_records(df) period_map = {"daily": "day", "weekly": "week", "monthly": "month"} db_period = period_map.get(period, "day") with get_db() as db: count = 0 for record in records: try: date_str = record.get("日期") if not date_str: continue date = datetime.strptime(str(date_str), "%Y-%m-%d") upsert_stock_kline( db, stock_code=symbol, period=db_period, date=date, open=safe_float(record.get("开盘")), high=safe_float(record.get("最高")), low=safe_float(record.get("最低")), close=safe_float(record.get("收盘")), volume=safe_int(record.get("成交量")) ) count += 1 except Exception as e: continue return count except Exception as e: print(f"同步 {symbol} K线失败: {e}") return 0 def sync_all_klines(period: str = "daily", days: int = 365, max_stocks: Optional[int] = None): """批量同步所有股票K线数据""" print(f"开始同步K线数据 (周期: {period}, 天数: {days})...") # 获取股票列表 with get_db() as db: stocks = db.query(Stock).limit(max_stocks).all() if max_stocks else db.query(Stock).all() total = len(stocks) print(f"共 {total} 只股票需要同步") success_count = 0 fail_count = 0 for i, stock in enumerate(stocks): try: count = sync_stock_kline(stock.code, period, days) if count > 0: success_count += 1 print(f"[{i+1}/{total}] {stock.code} {stock.name} 同步 {count} 条K线") else: fail_count += 1 # 避免请求过快 time.sleep(0.5) except Exception as e: print(f"[{i+1}/{total}] {stock.code} 同步失败: {e}") fail_count += 1 continue print(f"K线同步完成,成功: {success_count}, 失败: {fail_count}") return success_count # ==================== 板块数据同步 ==================== def sync_sectors(): """同步板块信息""" print("开始同步板块信息...") try: df = ak.stock_board_industry_name_em() records = dataframe_to_records(df) with get_db() as db: count = 0 for record in records: try: code = record.get("代码") name = record.get("名称") if not code or not name: continue upsert_sector(db, code=code, name=name) count += 1 except Exception as e: print(f"处理板块 {record.get('名称')} 失败: {e}") continue print(f"板块信息同步完成,共 {count} 个") return count except Exception as e: print(f"同步板块信息失败: {e}") return 0 def sync_sector_quotes(): """同步板块行情""" print("开始同步板块行情...") try: df = ak.stock_board_industry_name_em() records = dataframe_to_records(df) now = datetime.now() with get_db() as db: count = 0 for record in records: try: code = record.get("代码") if not code: continue quote = SectorQuote( sectorCode=code, current=0, change=0, changePercent=safe_float(record.get("涨跌幅")), volume=0, turnover=0, quoteTime=now ) db.add(quote) count += 1 except Exception as e: continue print(f"板块行情同步完成,共 {count} 条") return count except Exception as e: print(f"同步板块行情失败: {e}") return 0 # ==================== 指数数据同步 ==================== def sync_market_indices(): """同步市场指数""" print("开始同步市场指数...") try: df = ak.index_zh_a_spot_em() records = dataframe_to_records(df) with get_db() as db: count = 0 for record in records: try: code = record.get("代码") name = record.get("名称") if not code or not name: continue upsert_market_index( db, code=code, name=name, current=safe_float(record.get("最新价")), change=safe_float(record.get("涨跌额")), changePercent=safe_float(record.get("涨跌幅")), volume=safe_int(record.get("成交量")), turnover=safe_int(record.get("成交额")) ) count += 1 except Exception as e: print(f"处理指数 {record.get('名称')} 失败: {e}") continue print(f"市场指数同步完成,共 {count} 个") return count except Exception as e: print(f"同步市场指数失败: {e}") return 0 # ==================== 批量同步工具 ==================== def sync_all(quick: bool = False): """执行全部同步 Args: quick: 如果为True,只同步少量数据用于测试 """ print("=" * 60) print("开始全量数据同步") print("=" * 60) # 1. 同步市场指数 sync_market_indices() # 2. 同步板块 sync_sectors() sync_sector_quotes() # 3. 同步股票列表 sync_all_stocks() # 4. 同步实时行情 sync_realtime_quotes() # 5. 同步K线(如果quick=True,只同步少量股票) if quick: print("快速模式:只同步前10只股票的K线") sync_all_klines(days=30, max_stocks=10) else: sync_all_klines(days=365) print("=" * 60) print("全量数据同步完成") print("=" * 60) def sync_daily(): """每日增量同步""" print("=" * 60) print(f"开始每日增量同步 - {datetime.now()}") print("=" * 60) # 1. 同步实时行情 sync_realtime_quotes() # 2. 同步板块行情 sync_sector_quotes() # 3. 同步市场指数 sync_market_indices() # 4. 同步今日K线 sync_all_klines(days=1) print("=" * 60) print("每日增量同步完成") print("=" * 60) if __name__ == "__main__": # 测试数据库连接 if not check_connection(): print("数据库连接失败") exit(1) # 初始化数据库表 init_db() # 执行全量同步(快速模式) sync_all(quick=True)