""" 品种服务:管理品种元数据、品种树、主力合约计算 """ from typing import List, Optional, Dict from sqlalchemy.orm import Session from sqlalchemy import func, and_ from app.models import ProductInfo, ContractInfo from app.database import SessionLocal class ProductService: """品种信息服务""" def get_products( self, exchange: Optional[str] = None, category: Optional[str] = None, is_active: Optional[bool] = None ) -> List[dict]: """获取品种列表,包含合约数量和主力合约信息""" db = SessionLocal() try: query = db.query( ProductInfo, func.count(ContractInfo.symbol).label('contract_count'), func.max(ContractInfo.symbol).label('main_contract') # 临时使用,后续优化 ).outerjoin( ContractInfo, and_( ProductInfo.product_code == ContractInfo.product, ContractInfo.is_active == True ) ).group_by(ProductInfo.id) if exchange: query = query.filter(ProductInfo.exchange == exchange) if category: query = query.filter(ProductInfo.category == category) if is_active is not None: query = query.filter(ProductInfo.is_active == is_active) results = query.all() return [ { "id": p.id, "product_code": p.product_code, "product_name": p.product_name, "exchange": p.exchange, "multiplier": p.multiplier, "price_tick": p.price_tick, "category": p.category, "is_active": p.is_active, "contract_count": count, "main_contract": main, } for p, count, main in results ] finally: db.close() def get_product_tree(self) -> List[dict]: """获取品种树结构(按分类分组)""" db = SessionLocal() try: # 获取所有品种 products = db.query(ProductInfo).order_by( ProductInfo.category, ProductInfo.product_code ).all() # 按分类分组 tree = {} for p in products: if p.category not in tree: tree[p.category] = [] # 获取该品种的活跃合约 contracts = db.query(ContractInfo).filter( and_( ContractInfo.product == p.product_code, ContractInfo.is_active == True ) ).order_by(ContractInfo.year_month).all() contract_list = [ { "symbol": c.symbol, "year_month": c.year_month, "delivery_month": c.delivery_month, "is_main": c.is_main, "name": c.name, } for c in contracts ] tree[p.category].append({ "product_code": p.product_code, "product_name": p.product_name, "exchange": p.exchange, "contract_count": len(contract_list), "main_contract": next((c["symbol"] for c in contract_list if c["is_main"]), None), "contracts": contract_list, }) # 转换为列表格式 return [ {"category": cat, "products": prods} for cat, prods in tree.items() ] finally: db.close() def get_product_contracts( self, product_code: str, is_active: Optional[bool] = None ) -> List[dict]: """获取指定品种的所有合约""" db = SessionLocal() try: query = db.query(ContractInfo).filter( ContractInfo.product == product_code ) if is_active is not None: query = query.filter(ContractInfo.is_active == is_active) contracts = query.order_by(ContractInfo.year_month).all() return [ { "id": c.id, "symbol": c.symbol, "name": c.name, "exchange": c.exchange, "year_month": c.year_month, "delivery_month": c.delivery_month, "is_main": c.is_main, "is_active": c.is_active, "expire_date": c.expire_date, "multiplier": c.multiplier, "price_tick": c.price_tick, } for c in contracts ] finally: db.close() def set_main_contract(self, symbol: str) -> bool: """设置主力合约""" db = SessionLocal() try: # 获取合约信息 contract = db.query(ContractInfo).filter( ContractInfo.symbol == symbol ).first() if not contract: return False # 取消同品种其他合约的主力标识 db.query(ContractInfo).filter( and_( ContractInfo.product == contract.product, ContractInfo.symbol != symbol ) ).update({"is_main": False}) # 设置当前合约为主力 contract.is_main = True db.commit() return True except Exception: db.rollback() return False finally: db.close() def update_main_contracts(self) -> int: """根据持仓量自动更新主力合约标识 规则: 1. 优先按持仓量最大的活跃合约为主力 2. 当持仓量数据缺失(全为0)时,按交割月取最近的活跃合约 """ db = SessionLocal() try: products = db.query(ProductInfo).all() updated_count = 0 for product in products: # 获取该品种所有活跃合约 active_contracts = db.query(ContractInfo).filter( and_( ContractInfo.product == product.product_code, ContractInfo.is_active == True ) ).all() if not active_contracts: continue # 检查是否有持仓量数据 has_volume_data = any(c.open_interest > 0 for c in active_contracts) if has_volume_data: # 方案1:按持仓量排序 main_contract = max(active_contracts, key=lambda c: c.open_interest) else: # 方案2:按交割月取最近的(当前时间之后最近的合约) from datetime import datetime now = datetime.utcnow() # 过滤出未来交割的合约 future_contracts = [] for c in active_contracts: if c.year_month: try: expiry = datetime.strptime(c.year_month, '%Y-%m') if expiry >= now: future_contracts.append(c) except ValueError: pass if future_contracts: # 取最近的 main_contract = min(future_contracts, key=lambda c: c.year_month) else: # 如果没有未来合约,取最后一个活跃的 main_contract = active_contracts[-1] # 取消同品种其他合约的主力标识 db.query(ContractInfo).filter( and_( ContractInfo.product == product.product_code, ContractInfo.symbol != main_contract.symbol ) ).update({"is_main": False}) # 设置主力合约 if not main_contract.is_main: main_contract.is_main = True updated_count += 1 db.commit() return updated_count except Exception: db.rollback() return 0 finally: db.close() product_service = ProductService()