You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

238 lines
8.8 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
数据库迁移脚本:为合约管理优化添加新字段和品种表
"""
import sys
import os
import re
from datetime import datetime
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.database import engine, Base, SessionLocal
from app.models import ProductInfo, ContractInfo
from sqlalchemy import text
# 品种分类映射(根据交易所和品种代码)
CATEGORY_MAP = {
# 能源化工
"rb": "能源化工", "hc": "能源化工", "fu": "能源化工", "bu": "能源化工",
"ru": "能源化工", "nr": "能源化工", "sp": "能源化工", "pg": "能源化工",
"pp": "能源化工", "l": "能源化工", "v": "能源化工", "eg": "能源化工",
"sc": "能源化工", "lu": "能源化工", "low": "能源化工",
# 金属
"cu": "金属", "al": "金属", "zn": "金属", "pb": "金属", "ni": "金属",
"sn": "金属", "au": "金属", "ag": "金属", "ss": "能源化工",
# 农产品
"a": "农产品", "b": "农产品", "c": "农产品", "cs": "农产品",
"m": "农产品", "y": "农产品", "p": "农产品", "rm": "农产品",
"oi": "农产品", "cf": "农产品", "sr": "农产品", "ta": "能源化工",
"ma": "能源化工", "fg": "能源化工", "sa": "农产品", "ur": "农产品",
"sf": "金属", "sm": "金属", "ap": "农产品", "cj": "农产品",
"rr": "农产品", "jd": "农产品", "lh": "农产品",
# 金融
"if": "金融", "ih": "金融", "ic": "金融", "im": "金融",
"t": "金融", "tf": "金融", "ts": "金融", "tl": "金融",
# 其他
"px": "能源化工", "pr": "能源化工", "br": "能源化工",
"lc": "金属", "si": "金属", "ec": "金融",
}
# 品种中文名映射
PRODUCT_NAME_MAP = {
# 能源化工
"rb": "螺纹钢", "hc": "热卷", "fu": "燃料油", "bu": "沥青",
"ru": "橡胶", "nr": "20号胶", "sp": "纸浆", "pg": "液化气",
"pp": "聚丙烯", "l": "塑料", "v": "PVC", "eg": "乙二醇",
"sc": "原油", "lu": "低硫燃料油", "low": "低硫燃料油",
"ta": "PTA", "ma": "甲醇", "fg": "玻璃", "ur": "尿素",
"sa": "纯碱", "px": "对二甲苯", "pr": "丙二醇", "br": "合成橡胶",
# 金属
"cu": "", "al": "", "zn": "", "pb": "", "ni": "",
"sn": "", "au": "黄金", "ag": "白银", "ss": "不锈钢",
"sf": "硅铁", "sm": "锰硅", "lc": "碳酸锂", "si": "工业硅",
# 农产品
"a": "豆一", "b": "豆二", "c": "玉米", "cs": "淀粉",
"m": "豆粕", "y": "豆油", "p": "棕榈油", "rm": "菜粕",
"oi": "菜油", "cf": "棉花", "sr": "白糖", "ap": "苹果",
"cj": "红枣", "rr": "粳米", "jd": "鸡蛋", "lh": "生猪",
# 金融
"if": "沪深300", "ih": "上证50", "ic": "中证500", "im": "中证1000",
"t": "10年国债", "tf": "5年国债", "ts": "2年国债", "tl": "30年国债", "ec": "工业硅",
# 能源
"sc": "原油", "lu": "低硫燃油",
}
# 交易所映射
EXCHANGE_MAP = {
"SHFE": "上海期货交易所",
"DCE": "大连商品交易所",
"CZCE": "郑州商品交易所",
"CFFEX": "中国金融期货交易所",
"INE": "上海国际能源交易中心",
"GFEX": "广州期货交易所",
}
def extract_year_month(symbol: str) -> tuple:
"""从合约代码解析交割年月
示例:
rb2401 -> (2024-01, 1)
cu2312 -> (2023-12, 12)
IF2403 -> (2024-03, 3)
"""
# 匹配末尾的数字(年份+月份)
match = re.search(r'(\d{2})(\d{2})$', symbol)
if not match:
return None, None
year_suffix, month = match.groups()
year_suffix = int(year_suffix)
month = int(month)
# 年份处理:假设是 20xx 年
year = 2000 + year_suffix if year_suffix < 100 else year_suffix
return f"{year}-{month:02d}", month
def extract_product_code(symbol: str) -> str:
"""从合约代码提取品种代码
示例:
rb2401 -> rb
cu2312 -> cu
IF2403 -> IF
"""
# 去掉末尾的数字
return re.sub(r'\d+$', '', symbol)
def migrate():
"""执行迁移"""
print("🚀 开始数据库迁移...")
db = SessionLocal()
try:
# 1. 创建新表(如果不存在)
print("📦 创建 product_info 表...")
ProductInfo.__table__.create(engine, checkfirst=True)
# 2. 添加新字段到 contract_infoSQLite 不支持 ALTER TABLE ADD COLUMN 的某些操作,需要检查)
print("🔧 检查 contract_info 表结构...")
# 检查字段是否存在
inspector_result = db.execute(text("PRAGMA table_info(contract_info)")).fetchall()
existing_columns = [row[1] for row in inspector_result]
new_columns = {
"year_month": "ALTER TABLE contract_info ADD COLUMN year_month VARCHAR(7)",
"delivery_month": "ALTER TABLE contract_info ADD COLUMN delivery_month INTEGER",
"is_main": "ALTER TABLE contract_info ADD COLUMN is_main BOOLEAN DEFAULT 0",
"listing_date": "ALTER TABLE contract_info ADD COLUMN listing_date DATETIME",
"volume": "ALTER TABLE contract_info ADD COLUMN volume BIGINT DEFAULT 0",
"open_interest": "ALTER TABLE contract_info ADD COLUMN open_interest BIGINT DEFAULT 0",
}
for col_name, sql in new_columns.items():
if col_name not in existing_columns:
print(f" 添加字段: {col_name}")
db.execute(text(sql))
else:
print(f" ✅ 字段已存在: {col_name}")
db.commit()
# 3. 数据迁移:填充衍生字段
print("📝 迁移现有合约数据...")
contracts = db.query(ContractInfo).all()
updated_count = 0
for contract in contracts:
# 解析 year_month 和 delivery_month
if not contract.year_month and contract.symbol:
year_month, delivery_month = extract_year_month(contract.symbol)
if year_month:
contract.year_month = year_month
contract.delivery_month = delivery_month
# 提取 product_code
if not contract.product and contract.symbol:
contract.product = extract_product_code(contract.symbol)
updated_count += 1
if updated_count > 0:
db.commit()
print(f" ✅ 更新 {updated_count} 条合约记录")
# 4. 生成品种元数据
print(" 生成品种元数据...")
products = db.query(
ContractInfo.product,
ContractInfo.exchange,
ContractInfo.multiplier,
ContractInfo.price_tick
).distinct().all()
product_count = 0
for prod_code, exchange, multiplier, price_tick in products:
if not prod_code:
continue
# 检查是否已存在
existing = db.query(ProductInfo).filter(
ProductInfo.product_code == prod_code
).first()
if existing:
continue
# 查找中文名
# 优先使用映射表
product_name = PRODUCT_NAME_MAP.get(prod_code.lower(), prod_code)
category = CATEGORY_MAP.get(prod_code.lower(), "其他")
product = ProductInfo(
product_code=prod_code,
product_name=product_name,
exchange=exchange,
multiplier=multiplier or 10,
price_tick=price_tick,
category=category,
is_active=True,
)
db.add(product)
db.commit() # 逐条提交避免批量冲突
product_count += 1
if product_count > 0:
print(f" ✅ 创建 {product_count} 个品种元数据")
# 5. 创建索引
print("📑 创建索引...")
try:
db.execute(text("CREATE INDEX IF NOT EXISTS idx_contract_year_month ON contract_info(year_month)"))
db.execute(text("CREATE INDEX IF NOT EXISTS idx_contract_is_main ON contract_info(is_main)"))
db.execute(text("CREATE INDEX IF NOT EXISTS idx_product_category ON product_info(category)"))
db.execute(text("CREATE INDEX IF NOT EXISTS idx_product_exchange ON product_info(exchange)"))
db.commit()
except Exception as e:
print(f" ⚠️ 索引创建警告: {e}")
print("\n✅ 迁移完成!")
except Exception as e:
db.rollback()
print(f"\n❌ 迁移失败: {e}")
import traceback
traceback.print_exc()
finally:
db.close()
if __name__ == "__main__":
migrate()