""" 基础数据服务 """ from typing import List, Optional, Dict from datetime import date from sqlalchemy.orm import Session import pandas as pd from app.services.sdk_manager import sdk_manager from app.models.config import SDKConfig class BaseDataService: """基础数据服务""" def __init__(self, db: Session): self.db = db def _get_adapter(self): """获取SDK适配器(使用连接管理器)""" return sdk_manager.get_default_connection() def get_code_list(self, security_type: str) -> List[str]: """ 获取代码列表 Args: security_type: 证券类型 - EXTRA_STOCK_A: 沪深A股 - EXTRA_FUTURE: 期货 - EXTRA_ETF: ETF - EXTRA_INDEX_A: 指数 Returns: 代码列表 """ adapter = self._get_adapter() if not adapter: raise RuntimeError("SDK连接失败,请先测试连接") return adapter.get_code_list(security_type) def get_code_info(self, security_type: str) -> pd.DataFrame: """获取证券信息""" adapter = self._get_adapter() if not adapter: raise RuntimeError("SDK连接失败,请先测试连接") return adapter.get_code_info(security_type) def get_trading_calendar( self, market: str, start_date: date = None, end_date: date = None ) -> List[date]: """ 获取交易日历 Args: market: 市场代码 (SH, SZ, CFE) start_date: 开始日期 end_date: 结束日期 Returns: 交易日列表 """ adapter = self._get_adapter() if not adapter: raise RuntimeError("SDK连接失败,请先测试连接") calendar_ints = adapter.get_trading_calendar(market) # 转换为date对象 from app.utils.date_utils import int_to_date dates = [int_to_date(d) for d in calendar_ints] # 过滤日期范围 if start_date: dates = [d for d in dates if d >= start_date] if end_date: dates = [d for d in dates if d <= end_date] return dates def get_adj_factor(self, codes: List[str]) -> pd.DataFrame: """获取复权因子""" adapter = self._get_adapter() if not adapter: raise RuntimeError("SDK连接失败,请先测试连接") return adapter.get_adj_factor(codes) def get_backward_factor(self, codes: List[str]) -> pd.DataFrame: """获取后复权因子""" adapter = self._get_adapter() if not adapter: raise RuntimeError("SDK连接失败,请先测试连接") return adapter.get_backward_factor(codes) def get_security_type(self, code: str) -> str: """ 根据代码判断证券类型 Args: code: 证券代码 Returns: 证券类型 (stock, future, index, etf, unknown) """ if code.endswith(".CFE"): return "future" elif code.endswith((".SH", ".SZ", ".BJ")): # 根据代码前缀判断 prefix = code[:3] if prefix in ["000", "001", "002", "003", "300", "301", "600", "601", "603", "605", "688"]: return "stock" elif prefix in ["510", "511", "512", "513", "515", "516", "518", "560", "563", "588"]: return "etf" elif prefix in ["000", "880"]: return "index" return "unknown"