You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

111 lines
3.6 KiB

from typing import List, Optional
from sqlalchemy.orm import Session
from app.models import ContractInfo
from app.services.datasource.manager import DataSourceManager
from app.database import SessionLocal
class ContractService:
"""合约信息服务"""
def __init__(self):
self.manager = DataSourceManager()
def sync_contracts(self) -> int:
"""从数据源同步合约列表到数据库"""
source = self.manager.get_primary_source()
if not source:
raise Exception("没有可用的数据源")
# Tushare 需要遍历所有交易所
exchanges = ["CFFEX", "SHFE", "DCE", "CZCE", "INE", "GFEX"]
all_contracts = []
for ex in exchanges:
try:
contracts = source.get_contract_list(exchange=ex)
all_contracts.extend(contracts)
except Exception:
continue # 某个交易所失败不影响其他
# 去重:基于 symbol
seen_symbols = set()
unique_contracts = []
for c in all_contracts:
if c["symbol"] not in seen_symbols:
seen_symbols.add(c["symbol"])
unique_contracts.append(c)
all_contracts = unique_contracts
db = SessionLocal()
count = 0
try:
for c in all_contracts:
contract = db.query(ContractInfo).filter(
ContractInfo.symbol == c["symbol"]
).first()
if contract:
contract.exchange = c["exchange"]
contract.name = c["name"]
contract.product = c["product"]
contract.multiplier = c["multiplier"]
contract.price_tick = c["price_tick"]
contract.expire_date = c["expire_date"]
contract.is_active = c["is_active"]
else:
contract = ContractInfo(
symbol=c["symbol"],
exchange=c["exchange"],
name=c["name"],
product=c["product"],
multiplier=c["multiplier"],
price_tick=c["price_tick"],
expire_date=c["expire_date"],
is_active=c["is_active"],
)
db.add(contract)
count += 1
db.commit()
except Exception:
db.rollback()
raise
finally:
db.close()
return count
def get_contracts(
self,
exchange: Optional[str] = None,
product: Optional[str] = None,
is_active: Optional[bool] = None
) -> List[ContractInfo]:
"""查询合约列表"""
db = SessionLocal()
try:
query = db.query(ContractInfo)
if exchange:
query = query.filter(ContractInfo.exchange == exchange)
if product:
query = query.filter(ContractInfo.product == product)
if is_active is not None:
query = query.filter(ContractInfo.is_active == is_active)
query = query.order_by(ContractInfo.symbol)
return query.all()
finally:
db.close()
def get_contract(self, symbol: str) -> Optional[ContractInfo]:
"""查询单个合约"""
db = SessionLocal()
try:
return db.query(ContractInfo).filter(ContractInfo.symbol == symbol).first()
finally:
db.close()
contract_service = ContractService()