""" 数据库迁移脚本:为合约管理优化添加新字段和品种表 """ import sys import os import re from datetime import datetime sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from app.database import engine, Base, SessionLocal from app.models import ProductInfo, ContractInfo from sqlalchemy import text # 品种分类映射(根据交易所和品种代码) CATEGORY_MAP = { # 能源化工 "rb": "能源化工", "hc": "能源化工", "fu": "能源化工", "bu": "能源化工", "ru": "能源化工", "nr": "能源化工", "sp": "能源化工", "pg": "能源化工", "pp": "能源化工", "l": "能源化工", "v": "能源化工", "eg": "能源化工", "sc": "能源化工", "lu": "能源化工", "low": "能源化工", # 金属 "cu": "金属", "al": "金属", "zn": "金属", "pb": "金属", "ni": "金属", "sn": "金属", "au": "金属", "ag": "金属", "ss": "能源化工", # 农产品 "a": "农产品", "b": "农产品", "c": "农产品", "cs": "农产品", "m": "农产品", "y": "农产品", "p": "农产品", "rm": "农产品", "oi": "农产品", "cf": "农产品", "sr": "农产品", "ta": "能源化工", "ma": "能源化工", "fg": "能源化工", "sa": "农产品", "ur": "农产品", "sf": "金属", "sm": "金属", "ap": "农产品", "cj": "农产品", "rr": "农产品", "jd": "农产品", "lh": "农产品", # 金融 "if": "金融", "ih": "金融", "ic": "金融", "im": "金融", "t": "金融", "tf": "金融", "ts": "金融", "tl": "金融", # 其他 "px": "能源化工", "pr": "能源化工", "br": "能源化工", "lc": "金属", "si": "金属", "ec": "金融", } # 品种中文名映射 PRODUCT_NAME_MAP = { # 能源化工 "rb": "螺纹钢", "hc": "热卷", "fu": "燃料油", "bu": "沥青", "ru": "橡胶", "nr": "20号胶", "sp": "纸浆", "pg": "液化气", "pp": "聚丙烯", "l": "塑料", "v": "PVC", "eg": "乙二醇", "sc": "原油", "lu": "低硫燃料油", "low": "低硫燃料油", "ta": "PTA", "ma": "甲醇", "fg": "玻璃", "ur": "尿素", "sa": "纯碱", "px": "对二甲苯", "pr": "丙二醇", "br": "合成橡胶", # 金属 "cu": "铜", "al": "铝", "zn": "锌", "pb": "铅", "ni": "镍", "sn": "锡", "au": "黄金", "ag": "白银", "ss": "不锈钢", "sf": "硅铁", "sm": "锰硅", "lc": "碳酸锂", "si": "工业硅", # 农产品 "a": "豆一", "b": "豆二", "c": "玉米", "cs": "淀粉", "m": "豆粕", "y": "豆油", "p": "棕榈油", "rm": "菜粕", "oi": "菜油", "cf": "棉花", "sr": "白糖", "ap": "苹果", "cj": "红枣", "rr": "粳米", "jd": "鸡蛋", "lh": "生猪", # 金融 "if": "沪深300", "ih": "上证50", "ic": "中证500", "im": "中证1000", "t": "10年国债", "tf": "5年国债", "ts": "2年国债", "tl": "30年国债", "ec": "工业硅", # 能源 "sc": "原油", "lu": "低硫燃油", } # 交易所映射 EXCHANGE_MAP = { "SHFE": "上海期货交易所", "DCE": "大连商品交易所", "CZCE": "郑州商品交易所", "CFFEX": "中国金融期货交易所", "INE": "上海国际能源交易中心", "GFEX": "广州期货交易所", } def extract_year_month(symbol: str) -> tuple: """从合约代码解析交割年月 示例: rb2401 -> (2024-01, 1) cu2312 -> (2023-12, 12) IF2403 -> (2024-03, 3) """ # 匹配末尾的数字(年份+月份) match = re.search(r'(\d{2})(\d{2})$', symbol) if not match: return None, None year_suffix, month = match.groups() year_suffix = int(year_suffix) month = int(month) # 年份处理:假设是 20xx 年 year = 2000 + year_suffix if year_suffix < 100 else year_suffix return f"{year}-{month:02d}", month def extract_product_code(symbol: str) -> str: """从合约代码提取品种代码 示例: rb2401 -> rb cu2312 -> cu IF2403 -> IF """ # 去掉末尾的数字 return re.sub(r'\d+$', '', symbol) def migrate(): """执行迁移""" print("🚀 开始数据库迁移...") db = SessionLocal() try: # 1. 创建新表(如果不存在) print("📦 创建 product_info 表...") ProductInfo.__table__.create(engine, checkfirst=True) # 2. 添加新字段到 contract_info(SQLite 不支持 ALTER TABLE ADD COLUMN 的某些操作,需要检查) print("🔧 检查 contract_info 表结构...") # 检查字段是否存在 inspector_result = db.execute(text("PRAGMA table_info(contract_info)")).fetchall() existing_columns = [row[1] for row in inspector_result] new_columns = { "year_month": "ALTER TABLE contract_info ADD COLUMN year_month VARCHAR(7)", "delivery_month": "ALTER TABLE contract_info ADD COLUMN delivery_month INTEGER", "is_main": "ALTER TABLE contract_info ADD COLUMN is_main BOOLEAN DEFAULT 0", "listing_date": "ALTER TABLE contract_info ADD COLUMN listing_date DATETIME", "volume": "ALTER TABLE contract_info ADD COLUMN volume BIGINT DEFAULT 0", "open_interest": "ALTER TABLE contract_info ADD COLUMN open_interest BIGINT DEFAULT 0", } for col_name, sql in new_columns.items(): if col_name not in existing_columns: print(f" 添加字段: {col_name}") db.execute(text(sql)) else: print(f" ✅ 字段已存在: {col_name}") db.commit() # 3. 数据迁移:填充衍生字段 print("📝 迁移现有合约数据...") contracts = db.query(ContractInfo).all() updated_count = 0 for contract in contracts: # 解析 year_month 和 delivery_month if not contract.year_month and contract.symbol: year_month, delivery_month = extract_year_month(contract.symbol) if year_month: contract.year_month = year_month contract.delivery_month = delivery_month # 提取 product_code if not contract.product and contract.symbol: contract.product = extract_product_code(contract.symbol) updated_count += 1 if updated_count > 0: db.commit() print(f" ✅ 更新 {updated_count} 条合约记录") # 4. 生成品种元数据 print(" 生成品种元数据...") products = db.query( ContractInfo.product, ContractInfo.exchange, ContractInfo.multiplier, ContractInfo.price_tick ).distinct().all() product_count = 0 for prod_code, exchange, multiplier, price_tick in products: if not prod_code: continue # 检查是否已存在 existing = db.query(ProductInfo).filter( ProductInfo.product_code == prod_code ).first() if existing: continue # 查找中文名 # 优先使用映射表 product_name = PRODUCT_NAME_MAP.get(prod_code.lower(), prod_code) category = CATEGORY_MAP.get(prod_code.lower(), "其他") product = ProductInfo( product_code=prod_code, product_name=product_name, exchange=exchange, multiplier=multiplier or 10, price_tick=price_tick, category=category, is_active=True, ) db.add(product) db.commit() # 逐条提交避免批量冲突 product_count += 1 if product_count > 0: print(f" ✅ 创建 {product_count} 个品种元数据") # 5. 创建索引 print("📑 创建索引...") try: db.execute(text("CREATE INDEX IF NOT EXISTS idx_contract_year_month ON contract_info(year_month)")) db.execute(text("CREATE INDEX IF NOT EXISTS idx_contract_is_main ON contract_info(is_main)")) db.execute(text("CREATE INDEX IF NOT EXISTS idx_product_category ON product_info(category)")) db.execute(text("CREATE INDEX IF NOT EXISTS idx_product_exchange ON product_info(exchange)")) db.commit() except Exception as e: print(f" ⚠️ 索引创建警告: {e}") print("\n✅ 迁移完成!") except Exception as e: db.rollback() print(f"\n❌ 迁移失败: {e}") import traceback traceback.print_exc() finally: db.close() if __name__ == "__main__": migrate()