You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

238 lines
8.8 KiB

"""
数据库迁移脚本为合约管理优化添加新字段和品种表
"""
import sys
import os
import re
from datetime import datetime
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.database import engine, Base, SessionLocal
from app.models import ProductInfo, ContractInfo
from sqlalchemy import text
# 品种分类映射(根据交易所和品种代码)
CATEGORY_MAP = {
# 能源化工
"rb": "能源化工", "hc": "能源化工", "fu": "能源化工", "bu": "能源化工",
"ru": "能源化工", "nr": "能源化工", "sp": "能源化工", "pg": "能源化工",
"pp": "能源化工", "l": "能源化工", "v": "能源化工", "eg": "能源化工",
"sc": "能源化工", "lu": "能源化工", "low": "能源化工",
# 金属
"cu": "金属", "al": "金属", "zn": "金属", "pb": "金属", "ni": "金属",
"sn": "金属", "au": "金属", "ag": "金属", "ss": "能源化工",
# 农产品
"a": "农产品", "b": "农产品", "c": "农产品", "cs": "农产品",
"m": "农产品", "y": "农产品", "p": "农产品", "rm": "农产品",
"oi": "农产品", "cf": "农产品", "sr": "农产品", "ta": "能源化工",
"ma": "能源化工", "fg": "能源化工", "sa": "农产品", "ur": "农产品",
"sf": "金属", "sm": "金属", "ap": "农产品", "cj": "农产品",
"rr": "农产品", "jd": "农产品", "lh": "农产品",
# 金融
"if": "金融", "ih": "金融", "ic": "金融", "im": "金融",
"t": "金融", "tf": "金融", "ts": "金融", "tl": "金融",
# 其他
"px": "能源化工", "pr": "能源化工", "br": "能源化工",
"lc": "金属", "si": "金属", "ec": "金融",
}
# 品种中文名映射
PRODUCT_NAME_MAP = {
# 能源化工
"rb": "螺纹钢", "hc": "热卷", "fu": "燃料油", "bu": "沥青",
"ru": "橡胶", "nr": "20号胶", "sp": "纸浆", "pg": "液化气",
"pp": "聚丙烯", "l": "塑料", "v": "PVC", "eg": "乙二醇",
"sc": "原油", "lu": "低硫燃料油", "low": "低硫燃料油",
"ta": "PTA", "ma": "甲醇", "fg": "玻璃", "ur": "尿素",
"sa": "纯碱", "px": "对二甲苯", "pr": "丙二醇", "br": "合成橡胶",
# 金属
"cu": "", "al": "", "zn": "", "pb": "", "ni": "",
"sn": "", "au": "黄金", "ag": "白银", "ss": "不锈钢",
"sf": "硅铁", "sm": "锰硅", "lc": "碳酸锂", "si": "工业硅",
# 农产品
"a": "豆一", "b": "豆二", "c": "玉米", "cs": "淀粉",
"m": "豆粕", "y": "豆油", "p": "棕榈油", "rm": "菜粕",
"oi": "菜油", "cf": "棉花", "sr": "白糖", "ap": "苹果",
"cj": "红枣", "rr": "粳米", "jd": "鸡蛋", "lh": "生猪",
# 金融
"if": "沪深300", "ih": "上证50", "ic": "中证500", "im": "中证1000",
"t": "10年国债", "tf": "5年国债", "ts": "2年国债", "tl": "30年国债", "ec": "工业硅",
# 能源
"sc": "原油", "lu": "低硫燃油",
}
# 交易所映射
EXCHANGE_MAP = {
"SHFE": "上海期货交易所",
"DCE": "大连商品交易所",
"CZCE": "郑州商品交易所",
"CFFEX": "中国金融期货交易所",
"INE": "上海国际能源交易中心",
"GFEX": "广州期货交易所",
}
def extract_year_month(symbol: str) -> tuple:
"""从合约代码解析交割年月
示例:
rb2401 -> (2024-01, 1)
cu2312 -> (2023-12, 12)
IF2403 -> (2024-03, 3)
"""
# 匹配末尾的数字(年份+月份)
match = re.search(r'(\d{2})(\d{2})$', symbol)
if not match:
return None, None
year_suffix, month = match.groups()
year_suffix = int(year_suffix)
month = int(month)
# 年份处理:假设是 20xx 年
year = 2000 + year_suffix if year_suffix < 100 else year_suffix
return f"{year}-{month:02d}", month
def extract_product_code(symbol: str) -> str:
"""从合约代码提取品种代码
示例:
rb2401 -> rb
cu2312 -> cu
IF2403 -> IF
"""
# 去掉末尾的数字
return re.sub(r'\d+$', '', symbol)
def migrate():
"""执行迁移"""
print("🚀 开始数据库迁移...")
db = SessionLocal()
try:
# 1. 创建新表(如果不存在)
print("📦 创建 product_info 表...")
ProductInfo.__table__.create(engine, checkfirst=True)
# 2. 添加新字段到 contract_infoSQLite 不支持 ALTER TABLE ADD COLUMN 的某些操作,需要检查)
print("🔧 检查 contract_info 表结构...")
# 检查字段是否存在
inspector_result = db.execute(text("PRAGMA table_info(contract_info)")).fetchall()
existing_columns = [row[1] for row in inspector_result]
new_columns = {
"year_month": "ALTER TABLE contract_info ADD COLUMN year_month VARCHAR(7)",
"delivery_month": "ALTER TABLE contract_info ADD COLUMN delivery_month INTEGER",
"is_main": "ALTER TABLE contract_info ADD COLUMN is_main BOOLEAN DEFAULT 0",
"listing_date": "ALTER TABLE contract_info ADD COLUMN listing_date DATETIME",
"volume": "ALTER TABLE contract_info ADD COLUMN volume BIGINT DEFAULT 0",
"open_interest": "ALTER TABLE contract_info ADD COLUMN open_interest BIGINT DEFAULT 0",
}
for col_name, sql in new_columns.items():
if col_name not in existing_columns:
print(f" 添加字段: {col_name}")
db.execute(text(sql))
else:
print(f" ✅ 字段已存在: {col_name}")
db.commit()
# 3. 数据迁移:填充衍生字段
print("📝 迁移现有合约数据...")
contracts = db.query(ContractInfo).all()
updated_count = 0
for contract in contracts:
# 解析 year_month 和 delivery_month
if not contract.year_month and contract.symbol:
year_month, delivery_month = extract_year_month(contract.symbol)
if year_month:
contract.year_month = year_month
contract.delivery_month = delivery_month
# 提取 product_code
if not contract.product and contract.symbol:
contract.product = extract_product_code(contract.symbol)
updated_count += 1
if updated_count > 0:
db.commit()
print(f" ✅ 更新 {updated_count} 条合约记录")
# 4. 生成品种元数据
print(" 生成品种元数据...")
products = db.query(
ContractInfo.product,
ContractInfo.exchange,
ContractInfo.multiplier,
ContractInfo.price_tick
).distinct().all()
product_count = 0
for prod_code, exchange, multiplier, price_tick in products:
if not prod_code:
continue
# 检查是否已存在
existing = db.query(ProductInfo).filter(
ProductInfo.product_code == prod_code
).first()
if existing:
continue
# 查找中文名
# 优先使用映射表
product_name = PRODUCT_NAME_MAP.get(prod_code.lower(), prod_code)
category = CATEGORY_MAP.get(prod_code.lower(), "其他")
product = ProductInfo(
product_code=prod_code,
product_name=product_name,
exchange=exchange,
multiplier=multiplier or 10,
price_tick=price_tick,
category=category,
is_active=True,
)
db.add(product)
db.commit() # 逐条提交避免批量冲突
product_count += 1
if product_count > 0:
print(f" ✅ 创建 {product_count} 个品种元数据")
# 5. 创建索引
print("📑 创建索引...")
try:
db.execute(text("CREATE INDEX IF NOT EXISTS idx_contract_year_month ON contract_info(year_month)"))
db.execute(text("CREATE INDEX IF NOT EXISTS idx_contract_is_main ON contract_info(is_main)"))
db.execute(text("CREATE INDEX IF NOT EXISTS idx_product_category ON product_info(category)"))
db.execute(text("CREATE INDEX IF NOT EXISTS idx_product_exchange ON product_info(exchange)"))
db.commit()
except Exception as e:
print(f" ⚠️ 索引创建警告: {e}")
print("\n✅ 迁移完成!")
except Exception as e:
db.rollback()
print(f"\n❌ 迁移失败: {e}")
import traceback
traceback.print_exc()
finally:
db.close()
if __name__ == "__main__":
migrate()