""" 数据导入路由 """ import pandas as pd import logging from fastapi import APIRouter, Depends, UploadFile, File, HTTPException from sqlalchemy.orm import Session from datetime import datetime, date from app.db.session import get_db from app.schemas.base import ResponseModel from app.models.stock_basic import StockBasic, IndexBasic, IndexTrade from app.core.security import get_current_user from app.models.user import User router = APIRouter() logger = logging.getLogger(__name__) INDEX_TRADE_COLUMN_MAP = { '证券代码': 'index_code', '证券名称': 'name', '成分个数 [交易日期]最新': 'component_count', '开盘价 [交易日期]最新': 'open', '收盘价 [交易日期]最新': 'close', '成交量 [交易日期]最新 [单位]股': 'volume', '成交额 [交易日期]最新 [单位]百万元': 'amount', '总市值 [截止日期]最新 [单位]百万元': 'total_market_value', '自由流通市值 [交易日期]最新 [单位]百万元': 'float_market_value', '涨跌幅 [交易日期]最新 [单位]%': 'change_pct', '最高价 [交易日期]最新': 'high', '最低价 [交易日期]最新': 'low', '上涨家数 [交易日期]最新': 'up_count', '下跌家数 [交易日期]最新': 'down_count', '平盘家数 [交易日期]最新': 'flat_count', '涨停家数 [交易日期]最新': 'limit_up_count', '跌停家数 [交易日期]最新': 'limit_down_count', '停牌家数 [交易日期]最新': 'suspend_count', '近期创历史新高 [交易日期]最新 [近N日内]300 [复权方式]不复权': 'is_new_high', '近期创历史新低 [交易日期]最新 [近N日内]300 [复权方式]不复权': 'is_new_low', '市盈率PE(TTM) [交易日期]最新 [剔除规则]不调整': 'pe_ratio', '市盈率PE(TTM)中位值 [交易日期]最新 [剔除规则]不调整': 'pe_median' } STOCK_BASIC_COLUMN_MAP = { '证券代码': 'code', '证券名称': 'name', '首发上市日': 'list_date', '所属东财行业指数名称\n[行业类别]2级': 'industry_index_name', '所属东财行业指数代码\n[行业类别]2级': 'industry_index_code', '机构持股比例合计\n[报告期]最新一期\n[单位]%\n[比例类型]占总股本比例': 'institution_hold_ratio', '所属东财行业名称\n[行业类别]3级': 'industry_level3' } @router.post("/index-data", response_model=ResponseModel) async def import_index_data( file: UploadFile = File(...), trade_date: str = None, db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """导入指数数据(同时更新指数基础表和指数交易表)""" if not file.filename.endswith(('.xls', '.xlsx')): raise HTTPException(status_code=400, detail="只支持xls或xlsx格式文件") if not trade_date: raise HTTPException(status_code=400, detail="请提供交易日期参数(YYYY-MM-DD格式)") try: trade_date_obj = datetime.strptime(trade_date, '%Y-%m-%d').date() except: raise HTTPException(status_code=400, detail="交易日期格式错误,请使用YYYY-MM-DD格式") try: df = pd.read_excel(file.file) df.columns = df.columns.str.strip() renamed_df = df.rename(columns=INDEX_TRADE_COLUMN_MAP) if 'index_code' not in renamed_df.columns: raise HTTPException(status_code=400, detail="缺少必要列:证券代码") success_count = 0 error_count = 0 index_basic_updated = 0 index_basic_added = 0 for _, row in renamed_df.iterrows(): try: index_code = str(row['index_code']).strip() if not index_code: continue name = str(row.get('name', '')) if pd.notna(row.get('name')) else None component_count = int(row.get('component_count')) if pd.notna(row.get('component_count')) else None index_basic = db.query(IndexBasic).filter(IndexBasic.code == index_code).first() if index_basic: if component_count and index_basic.component_count != component_count: index_basic.component_count = component_count index_basic.name = name or index_basic.name index_basic.updated_at = datetime.utcnow() index_basic_updated += 1 else: index_basic = IndexBasic( code=index_code, name=name, component_count=component_count ) db.add(index_basic) db.flush() index_basic_added += 1 existing_trade = db.query(IndexTrade).filter( IndexTrade.index_code == index_code, IndexTrade.trade_date == trade_date_obj ).first() def get_float_val(col_name): val = row.get(col_name) if pd.notna(val): try: return float(val) except: return None return None def get_int_val(col_name): val = row.get(col_name) if pd.notna(val): try: return int(float(val)) except: return None return None def get_bool_val(col_name): val = row.get(col_name) if pd.notna(val): if isinstance(val, bool): return val if isinstance(val, str): return val.lower() in ['true', '1', 'yes', '是'] return bool(val) return False open_price = get_float_val('open') close_price = get_float_val('close') high_price = get_float_val('high') low_price = get_float_val('low') change_pct = get_float_val('change_pct') volume = get_int_val('volume') amount = get_float_val('amount') total_market_value = get_float_val('total_market_value') float_market_value = get_float_val('float_market_value') up_count = get_int_val('up_count') down_count = get_int_val('down_count') flat_count = get_int_val('flat_count') limit_up_count = get_int_val('limit_up_count') limit_down_count = get_int_val('limit_down_count') suspend_count = get_int_val('suspend_count') pe_ratio = get_float_val('pe_ratio') pe_median = get_float_val('pe_median') is_new_high = get_bool_val('is_new_high') is_new_low = get_bool_val('is_new_low') if existing_trade: existing_trade.open = open_price existing_trade.close = close_price existing_trade.high = high_price existing_trade.low = low_price existing_trade.change_pct = change_pct existing_trade.volume = volume existing_trade.amount = amount existing_trade.total_market_value = total_market_value existing_trade.float_market_value = float_market_value existing_trade.up_count = up_count existing_trade.down_count = down_count existing_trade.flat_count = flat_count existing_trade.limit_up_count = limit_up_count existing_trade.limit_down_count = limit_down_count existing_trade.suspend_count = suspend_count existing_trade.pe_ratio = pe_ratio existing_trade.pe_median = pe_median existing_trade.is_new_high = is_new_high existing_trade.is_new_low = is_new_low existing_trade.updated_at = datetime.utcnow() else: trade = IndexTrade( index_code=index_code, trade_date=trade_date_obj, open=open_price, close=close_price, high=high_price, low=low_price, change_pct=change_pct, volume=volume, amount=amount, total_market_value=total_market_value, float_market_value=float_market_value, up_count=up_count, down_count=down_count, flat_count=flat_count, limit_up_count=limit_up_count, limit_down_count=limit_down_count, suspend_count=suspend_count, pe_ratio=pe_ratio, pe_median=pe_median, is_new_high=is_new_high, is_new_low=is_new_low ) db.add(trade) success_count += 1 except Exception as e: logger.error(f"导入指数{row.get('index_code')}失败: {str(e)}") error_count += 1 db.commit() return ResponseModel(data={ "success_count": success_count, "error_count": error_count, "total_count": len(df), "index_basic_added": index_basic_added, "index_basic_updated": index_basic_updated, "trade_date": trade_date }) except Exception as e: logger.error(f"导入指数数据失败: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/stock-basic", response_model=ResponseModel) async def import_stock_basic( file: UploadFile = File(...), db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """导入股票基础数据(支持模板格式)""" if not file.filename.endswith(('.xls', '.xlsx')): raise HTTPException(status_code=400, detail="只支持xls或xlsx格式文件") try: df = pd.read_excel(file.file) df.columns = df.columns.str.strip() renamed_df = df.rename(columns=STOCK_BASIC_COLUMN_MAP) if 'code' not in renamed_df.columns: raise HTTPException(status_code=400, detail="缺少必要列:证券代码") if 'name' not in renamed_df.columns: raise HTTPException(status_code=400, detail="缺少必要列:证券名称") success_count = 0 error_count = 0 added_count = 0 updated_count = 0 skipped_count = 0 skipped_details = [] error_details = [] for _, row in renamed_df.iterrows(): try: code_val = row.get('code') if pd.isna(code_val): continue code = str(code_val).strip() if not code or code.lower() == 'nan': continue existing = db.query(StockBasic).filter(StockBasic.code == code).first() list_date = None list_date_val = row.get('list_date') if pd.notna(list_date_val): if isinstance(list_date_val, datetime): list_date = list_date_val.date() elif isinstance(list_date_val, str): try: list_date = datetime.strptime(list_date_val, '%Y-%m-%d').date() except: pass elif hasattr(list_date_val, 'date'): list_date = list_date_val.date() name = str(row.get('name', '')) if pd.notna(row.get('name')) else None industry_index_name = str(row.get('industry_index_name', '')) if pd.notna(row.get('industry_index_name')) else None industry_index_code = str(row.get('industry_index_code', '')) if pd.notna(row.get('industry_index_code')) else None institution_hold_ratio = None ratio_val = row.get('institution_hold_ratio') if pd.notna(ratio_val): try: if str(ratio_val).strip() == '--': institution_hold_ratio = 0.0 else: institution_hold_ratio = float(ratio_val) except: institution_hold_ratio = 0.0 industry_level3 = str(row.get('industry_level3', '')) if pd.notna(row.get('industry_level3')) else None if industry_index_code: index_basic = db.query(IndexBasic).filter(IndexBasic.code == industry_index_code).first() if not index_basic: index_basic = IndexBasic( code=industry_index_code, name=industry_index_name or industry_index_code ) db.add(index_basic) db.flush() if existing: def is_same_data(): def compare_ratio(): if existing.institution_hold_ratio is None and institution_hold_ratio is None: return True if existing.institution_hold_ratio is None or institution_hold_ratio is None: return False return abs(float(existing.institution_hold_ratio) - institution_hold_ratio) < 0.0001 return ( (existing.name == name or (existing.name is None and name is None)) and (existing.industry_index_name == industry_index_name or (existing.industry_index_name is None and industry_index_name is None)) and (existing.industry_index_code == industry_index_code or (existing.industry_index_code is None and industry_index_code is None)) and compare_ratio() and (existing.industry_level3 == industry_level3 or (existing.industry_level3 is None and industry_level3 is None)) and (existing.list_date == list_date or (existing.list_date is None and list_date is None)) ) if is_same_data(): skipped_count += 1 skipped_details.append({ "code": code, "name": name, "reason": "数据相同,无需更新" }) else: existing.name = name existing.industry_index_name = industry_index_name existing.industry_index_code = industry_index_code existing.institution_hold_ratio = institution_hold_ratio existing.industry_level3 = industry_level3 existing.list_date = list_date existing.updated_at = datetime.utcnow() updated_count += 1 else: stock = StockBasic( code=code, name=name, total_shares=None, float_shares=None, industry_index_name=industry_index_name, industry_index_code=industry_index_code, institution_hold_ratio=institution_hold_ratio, industry_level3=industry_level3, list_date=list_date ) db.add(stock) added_count += 1 success_count += 1 except Exception as e: logger.error(f"导入股票{row.get('code')}失败: {str(e)}") error_count += 1 error_details.append({ "code": str(row.get('code', '')), "name": str(row.get('name', '')) if pd.notna(row.get('name')) else '', "reason": str(e) }) db.commit() return ResponseModel(data={ "success_count": success_count, "error_count": error_count, "total_count": len(df), "added_count": added_count, "updated_count": updated_count, "skipped_count": skipped_count, "skipped_details": skipped_details[:100], "error_details": error_details[:100] }) except Exception as e: logger.error(f"导入股票基础数据失败: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/index-basic", response_model=ResponseModel) async def import_index_basic( file: UploadFile = File(...), db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """导入指数基础数据""" if not file.filename.endswith(('.xls', '.xlsx')): raise HTTPException(status_code=400, detail="只支持xls或xlsx格式文件") try: df = pd.read_excel(file.file) required_columns = ['code', 'name', 'component_count'] missing_columns = [col for col in required_columns if col not in df.columns] if missing_columns: raise HTTPException(status_code=400, detail=f"缺少必要列: {missing_columns}") success_count = 0 error_count = 0 for _, row in df.iterrows(): try: existing = db.query(IndexBasic).filter(IndexBasic.code == str(row['code'])).first() if existing: existing.name = str(row.get('name', existing.name)) existing.component_count = int(row.get('component_count', existing.component_count)) if pd.notna(row.get('component_count')) else existing.component_count existing.updated_at = datetime.utcnow() else: index = IndexBasic( code=str(row['code']), name=str(row.get('name', '')), component_count=int(row['component_count']) if pd.notna(row['component_count']) else None ) db.add(index) success_count += 1 except Exception as e: logger.error(f"导入指数{row.get('code')}失败: {str(e)}") error_count += 1 db.commit() return ResponseModel(data={ "success_count": success_count, "error_count": error_count, "total_count": len(df) }) except Exception as e: logger.error(f"导入指数基础数据失败: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/index-trade", response_model=ResponseModel) async def import_index_trade( file: UploadFile = File(...), db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """导入指数交易数据""" if not file.filename.endswith(('.xls', '.xlsx')): raise HTTPException(status_code=400, detail="只支持xls或xlsx格式文件") try: df = pd.read_excel(file.file) required_columns = ['index_code', 'trade_date', 'open', 'close', 'high', 'low'] missing_columns = [col for col in required_columns if col not in df.columns] if missing_columns: raise HTTPException(status_code=400, detail=f"缺少必要列: {missing_columns}") success_count = 0 error_count = 0 for _, row in df.iterrows(): try: trade_date = None if pd.notna(row['trade_date']): if isinstance(row['trade_date'], datetime): trade_date = row['trade_date'].date() elif isinstance(row['trade_date'], str): trade_date = datetime.strptime(row['trade_date'], '%Y-%m-%d').date() existing = db.query(IndexTrade).filter( IndexTrade.index_code == str(row['index_code']), IndexTrade.trade_date == trade_date ).first() if existing: existing.open = float(row.get('open', existing.open)) if pd.notna(row.get('open')) else existing.open existing.close = float(row.get('close', existing.close)) if pd.notna(row.get('close')) else existing.close existing.high = float(row.get('high', existing.high)) if pd.notna(row.get('high')) else existing.high existing.low = float(row.get('low', existing.low)) if pd.notna(row.get('low')) else existing.low existing.change_pct = float(row.get('change_pct', existing.change_pct)) if pd.notna(row.get('change_pct')) else existing.change_pct existing.volume = int(row.get('volume', existing.volume)) if pd.notna(row.get('volume')) else existing.volume existing.amount = float(row.get('amount', existing.amount)) if pd.notna(row.get('amount')) else existing.amount existing.total_market_value = float(row.get('total_market_value', existing.total_market_value)) if pd.notna(row.get('total_market_value')) else existing.total_market_value existing.float_market_value = float(row.get('float_market_value', existing.float_market_value)) if pd.notna(row.get('float_market_value')) else existing.float_market_value existing.up_count = int(row.get('up_count', existing.up_count)) if pd.notna(row.get('up_count')) else existing.up_count existing.down_count = int(row.get('down_count', existing.down_count)) if pd.notna(row.get('down_count')) else existing.down_count existing.flat_count = int(row.get('flat_count', existing.flat_count)) if pd.notna(row.get('flat_count')) else existing.flat_count existing.limit_up_count = int(row.get('limit_up_count', existing.limit_up_count)) if pd.notna(row.get('limit_up_count')) else existing.limit_up_count existing.limit_down_count = int(row.get('limit_down_count', existing.limit_down_count)) if pd.notna(row.get('limit_down_count')) else existing.limit_down_count existing.suspend_count = int(row.get('suspend_count', existing.suspend_count)) if pd.notna(row.get('suspend_count')) else existing.suspend_count existing.pe_ratio = float(row.get('pe_ratio', existing.pe_ratio)) if pd.notna(row.get('pe_ratio')) else existing.pe_ratio existing.pe_median = float(row.get('pe_median', existing.pe_median)) if pd.notna(row.get('pe_median')) else existing.pe_median existing.is_new_high = bool(row.get('is_new_high', existing.is_new_high)) if pd.notna(row.get('is_new_high')) else existing.is_new_high existing.is_new_low = bool(row.get('is_new_low', existing.is_new_low)) if pd.notna(row.get('is_new_low')) else existing.is_new_low existing.updated_at = datetime.utcnow() else: trade = IndexTrade( index_code=str(row['index_code']), trade_date=trade_date, open=float(row['open']) if pd.notna(row['open']) else None, close=float(row['close']) if pd.notna(row['close']) else None, high=float(row['high']) if pd.notna(row['high']) else None, low=float(row['low']) if pd.notna(row['low']) else None, change_pct=float(row.get('change_pct')) if pd.notna(row.get('change_pct')) else None, volume=int(row.get('volume')) if pd.notna(row.get('volume')) else None, amount=float(row.get('amount')) if pd.notna(row.get('amount')) else None, total_market_value=float(row.get('total_market_value')) if pd.notna(row.get('total_market_value')) else None, float_market_value=float(row.get('float_market_value')) if pd.notna(row.get('float_market_value')) else None, up_count=int(row.get('up_count')) if pd.notna(row.get('up_count')) else None, down_count=int(row.get('down_count')) if pd.notna(row.get('down_count')) else None, flat_count=int(row.get('flat_count')) if pd.notna(row.get('flat_count')) else None, limit_up_count=int(row.get('limit_up_count')) if pd.notna(row.get('limit_up_count')) else None, limit_down_count=int(row.get('limit_down_count')) if pd.notna(row.get('limit_down_count')) else None, suspend_count=int(row.get('suspend_count')) if pd.notna(row.get('suspend_count')) else None, pe_ratio=float(row.get('pe_ratio')) if pd.notna(row.get('pe_ratio')) else None, pe_median=float(row.get('pe_median')) if pd.notna(row.get('pe_median')) else None, is_new_high=bool(row.get('is_new_high')) if pd.notna(row.get('is_new_high')) else False, is_new_low=bool(row.get('is_new_low')) if pd.notna(row.get('is_new_low')) else False ) db.add(trade) success_count += 1 except Exception as e: logger.error(f"导入指数交易{row.get('index_code')}-{row.get('trade_date')}失败: {str(e)}") error_count += 1 db.commit() return ResponseModel(data={ "success_count": success_count, "error_count": error_count, "total_count": len(df) }) except Exception as e: logger.error(f"导入指数交易数据失败: {str(e)}") raise HTTPException(status_code=500, detail=str(e))