from fastapi import FastAPI, Depends, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from typing import Optional from datetime import datetime import logging # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) from app.config import settings from app.database import get_db, engine, Base from app.schemas import ( KlineRequest, KlineResponse, KlineItem, ContractInfo as ContractSchema, ContractListResponse, DataSourceConfigItem, DataSourceConfigUpdate, DataSourceCreate, ApiResponse, HealthResponse, DataSourceStatus, ) from app.services.kline_service import kline_service from app.services.contract_service import contract_service from app.services.datasource.manager import DataSourceManager from app.models import DataSourceConfig logger = logging.getLogger(__name__) app = FastAPI( title=settings.PROJECT_NAME, version=settings.VERSION, docs_url="/docs", redoc_url="/redoc", ) # CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ========== 启动事件 ========== @app.on_event("startup") async def startup(): # 创建数据库表 Base.metadata.create_all(bind=engine) # 加载数据源配置 DataSourceManager.load_enabled_sources() # 初始化默认数据源配置(如果不存在) _init_default_datasource() def _init_default_datasource(): """初始化默认的数据源配置(如果不存在)""" from app.database import SessionLocal db = SessionLocal() try: # 初始化 Tushare existing = db.query(DataSourceConfig).filter( DataSourceConfig.source_name == "tushare" ).first() if not existing: import json cfg = DataSourceConfig( source_name="tushare", display_name="Tushare", is_enabled=False, config_json=json.dumps({"token": ""}), priority=1, status="unknown", ) db.add(cfg) # 初始化 Akshare existing_ak = db.query(DataSourceConfig).filter( DataSourceConfig.source_name == "akshare" ).first() if not existing_ak: import json cfg_ak = DataSourceConfig( source_name="akshare", display_name="AKShare", is_enabled=False, config_json=json.dumps({"max_retries": 3}), priority=2, status="unknown", ) db.add(cfg_ak) db.commit() finally: db.close() # ========== 健康检查 ========== @app.get("/api/health", response_model=HealthResponse) async def health_check(): services = {} try: from sqlalchemy import text from app.database import engine with engine.connect() as conn: conn.execute(text("SELECT 1")) services["database"] = "ok" except Exception as e: services["database"] = f"error: {str(e)}" try: import redis r = redis.from_url(settings.REDIS_URL) r.ping() services["redis"] = "ok" except Exception as e: services["redis"] = "not configured" # Redis 非必须 status = "healthy" if all(v == "ok" for v in services.values()) else "degraded" return HealthResponse( status=status, services=services, version=settings.VERSION, ) # ========== 合约接口 ========== @app.get("/api/v1/contracts", response_model=ContractListResponse) async def list_contracts( exchange: Optional[str] = Query(None, description="交易所代码"), product: Optional[str] = Query(None, description="品种代码"), is_active: Optional[bool] = Query(None, description="是否活跃"), ): contracts = contract_service.get_contracts( exchange=exchange, product=product, is_active=is_active ) return ContractListResponse( total=len(contracts), items=[ContractSchema.model_validate(c) for c in contracts], ) @app.get("/api/v1/contracts/{symbol}", response_model=ContractSchema) async def get_contract(symbol: str): contract = contract_service.get_contract(symbol) if not contract: raise HTTPException(status_code=404, detail="合约不存在") return ContractSchema.model_validate(contract) @app.post("/api/v1/contracts/sync") async def sync_contracts(): """从数据源同步合约列表""" try: count = contract_service.sync_contracts() return {"code": 0, "message": "同步成功", "data": {"synced": count}} except Exception as e: return {"code": 1, "message": f"同步失败: {str(e)}", "data": None} # ========== K线接口 ========== @app.get("/api/v1/kline", response_model=KlineResponse) async def get_kline( symbol: str = Query(..., description="合约代码"), period: str = Query("daily", description="周期: daily/weekly/5m/15m/30m/60m"), start_date: Optional[str] = Query(None, description="开始日期 YYYY-MM-DD"), end_date: Optional[str] = Query(None, description="结束日期 YYYY-MM-DD"), limit: int = Query(500, ge=1, le=5000, description="返回条数"), ): logger.info(f"[API-查询K线] 请求参数: symbol={symbol}, period={period}, start_date={start_date}, end_date={end_date}, limit={limit}") items = kline_service.get_kline( symbol=symbol, period=period, start_date=start_date, end_date=end_date, limit=limit, ) logger.info(f"[API-查询K线] 返回 {len(items)} 条记录") return KlineResponse( symbol=symbol, period=period, total=len(items), items=[KlineItem(**item) for item in items], ) @app.post("/api/v1/kline/sync") async def sync_kline(req: KlineRequest): """从数据源同步K线数据""" logger.info(f"[API-同步K线] 请求参数: symbol={req.symbol}, period={req.period}, start_date={req.start_date}, end_date={req.end_date}") try: start = req.start_date or "2020-01-01" end = req.end_date or datetime.now().strftime("%Y-%m-%d") logger.info(f"[API-同步K线] 使用日期范围: {start} ~ {end}") if req.period == "daily": count = kline_service.sync_daily(req.symbol, start, end) elif req.period == "weekly": count = kline_service.sync_weekly(req.symbol, start, end) else: count = kline_service.sync_intraday(req.symbol, req.period, start, end) logger.info(f"[API-同步K线] 同步成功,共同步 {count} 条记录") return {"code": 0, "message": "同步成功", "data": {"synced": count}} except Exception as e: logger.error(f"[API-同步K线] 同步失败: {e}", exc_info=True) return {"code": 1, "message": f"同步失败: {str(e)}", "data": None} # ========== 数据源管理接口 ========== @app.get("/api/v1/datasources") async def list_datasources(): """获取所有数据源状态""" sources = DataSourceManager.get_all_sources_status() return {"code": 0, "data": sources} @app.post("/api/v1/datasources") async def create_datasource(req: DataSourceCreate): """创建数据源配置""" from app.database import SessionLocal db = SessionLocal() try: existing = db.query(DataSourceConfig).filter( DataSourceConfig.source_name == req.source_name ).first() if existing: return {"code": 1, "message": "数据源已存在"} cfg = DataSourceConfig( source_name=req.source_name, display_name=req.display_name or req.source_name, is_enabled=False, config_json=req.config_json or {}, priority=req.priority, status="unknown", ) db.add(cfg) db.commit() return {"code": 0, "message": "创建成功", "data": {"id": cfg.id}} except Exception as e: db.rollback() return {"code": 1, "message": f"创建失败: {str(e)}"} finally: db.close() @app.put("/api/v1/datasources/{source_name}") async def update_datasource(source_name: str, req: DataSourceConfigUpdate): """更新数据源配置""" from app.database import SessionLocal db = SessionLocal() try: cfg = db.query(DataSourceConfig).filter( DataSourceConfig.source_name == source_name ).first() if not cfg: return {"code": 1, "message": "数据源不存在"} if req.is_enabled is not None: cfg.is_enabled = req.is_enabled if req.config_json is not None: import json cfg.config_json = json.dumps(req.config_json) if req.priority is not None: cfg.priority = req.priority db.commit() # 重新加载数据源 DataSourceManager.load_enabled_sources() return {"code": 0, "message": "更新成功"} except Exception as e: db.rollback() return {"code": 1, "message": f"更新失败: {str(e)}"} finally: db.close() @app.post("/api/v1/datasources/{source_name}/test") async def test_datasource(source_name: str): """测试数据源连接""" source = DataSourceManager.get_source(source_name) if not source: # 尝试创建临时实例测试 from app.database import SessionLocal import json db = SessionLocal() try: cfg = db.query(DataSourceConfig).filter( DataSourceConfig.source_name == source_name ).first() if not cfg: return {"code": 1, "message": "数据源不存在"} config = json.loads(cfg.config_json) if cfg.config_json else {} # 动态获取数据源类 source_class = DataSourceManager._source_map.get(source_name) if not source_class: return {"code": 1, "message": "不支持的数据源类型"} source = source_class(config) finally: db.close() ok, msg = source.health_check() if ok: # 更新状态 from app.database import SessionLocal db = SessionLocal() try: cfg = db.query(DataSourceConfig).filter( DataSourceConfig.source_name == source_name ).first() if cfg: cfg.status = "ok" cfg.error_msg = None db.commit() finally: db.close() return {"code": 0, "message": "连接成功", "data": {"status": "ok"}} else: from app.database import SessionLocal db = SessionLocal() try: cfg = db.query(DataSourceConfig).filter( DataSourceConfig.source_name == source_name ).first() if cfg: cfg.status = "error" cfg.error_msg = msg db.commit() finally: db.close() return {"code": 1, "message": f"连接失败: {msg}", "data": {"status": "error"}}