You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

256 lines
8.8 KiB

"""
品种服务管理品种元数据品种树主力合约计算
"""
from typing import List, Optional, Dict
from sqlalchemy.orm import Session
from sqlalchemy import func, and_
from app.models import ProductInfo, ContractInfo
from app.database import SessionLocal
class ProductService:
"""品种信息服务"""
def get_products(
self,
exchange: Optional[str] = None,
category: Optional[str] = None,
is_active: Optional[bool] = None
) -> List[dict]:
"""获取品种列表,包含合约数量和主力合约信息"""
db = SessionLocal()
try:
query = db.query(
ProductInfo,
func.count(ContractInfo.symbol).label('contract_count'),
func.max(ContractInfo.symbol).label('main_contract') # 临时使用,后续优化
).outerjoin(
ContractInfo,
and_(
ProductInfo.product_code == ContractInfo.product,
ContractInfo.is_active == True
)
).group_by(ProductInfo.id)
if exchange:
query = query.filter(ProductInfo.exchange == exchange)
if category:
query = query.filter(ProductInfo.category == category)
if is_active is not None:
query = query.filter(ProductInfo.is_active == is_active)
results = query.all()
return [
{
"id": p.id,
"product_code": p.product_code,
"product_name": p.product_name,
"exchange": p.exchange,
"multiplier": p.multiplier,
"price_tick": p.price_tick,
"category": p.category,
"is_active": p.is_active,
"contract_count": count,
"main_contract": main,
}
for p, count, main in results
]
finally:
db.close()
def get_product_tree(self) -> List[dict]:
"""获取品种树结构(按分类分组)"""
db = SessionLocal()
try:
# 获取所有品种
products = db.query(ProductInfo).order_by(
ProductInfo.category,
ProductInfo.product_code
).all()
# 按分类分组
tree = {}
for p in products:
if p.category not in tree:
tree[p.category] = []
# 获取该品种的活跃合约
contracts = db.query(ContractInfo).filter(
and_(
ContractInfo.product == p.product_code,
ContractInfo.is_active == True
)
).order_by(ContractInfo.year_month).all()
contract_list = [
{
"symbol": c.symbol,
"year_month": c.year_month,
"delivery_month": c.delivery_month,
"is_main": c.is_main,
"name": c.name,
}
for c in contracts
]
tree[p.category].append({
"product_code": p.product_code,
"product_name": p.product_name,
"exchange": p.exchange,
"contract_count": len(contract_list),
"main_contract": next((c["symbol"] for c in contract_list if c["is_main"]), None),
"contracts": contract_list,
})
# 转换为列表格式
return [
{"category": cat, "products": prods}
for cat, prods in tree.items()
]
finally:
db.close()
def get_product_contracts(
self,
product_code: str,
is_active: Optional[bool] = None
) -> List[dict]:
"""获取指定品种的所有合约"""
db = SessionLocal()
try:
query = db.query(ContractInfo).filter(
ContractInfo.product == product_code
)
if is_active is not None:
query = query.filter(ContractInfo.is_active == is_active)
contracts = query.order_by(ContractInfo.year_month).all()
return [
{
"id": c.id,
"symbol": c.symbol,
"name": c.name,
"exchange": c.exchange,
"year_month": c.year_month,
"delivery_month": c.delivery_month,
"is_main": c.is_main,
"is_active": c.is_active,
"expire_date": c.expire_date,
"multiplier": c.multiplier,
"price_tick": c.price_tick,
}
for c in contracts
]
finally:
db.close()
def set_main_contract(self, symbol: str) -> bool:
"""设置主力合约"""
db = SessionLocal()
try:
# 获取合约信息
contract = db.query(ContractInfo).filter(
ContractInfo.symbol == symbol
).first()
if not contract:
return False
# 取消同品种其他合约的主力标识
db.query(ContractInfo).filter(
and_(
ContractInfo.product == contract.product,
ContractInfo.symbol != symbol
)
).update({"is_main": False})
# 设置当前合约为主力
contract.is_main = True
db.commit()
return True
except Exception:
db.rollback()
return False
finally:
db.close()
def update_main_contracts(self) -> int:
"""根据持仓量自动更新主力合约标识
规则
1. 优先按持仓量最大的活跃合约为主力
2. 当持仓量数据缺失全为0按交割月取最近的活跃合约
"""
db = SessionLocal()
try:
products = db.query(ProductInfo).all()
updated_count = 0
for product in products:
# 获取该品种所有活跃合约
active_contracts = db.query(ContractInfo).filter(
and_(
ContractInfo.product == product.product_code,
ContractInfo.is_active == True
)
).all()
if not active_contracts:
continue
# 检查是否有持仓量数据
has_volume_data = any(c.open_interest > 0 for c in active_contracts)
if has_volume_data:
# 方案1按持仓量排序
main_contract = max(active_contracts, key=lambda c: c.open_interest)
else:
# 方案2按交割月取最近的当前时间之后最近的合约
from datetime import datetime
now = datetime.utcnow()
# 过滤出未来交割的合约
future_contracts = []
for c in active_contracts:
if c.year_month:
try:
expiry = datetime.strptime(c.year_month, '%Y-%m')
if expiry >= now:
future_contracts.append(c)
except ValueError:
pass
if future_contracts:
# 取最近的
main_contract = min(future_contracts, key=lambda c: c.year_month)
else:
# 如果没有未来合约,取最后一个活跃的
main_contract = active_contracts[-1]
# 取消同品种其他合约的主力标识
db.query(ContractInfo).filter(
and_(
ContractInfo.product == product.product_code,
ContractInfo.symbol != main_contract.symbol
)
).update({"is_main": False})
# 设置主力合约
if not main_contract.is_main:
main_contract.is_main = True
updated_count += 1
db.commit()
return updated_count
except Exception:
db.rollback()
return 0
finally:
db.close()
product_service = ProductService()