You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

549 lines
27 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
数据导入路由
"""
import pandas as pd
import logging
from fastapi import APIRouter, Depends, UploadFile, File, HTTPException
from sqlalchemy.orm import Session
from datetime import datetime, date
from app.db.session import get_db
from app.schemas.base import ResponseModel
from app.models.stock_basic import StockBasic, IndexBasic, IndexTrade
from app.core.security import get_current_user
from app.models.user import User
router = APIRouter()
logger = logging.getLogger(__name__)
INDEX_TRADE_COLUMN_MAP = {
'证券代码': 'index_code',
'证券名称': 'name',
'成分个数 [交易日期]最新': 'component_count',
'开盘价 [交易日期]最新': 'open',
'收盘价 [交易日期]最新': 'close',
'成交量 [交易日期]最新 [单位]股': 'volume',
'成交额 [交易日期]最新 [单位]百万元': 'amount',
'总市值 [截止日期]最新 [单位]百万元': 'total_market_value',
'自由流通市值 [交易日期]最新 [单位]百万元': 'float_market_value',
'涨跌幅 [交易日期]最新 [单位]%': 'change_pct',
'最高价 [交易日期]最新': 'high',
'最低价 [交易日期]最新': 'low',
'上涨家数 [交易日期]最新': 'up_count',
'下跌家数 [交易日期]最新': 'down_count',
'平盘家数 [交易日期]最新': 'flat_count',
'涨停家数 [交易日期]最新': 'limit_up_count',
'跌停家数 [交易日期]最新': 'limit_down_count',
'停牌家数 [交易日期]最新': 'suspend_count',
'近期创历史新高 [交易日期]最新 [近N日内]300 [复权方式]不复权': 'is_new_high',
'近期创历史新低 [交易日期]最新 [近N日内]300 [复权方式]不复权': 'is_new_low',
'市盈率PE(TTM) [交易日期]最新 [剔除规则]不调整': 'pe_ratio',
'市盈率PE(TTM)中位值 [交易日期]最新 [剔除规则]不调整': 'pe_median'
}
STOCK_BASIC_COLUMN_MAP = {
'证券代码': 'code',
'证券名称': 'name',
'首发上市日': 'list_date',
'所属东财行业指数名称\n[行业类别]2级': 'industry_index_name',
'所属东财行业指数代码\n[行业类别]2级': 'industry_index_code',
'机构持股比例合计\n[报告期]最新一期\n[单位]%\n[比例类型]占总股本比例': 'institution_hold_ratio',
'所属东财行业名称\n[行业类别]3级': 'industry_level3'
}
@router.post("/index-data", response_model=ResponseModel)
async def import_index_data(
file: UploadFile = File(...),
trade_date: str = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""导入指数数据(同时更新指数基础表和指数交易表)"""
if not file.filename.endswith(('.xls', '.xlsx')):
raise HTTPException(status_code=400, detail="只支持xls或xlsx格式文件")
if not trade_date:
raise HTTPException(status_code=400, detail="请提供交易日期参数(YYYY-MM-DD格式)")
try:
trade_date_obj = datetime.strptime(trade_date, '%Y-%m-%d').date()
except:
raise HTTPException(status_code=400, detail="交易日期格式错误请使用YYYY-MM-DD格式")
try:
df = pd.read_excel(file.file)
df.columns = df.columns.str.strip()
renamed_df = df.rename(columns=INDEX_TRADE_COLUMN_MAP)
if 'index_code' not in renamed_df.columns:
raise HTTPException(status_code=400, detail="缺少必要列:证券代码")
success_count = 0
error_count = 0
index_basic_updated = 0
index_basic_added = 0
for _, row in renamed_df.iterrows():
try:
index_code = str(row['index_code']).strip()
if not index_code:
continue
name = str(row.get('name', '')) if pd.notna(row.get('name')) else None
component_count = int(row.get('component_count')) if pd.notna(row.get('component_count')) else None
index_basic = db.query(IndexBasic).filter(IndexBasic.code == index_code).first()
if index_basic:
if component_count and index_basic.component_count != component_count:
index_basic.component_count = component_count
index_basic.name = name or index_basic.name
index_basic.updated_at = datetime.utcnow()
index_basic_updated += 1
else:
index_basic = IndexBasic(
code=index_code,
name=name,
component_count=component_count
)
db.add(index_basic)
db.flush()
index_basic_added += 1
existing_trade = db.query(IndexTrade).filter(
IndexTrade.index_code == index_code,
IndexTrade.trade_date == trade_date_obj
).first()
def get_float_val(col_name):
val = row.get(col_name)
if pd.notna(val):
try:
return float(val)
except:
return None
return None
def get_int_val(col_name):
val = row.get(col_name)
if pd.notna(val):
try:
return int(float(val))
except:
return None
return None
def get_bool_val(col_name):
val = row.get(col_name)
if pd.notna(val):
if isinstance(val, bool):
return val
if isinstance(val, str):
return val.lower() in ['true', '1', 'yes', '']
return bool(val)
return False
open_price = get_float_val('open')
close_price = get_float_val('close')
high_price = get_float_val('high')
low_price = get_float_val('low')
change_pct = get_float_val('change_pct')
volume = get_int_val('volume')
amount = get_float_val('amount')
total_market_value = get_float_val('total_market_value')
float_market_value = get_float_val('float_market_value')
up_count = get_int_val('up_count')
down_count = get_int_val('down_count')
flat_count = get_int_val('flat_count')
limit_up_count = get_int_val('limit_up_count')
limit_down_count = get_int_val('limit_down_count')
suspend_count = get_int_val('suspend_count')
pe_ratio = get_float_val('pe_ratio')
pe_median = get_float_val('pe_median')
is_new_high = get_bool_val('is_new_high')
is_new_low = get_bool_val('is_new_low')
if existing_trade:
existing_trade.open = open_price
existing_trade.close = close_price
existing_trade.high = high_price
existing_trade.low = low_price
existing_trade.change_pct = change_pct
existing_trade.volume = volume
existing_trade.amount = amount
existing_trade.total_market_value = total_market_value
existing_trade.float_market_value = float_market_value
existing_trade.up_count = up_count
existing_trade.down_count = down_count
existing_trade.flat_count = flat_count
existing_trade.limit_up_count = limit_up_count
existing_trade.limit_down_count = limit_down_count
existing_trade.suspend_count = suspend_count
existing_trade.pe_ratio = pe_ratio
existing_trade.pe_median = pe_median
existing_trade.is_new_high = is_new_high
existing_trade.is_new_low = is_new_low
existing_trade.updated_at = datetime.utcnow()
else:
trade = IndexTrade(
index_code=index_code,
trade_date=trade_date_obj,
open=open_price,
close=close_price,
high=high_price,
low=low_price,
change_pct=change_pct,
volume=volume,
amount=amount,
total_market_value=total_market_value,
float_market_value=float_market_value,
up_count=up_count,
down_count=down_count,
flat_count=flat_count,
limit_up_count=limit_up_count,
limit_down_count=limit_down_count,
suspend_count=suspend_count,
pe_ratio=pe_ratio,
pe_median=pe_median,
is_new_high=is_new_high,
is_new_low=is_new_low
)
db.add(trade)
success_count += 1
except Exception as e:
logger.error(f"导入指数{row.get('index_code')}失败: {str(e)}")
error_count += 1
db.commit()
return ResponseModel(data={
"success_count": success_count,
"error_count": error_count,
"total_count": len(df),
"index_basic_added": index_basic_added,
"index_basic_updated": index_basic_updated,
"trade_date": trade_date
})
except Exception as e:
logger.error(f"导入指数数据失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/stock-basic", response_model=ResponseModel)
async def import_stock_basic(
file: UploadFile = File(...),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""导入股票基础数据(支持模板格式)"""
if not file.filename.endswith(('.xls', '.xlsx')):
raise HTTPException(status_code=400, detail="只支持xls或xlsx格式文件")
try:
df = pd.read_excel(file.file)
df.columns = df.columns.str.strip()
renamed_df = df.rename(columns=STOCK_BASIC_COLUMN_MAP)
if 'code' not in renamed_df.columns:
raise HTTPException(status_code=400, detail="缺少必要列:证券代码")
if 'name' not in renamed_df.columns:
raise HTTPException(status_code=400, detail="缺少必要列:证券名称")
success_count = 0
error_count = 0
added_count = 0
updated_count = 0
skipped_count = 0
skipped_details = []
error_details = []
for _, row in renamed_df.iterrows():
try:
code_val = row.get('code')
if pd.isna(code_val):
continue
code = str(code_val).strip()
if not code or code.lower() == 'nan':
continue
existing = db.query(StockBasic).filter(StockBasic.code == code).first()
list_date = None
list_date_val = row.get('list_date')
if pd.notna(list_date_val):
if isinstance(list_date_val, datetime):
list_date = list_date_val.date()
elif isinstance(list_date_val, str):
try:
list_date = datetime.strptime(list_date_val, '%Y-%m-%d').date()
except:
pass
elif hasattr(list_date_val, 'date'):
list_date = list_date_val.date()
name = str(row.get('name', '')) if pd.notna(row.get('name')) else None
industry_index_name = str(row.get('industry_index_name', '')) if pd.notna(row.get('industry_index_name')) else None
industry_index_code = str(row.get('industry_index_code', '')) if pd.notna(row.get('industry_index_code')) else None
institution_hold_ratio = None
ratio_val = row.get('institution_hold_ratio')
if pd.notna(ratio_val):
try:
if str(ratio_val).strip() == '--':
institution_hold_ratio = 0.0
else:
institution_hold_ratio = float(ratio_val)
except:
institution_hold_ratio = 0.0
industry_level3 = str(row.get('industry_level3', '')) if pd.notna(row.get('industry_level3')) else None
if industry_index_code:
index_basic = db.query(IndexBasic).filter(IndexBasic.code == industry_index_code).first()
if not index_basic:
index_basic = IndexBasic(
code=industry_index_code,
name=industry_index_name or industry_index_code
)
db.add(index_basic)
db.flush()
if existing:
def is_same_data():
def compare_ratio():
if existing.institution_hold_ratio is None and institution_hold_ratio is None:
return True
if existing.institution_hold_ratio is None or institution_hold_ratio is None:
return False
return abs(float(existing.institution_hold_ratio) - institution_hold_ratio) < 0.0001
return (
(existing.name == name or (existing.name is None and name is None)) and
(existing.industry_index_name == industry_index_name or (existing.industry_index_name is None and industry_index_name is None)) and
(existing.industry_index_code == industry_index_code or (existing.industry_index_code is None and industry_index_code is None)) and
compare_ratio() and
(existing.industry_level3 == industry_level3 or (existing.industry_level3 is None and industry_level3 is None)) and
(existing.list_date == list_date or (existing.list_date is None and list_date is None))
)
if is_same_data():
skipped_count += 1
skipped_details.append({
"code": code,
"name": name,
"reason": "数据相同,无需更新"
})
else:
existing.name = name
existing.industry_index_name = industry_index_name
existing.industry_index_code = industry_index_code
existing.institution_hold_ratio = institution_hold_ratio
existing.industry_level3 = industry_level3
existing.list_date = list_date
existing.updated_at = datetime.utcnow()
updated_count += 1
else:
stock = StockBasic(
code=code,
name=name,
total_shares=None,
float_shares=None,
industry_index_name=industry_index_name,
industry_index_code=industry_index_code,
institution_hold_ratio=institution_hold_ratio,
industry_level3=industry_level3,
list_date=list_date
)
db.add(stock)
added_count += 1
success_count += 1
except Exception as e:
logger.error(f"导入股票{row.get('code')}失败: {str(e)}")
error_count += 1
error_details.append({
"code": str(row.get('code', '')),
"name": str(row.get('name', '')) if pd.notna(row.get('name')) else '',
"reason": str(e)
})
db.commit()
return ResponseModel(data={
"success_count": success_count,
"error_count": error_count,
"total_count": len(df),
"added_count": added_count,
"updated_count": updated_count,
"skipped_count": skipped_count,
"skipped_details": skipped_details[:100],
"error_details": error_details[:100]
})
except Exception as e:
logger.error(f"导入股票基础数据失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/index-basic", response_model=ResponseModel)
async def import_index_basic(
file: UploadFile = File(...),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""导入指数基础数据"""
if not file.filename.endswith(('.xls', '.xlsx')):
raise HTTPException(status_code=400, detail="只支持xls或xlsx格式文件")
try:
df = pd.read_excel(file.file)
required_columns = ['code', 'name', 'component_count']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
raise HTTPException(status_code=400, detail=f"缺少必要列: {missing_columns}")
success_count = 0
error_count = 0
for _, row in df.iterrows():
try:
existing = db.query(IndexBasic).filter(IndexBasic.code == str(row['code'])).first()
if existing:
existing.name = str(row.get('name', existing.name))
existing.component_count = int(row.get('component_count', existing.component_count)) if pd.notna(row.get('component_count')) else existing.component_count
existing.updated_at = datetime.utcnow()
else:
index = IndexBasic(
code=str(row['code']),
name=str(row.get('name', '')),
component_count=int(row['component_count']) if pd.notna(row['component_count']) else None
)
db.add(index)
success_count += 1
except Exception as e:
logger.error(f"导入指数{row.get('code')}失败: {str(e)}")
error_count += 1
db.commit()
return ResponseModel(data={
"success_count": success_count,
"error_count": error_count,
"total_count": len(df)
})
except Exception as e:
logger.error(f"导入指数基础数据失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/index-trade", response_model=ResponseModel)
async def import_index_trade(
file: UploadFile = File(...),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""导入指数交易数据"""
if not file.filename.endswith(('.xls', '.xlsx')):
raise HTTPException(status_code=400, detail="只支持xls或xlsx格式文件")
try:
df = pd.read_excel(file.file)
required_columns = ['index_code', 'trade_date', 'open', 'close', 'high', 'low']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
raise HTTPException(status_code=400, detail=f"缺少必要列: {missing_columns}")
success_count = 0
error_count = 0
for _, row in df.iterrows():
try:
trade_date = None
if pd.notna(row['trade_date']):
if isinstance(row['trade_date'], datetime):
trade_date = row['trade_date'].date()
elif isinstance(row['trade_date'], str):
trade_date = datetime.strptime(row['trade_date'], '%Y-%m-%d').date()
existing = db.query(IndexTrade).filter(
IndexTrade.index_code == str(row['index_code']),
IndexTrade.trade_date == trade_date
).first()
if existing:
existing.open = float(row.get('open', existing.open)) if pd.notna(row.get('open')) else existing.open
existing.close = float(row.get('close', existing.close)) if pd.notna(row.get('close')) else existing.close
existing.high = float(row.get('high', existing.high)) if pd.notna(row.get('high')) else existing.high
existing.low = float(row.get('low', existing.low)) if pd.notna(row.get('low')) else existing.low
existing.change_pct = float(row.get('change_pct', existing.change_pct)) if pd.notna(row.get('change_pct')) else existing.change_pct
existing.volume = int(row.get('volume', existing.volume)) if pd.notna(row.get('volume')) else existing.volume
existing.amount = float(row.get('amount', existing.amount)) if pd.notna(row.get('amount')) else existing.amount
existing.total_market_value = float(row.get('total_market_value', existing.total_market_value)) if pd.notna(row.get('total_market_value')) else existing.total_market_value
existing.float_market_value = float(row.get('float_market_value', existing.float_market_value)) if pd.notna(row.get('float_market_value')) else existing.float_market_value
existing.up_count = int(row.get('up_count', existing.up_count)) if pd.notna(row.get('up_count')) else existing.up_count
existing.down_count = int(row.get('down_count', existing.down_count)) if pd.notna(row.get('down_count')) else existing.down_count
existing.flat_count = int(row.get('flat_count', existing.flat_count)) if pd.notna(row.get('flat_count')) else existing.flat_count
existing.limit_up_count = int(row.get('limit_up_count', existing.limit_up_count)) if pd.notna(row.get('limit_up_count')) else existing.limit_up_count
existing.limit_down_count = int(row.get('limit_down_count', existing.limit_down_count)) if pd.notna(row.get('limit_down_count')) else existing.limit_down_count
existing.suspend_count = int(row.get('suspend_count', existing.suspend_count)) if pd.notna(row.get('suspend_count')) else existing.suspend_count
existing.pe_ratio = float(row.get('pe_ratio', existing.pe_ratio)) if pd.notna(row.get('pe_ratio')) else existing.pe_ratio
existing.pe_median = float(row.get('pe_median', existing.pe_median)) if pd.notna(row.get('pe_median')) else existing.pe_median
existing.is_new_high = bool(row.get('is_new_high', existing.is_new_high)) if pd.notna(row.get('is_new_high')) else existing.is_new_high
existing.is_new_low = bool(row.get('is_new_low', existing.is_new_low)) if pd.notna(row.get('is_new_low')) else existing.is_new_low
existing.updated_at = datetime.utcnow()
else:
trade = IndexTrade(
index_code=str(row['index_code']),
trade_date=trade_date,
open=float(row['open']) if pd.notna(row['open']) else None,
close=float(row['close']) if pd.notna(row['close']) else None,
high=float(row['high']) if pd.notna(row['high']) else None,
low=float(row['low']) if pd.notna(row['low']) else None,
change_pct=float(row.get('change_pct')) if pd.notna(row.get('change_pct')) else None,
volume=int(row.get('volume')) if pd.notna(row.get('volume')) else None,
amount=float(row.get('amount')) if pd.notna(row.get('amount')) else None,
total_market_value=float(row.get('total_market_value')) if pd.notna(row.get('total_market_value')) else None,
float_market_value=float(row.get('float_market_value')) if pd.notna(row.get('float_market_value')) else None,
up_count=int(row.get('up_count')) if pd.notna(row.get('up_count')) else None,
down_count=int(row.get('down_count')) if pd.notna(row.get('down_count')) else None,
flat_count=int(row.get('flat_count')) if pd.notna(row.get('flat_count')) else None,
limit_up_count=int(row.get('limit_up_count')) if pd.notna(row.get('limit_up_count')) else None,
limit_down_count=int(row.get('limit_down_count')) if pd.notna(row.get('limit_down_count')) else None,
suspend_count=int(row.get('suspend_count')) if pd.notna(row.get('suspend_count')) else None,
pe_ratio=float(row.get('pe_ratio')) if pd.notna(row.get('pe_ratio')) else None,
pe_median=float(row.get('pe_median')) if pd.notna(row.get('pe_median')) else None,
is_new_high=bool(row.get('is_new_high')) if pd.notna(row.get('is_new_high')) else False,
is_new_low=bool(row.get('is_new_low')) if pd.notna(row.get('is_new_low')) else False
)
db.add(trade)
success_count += 1
except Exception as e:
logger.error(f"导入指数交易{row.get('index_code')}-{row.get('trade_date')}失败: {str(e)}")
error_count += 1
db.commit()
return ResponseModel(data={
"success_count": success_count,
"error_count": error_count,
"total_count": len(df)
})
except Exception as e:
logger.error(f"导入指数交易数据失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))