|
|
|
|
from fastapi import FastAPI, Depends, HTTPException, Query
|
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
from typing import Optional
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
import logging
|
|
|
|
|
|
|
|
|
|
# 配置日志
|
|
|
|
|
logging.basicConfig(
|
|
|
|
|
level=logging.INFO,
|
|
|
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
from app.config import settings
|
|
|
|
|
from app.database import get_db, engine, Base
|
|
|
|
|
from app.schemas import (
|
|
|
|
|
KlineRequest, KlineResponse, KlineItem,
|
|
|
|
|
ContractInfo as ContractSchema, ContractListResponse,
|
|
|
|
|
DataSourceConfigItem, DataSourceConfigUpdate, DataSourceCreate,
|
|
|
|
|
ApiResponse, HealthResponse, DataSourceStatus,
|
|
|
|
|
BatchSyncRequest, BatchSyncResult,
|
|
|
|
|
ProductInfo as ProductSchema, ProductTreeResponse,
|
|
|
|
|
)
|
|
|
|
|
from app.services.kline_service import kline_service
|
|
|
|
|
from app.services.contract_service import contract_service
|
|
|
|
|
from app.services.product_service import product_service
|
|
|
|
|
from app.services.datasource.manager import DataSourceManager
|
|
|
|
|
from app.models import DataSourceConfig
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(
|
|
|
|
|
title=settings.PROJECT_NAME,
|
|
|
|
|
version=settings.VERSION,
|
|
|
|
|
docs_url="/docs",
|
|
|
|
|
redoc_url="/redoc",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# CORS
|
|
|
|
|
app.add_middleware(
|
|
|
|
|
CORSMiddleware,
|
|
|
|
|
allow_origins=["*"],
|
|
|
|
|
allow_credentials=True,
|
|
|
|
|
allow_methods=["*"],
|
|
|
|
|
allow_headers=["*"],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ========== 启动事件 ==========
|
|
|
|
|
@app.on_event("startup")
|
|
|
|
|
async def startup():
|
|
|
|
|
# 创建数据库表
|
|
|
|
|
Base.metadata.create_all(bind=engine)
|
|
|
|
|
# 加载数据源配置
|
|
|
|
|
DataSourceManager.load_enabled_sources()
|
|
|
|
|
# 初始化默认数据源配置(如果不存在)
|
|
|
|
|
_init_default_datasource()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_default_datasource():
|
|
|
|
|
"""初始化默认的数据源配置(如果不存在)"""
|
|
|
|
|
from app.database import SessionLocal
|
|
|
|
|
db = SessionLocal()
|
|
|
|
|
try:
|
|
|
|
|
# 初始化 Tushare
|
|
|
|
|
existing = db.query(DataSourceConfig).filter(
|
|
|
|
|
DataSourceConfig.source_name == "tushare"
|
|
|
|
|
).first()
|
|
|
|
|
if not existing:
|
|
|
|
|
import json
|
|
|
|
|
cfg = DataSourceConfig(
|
|
|
|
|
source_name="tushare",
|
|
|
|
|
display_name="Tushare",
|
|
|
|
|
is_enabled=False,
|
|
|
|
|
config_json=json.dumps({"token": ""}),
|
|
|
|
|
priority=1,
|
|
|
|
|
status="unknown",
|
|
|
|
|
)
|
|
|
|
|
db.add(cfg)
|
|
|
|
|
|
|
|
|
|
# 初始化 Akshare
|
|
|
|
|
existing_ak = db.query(DataSourceConfig).filter(
|
|
|
|
|
DataSourceConfig.source_name == "akshare"
|
|
|
|
|
).first()
|
|
|
|
|
if not existing_ak:
|
|
|
|
|
import json
|
|
|
|
|
cfg_ak = DataSourceConfig(
|
|
|
|
|
source_name="akshare",
|
|
|
|
|
display_name="AKShare",
|
|
|
|
|
is_enabled=False,
|
|
|
|
|
config_json=json.dumps({"max_retries": 3}),
|
|
|
|
|
priority=2,
|
|
|
|
|
status="unknown",
|
|
|
|
|
)
|
|
|
|
|
db.add(cfg_ak)
|
|
|
|
|
|
|
|
|
|
db.commit()
|
|
|
|
|
finally:
|
|
|
|
|
db.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ========== 健康检查 ==========
|
|
|
|
|
@app.get("/api/health", response_model=HealthResponse)
|
|
|
|
|
async def health_check():
|
|
|
|
|
services = {}
|
|
|
|
|
try:
|
|
|
|
|
from sqlalchemy import text
|
|
|
|
|
from app.database import engine
|
|
|
|
|
with engine.connect() as conn:
|
|
|
|
|
conn.execute(text("SELECT 1"))
|
|
|
|
|
services["database"] = "ok"
|
|
|
|
|
except Exception as e:
|
|
|
|
|
services["database"] = f"error: {str(e)}"
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
import redis
|
|
|
|
|
r = redis.from_url(settings.REDIS_URL)
|
|
|
|
|
r.ping()
|
|
|
|
|
services["redis"] = "ok"
|
|
|
|
|
except Exception as e:
|
|
|
|
|
services["redis"] = "not configured" # Redis 非必须
|
|
|
|
|
|
|
|
|
|
status = "healthy" if all(v == "ok" for v in services.values()) else "degraded"
|
|
|
|
|
|
|
|
|
|
return HealthResponse(
|
|
|
|
|
status=status,
|
|
|
|
|
services=services,
|
|
|
|
|
version=settings.VERSION,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ========== 品种接口 ==========
|
|
|
|
|
@app.get("/api/v1/products")
|
|
|
|
|
async def list_products(
|
|
|
|
|
exchange: Optional[str] = Query(None, description="交易所代码"),
|
|
|
|
|
category: Optional[str] = Query(None, description="品种分类"),
|
|
|
|
|
is_active: Optional[bool] = Query(None, description="是否活跃"),
|
|
|
|
|
):
|
|
|
|
|
"""获取品种列表"""
|
|
|
|
|
products = product_service.get_products(
|
|
|
|
|
exchange=exchange, category=category, is_active=is_active
|
|
|
|
|
)
|
|
|
|
|
return {"code": 0, "data": products}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/api/v1/products/tree")
|
|
|
|
|
async def get_product_tree():
|
|
|
|
|
"""获取品种树结构"""
|
|
|
|
|
tree = product_service.get_product_tree()
|
|
|
|
|
return {"code": 0, "data": {"categories": tree}}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/api/v1/products/{product_code}/contracts")
|
|
|
|
|
async def get_product_contracts(
|
|
|
|
|
product_code: str,
|
|
|
|
|
is_active: Optional[bool] = Query(None, description="是否活跃"),
|
|
|
|
|
):
|
|
|
|
|
"""获取指定品种的所有合约"""
|
|
|
|
|
contracts = product_service.get_product_contracts(
|
|
|
|
|
product_code=product_code, is_active=is_active
|
|
|
|
|
)
|
|
|
|
|
return {"code": 0, "data": contracts}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/v1/contracts/{symbol}/set-main")
|
|
|
|
|
async def set_main_contract(symbol: str):
|
|
|
|
|
"""设置主力合约"""
|
|
|
|
|
success = product_service.set_main_contract(symbol)
|
|
|
|
|
if success:
|
|
|
|
|
return {"code": 0, "message": "设置成功"}
|
|
|
|
|
return {"code": 1, "message": "设置失败,合约不存在"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/v1/contracts/update-main")
|
|
|
|
|
async def update_main_contracts():
|
|
|
|
|
"""根据持仓量自动更新主力合约标识"""
|
|
|
|
|
count = product_service.update_main_contracts()
|
|
|
|
|
return {"code": 0, "message": f"更新了 {count} 个主力合约"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ========== 合约接口 ==========
|
|
|
|
|
@app.get("/api/v1/contracts", response_model=ContractListResponse)
|
|
|
|
|
async def list_contracts(
|
|
|
|
|
exchange: Optional[str] = Query(None, description="交易所代码"),
|
|
|
|
|
product: Optional[str] = Query(None, description="品种代码"),
|
|
|
|
|
is_active: Optional[bool] = Query(None, description="是否活跃"),
|
|
|
|
|
):
|
|
|
|
|
contracts = contract_service.get_contracts(
|
|
|
|
|
exchange=exchange, product=product, is_active=is_active
|
|
|
|
|
)
|
|
|
|
|
return ContractListResponse(
|
|
|
|
|
total=len(contracts),
|
|
|
|
|
items=[ContractSchema.model_validate(c) for c in contracts],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/api/v1/contracts/products", response_model=ApiResponse)
|
|
|
|
|
async def list_products(
|
|
|
|
|
exchange: Optional[str] = Query(None, description="交易所代码"),
|
|
|
|
|
):
|
|
|
|
|
"""获取品种列表(去重后的品种信息)"""
|
|
|
|
|
logger.info(f"[API-获取品种列表] exchange={exchange}")
|
|
|
|
|
products = contract_service.get_products(exchange=exchange)
|
|
|
|
|
return {"code": 0, "message": "ok", "data": {"items": products, "total": len(products)}}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/api/v1/contracts/by-month", response_model=ContractListResponse)
|
|
|
|
|
async def get_contracts_by_month(
|
|
|
|
|
product: str = Query(..., description="品种代码"),
|
|
|
|
|
start_month: str = Query(..., description="起始月份 YYYY-MM 或 YYYYMM"),
|
|
|
|
|
limit: int = Query(5, ge=1, le=20, description="返回合约数量"),
|
|
|
|
|
):
|
|
|
|
|
"""根据品种和起始月份查询合约列表"""
|
|
|
|
|
logger.info(f"[API-按月份查询合约] product={product}, start_month={start_month}, limit={limit}")
|
|
|
|
|
contracts = contract_service.get_contracts_by_month(
|
|
|
|
|
product=product,
|
|
|
|
|
start_month=start_month,
|
|
|
|
|
limit=limit
|
|
|
|
|
)
|
|
|
|
|
return ContractListResponse(
|
|
|
|
|
total=len(contracts),
|
|
|
|
|
items=[ContractSchema.model_validate(c) for c in contracts],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/api/v1/contracts/{symbol}", response_model=ContractSchema)
|
|
|
|
|
async def get_contract(symbol: str):
|
|
|
|
|
contract = contract_service.get_contract(symbol)
|
|
|
|
|
if not contract:
|
|
|
|
|
raise HTTPException(status_code=404, detail="合约不存在")
|
|
|
|
|
return ContractSchema.model_validate(contract)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/v1/contracts/sync")
|
|
|
|
|
async def sync_contracts():
|
|
|
|
|
"""从数据源同步合约列表"""
|
|
|
|
|
try:
|
|
|
|
|
count = contract_service.sync_contracts()
|
|
|
|
|
return {"code": 0, "message": "同步成功", "data": {"synced": count}}
|
|
|
|
|
except Exception as e:
|
|
|
|
|
return {"code": 1, "message": f"同步失败: {str(e)}", "data": None}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ========== K线接口 ==========
|
|
|
|
|
@app.get("/api/v1/kline", response_model=KlineResponse)
|
|
|
|
|
async def get_kline(
|
|
|
|
|
symbol: str = Query(..., description="合约代码"),
|
|
|
|
|
period: str = Query("daily", description="周期: daily/weekly/5m/15m/30m/60m"),
|
|
|
|
|
start_date: Optional[str] = Query(None, description="开始日期 YYYY-MM-DD"),
|
|
|
|
|
end_date: Optional[str] = Query(None, description="结束日期 YYYY-MM-DD"),
|
|
|
|
|
limit: int = Query(500, ge=1, le=5000, description="返回条数"),
|
|
|
|
|
):
|
|
|
|
|
logger.info(f"[API-查询K线] 请求参数: symbol={symbol}, period={period}, start_date={start_date}, end_date={end_date}, limit={limit}")
|
|
|
|
|
items = kline_service.get_kline(
|
|
|
|
|
symbol=symbol,
|
|
|
|
|
period=period,
|
|
|
|
|
start_date=start_date,
|
|
|
|
|
end_date=end_date,
|
|
|
|
|
limit=limit,
|
|
|
|
|
)
|
|
|
|
|
logger.info(f"[API-查询K线] 返回 {len(items)} 条记录")
|
|
|
|
|
return KlineResponse(
|
|
|
|
|
symbol=symbol,
|
|
|
|
|
period=period,
|
|
|
|
|
total=len(items),
|
|
|
|
|
items=[KlineItem(**item) for item in items],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/v1/kline/sync")
|
|
|
|
|
async def sync_kline(req: KlineRequest):
|
|
|
|
|
"""从数据源同步K线数据"""
|
|
|
|
|
logger.info(f"[API-同步K线] 请求参数: symbol={req.symbol}, period={req.period}, start_date={req.start_date}, end_date={req.end_date}")
|
|
|
|
|
try:
|
|
|
|
|
start = req.start_date or "2020-01-01"
|
|
|
|
|
end = req.end_date or datetime.now().strftime("%Y-%m-%d")
|
|
|
|
|
logger.info(f"[API-同步K线] 使用日期范围: {start} ~ {end}")
|
|
|
|
|
|
|
|
|
|
if req.period == "daily":
|
|
|
|
|
count = kline_service.sync_daily(req.symbol, start, end)
|
|
|
|
|
elif req.period == "weekly":
|
|
|
|
|
count = kline_service.sync_weekly(req.symbol, start, end)
|
|
|
|
|
else:
|
|
|
|
|
count = kline_service.sync_intraday(req.symbol, req.period, start, end)
|
|
|
|
|
|
|
|
|
|
logger.info(f"[API-同步K线] 同步成功,共同步 {count} 条记录")
|
|
|
|
|
return {"code": 0, "message": "同步成功", "data": {"synced": count}}
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"[API-同步K线] 同步失败: {e}", exc_info=True)
|
|
|
|
|
return {"code": 1, "message": f"同步失败: {str(e)}", "data": None}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/v1/kline/batch-sync", response_model=BatchSyncResult)
|
|
|
|
|
async def batch_sync_kline(req: BatchSyncRequest):
|
|
|
|
|
"""批量同步K线数据"""
|
|
|
|
|
logger.info(f"[API-批量同步K线] 请求参数: symbols={req.symbols}, period={req.period}, start_date={req.start_date}, end_date={req.end_date}")
|
|
|
|
|
try:
|
|
|
|
|
result = kline_service.batch_sync(
|
|
|
|
|
symbols=req.symbols,
|
|
|
|
|
period=req.period,
|
|
|
|
|
start_date=req.start_date,
|
|
|
|
|
end_date=req.end_date,
|
|
|
|
|
)
|
|
|
|
|
logger.info(f"[API-批量同步K线] 同步完成: 成功={result['success']}, 失败={result['failed']}, 总记录={result['total_records']}")
|
|
|
|
|
return BatchSyncResult(**result)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"[API-批量同步K线] 同步失败: {e}", exc_info=True)
|
|
|
|
|
raise HTTPException(status_code=500, detail=f"批量同步失败: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ========== 数据源管理接口 ==========
|
|
|
|
|
@app.get("/api/v1/datasources")
|
|
|
|
|
async def list_datasources():
|
|
|
|
|
"""获取所有数据源状态"""
|
|
|
|
|
sources = DataSourceManager.get_all_sources_status()
|
|
|
|
|
return {"code": 0, "data": sources}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/v1/datasources")
|
|
|
|
|
async def create_datasource(req: DataSourceCreate):
|
|
|
|
|
"""创建数据源配置"""
|
|
|
|
|
from app.database import SessionLocal
|
|
|
|
|
db = SessionLocal()
|
|
|
|
|
try:
|
|
|
|
|
existing = db.query(DataSourceConfig).filter(
|
|
|
|
|
DataSourceConfig.source_name == req.source_name
|
|
|
|
|
).first()
|
|
|
|
|
if existing:
|
|
|
|
|
return {"code": 1, "message": "数据源已存在"}
|
|
|
|
|
|
|
|
|
|
cfg = DataSourceConfig(
|
|
|
|
|
source_name=req.source_name,
|
|
|
|
|
display_name=req.display_name or req.source_name,
|
|
|
|
|
is_enabled=False,
|
|
|
|
|
config_json=req.config_json or {},
|
|
|
|
|
priority=req.priority,
|
|
|
|
|
status="unknown",
|
|
|
|
|
)
|
|
|
|
|
db.add(cfg)
|
|
|
|
|
db.commit()
|
|
|
|
|
return {"code": 0, "message": "创建成功", "data": {"id": cfg.id}}
|
|
|
|
|
except Exception as e:
|
|
|
|
|
db.rollback()
|
|
|
|
|
return {"code": 1, "message": f"创建失败: {str(e)}"}
|
|
|
|
|
finally:
|
|
|
|
|
db.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.put("/api/v1/datasources/{source_name}")
|
|
|
|
|
async def update_datasource(source_name: str, req: DataSourceConfigUpdate):
|
|
|
|
|
"""更新数据源配置"""
|
|
|
|
|
from app.database import SessionLocal
|
|
|
|
|
db = SessionLocal()
|
|
|
|
|
try:
|
|
|
|
|
cfg = db.query(DataSourceConfig).filter(
|
|
|
|
|
DataSourceConfig.source_name == source_name
|
|
|
|
|
).first()
|
|
|
|
|
if not cfg:
|
|
|
|
|
return {"code": 1, "message": "数据源不存在"}
|
|
|
|
|
|
|
|
|
|
if req.is_enabled is not None:
|
|
|
|
|
cfg.is_enabled = req.is_enabled
|
|
|
|
|
if req.config_json is not None:
|
|
|
|
|
import json
|
|
|
|
|
cfg.config_json = json.dumps(req.config_json)
|
|
|
|
|
if req.priority is not None:
|
|
|
|
|
cfg.priority = req.priority
|
|
|
|
|
|
|
|
|
|
db.commit()
|
|
|
|
|
|
|
|
|
|
# 重新加载数据源
|
|
|
|
|
DataSourceManager.load_enabled_sources()
|
|
|
|
|
|
|
|
|
|
return {"code": 0, "message": "更新成功"}
|
|
|
|
|
except Exception as e:
|
|
|
|
|
db.rollback()
|
|
|
|
|
return {"code": 1, "message": f"更新失败: {str(e)}"}
|
|
|
|
|
finally:
|
|
|
|
|
db.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/v1/datasources/{source_name}/test")
|
|
|
|
|
async def test_datasource(source_name: str):
|
|
|
|
|
"""测试数据源连接"""
|
|
|
|
|
source = DataSourceManager.get_source(source_name)
|
|
|
|
|
if not source:
|
|
|
|
|
# 尝试创建临时实例测试
|
|
|
|
|
from app.database import SessionLocal
|
|
|
|
|
import json
|
|
|
|
|
db = SessionLocal()
|
|
|
|
|
try:
|
|
|
|
|
cfg = db.query(DataSourceConfig).filter(
|
|
|
|
|
DataSourceConfig.source_name == source_name
|
|
|
|
|
).first()
|
|
|
|
|
if not cfg:
|
|
|
|
|
return {"code": 1, "message": "数据源不存在"}
|
|
|
|
|
|
|
|
|
|
config = json.loads(cfg.config_json) if cfg.config_json else {}
|
|
|
|
|
|
|
|
|
|
# 动态获取数据源类
|
|
|
|
|
source_class = DataSourceManager._source_map.get(source_name)
|
|
|
|
|
if not source_class:
|
|
|
|
|
return {"code": 1, "message": "不支持的数据源类型"}
|
|
|
|
|
|
|
|
|
|
source = source_class(config)
|
|
|
|
|
finally:
|
|
|
|
|
db.close()
|
|
|
|
|
|
|
|
|
|
ok, msg = source.health_check()
|
|
|
|
|
if ok:
|
|
|
|
|
# 更新状态
|
|
|
|
|
from app.database import SessionLocal
|
|
|
|
|
db = SessionLocal()
|
|
|
|
|
try:
|
|
|
|
|
cfg = db.query(DataSourceConfig).filter(
|
|
|
|
|
DataSourceConfig.source_name == source_name
|
|
|
|
|
).first()
|
|
|
|
|
if cfg:
|
|
|
|
|
cfg.status = "ok"
|
|
|
|
|
cfg.error_msg = None
|
|
|
|
|
db.commit()
|
|
|
|
|
finally:
|
|
|
|
|
db.close()
|
|
|
|
|
return {"code": 0, "message": "连接成功", "data": {"status": "ok"}}
|
|
|
|
|
else:
|
|
|
|
|
from app.database import SessionLocal
|
|
|
|
|
db = SessionLocal()
|
|
|
|
|
try:
|
|
|
|
|
cfg = db.query(DataSourceConfig).filter(
|
|
|
|
|
DataSourceConfig.source_name == source_name
|
|
|
|
|
).first()
|
|
|
|
|
if cfg:
|
|
|
|
|
cfg.status = "error"
|
|
|
|
|
cfg.error_msg = msg
|
|
|
|
|
db.commit()
|
|
|
|
|
finally:
|
|
|
|
|
db.close()
|
|
|
|
|
return {"code": 1, "message": f"连接失败: {msg}", "data": {"status": "error"}}
|