from typing import List, Optional from sqlalchemy.orm import Session from app.models import ContractInfo from app.services.datasource.manager import DataSourceManager from app.database import SessionLocal import logging logger = logging.getLogger(__name__) class ContractService: """合约信息服务""" def __init__(self): self.manager = DataSourceManager() def sync_contracts(self) -> int: """从数据源同步合约列表到数据库""" source = self.manager.get_primary_source() if not source: raise Exception("没有可用的数据源") # Tushare 需要遍历所有交易所 exchanges = ["CFFEX", "SHFE", "DCE", "CZCE", "INE", "GFEX"] all_contracts = [] for ex in exchanges: try: contracts = source.get_contract_list(exchange=ex) all_contracts.extend(contracts) except Exception: continue # 某个交易所失败不影响其他 # 去重:基于 symbol seen_symbols = set() unique_contracts = [] for c in all_contracts: if c["symbol"] not in seen_symbols: seen_symbols.add(c["symbol"]) unique_contracts.append(c) all_contracts = unique_contracts db = SessionLocal() count = 0 try: for c in all_contracts: contract = db.query(ContractInfo).filter( ContractInfo.symbol == c["symbol"] ).first() if contract: contract.exchange = c["exchange"] contract.name = c["name"] contract.product = c["product"] contract.multiplier = c["multiplier"] contract.price_tick = c["price_tick"] contract.expire_date = c["expire_date"] contract.is_active = c["is_active"] else: contract = ContractInfo( symbol=c["symbol"], exchange=c["exchange"], name=c["name"], product=c["product"], multiplier=c["multiplier"], price_tick=c["price_tick"], expire_date=c["expire_date"], is_active=c["is_active"], ) db.add(contract) count += 1 db.commit() except Exception: db.rollback() raise finally: db.close() return count def get_contracts( self, exchange: Optional[str] = None, product: Optional[str] = None, is_active: Optional[bool] = None ) -> List[ContractInfo]: """查询合约列表""" db = SessionLocal() try: query = db.query(ContractInfo) if exchange: query = query.filter(ContractInfo.exchange == exchange) if product: query = query.filter(ContractInfo.product == product) if is_active is not None: query = query.filter(ContractInfo.is_active == is_active) query = query.order_by(ContractInfo.symbol) return query.all() finally: db.close() def get_contract(self, symbol: str) -> Optional[ContractInfo]: """查询单个合约""" db = SessionLocal() try: return db.query(ContractInfo).filter(ContractInfo.symbol == symbol).first() finally: db.close() def get_products( self, exchange: Optional[str] = None ) -> List[dict]: """获取品种列表(按品种代码去重)""" logger.info(f"[获取品种列表] exchange={exchange}") db = SessionLocal() try: # 按品种分组,获取每个品种的基本信息 query = db.query( ContractInfo.product, ContractInfo.exchange ).filter( ContractInfo.product.isnot(None), ContractInfo.product != '' ) if exchange: query = query.filter(ContractInfo.exchange == exchange) # 按品种分组 query = query.group_by( ContractInfo.product, ContractInfo.exchange ).order_by(ContractInfo.product) results = query.all() # 按品种代码+交易所组合去重(同一品种代码可能在不同交易所) seen_products = set() products = [] for row in results: key = (row.product, row.exchange) if key not in seen_products: seen_products.add(key) # 获取该品种的详细信息(优先主力合约,其次最新合约) product_info = self._get_product_info(db, row.product, row.exchange) products.append({ "product": row.product, "exchange": row.exchange, "name": product_info.get("name", row.product), "multiplier": product_info.get("multiplier"), "price_tick": product_info.get("price_tick"), }) logger.info(f"[获取品种列表] 返回 {len(products)} 个品种") return products finally: db.close() def _get_product_info(self, db: Session, product: str, exchange: str) -> dict: """ 获取品种的详细信息 优先使用主力合约的信息,如果没有主力合约则使用最新的合约 """ # 1. 优先查找主力合约 main_contract = db.query(ContractInfo).filter( ContractInfo.product == product, ContractInfo.exchange == exchange, ContractInfo.is_active == True, ContractInfo.is_main == True ).first() if main_contract: return { "name": main_contract.name, "multiplier": main_contract.multiplier, "price_tick": main_contract.price_tick, } # 2. 如果没有主力合约,查找最新的活跃合约(按合约代码排序) latest_contract = db.query(ContractInfo).filter( ContractInfo.product == product, ContractInfo.exchange == exchange, ContractInfo.is_active == True ).order_by(ContractInfo.symbol.desc()).first() if latest_contract: return { "name": latest_contract.name, "multiplier": latest_contract.multiplier, "price_tick": latest_contract.price_tick, } # 3. 如果都没有,返回默认值 return { "name": product, "multiplier": None, "price_tick": None, } def get_contracts_by_month( self, product: str, start_month: str, limit: int = 5 ) -> List[ContractInfo]: """ 根据品种和起始月份查询合约列表 start_month 格式: YYYY-MM 或 YYYYMM 返回从指定月份开始的 limit 个合约 """ logger.info(f"[按月份查询合约] product={product}, start_month={start_month}, limit={limit}") db = SessionLocal() try: # 查询指定品种的合约 query = db.query(ContractInfo).filter( ContractInfo.product == product ) contracts = query.all() if not contracts: logger.warning(f"[按月份查询合约] 品种 {product} 没有任何合约数据") return [] # 解析并排序所有合约 contract_with_month = [] for contract in contracts: month_tuple = self._extract_contract_month(contract.symbol, contract.expire_date) if month_tuple: contract_with_month.append((contract, month_tuple)) if not contract_with_month: logger.warning(f"[按月份查询合约] 品种 {product} 的合约无法解析月份") return [] # 按月份排序 contract_with_month.sort(key=lambda x: x[1]) # 解析起始月份 if len(start_month) == 7: # YYYY-MM year = int(start_month[:4]) month = int(start_month[5:7]) elif len(start_month) == 6: # YYYYMM year = int(start_month[:4]) month = int(start_month[4:6]) else: raise ValueError(f"月份格式错误: {start_month},应为 YYYY-MM 或 YYYYMM") start_tuple = (year, month) # 过滤 >= start_month 的合约 filtered = [(c, m) for c, m in contract_with_month if m >= start_tuple] # 如果没有找到,返回最接近的合约(往前找) if not filtered: logger.info(f"[按月份查询合约] 没有找到 >= {start_tuple} 的合约,返回最早的 {limit} 个合约") filtered = contract_with_month[:limit] else: filtered = filtered[:limit] result = [c[0] for c in filtered] logger.info(f"[按月份查询合约] 返回 {len(result)} 个合约") return result finally: db.close() def _extract_contract_month(self, symbol: str, expire_date): """ 从合约代码或到期日中提取月份 返回 (year, month) 元组 """ import re from datetime import datetime # 优先使用 expire_date if expire_date: if isinstance(expire_date, datetime): return (expire_date.year, expire_date.month) # 从合约代码中解析,如 rb2401 -> 2024-01 match = re.search(r'(\d{2})(\d{2})$', symbol) if match: year_suffix = int(match.group(1)) month = int(match.group(2)) # 判断世纪(期货合约通常在当前年份附近) current_year = datetime.now().year current_century = current_year // 100 * 100 # 如果月份 > 当前月份,可能是上一年的合约 current_month = datetime.now().month if month > current_month: year = current_century + year_suffix - 100 else: year = current_century + year_suffix # 处理跨世纪情况 if year > current_year + 10: year -= 100 return (year, month) return None contract_service = ContractService()