You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

312 lines
11 KiB

from typing import List, Optional
from sqlalchemy.orm import Session
from app.models import ContractInfo
from app.services.datasource.manager import DataSourceManager
from app.database import SessionLocal
import logging
logger = logging.getLogger(__name__)
class ContractService:
"""合约信息服务"""
def __init__(self):
self.manager = DataSourceManager()
def sync_contracts(self) -> int:
"""从数据源同步合约列表到数据库"""
source = self.manager.get_primary_source()
if not source:
raise Exception("没有可用的数据源")
# Tushare 需要遍历所有交易所
exchanges = ["CFFEX", "SHFE", "DCE", "CZCE", "INE", "GFEX"]
all_contracts = []
for ex in exchanges:
try:
contracts = source.get_contract_list(exchange=ex)
all_contracts.extend(contracts)
except Exception:
continue # 某个交易所失败不影响其他
# 去重:基于 symbol
seen_symbols = set()
unique_contracts = []
for c in all_contracts:
if c["symbol"] not in seen_symbols:
seen_symbols.add(c["symbol"])
unique_contracts.append(c)
all_contracts = unique_contracts
db = SessionLocal()
count = 0
try:
for c in all_contracts:
contract = db.query(ContractInfo).filter(
ContractInfo.symbol == c["symbol"]
).first()
if contract:
contract.exchange = c["exchange"]
contract.name = c["name"]
contract.product = c["product"]
contract.multiplier = c["multiplier"]
contract.price_tick = c["price_tick"]
contract.expire_date = c["expire_date"]
contract.is_active = c["is_active"]
else:
contract = ContractInfo(
symbol=c["symbol"],
exchange=c["exchange"],
name=c["name"],
product=c["product"],
multiplier=c["multiplier"],
price_tick=c["price_tick"],
expire_date=c["expire_date"],
is_active=c["is_active"],
)
db.add(contract)
count += 1
db.commit()
except Exception:
db.rollback()
raise
finally:
db.close()
return count
def get_contracts(
self,
exchange: Optional[str] = None,
product: Optional[str] = None,
is_active: Optional[bool] = None
) -> List[ContractInfo]:
"""查询合约列表"""
db = SessionLocal()
try:
query = db.query(ContractInfo)
if exchange:
query = query.filter(ContractInfo.exchange == exchange)
if product:
query = query.filter(ContractInfo.product == product)
if is_active is not None:
query = query.filter(ContractInfo.is_active == is_active)
query = query.order_by(ContractInfo.symbol)
return query.all()
finally:
db.close()
def get_contract(self, symbol: str) -> Optional[ContractInfo]:
"""查询单个合约"""
db = SessionLocal()
try:
return db.query(ContractInfo).filter(ContractInfo.symbol == symbol).first()
finally:
db.close()
def get_products(
self,
exchange: Optional[str] = None
) -> List[dict]:
"""获取品种列表(按品种代码去重)"""
logger.info(f"[获取品种列表] exchange={exchange}")
db = SessionLocal()
try:
# 按品种分组,获取每个品种的基本信息
query = db.query(
ContractInfo.product,
ContractInfo.exchange
).filter(
ContractInfo.product.isnot(None),
ContractInfo.product != ''
)
if exchange:
query = query.filter(ContractInfo.exchange == exchange)
# 按品种分组
query = query.group_by(
ContractInfo.product,
ContractInfo.exchange
).order_by(ContractInfo.product)
results = query.all()
# 按品种代码+交易所组合去重(同一品种代码可能在不同交易所)
seen_products = set()
products = []
for row in results:
key = (row.product, row.exchange)
if key not in seen_products:
seen_products.add(key)
# 获取该品种的详细信息(优先主力合约,其次最新合约)
product_info = self._get_product_info(db, row.product, row.exchange)
products.append({
"product": row.product,
"exchange": row.exchange,
"name": product_info.get("name", row.product),
"multiplier": product_info.get("multiplier"),
"price_tick": product_info.get("price_tick"),
})
logger.info(f"[获取品种列表] 返回 {len(products)} 个品种")
return products
finally:
db.close()
def _get_product_info(self, db: Session, product: str, exchange: str) -> dict:
"""
获取品种的详细信息
优先使用主力合约的信息,如果没有主力合约则使用最新的合约
"""
# 1. 优先查找主力合约
main_contract = db.query(ContractInfo).filter(
ContractInfo.product == product,
ContractInfo.exchange == exchange,
ContractInfo.is_active == True,
ContractInfo.is_main == True
).first()
if main_contract:
return {
"name": main_contract.name,
"multiplier": main_contract.multiplier,
"price_tick": main_contract.price_tick,
}
# 2. 如果没有主力合约,查找最新的活跃合约(按合约代码排序)
latest_contract = db.query(ContractInfo).filter(
ContractInfo.product == product,
ContractInfo.exchange == exchange,
ContractInfo.is_active == True
).order_by(ContractInfo.symbol.desc()).first()
if latest_contract:
return {
"name": latest_contract.name,
"multiplier": latest_contract.multiplier,
"price_tick": latest_contract.price_tick,
}
# 3. 如果都没有,返回默认值
return {
"name": product,
"multiplier": None,
"price_tick": None,
}
def get_contracts_by_month(
self,
product: str,
start_month: str,
limit: int = 5
) -> List[ContractInfo]:
"""
根据品种和起始月份查询合约列表
start_month 格式: YYYY-MM 或 YYYYMM
返回从指定月份开始的 limit 个合约
"""
logger.info(f"[按月份查询合约] product={product}, start_month={start_month}, limit={limit}")
db = SessionLocal()
try:
# 查询指定品种的合约
query = db.query(ContractInfo).filter(
ContractInfo.product == product
)
contracts = query.all()
if not contracts:
logger.warning(f"[按月份查询合约] 品种 {product} 没有任何合约数据")
return []
# 解析并排序所有合约
contract_with_month = []
for contract in contracts:
month_tuple = self._extract_contract_month(contract.symbol, contract.expire_date)
if month_tuple:
contract_with_month.append((contract, month_tuple))
if not contract_with_month:
logger.warning(f"[按月份查询合约] 品种 {product} 的合约无法解析月份")
return []
# 按月份排序
contract_with_month.sort(key=lambda x: x[1])
# 解析起始月份
if len(start_month) == 7: # YYYY-MM
year = int(start_month[:4])
month = int(start_month[5:7])
elif len(start_month) == 6: # YYYYMM
year = int(start_month[:4])
month = int(start_month[4:6])
else:
raise ValueError(f"月份格式错误: {start_month},应为 YYYY-MM 或 YYYYMM")
start_tuple = (year, month)
# 过滤 >= start_month 的合约
filtered = [(c, m) for c, m in contract_with_month if m >= start_tuple]
# 如果没有找到,返回最接近的合约(往前找)
if not filtered:
logger.info(f"[按月份查询合约] 没有找到 >= {start_tuple} 的合约,返回最早的 {limit} 个合约")
filtered = contract_with_month[:limit]
else:
filtered = filtered[:limit]
result = [c[0] for c in filtered]
logger.info(f"[按月份查询合约] 返回 {len(result)} 个合约")
return result
finally:
db.close()
def _extract_contract_month(self, symbol: str, expire_date):
"""
从合约代码或到期日中提取月份
返回 (year, month) 元组
"""
import re
from datetime import datetime
# 优先使用 expire_date
if expire_date:
if isinstance(expire_date, datetime):
return (expire_date.year, expire_date.month)
# 从合约代码中解析,如 rb2401 -> 2024-01
match = re.search(r'(\d{2})(\d{2})$', symbol)
if match:
year_suffix = int(match.group(1))
month = int(match.group(2))
# 判断世纪(期货合约通常在当前年份附近)
current_year = datetime.now().year
current_century = current_year // 100 * 100
# 如果月份 > 当前月份,可能是上一年的合约
current_month = datetime.now().month
if month > current_month:
year = current_century + year_suffix - 100
else:
year = current_century + year_suffix
# 处理跨世纪情况
if year > current_year + 10:
year -= 100
return (year, month)
return None
contract_service = ContractService()