from typing import List, Optional from sqlalchemy.orm import Session from app.models import ContractInfo from app.services.datasource.manager import DataSourceManager from app.database import SessionLocal class ContractService: """合约信息服务""" def __init__(self): self.manager = DataSourceManager() def sync_contracts(self) -> int: """从数据源同步合约列表到数据库""" source = self.manager.get_primary_source() if not source: raise Exception("没有可用的数据源") # Tushare 需要遍历所有交易所 exchanges = ["CFFEX", "SHFE", "DCE", "CZCE", "INE", "GFEX"] all_contracts = [] for ex in exchanges: try: contracts = source.get_contract_list(exchange=ex) all_contracts.extend(contracts) except Exception: continue # 某个交易所失败不影响其他 # 去重:基于 symbol seen_symbols = set() unique_contracts = [] for c in all_contracts: if c["symbol"] not in seen_symbols: seen_symbols.add(c["symbol"]) unique_contracts.append(c) all_contracts = unique_contracts db = SessionLocal() count = 0 try: for c in all_contracts: contract = db.query(ContractInfo).filter( ContractInfo.symbol == c["symbol"] ).first() if contract: contract.exchange = c["exchange"] contract.name = c["name"] contract.product = c["product"] contract.multiplier = c["multiplier"] contract.price_tick = c["price_tick"] contract.expire_date = c["expire_date"] contract.is_active = c["is_active"] else: contract = ContractInfo( symbol=c["symbol"], exchange=c["exchange"], name=c["name"], product=c["product"], multiplier=c["multiplier"], price_tick=c["price_tick"], expire_date=c["expire_date"], is_active=c["is_active"], ) db.add(contract) count += 1 db.commit() except Exception: db.rollback() raise finally: db.close() return count def get_contracts( self, exchange: Optional[str] = None, product: Optional[str] = None, is_active: Optional[bool] = None ) -> List[ContractInfo]: """查询合约列表""" db = SessionLocal() try: query = db.query(ContractInfo) if exchange: query = query.filter(ContractInfo.exchange == exchange) if product: query = query.filter(ContractInfo.product == product) if is_active is not None: query = query.filter(ContractInfo.is_active == is_active) query = query.order_by(ContractInfo.symbol) return query.all() finally: db.close() def get_contract(self, symbol: str) -> Optional[ContractInfo]: """查询单个合约""" db = SessionLocal() try: return db.query(ContractInfo).filter(ContractInfo.symbol == symbol).first() finally: db.close() contract_service = ContractService()