fix: 修复导入股票基本信息时,新增和更新记录数的统计问题

master
Lxy 1 month ago
parent 6521e10c3f
commit cbaefd4230

@ -41,6 +41,16 @@ INDEX_TRADE_COLUMN_MAP = {
'市盈率PE(TTM)中位值 [交易日期]最新 [剔除规则]不调整': 'pe_median'
}
STOCK_BASIC_COLUMN_MAP = {
'证券代码': 'code',
'证券名称': 'name',
'首发上市日': 'list_date',
'所属东财行业指数名称\n[行业类别]2级': 'industry_index_name',
'所属东财行业指数代码\n[行业类别]2级': 'industry_index_code',
'机构持股比例合计\n[报告期]最新一期\n[单位]%\n[比例类型]占总股本比例': 'institution_hold_ratio',
'所属东财行业名称\n[行业类别]3级': 'industry_level3'
}
@router.post("/index-data", response_model=ResponseModel)
async def import_index_data(
@ -231,70 +241,152 @@ async def import_stock_basic(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""导入股票基础数据"""
"""导入股票基础数据(支持模板格式)"""
if not file.filename.endswith(('.xls', '.xlsx')):
raise HTTPException(status_code=400, detail="只支持xls或xlsx格式文件")
try:
df = pd.read_excel(file.file)
required_columns = ['code', 'name', 'total_shares', 'float_shares',
'industry_index_name', 'industry_index_code',
'institution_hold_ratio', 'industry_level3', 'list_date']
df.columns = df.columns.str.strip()
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
raise HTTPException(status_code=400, detail=f"缺少必要列: {missing_columns}")
renamed_df = df.rename(columns=STOCK_BASIC_COLUMN_MAP)
if 'code' not in renamed_df.columns:
raise HTTPException(status_code=400, detail="缺少必要列:证券代码")
if 'name' not in renamed_df.columns:
raise HTTPException(status_code=400, detail="缺少必要列:证券名称")
success_count = 0
error_count = 0
added_count = 0
updated_count = 0
skipped_count = 0
skipped_details = []
error_details = []
for _, row in df.iterrows():
for _, row in renamed_df.iterrows():
try:
existing = db.query(StockBasic).filter(StockBasic.code == str(row['code'])).first()
code_val = row.get('code')
if pd.isna(code_val):
continue
code = str(code_val).strip()
if not code or code.lower() == 'nan':
continue
existing = db.query(StockBasic).filter(StockBasic.code == code).first()
list_date = None
if pd.notna(row['list_date']):
if isinstance(row['list_date'], datetime):
list_date = row['list_date'].date()
elif isinstance(row['list_date'], str):
list_date = datetime.strptime(row['list_date'], '%Y-%m-%d').date()
list_date_val = row.get('list_date')
if pd.notna(list_date_val):
if isinstance(list_date_val, datetime):
list_date = list_date_val.date()
elif isinstance(list_date_val, str):
try:
list_date = datetime.strptime(list_date_val, '%Y-%m-%d').date()
except:
pass
elif hasattr(list_date_val, 'date'):
list_date = list_date_val.date()
name = str(row.get('name', '')) if pd.notna(row.get('name')) else None
industry_index_name = str(row.get('industry_index_name', '')) if pd.notna(row.get('industry_index_name')) else None
industry_index_code = str(row.get('industry_index_code', '')) if pd.notna(row.get('industry_index_code')) else None
institution_hold_ratio = None
ratio_val = row.get('institution_hold_ratio')
if pd.notna(ratio_val):
try:
if str(ratio_val).strip() == '--':
institution_hold_ratio = 0.0
else:
institution_hold_ratio = float(ratio_val)
except:
institution_hold_ratio = 0.0
industry_level3 = str(row.get('industry_level3', '')) if pd.notna(row.get('industry_level3')) else None
if industry_index_code:
index_basic = db.query(IndexBasic).filter(IndexBasic.code == industry_index_code).first()
if not index_basic:
index_basic = IndexBasic(
code=industry_index_code,
name=industry_index_name or industry_index_code
)
db.add(index_basic)
db.flush()
if existing:
existing.name = str(row.get('name', existing.name))
existing.total_shares = int(row.get('total_shares', existing.total_shares)) if pd.notna(row.get('total_shares')) else existing.total_shares
existing.float_shares = int(row.get('float_shares', existing.float_shares)) if pd.notna(row.get('float_shares')) else existing.float_shares
existing.industry_index_name = str(row.get('industry_index_name', existing.industry_index_name)) if pd.notna(row.get('industry_index_name')) else existing.industry_index_name
existing.industry_index_code = str(row.get('industry_index_code', existing.industry_index_code)) if pd.notna(row.get('industry_index_code')) else existing.industry_index_code
existing.institution_hold_ratio = float(row.get('institution_hold_ratio', existing.institution_hold_ratio)) if pd.notna(row.get('institution_hold_ratio')) else existing.institution_hold_ratio
existing.industry_level3 = str(row.get('industry_level3', existing.industry_level3)) if pd.notna(row.get('industry_level3')) else existing.industry_level3
def is_same_data():
def compare_ratio():
if existing.institution_hold_ratio is None and institution_hold_ratio is None:
return True
if existing.institution_hold_ratio is None or institution_hold_ratio is None:
return False
return abs(float(existing.institution_hold_ratio) - institution_hold_ratio) < 0.0001
return (
(existing.name == name or (existing.name is None and name is None)) and
(existing.industry_index_name == industry_index_name or (existing.industry_index_name is None and industry_index_name is None)) and
(existing.industry_index_code == industry_index_code or (existing.industry_index_code is None and industry_index_code is None)) and
compare_ratio() and
(existing.industry_level3 == industry_level3 or (existing.industry_level3 is None and industry_level3 is None)) and
(existing.list_date == list_date or (existing.list_date is None and list_date is None))
)
if is_same_data():
skipped_count += 1
skipped_details.append({
"code": code,
"name": name,
"reason": "数据相同,无需更新"
})
else:
existing.name = name
existing.industry_index_name = industry_index_name
existing.industry_index_code = industry_index_code
existing.institution_hold_ratio = institution_hold_ratio
existing.industry_level3 = industry_level3
existing.list_date = list_date
existing.updated_at = datetime.utcnow()
updated_count += 1
else:
stock = StockBasic(
code=str(row['code']),
name=str(row.get('name', '')),
total_shares=int(row['total_shares']) if pd.notna(row['total_shares']) else None,
float_shares=int(row['float_shares']) if pd.notna(row['float_shares']) else None,
industry_index_name=str(row.get('industry_index_name', '')) if pd.notna(row.get('industry_index_name')) else None,
industry_index_code=str(row.get('industry_index_code', '')) if pd.notna(row.get('industry_index_code')) else None,
institution_hold_ratio=float(row['institution_hold_ratio']) if pd.notna(row['institution_hold_ratio']) else None,
industry_level3=str(row.get('industry_level3', '')) if pd.notna(row.get('industry_level3')) else None,
code=code,
name=name,
total_shares=None,
float_shares=None,
industry_index_name=industry_index_name,
industry_index_code=industry_index_code,
institution_hold_ratio=institution_hold_ratio,
industry_level3=industry_level3,
list_date=list_date
)
db.add(stock)
added_count += 1
success_count += 1
except Exception as e:
logger.error(f"导入股票{row.get('code')}失败: {str(e)}")
error_count += 1
error_details.append({
"code": str(row.get('code', '')),
"name": str(row.get('name', '')) if pd.notna(row.get('name')) else '',
"reason": str(e)
})
db.commit()
return ResponseModel(data={
"success_count": success_count,
"error_count": error_count,
"total_count": len(df)
"total_count": len(df),
"added_count": added_count,
"updated_count": updated_count,
"skipped_count": skipped_count,
"skipped_details": skipped_details[:100],
"error_details": error_details[:100]
})
except Exception as e:

@ -7,6 +7,7 @@ from app.models.realtime import RealtimeSnapshot
from app.models.finance import FinanceBalanceSheet, FinanceCashFlow, FinanceIncome
from app.models.cache import CacheTask, CacheTaskDetail
from app.models.test import APITestLog
from app.models.stock_basic import StockBasic, IndexBasic, IndexTrade
__all__ = [
"User",
@ -25,4 +26,7 @@ __all__ = [
"CacheTask",
"CacheTaskDetail",
"APITestLog",
"StockBasic",
"IndexBasic",
"IndexTrade",
]

@ -1,96 +1,18 @@
"""
创建股票基础数据相关表
创建股票基础数据相关表使用 SQLAlchemy ORM兼容 SQLite PostgreSQL
"""
from sqlalchemy import text
from app.db.session import SessionLocal
from app.db.session import engine, SessionLocal
from app.db.base import Base
from app.models.stock_basic import StockBasic, IndexBasic, IndexTrade
db = SessionLocal()
def create_tables():
"""创建所有股票基础数据相关表"""
print("开始创建数据库表...")
try:
# 创建股票基础数据表
db.execute(text("""
CREATE TABLE IF NOT EXISTS stock_basic (
id BIGSERIAL PRIMARY KEY,
code VARCHAR(20) UNIQUE NOT NULL,
name VARCHAR(50),
total_shares BIGINT,
float_shares BIGINT,
industry_index_name VARCHAR(100),
industry_index_code VARCHAR(20),
institution_hold_ratio DECIMAL(10, 4),
industry_level3 VARCHAR(100),
list_date DATE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""))
Base.metadata.create_all(bind=engine)
# 创建指数基础表
db.execute(text("""
CREATE TABLE IF NOT EXISTS index_basic (
id BIGSERIAL PRIMARY KEY,
code VARCHAR(20) UNIQUE NOT NULL,
name VARCHAR(100),
component_count INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""))
print("数据库表创建完成")
print(f"创建的表: stock_basic, index_basic, index_trade")
# 创建指数交易表
db.execute(text("""
CREATE TABLE IF NOT EXISTS index_trade (
id BIGSERIAL PRIMARY KEY,
index_code VARCHAR(20) NOT NULL,
trade_date DATE NOT NULL,
open DECIMAL(10, 3),
close DECIMAL(10, 3),
high DECIMAL(10, 3),
low DECIMAL(10, 3),
change_pct DECIMAL(10, 4),
volume BIGINT,
amount DECIMAL(18, 2),
total_market_value DECIMAL(18, 2),
float_market_value DECIMAL(18, 2),
up_count INTEGER,
down_count INTEGER,
flat_count INTEGER,
limit_up_count INTEGER,
limit_down_count INTEGER,
suspend_count INTEGER,
pe_ratio DECIMAL(10, 4),
pe_median DECIMAL(10, 4),
is_new_high BOOLEAN DEFAULT FALSE,
is_new_low BOOLEAN DEFAULT FALSE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE(index_code, trade_date)
)
"""))
# 创建索引
db.execute(text("CREATE INDEX IF NOT EXISTS idx_stock_basic_code ON stock_basic(code)"))
db.execute(text("CREATE INDEX IF NOT EXISTS idx_index_basic_code ON index_basic(code)"))
db.execute(text("CREATE INDEX IF NOT EXISTS idx_index_trade_code ON index_trade(index_code)"))
db.execute(text("CREATE INDEX IF NOT EXISTS idx_index_trade_date ON index_trade(trade_date)"))
# 添加外键约束
db.execute(text("""
ALTER TABLE stock_basic
ADD CONSTRAINT fk_stock_basic_index_code
FOREIGN KEY (industry_index_code) REFERENCES index_basic(code)
"""))
db.execute(text("""
ALTER TABLE index_trade
ADD CONSTRAINT fk_index_trade_index_code
FOREIGN KEY (index_code) REFERENCES index_basic(code)
"""))
db.commit()
print("表创建成功")
except Exception as e:
print(f"创建表失败: {str(e)}")
db.rollback()
finally:
db.close()
if __name__ == "__main__":
create_tables()

@ -58,9 +58,9 @@
<el-form label-width="120px">
<el-form-item label="文件格式说明">
<el-text>
必须包含以下列code, name, total_shares, float_shares,
industry_index_name, industry_index_code, institution_hold_ratio,
industry_level3, list_date
支持模板格式stock_info_template.xlsx自动适配以下列<br/>
证券代码证券名称首发上市日所属东财行业指数名称[行业类别]2<br/>
所属东财行业指数代码[行业类别]2机构持股比例合计所属东财行业名称[行业类别]3
</el-text>
</el-form-item>
<el-form-item label="选择文件">
@ -83,7 +83,18 @@
</el-button>
</el-form-item>
</el-form>
<el-alert v-if="stockBasicResult" :title="stockBasicResult.title" :type="stockBasicResult.type" show-icon />
<div v-if="stockBasicResult" class="result-area">
<el-alert :title="stockBasicResult.title" :type="stockBasicResult.type" show-icon />
<el-button
v-if="stockBasicResult.skipped_count > 0 || stockBasicResult.error_count > 0"
type="primary"
link
@click="showStockBasicDetail"
style="margin-left: 10px;"
>
查看详情
</el-button>
</div>
</el-tab-pane>
<el-tab-pane label="指数基础数据" name="indexBasic">
@ -150,6 +161,27 @@
</el-tab-pane>
</el-tabs>
</el-card>
<el-dialog v-model="detailDialogVisible" title="导入详情" width="800px">
<el-tabs v-model="detailTab">
<el-tab-pane label="跳过数据" name="skipped">
<el-table :data="stockBasicDetailData.skipped_details" border stripe max-height="400">
<el-table-column prop="code" label="股票代码" width="120" />
<el-table-column prop="name" label="股票名称" width="150" />
<el-table-column prop="reason" label="跳过原因" />
</el-table>
<div class="detail-count"> {{ stockBasicDetailData.skipped_details?.length || 0 }} 条跳过数据</div>
</el-tab-pane>
<el-tab-pane label="失败数据" name="error">
<el-table :data="stockBasicDetailData.error_details" border stripe max-height="400">
<el-table-column prop="code" label="股票代码" width="120" />
<el-table-column prop="name" label="股票名称" width="150" />
<el-table-column prop="reason" label="失败原因" />
</el-table>
<div class="detail-count"> {{ stockBasicDetailData.error_details?.length || 0 }} 条失败数据</div>
</el-tab-pane>
</el-tabs>
</el-dialog>
</div>
</template>
@ -175,6 +207,10 @@ const indexBasicResult = ref<any>(null)
const indexTradeResult = ref<any>(null)
const indexDataResult = ref<any>(null)
const stockBasicDetailData = ref<any>({})
const detailDialogVisible = ref(false)
const detailTab = ref('skipped')
const handleStockBasicChange = (file: any) => {
stockBasicFile.value = file.raw
}
@ -238,15 +274,20 @@ const handleImportStockBasic = async () => {
const res: any = await importStockBasic(stockBasicFile.value)
if (res.data) {
stockBasicResult.value = {
title: `导入完成:成功${res.data.success_count}条,失败${res.data.error_count}条,共${res.data.total_count}`,
type: res.data.error_count > 0 ? 'warning' : 'success'
title: `导入完成:新增${res.data.added_count}条,更新${res.data.updated_count}条,跳过${res.data.skipped_count}条,失败${res.data.error_count}`,
type: res.data.error_count > 0 ? 'warning' : 'success',
skipped_count: res.data.skipped_count,
error_count: res.data.error_count
}
stockBasicDetailData.value = res.data
ElMessage.success('导入完成')
}
} catch (error: any) {
stockBasicResult.value = {
title: `导入失败:${error.response?.data?.detail || error.message}`,
type: 'error'
type: 'error',
skipped_count: 0,
error_count: 0
}
ElMessage.error('导入失败')
} finally {
@ -254,6 +295,11 @@ const handleImportStockBasic = async () => {
}
}
const showStockBasicDetail = () => {
detailTab.value = stockBasicDetailData.value.error_count > 0 ? 'error' : 'skipped'
detailDialogVisible.value = true
}
const handleImportIndexBasic = async () => {
if (!indexBasicFile.value) {
ElMessage.warning('请先选择文件')
@ -317,4 +363,14 @@ const handleImportIndexTrade = async () => {
.data-import {
padding: 20px;
}
.result-area {
display: flex;
align-items: center;
margin-top: 10px;
}
.detail-count {
margin-top: 10px;
color: #666;
font-size: 14px;
}
</style>

Binary file not shown.

Binary file not shown.
Loading…
Cancel
Save