from fastapi import FastAPI, Depends, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from typing import Optional from datetime import datetime import logging # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) from app.config import settings from app.database import get_db, engine, Base from app.schemas import ( KlineRequest, KlineResponse, KlineItem, ContractInfo as ContractSchema, ContractListResponse, DataSourceConfigItem, DataSourceConfigUpdate, DataSourceCreate, ApiResponse, HealthResponse, DataSourceStatus, BatchSyncRequest, BatchSyncResult, ProductInfo as ProductSchema, ProductTreeResponse, ) from app.services.kline_service import kline_service from app.services.contract_service import contract_service from app.services.product_service import product_service from app.services.datasource.manager import DataSourceManager from app.models import DataSourceConfig logger = logging.getLogger(__name__) app = FastAPI( title=settings.PROJECT_NAME, version=settings.VERSION, docs_url="/docs", redoc_url="/redoc", ) # CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ========== 启动事件 ========== @app.on_event("startup") async def startup(): # 创建数据库表 Base.metadata.create_all(bind=engine) # 加载数据源配置 DataSourceManager.load_enabled_sources() # 初始化默认数据源配置(如果不存在) _init_default_datasource() def _init_default_datasource(): """初始化默认的数据源配置(如果不存在)""" from app.database import SessionLocal db = SessionLocal() try: # 初始化 Tushare existing = db.query(DataSourceConfig).filter( DataSourceConfig.source_name == "tushare" ).first() if not existing: import json cfg = DataSourceConfig( source_name="tushare", display_name="Tushare", is_enabled=False, config_json=json.dumps({"token": ""}), priority=1, status="unknown", ) db.add(cfg) # 初始化 Akshare existing_ak = db.query(DataSourceConfig).filter( DataSourceConfig.source_name == "akshare" ).first() if not existing_ak: import json cfg_ak = DataSourceConfig( source_name="akshare", display_name="AKShare", is_enabled=False, config_json=json.dumps({"max_retries": 3}), priority=2, status="unknown", ) db.add(cfg_ak) db.commit() finally: db.close() # ========== 健康检查 ========== @app.get("/api/health", response_model=HealthResponse) async def health_check(): services = {} try: from sqlalchemy import text from app.database import engine with engine.connect() as conn: conn.execute(text("SELECT 1")) services["database"] = "ok" except Exception as e: services["database"] = f"error: {str(e)}" try: import redis r = redis.from_url(settings.REDIS_URL) r.ping() services["redis"] = "ok" except Exception as e: services["redis"] = "not configured" # Redis 非必须 status = "healthy" if all(v == "ok" for v in services.values()) else "degraded" return HealthResponse( status=status, services=services, version=settings.VERSION, ) # ========== 品种接口 ========== @app.get("/api/v1/products") async def list_products( exchange: Optional[str] = Query(None, description="交易所代码"), category: Optional[str] = Query(None, description="品种分类"), is_active: Optional[bool] = Query(None, description="是否活跃"), ): """获取品种列表""" products = product_service.get_products( exchange=exchange, category=category, is_active=is_active ) return {"code": 0, "data": products} @app.get("/api/v1/products/tree") async def get_product_tree(): """获取品种树结构""" tree = product_service.get_product_tree() return {"code": 0, "data": {"categories": tree}} @app.get("/api/v1/products/{product_code}/contracts") async def get_product_contracts( product_code: str, is_active: Optional[bool] = Query(None, description="是否活跃"), ): """获取指定品种的所有合约""" contracts = product_service.get_product_contracts( product_code=product_code, is_active=is_active ) return {"code": 0, "data": contracts} @app.post("/api/v1/contracts/{symbol}/set-main") async def set_main_contract(symbol: str): """设置主力合约""" success = product_service.set_main_contract(symbol) if success: return {"code": 0, "message": "设置成功"} return {"code": 1, "message": "设置失败,合约不存在"} @app.post("/api/v1/contracts/update-main") async def update_main_contracts(): """根据持仓量自动更新主力合约标识""" count = product_service.update_main_contracts() return {"code": 0, "message": f"更新了 {count} 个主力合约"} # ========== 合约接口 ========== @app.get("/api/v1/contracts", response_model=ContractListResponse) async def list_contracts( exchange: Optional[str] = Query(None, description="交易所代码"), product: Optional[str] = Query(None, description="品种代码"), is_active: Optional[bool] = Query(None, description="是否活跃"), ): contracts = contract_service.get_contracts( exchange=exchange, product=product, is_active=is_active ) return ContractListResponse( total=len(contracts), items=[ContractSchema.model_validate(c) for c in contracts], ) @app.get("/api/v1/contracts/products", response_model=ApiResponse) async def list_products( exchange: Optional[str] = Query(None, description="交易所代码"), ): """获取品种列表(去重后的品种信息)""" logger.info(f"[API-获取品种列表] exchange={exchange}") products = contract_service.get_products(exchange=exchange) return {"code": 0, "message": "ok", "data": {"items": products, "total": len(products)}} @app.get("/api/v1/contracts/by-month", response_model=ContractListResponse) async def get_contracts_by_month( product: str = Query(..., description="品种代码"), start_month: str = Query(..., description="起始月份 YYYY-MM 或 YYYYMM"), limit: int = Query(5, ge=1, le=20, description="返回合约数量"), ): """根据品种和起始月份查询合约列表""" logger.info(f"[API-按月份查询合约] product={product}, start_month={start_month}, limit={limit}") contracts = contract_service.get_contracts_by_month( product=product, start_month=start_month, limit=limit ) return ContractListResponse( total=len(contracts), items=[ContractSchema.model_validate(c) for c in contracts], ) @app.get("/api/v1/contracts/{symbol}", response_model=ContractSchema) async def get_contract(symbol: str): contract = contract_service.get_contract(symbol) if not contract: raise HTTPException(status_code=404, detail="合约不存在") return ContractSchema.model_validate(contract) @app.post("/api/v1/contracts/sync") async def sync_contracts(): """从数据源同步合约列表""" try: count = contract_service.sync_contracts() return {"code": 0, "message": "同步成功", "data": {"synced": count}} except Exception as e: return {"code": 1, "message": f"同步失败: {str(e)}", "data": None} # ========== K线接口 ========== @app.get("/api/v1/kline", response_model=KlineResponse) async def get_kline( symbol: str = Query(..., description="合约代码"), period: str = Query("daily", description="周期: daily/weekly/5m/15m/30m/60m"), start_date: Optional[str] = Query(None, description="开始日期 YYYY-MM-DD"), end_date: Optional[str] = Query(None, description="结束日期 YYYY-MM-DD"), limit: int = Query(500, ge=1, le=5000, description="返回条数"), ): logger.info(f"[API-查询K线] 请求参数: symbol={symbol}, period={period}, start_date={start_date}, end_date={end_date}, limit={limit}") items = kline_service.get_kline( symbol=symbol, period=period, start_date=start_date, end_date=end_date, limit=limit, ) logger.info(f"[API-查询K线] 返回 {len(items)} 条记录") return KlineResponse( symbol=symbol, period=period, total=len(items), items=[KlineItem(**item) for item in items], ) @app.post("/api/v1/kline/sync") async def sync_kline(req: KlineRequest): """从数据源同步K线数据""" logger.info(f"[API-同步K线] 请求参数: symbol={req.symbol}, period={req.period}, start_date={req.start_date}, end_date={req.end_date}") try: start = req.start_date or "2020-01-01" end = req.end_date or datetime.now().strftime("%Y-%m-%d") logger.info(f"[API-同步K线] 使用日期范围: {start} ~ {end}") if req.period == "daily": count = kline_service.sync_daily(req.symbol, start, end) elif req.period == "weekly": count = kline_service.sync_weekly(req.symbol, start, end) else: count = kline_service.sync_intraday(req.symbol, req.period, start, end) logger.info(f"[API-同步K线] 同步成功,共同步 {count} 条记录") return {"code": 0, "message": "同步成功", "data": {"synced": count}} except Exception as e: logger.error(f"[API-同步K线] 同步失败: {e}", exc_info=True) return {"code": 1, "message": f"同步失败: {str(e)}", "data": None} @app.post("/api/v1/kline/batch-sync", response_model=BatchSyncResult) async def batch_sync_kline(req: BatchSyncRequest): """批量同步K线数据""" logger.info(f"[API-批量同步K线] 请求参数: symbols={req.symbols}, period={req.period}, start_date={req.start_date}, end_date={req.end_date}") try: result = kline_service.batch_sync( symbols=req.symbols, period=req.period, start_date=req.start_date, end_date=req.end_date, ) logger.info(f"[API-批量同步K线] 同步完成: 成功={result['success']}, 失败={result['failed']}, 总记录={result['total_records']}") return BatchSyncResult(**result) except Exception as e: logger.error(f"[API-批量同步K线] 同步失败: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"批量同步失败: {str(e)}") # ========== 数据源管理接口 ========== @app.get("/api/v1/datasources") async def list_datasources(): """获取所有数据源状态""" sources = DataSourceManager.get_all_sources_status() return {"code": 0, "data": sources} @app.post("/api/v1/datasources") async def create_datasource(req: DataSourceCreate): """创建数据源配置""" from app.database import SessionLocal db = SessionLocal() try: existing = db.query(DataSourceConfig).filter( DataSourceConfig.source_name == req.source_name ).first() if existing: return {"code": 1, "message": "数据源已存在"} cfg = DataSourceConfig( source_name=req.source_name, display_name=req.display_name or req.source_name, is_enabled=False, config_json=req.config_json or {}, priority=req.priority, status="unknown", ) db.add(cfg) db.commit() return {"code": 0, "message": "创建成功", "data": {"id": cfg.id}} except Exception as e: db.rollback() return {"code": 1, "message": f"创建失败: {str(e)}"} finally: db.close() @app.put("/api/v1/datasources/{source_name}") async def update_datasource(source_name: str, req: DataSourceConfigUpdate): """更新数据源配置""" from app.database import SessionLocal db = SessionLocal() try: cfg = db.query(DataSourceConfig).filter( DataSourceConfig.source_name == source_name ).first() if not cfg: return {"code": 1, "message": "数据源不存在"} if req.is_enabled is not None: cfg.is_enabled = req.is_enabled if req.config_json is not None: import json cfg.config_json = json.dumps(req.config_json) if req.priority is not None: cfg.priority = req.priority db.commit() # 重新加载数据源 DataSourceManager.load_enabled_sources() return {"code": 0, "message": "更新成功"} except Exception as e: db.rollback() return {"code": 1, "message": f"更新失败: {str(e)}"} finally: db.close() @app.post("/api/v1/datasources/{source_name}/test") async def test_datasource(source_name: str): """测试数据源连接""" source = DataSourceManager.get_source(source_name) if not source: # 尝试创建临时实例测试 from app.database import SessionLocal import json db = SessionLocal() try: cfg = db.query(DataSourceConfig).filter( DataSourceConfig.source_name == source_name ).first() if not cfg: return {"code": 1, "message": "数据源不存在"} config = json.loads(cfg.config_json) if cfg.config_json else {} # 动态获取数据源类 source_class = DataSourceManager._source_map.get(source_name) if not source_class: return {"code": 1, "message": "不支持的数据源类型"} source = source_class(config) finally: db.close() ok, msg = source.health_check() if ok: # 更新状态 from app.database import SessionLocal db = SessionLocal() try: cfg = db.query(DataSourceConfig).filter( DataSourceConfig.source_name == source_name ).first() if cfg: cfg.status = "ok" cfg.error_msg = None db.commit() finally: db.close() return {"code": 0, "message": "连接成功", "data": {"status": "ok"}} else: from app.database import SessionLocal db = SessionLocal() try: cfg = db.query(DataSourceConfig).filter( DataSourceConfig.source_name == source_name ).first() if cfg: cfg.status = "error" cfg.error_msg = msg db.commit() finally: db.close() return {"code": 1, "message": f"连接失败: {msg}", "data": {"status": "error"}}