|
|
|
|
|
from typing import List, Optional
|
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
from app.models import ContractInfo
|
|
|
|
|
|
from app.services.datasource.manager import DataSourceManager
|
|
|
|
|
|
from app.database import SessionLocal
|
|
|
|
|
|
import logging
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ContractService:
|
|
|
|
|
|
"""合约信息服务"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
self.manager = DataSourceManager()
|
|
|
|
|
|
|
|
|
|
|
|
def sync_contracts(self) -> int:
|
|
|
|
|
|
"""从数据源同步合约列表到数据库"""
|
|
|
|
|
|
source = self.manager.get_primary_source()
|
|
|
|
|
|
if not source:
|
|
|
|
|
|
raise Exception("没有可用的数据源")
|
|
|
|
|
|
|
|
|
|
|
|
# Tushare 需要遍历所有交易所
|
|
|
|
|
|
exchanges = ["CFFEX", "SHFE", "DCE", "CZCE", "INE", "GFEX"]
|
|
|
|
|
|
all_contracts = []
|
|
|
|
|
|
|
|
|
|
|
|
for ex in exchanges:
|
|
|
|
|
|
try:
|
|
|
|
|
|
contracts = source.get_contract_list(exchange=ex)
|
|
|
|
|
|
all_contracts.extend(contracts)
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
continue # 某个交易所失败不影响其他
|
|
|
|
|
|
|
|
|
|
|
|
# 去重:基于 symbol
|
|
|
|
|
|
seen_symbols = set()
|
|
|
|
|
|
unique_contracts = []
|
|
|
|
|
|
for c in all_contracts:
|
|
|
|
|
|
if c["symbol"] not in seen_symbols:
|
|
|
|
|
|
seen_symbols.add(c["symbol"])
|
|
|
|
|
|
unique_contracts.append(c)
|
|
|
|
|
|
all_contracts = unique_contracts
|
|
|
|
|
|
|
|
|
|
|
|
db = SessionLocal()
|
|
|
|
|
|
count = 0
|
|
|
|
|
|
try:
|
|
|
|
|
|
for c in all_contracts:
|
|
|
|
|
|
contract = db.query(ContractInfo).filter(
|
|
|
|
|
|
ContractInfo.symbol == c["symbol"]
|
|
|
|
|
|
).first()
|
|
|
|
|
|
|
|
|
|
|
|
if contract:
|
|
|
|
|
|
contract.exchange = c["exchange"]
|
|
|
|
|
|
contract.name = c["name"]
|
|
|
|
|
|
contract.product = c["product"]
|
|
|
|
|
|
contract.multiplier = c["multiplier"]
|
|
|
|
|
|
contract.price_tick = c["price_tick"]
|
|
|
|
|
|
contract.expire_date = c["expire_date"]
|
|
|
|
|
|
contract.is_active = c["is_active"]
|
|
|
|
|
|
else:
|
|
|
|
|
|
contract = ContractInfo(
|
|
|
|
|
|
symbol=c["symbol"],
|
|
|
|
|
|
exchange=c["exchange"],
|
|
|
|
|
|
name=c["name"],
|
|
|
|
|
|
product=c["product"],
|
|
|
|
|
|
multiplier=c["multiplier"],
|
|
|
|
|
|
price_tick=c["price_tick"],
|
|
|
|
|
|
expire_date=c["expire_date"],
|
|
|
|
|
|
is_active=c["is_active"],
|
|
|
|
|
|
)
|
|
|
|
|
|
db.add(contract)
|
|
|
|
|
|
count += 1
|
|
|
|
|
|
|
|
|
|
|
|
db.commit()
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
db.rollback()
|
|
|
|
|
|
raise
|
|
|
|
|
|
finally:
|
|
|
|
|
|
db.close()
|
|
|
|
|
|
|
|
|
|
|
|
return count
|
|
|
|
|
|
|
|
|
|
|
|
def get_contracts(
|
|
|
|
|
|
self,
|
|
|
|
|
|
exchange: Optional[str] = None,
|
|
|
|
|
|
product: Optional[str] = None,
|
|
|
|
|
|
is_active: Optional[bool] = None
|
|
|
|
|
|
) -> List[ContractInfo]:
|
|
|
|
|
|
"""查询合约列表"""
|
|
|
|
|
|
db = SessionLocal()
|
|
|
|
|
|
try:
|
|
|
|
|
|
query = db.query(ContractInfo)
|
|
|
|
|
|
if exchange:
|
|
|
|
|
|
query = query.filter(ContractInfo.exchange == exchange)
|
|
|
|
|
|
if product:
|
|
|
|
|
|
query = query.filter(ContractInfo.product == product)
|
|
|
|
|
|
if is_active is not None:
|
|
|
|
|
|
query = query.filter(ContractInfo.is_active == is_active)
|
|
|
|
|
|
|
|
|
|
|
|
query = query.order_by(ContractInfo.symbol)
|
|
|
|
|
|
return query.all()
|
|
|
|
|
|
finally:
|
|
|
|
|
|
db.close()
|
|
|
|
|
|
|
|
|
|
|
|
def get_contract(self, symbol: str) -> Optional[ContractInfo]:
|
|
|
|
|
|
"""查询单个合约"""
|
|
|
|
|
|
db = SessionLocal()
|
|
|
|
|
|
try:
|
|
|
|
|
|
return db.query(ContractInfo).filter(ContractInfo.symbol == symbol).first()
|
|
|
|
|
|
finally:
|
|
|
|
|
|
db.close()
|
|
|
|
|
|
|
|
|
|
|
|
def get_products(
|
|
|
|
|
|
self,
|
|
|
|
|
|
exchange: Optional[str] = None
|
|
|
|
|
|
) -> List[dict]:
|
|
|
|
|
|
"""获取品种列表(按品种代码去重)"""
|
|
|
|
|
|
logger.info(f"[获取品种列表] exchange={exchange}")
|
|
|
|
|
|
db = SessionLocal()
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 按品种分组,获取每个品种的基本信息
|
|
|
|
|
|
query = db.query(
|
|
|
|
|
|
ContractInfo.product,
|
|
|
|
|
|
ContractInfo.exchange
|
|
|
|
|
|
).filter(
|
|
|
|
|
|
ContractInfo.product.isnot(None),
|
|
|
|
|
|
ContractInfo.product != ''
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if exchange:
|
|
|
|
|
|
query = query.filter(ContractInfo.exchange == exchange)
|
|
|
|
|
|
|
|
|
|
|
|
# 按品种分组
|
|
|
|
|
|
query = query.group_by(
|
|
|
|
|
|
ContractInfo.product,
|
|
|
|
|
|
ContractInfo.exchange
|
|
|
|
|
|
).order_by(ContractInfo.product)
|
|
|
|
|
|
|
|
|
|
|
|
results = query.all()
|
|
|
|
|
|
|
|
|
|
|
|
# 按品种代码+交易所组合去重(同一品种代码可能在不同交易所)
|
|
|
|
|
|
seen_products = set()
|
|
|
|
|
|
products = []
|
|
|
|
|
|
for row in results:
|
|
|
|
|
|
key = (row.product, row.exchange)
|
|
|
|
|
|
if key not in seen_products:
|
|
|
|
|
|
seen_products.add(key)
|
|
|
|
|
|
|
|
|
|
|
|
# 获取该品种的详细信息(优先主力合约,其次最新合约)
|
|
|
|
|
|
product_info = self._get_product_info(db, row.product, row.exchange)
|
|
|
|
|
|
|
|
|
|
|
|
products.append({
|
|
|
|
|
|
"product": row.product,
|
|
|
|
|
|
"exchange": row.exchange,
|
|
|
|
|
|
"name": product_info.get("name", row.product),
|
|
|
|
|
|
"multiplier": product_info.get("multiplier"),
|
|
|
|
|
|
"price_tick": product_info.get("price_tick"),
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"[获取品种列表] 返回 {len(products)} 个品种")
|
|
|
|
|
|
return products
|
|
|
|
|
|
finally:
|
|
|
|
|
|
db.close()
|
|
|
|
|
|
|
|
|
|
|
|
def _get_product_info(self, db: Session, product: str, exchange: str) -> dict:
|
|
|
|
|
|
"""
|
|
|
|
|
|
获取品种的详细信息
|
|
|
|
|
|
优先使用主力合约的信息,如果没有主力合约则使用最新的合约
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. 优先查找主力合约
|
|
|
|
|
|
main_contract = db.query(ContractInfo).filter(
|
|
|
|
|
|
ContractInfo.product == product,
|
|
|
|
|
|
ContractInfo.exchange == exchange,
|
|
|
|
|
|
ContractInfo.is_active == True,
|
|
|
|
|
|
ContractInfo.is_main == True
|
|
|
|
|
|
).first()
|
|
|
|
|
|
|
|
|
|
|
|
if main_contract:
|
|
|
|
|
|
return {
|
|
|
|
|
|
"name": main_contract.name,
|
|
|
|
|
|
"multiplier": main_contract.multiplier,
|
|
|
|
|
|
"price_tick": main_contract.price_tick,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 如果没有主力合约,查找最新的活跃合约(按合约代码排序)
|
|
|
|
|
|
latest_contract = db.query(ContractInfo).filter(
|
|
|
|
|
|
ContractInfo.product == product,
|
|
|
|
|
|
ContractInfo.exchange == exchange,
|
|
|
|
|
|
ContractInfo.is_active == True
|
|
|
|
|
|
).order_by(ContractInfo.symbol.desc()).first()
|
|
|
|
|
|
|
|
|
|
|
|
if latest_contract:
|
|
|
|
|
|
return {
|
|
|
|
|
|
"name": latest_contract.name,
|
|
|
|
|
|
"multiplier": latest_contract.multiplier,
|
|
|
|
|
|
"price_tick": latest_contract.price_tick,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 3. 如果都没有,返回默认值
|
|
|
|
|
|
return {
|
|
|
|
|
|
"name": product,
|
|
|
|
|
|
"multiplier": None,
|
|
|
|
|
|
"price_tick": None,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def get_contracts_by_month(
|
|
|
|
|
|
self,
|
|
|
|
|
|
product: str,
|
|
|
|
|
|
start_month: str,
|
|
|
|
|
|
limit: int = 5
|
|
|
|
|
|
) -> List[ContractInfo]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
根据品种和起始月份查询合约列表
|
|
|
|
|
|
start_month 格式: YYYY-MM 或 YYYYMM
|
|
|
|
|
|
返回从指定月份开始的 limit 个合约
|
|
|
|
|
|
"""
|
|
|
|
|
|
logger.info(f"[按月份查询合约] product={product}, start_month={start_month}, limit={limit}")
|
|
|
|
|
|
db = SessionLocal()
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 查询指定品种的合约
|
|
|
|
|
|
query = db.query(ContractInfo).filter(
|
|
|
|
|
|
ContractInfo.product == product
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
contracts = query.all()
|
|
|
|
|
|
|
|
|
|
|
|
if not contracts:
|
|
|
|
|
|
logger.warning(f"[按月份查询合约] 品种 {product} 没有任何合约数据")
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
# 解析并排序所有合约
|
|
|
|
|
|
contract_with_month = []
|
|
|
|
|
|
for contract in contracts:
|
|
|
|
|
|
month_tuple = self._extract_contract_month(contract.symbol, contract.expire_date)
|
|
|
|
|
|
if month_tuple:
|
|
|
|
|
|
contract_with_month.append((contract, month_tuple))
|
|
|
|
|
|
|
|
|
|
|
|
if not contract_with_month:
|
|
|
|
|
|
logger.warning(f"[按月份查询合约] 品种 {product} 的合约无法解析月份")
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
# 按月份排序
|
|
|
|
|
|
contract_with_month.sort(key=lambda x: x[1])
|
|
|
|
|
|
|
|
|
|
|
|
# 解析起始月份
|
|
|
|
|
|
if len(start_month) == 7: # YYYY-MM
|
|
|
|
|
|
year = int(start_month[:4])
|
|
|
|
|
|
month = int(start_month[5:7])
|
|
|
|
|
|
elif len(start_month) == 6: # YYYYMM
|
|
|
|
|
|
year = int(start_month[:4])
|
|
|
|
|
|
month = int(start_month[4:6])
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise ValueError(f"月份格式错误: {start_month},应为 YYYY-MM 或 YYYYMM")
|
|
|
|
|
|
|
|
|
|
|
|
start_tuple = (year, month)
|
|
|
|
|
|
|
|
|
|
|
|
# 过滤 >= start_month 的合约
|
|
|
|
|
|
filtered = [(c, m) for c, m in contract_with_month if m >= start_tuple]
|
|
|
|
|
|
|
|
|
|
|
|
# 如果没有找到,返回最接近的合约(往前找)
|
|
|
|
|
|
if not filtered:
|
|
|
|
|
|
logger.info(f"[按月份查询合约] 没有找到 >= {start_tuple} 的合约,返回最早的 {limit} 个合约")
|
|
|
|
|
|
filtered = contract_with_month[:limit]
|
|
|
|
|
|
else:
|
|
|
|
|
|
filtered = filtered[:limit]
|
|
|
|
|
|
|
|
|
|
|
|
result = [c[0] for c in filtered]
|
|
|
|
|
|
logger.info(f"[按月份查询合约] 返回 {len(result)} 个合约")
|
|
|
|
|
|
return result
|
|
|
|
|
|
finally:
|
|
|
|
|
|
db.close()
|
|
|
|
|
|
|
|
|
|
|
|
def _extract_contract_month(self, symbol: str, expire_date):
|
|
|
|
|
|
"""
|
|
|
|
|
|
从合约代码或到期日中提取月份
|
|
|
|
|
|
返回 (year, month) 元组
|
|
|
|
|
|
"""
|
|
|
|
|
|
import re
|
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
|
|
|
|
|
|
# 优先使用 expire_date
|
|
|
|
|
|
if expire_date:
|
|
|
|
|
|
if isinstance(expire_date, datetime):
|
|
|
|
|
|
return (expire_date.year, expire_date.month)
|
|
|
|
|
|
|
|
|
|
|
|
# 从合约代码中解析,如 rb2401 -> 2024-01
|
|
|
|
|
|
match = re.search(r'(\d{2})(\d{2})$', symbol)
|
|
|
|
|
|
if match:
|
|
|
|
|
|
year_suffix = int(match.group(1))
|
|
|
|
|
|
month = int(match.group(2))
|
|
|
|
|
|
|
|
|
|
|
|
# 判断世纪(期货合约通常在当前年份附近)
|
|
|
|
|
|
current_year = datetime.now().year
|
|
|
|
|
|
current_century = current_year // 100 * 100
|
|
|
|
|
|
|
|
|
|
|
|
# 如果月份 > 当前月份,可能是上一年的合约
|
|
|
|
|
|
current_month = datetime.now().month
|
|
|
|
|
|
if month > current_month:
|
|
|
|
|
|
year = current_century + year_suffix - 100
|
|
|
|
|
|
else:
|
|
|
|
|
|
year = current_century + year_suffix
|
|
|
|
|
|
|
|
|
|
|
|
# 处理跨世纪情况
|
|
|
|
|
|
if year > current_year + 10:
|
|
|
|
|
|
year -= 100
|
|
|
|
|
|
|
|
|
|
|
|
return (year, month)
|
|
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
contract_service = ContractService()
|