You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

256 lines
8.8 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
品种服务:管理品种元数据、品种树、主力合约计算
"""
from typing import List, Optional, Dict
from sqlalchemy.orm import Session
from sqlalchemy import func, and_
from app.models import ProductInfo, ContractInfo
from app.database import SessionLocal
class ProductService:
"""品种信息服务"""
def get_products(
self,
exchange: Optional[str] = None,
category: Optional[str] = None,
is_active: Optional[bool] = None
) -> List[dict]:
"""获取品种列表,包含合约数量和主力合约信息"""
db = SessionLocal()
try:
query = db.query(
ProductInfo,
func.count(ContractInfo.symbol).label('contract_count'),
func.max(ContractInfo.symbol).label('main_contract') # 临时使用,后续优化
).outerjoin(
ContractInfo,
and_(
ProductInfo.product_code == ContractInfo.product,
ContractInfo.is_active == True
)
).group_by(ProductInfo.id)
if exchange:
query = query.filter(ProductInfo.exchange == exchange)
if category:
query = query.filter(ProductInfo.category == category)
if is_active is not None:
query = query.filter(ProductInfo.is_active == is_active)
results = query.all()
return [
{
"id": p.id,
"product_code": p.product_code,
"product_name": p.product_name,
"exchange": p.exchange,
"multiplier": p.multiplier,
"price_tick": p.price_tick,
"category": p.category,
"is_active": p.is_active,
"contract_count": count,
"main_contract": main,
}
for p, count, main in results
]
finally:
db.close()
def get_product_tree(self) -> List[dict]:
"""获取品种树结构(按分类分组)"""
db = SessionLocal()
try:
# 获取所有品种
products = db.query(ProductInfo).order_by(
ProductInfo.category,
ProductInfo.product_code
).all()
# 按分类分组
tree = {}
for p in products:
if p.category not in tree:
tree[p.category] = []
# 获取该品种的活跃合约
contracts = db.query(ContractInfo).filter(
and_(
ContractInfo.product == p.product_code,
ContractInfo.is_active == True
)
).order_by(ContractInfo.year_month).all()
contract_list = [
{
"symbol": c.symbol,
"year_month": c.year_month,
"delivery_month": c.delivery_month,
"is_main": c.is_main,
"name": c.name,
}
for c in contracts
]
tree[p.category].append({
"product_code": p.product_code,
"product_name": p.product_name,
"exchange": p.exchange,
"contract_count": len(contract_list),
"main_contract": next((c["symbol"] for c in contract_list if c["is_main"]), None),
"contracts": contract_list,
})
# 转换为列表格式
return [
{"category": cat, "products": prods}
for cat, prods in tree.items()
]
finally:
db.close()
def get_product_contracts(
self,
product_code: str,
is_active: Optional[bool] = None
) -> List[dict]:
"""获取指定品种的所有合约"""
db = SessionLocal()
try:
query = db.query(ContractInfo).filter(
ContractInfo.product == product_code
)
if is_active is not None:
query = query.filter(ContractInfo.is_active == is_active)
contracts = query.order_by(ContractInfo.year_month).all()
return [
{
"id": c.id,
"symbol": c.symbol,
"name": c.name,
"exchange": c.exchange,
"year_month": c.year_month,
"delivery_month": c.delivery_month,
"is_main": c.is_main,
"is_active": c.is_active,
"expire_date": c.expire_date,
"multiplier": c.multiplier,
"price_tick": c.price_tick,
}
for c in contracts
]
finally:
db.close()
def set_main_contract(self, symbol: str) -> bool:
"""设置主力合约"""
db = SessionLocal()
try:
# 获取合约信息
contract = db.query(ContractInfo).filter(
ContractInfo.symbol == symbol
).first()
if not contract:
return False
# 取消同品种其他合约的主力标识
db.query(ContractInfo).filter(
and_(
ContractInfo.product == contract.product,
ContractInfo.symbol != symbol
)
).update({"is_main": False})
# 设置当前合约为主力
contract.is_main = True
db.commit()
return True
except Exception:
db.rollback()
return False
finally:
db.close()
def update_main_contracts(self) -> int:
"""根据持仓量自动更新主力合约标识
规则:
1. 优先按持仓量最大的活跃合约为主力
2. 当持仓量数据缺失全为0按交割月取最近的活跃合约
"""
db = SessionLocal()
try:
products = db.query(ProductInfo).all()
updated_count = 0
for product in products:
# 获取该品种所有活跃合约
active_contracts = db.query(ContractInfo).filter(
and_(
ContractInfo.product == product.product_code,
ContractInfo.is_active == True
)
).all()
if not active_contracts:
continue
# 检查是否有持仓量数据
has_volume_data = any(c.open_interest > 0 for c in active_contracts)
if has_volume_data:
# 方案1按持仓量排序
main_contract = max(active_contracts, key=lambda c: c.open_interest)
else:
# 方案2按交割月取最近的当前时间之后最近的合约
from datetime import datetime
now = datetime.utcnow()
# 过滤出未来交割的合约
future_contracts = []
for c in active_contracts:
if c.year_month:
try:
expiry = datetime.strptime(c.year_month, '%Y-%m')
if expiry >= now:
future_contracts.append(c)
except ValueError:
pass
if future_contracts:
# 取最近的
main_contract = min(future_contracts, key=lambda c: c.year_month)
else:
# 如果没有未来合约,取最后一个活跃的
main_contract = active_contracts[-1]
# 取消同品种其他合约的主力标识
db.query(ContractInfo).filter(
and_(
ContractInfo.product == product.product_code,
ContractInfo.symbol != main_contract.symbol
)
).update({"is_main": False})
# 设置主力合约
if not main_contract.is_main:
main_contract.is_main = True
updated_count += 1
db.commit()
return updated_count
except Exception:
db.rollback()
return 0
finally:
db.close()
product_service = ProductService()