You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
111 lines
3.6 KiB
111 lines
3.6 KiB
|
1 month ago
|
from typing import List, Optional
|
||
|
|
from sqlalchemy.orm import Session
|
||
|
|
from app.models import ContractInfo
|
||
|
|
from app.services.datasource.manager import DataSourceManager
|
||
|
|
from app.database import SessionLocal
|
||
|
|
|
||
|
|
|
||
|
|
class ContractService:
|
||
|
|
"""合约信息服务"""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self.manager = DataSourceManager()
|
||
|
|
|
||
|
|
def sync_contracts(self) -> int:
|
||
|
|
"""从数据源同步合约列表到数据库"""
|
||
|
|
source = self.manager.get_primary_source()
|
||
|
|
if not source:
|
||
|
|
raise Exception("没有可用的数据源")
|
||
|
|
|
||
|
|
# Tushare 需要遍历所有交易所
|
||
|
|
exchanges = ["CFFEX", "SHFE", "DCE", "CZCE", "INE", "GFEX"]
|
||
|
|
all_contracts = []
|
||
|
|
|
||
|
|
for ex in exchanges:
|
||
|
|
try:
|
||
|
|
contracts = source.get_contract_list(exchange=ex)
|
||
|
|
all_contracts.extend(contracts)
|
||
|
|
except Exception:
|
||
|
|
continue # 某个交易所失败不影响其他
|
||
|
|
|
||
|
|
# 去重:基于 symbol
|
||
|
|
seen_symbols = set()
|
||
|
|
unique_contracts = []
|
||
|
|
for c in all_contracts:
|
||
|
|
if c["symbol"] not in seen_symbols:
|
||
|
|
seen_symbols.add(c["symbol"])
|
||
|
|
unique_contracts.append(c)
|
||
|
|
all_contracts = unique_contracts
|
||
|
|
|
||
|
|
db = SessionLocal()
|
||
|
|
count = 0
|
||
|
|
try:
|
||
|
|
for c in all_contracts:
|
||
|
|
contract = db.query(ContractInfo).filter(
|
||
|
|
ContractInfo.symbol == c["symbol"]
|
||
|
|
).first()
|
||
|
|
|
||
|
|
if contract:
|
||
|
|
contract.exchange = c["exchange"]
|
||
|
|
contract.name = c["name"]
|
||
|
|
contract.product = c["product"]
|
||
|
|
contract.multiplier = c["multiplier"]
|
||
|
|
contract.price_tick = c["price_tick"]
|
||
|
|
contract.expire_date = c["expire_date"]
|
||
|
|
contract.is_active = c["is_active"]
|
||
|
|
else:
|
||
|
|
contract = ContractInfo(
|
||
|
|
symbol=c["symbol"],
|
||
|
|
exchange=c["exchange"],
|
||
|
|
name=c["name"],
|
||
|
|
product=c["product"],
|
||
|
|
multiplier=c["multiplier"],
|
||
|
|
price_tick=c["price_tick"],
|
||
|
|
expire_date=c["expire_date"],
|
||
|
|
is_active=c["is_active"],
|
||
|
|
)
|
||
|
|
db.add(contract)
|
||
|
|
count += 1
|
||
|
|
|
||
|
|
db.commit()
|
||
|
|
except Exception:
|
||
|
|
db.rollback()
|
||
|
|
raise
|
||
|
|
finally:
|
||
|
|
db.close()
|
||
|
|
|
||
|
|
return count
|
||
|
|
|
||
|
|
def get_contracts(
|
||
|
|
self,
|
||
|
|
exchange: Optional[str] = None,
|
||
|
|
product: Optional[str] = None,
|
||
|
|
is_active: Optional[bool] = None
|
||
|
|
) -> List[ContractInfo]:
|
||
|
|
"""查询合约列表"""
|
||
|
|
db = SessionLocal()
|
||
|
|
try:
|
||
|
|
query = db.query(ContractInfo)
|
||
|
|
if exchange:
|
||
|
|
query = query.filter(ContractInfo.exchange == exchange)
|
||
|
|
if product:
|
||
|
|
query = query.filter(ContractInfo.product == product)
|
||
|
|
if is_active is not None:
|
||
|
|
query = query.filter(ContractInfo.is_active == is_active)
|
||
|
|
|
||
|
|
query = query.order_by(ContractInfo.symbol)
|
||
|
|
return query.all()
|
||
|
|
finally:
|
||
|
|
db.close()
|
||
|
|
|
||
|
|
def get_contract(self, symbol: str) -> Optional[ContractInfo]:
|
||
|
|
"""查询单个合约"""
|
||
|
|
db = SessionLocal()
|
||
|
|
try:
|
||
|
|
return db.query(ContractInfo).filter(ContractInfo.symbol == symbol).first()
|
||
|
|
finally:
|
||
|
|
db.close()
|
||
|
|
|
||
|
|
|
||
|
|
contract_service = ContractService()
|