"""股票业务服务 - 对应Go的internal/service/stock.go""" import asyncio from datetime import datetime, timedelta from typing import List from sqlalchemy.orm import Session from app.models import ( KLineQueryRequest, KLineData, SymbolListRequest, SymbolListData, BatchKLineRequest, BatchKLineData, BatchKLineResult, KLineSubData, TradingDatesRequest, TradingDatesData, AdjustType, Frequency, KLineItem ) from app.repositories import StockRepository from app.services.adapter_service import AdapterService from app.core.logger import error, info class StockService: """股票业务服务""" def __init__(self, db: Session): self.repository = StockRepository(db) self.db = db def query_klines(self, req: KLineQueryRequest) -> KLineData: """查询K线数据""" # 解析日期 try: start = datetime.strptime(req.start, "%Y%m%d") end = datetime.strptime(req.end, "%Y%m%d") end = end + timedelta(days=1) - timedelta(seconds=1) # 包含结束日期全天 except ValueError as e: raise ValueError(f"Invalid date format: {e}") # 获取K线数据(从数据库) items = self.repository.get_klines( req.symbol, req.freq, start, end, req.adjust ) # 如果数据库没有数据,尝试从适配器获取 if not items: info(f"No data in DB for {req.symbol}, fetching from adapter...") items = self._fetch_from_adapter(req.symbol, req.start, req.end, req.freq) # 保存到数据库 if items: self._save_klines_to_db(req.symbol, req.freq, items) # 处理复权(简化实现,实际需要复权系数表) if req.adjust != AdjustType.NONE: items = self._apply_adjust(req.symbol, items, req.adjust) return KLineData( symbol=req.symbol, freq=req.freq, adjust=req.adjust, count=len(items), items=items ) def _fetch_from_adapter(self, symbol: str, start: str, end: str, freq: Frequency) -> List[KLineItem]: """从适配器获取K线数据""" try: # 获取适配器服务 adapter_service = AdapterService() # 确保适配器已连接 adapter = adapter_service.get_active_adapter("stock") if not adapter: # 从配置获取当前激活的适配器名称 from app.core.config import get_config config = get_config() active_source = config.sources.stock.active # 尝试连接配置的适配器 info(f"Connecting to configured adapter: {active_source}") loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(adapter_service._connect_adapter(active_source)) loop.close() adapter = adapter_service.get_active_adapter("stock") if not adapter: error("No active adapter available") return [] # 转换频率格式 freq_str = self._convert_freq_to_str(freq) # 异步获取数据 loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) klines = loop.run_until_complete( adapter.fetch_klines(symbol, start, end, freq_str) ) loop.close() # 转换为KLineItem items = [] for k in klines: items.append(KLineItem( symbol=symbol, time=datetime.fromtimestamp(k.time), open=k.open, high=k.high, low=k.low, close=k.close, volume=k.volume, amount=k.amount, trade_date=getattr(k, 'trade_date', None), is_limit_up=getattr(k, 'is_limit_up', None), is_limit_down=getattr(k, 'is_limit_down', None), total_market_cap=getattr(k, 'total_market_cap', None), float_market_cap=getattr(k, 'float_market_cap', None), inst_holding_ratio=getattr(k, 'inst_holding_ratio', None), trading_days=getattr(k, 'trading_days', None), created_at=datetime.now() )) info(f"Fetched {len(items)} klines from adapter for {symbol}") return items except Exception as e: error(f"Failed to fetch from adapter: {e}") return [] def _convert_freq_to_str(self, freq: Frequency) -> str: """转换频率枚举为字符串""" mapping = { Frequency.FREQ_1M: "1m", Frequency.FREQ_5M: "5m", Frequency.FREQ_15M: "15m", Frequency.FREQ_30M: "30m", Frequency.FREQ_60M: "60m", Frequency.FREQ_1D: "1d", Frequency.FREQ_1W: "1w", Frequency.FREQ_1MONTH: "1month", } return mapping.get(freq, "1d") def _save_klines_to_db(self, symbol: str, freq: Frequency, items: List[KLineItem]) -> None: """保存K线数据到数据库""" try: # 添加symbol属性 for item in items: item.symbol = symbol self.repository.save_klines(freq, items) info(f"Saved {len(items)} klines to DB for {symbol}") except Exception as e: error(f"Failed to save klines to DB: {e}") def _apply_adjust( self, symbol: str, items: List[KLineItem], adjust_type: AdjustType ) -> List[KLineItem]: """应用复权计算 复权原理: - 前复权(qfq): 以最新价格为基准,将历史价格按比例缩小 - 后复权(hfq): 以历史最早价格为基准,将后续价格按比例放大 """ if not items or adjust_type == AdjustType.NONE: return items try: # 获取日期范围 start_date = items[0].time.strftime("%Y%m%d") end_date = items[-1].time.strftime("%Y%m%d") # 从数据库获取复权系数 factors = self.repository.get_adjust_factors(symbol, start_date, end_date) # 如果没有复权系数,尝试从适配器获取 if not factors: factors = self._fetch_adjust_factors_from_adapter(symbol, start_date, end_date) if factors: self.repository.save_adjust_factors(symbol, factors) # 将复权系数转换为字典,方便查找 factor_map = {f["trade_date"]: f for f in factors} # 应用复权 adjusted_items = [] for item in items: # 获取交易日期 trade_date = getattr(item, 'trade_date', None) if not trade_date and hasattr(item, 'time'): trade_date = item.time.strftime("%Y-%m-%d") factor = factor_map.get(trade_date, {"qfq_factor": 1.0, "hfq_factor": 1.0}) # 根据复权类型选择系数 if adjust_type == AdjustType.QFQ: adj_factor = factor.get("qfq_factor", 1.0) else: # HFQ adj_factor = factor.get("hfq_factor", 1.0) # 应用复权系数到价格字段 adjusted_item = KLineItem( symbol=item.symbol, time=item.time, open=round(item.open * adj_factor, 4), high=round(item.high * adj_factor, 4), low=round(item.low * adj_factor, 4), close=round(item.close * adj_factor, 4), volume=item.volume, amount=round(item.amount * adj_factor, 4) if item.amount else item.amount, trade_date=getattr(item, 'trade_date', None), is_limit_up=getattr(item, 'is_limit_up', None), is_limit_down=getattr(item, 'is_limit_down', None), total_market_cap=getattr(item, 'total_market_cap', None), float_market_cap=getattr(item, 'float_market_cap', None), inst_holding_ratio=getattr(item, 'inst_holding_ratio', None), trading_days=getattr(item, 'trading_days', None), adj_factor=adj_factor ) adjusted_items.append(adjusted_item) return adjusted_items except Exception as e: error(f"Failed to apply adjust factor for {symbol}: {e}") # 出错时返回原始数据 return items def _fetch_adjust_factors_from_adapter( self, symbol: str, start_date: str, end_date: str ) -> List[dict]: """从适配器获取复权系数""" try: adapter_service = AdapterService() adapter = adapter_service.get_active_adapter("stock") if not adapter: error("No active adapter available for fetching adjust factors") return [] # 检查适配器是否支持获取复权因子 if not hasattr(adapter, 'get_adj_factor'): return [] # 异步获取前复权因子 loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: qfq_df = loop.run_until_complete( adapter.get_adj_factor([symbol]) ) hfq_df = loop.run_until_complete( adapter.get_backward_factor([symbol]) ) finally: loop.close() # 转换DataFrame为列表 factors = [] # 处理日期格式 for idx in qfq_df.index: date_obj = idx if hasattr(idx, 'strftime') else datetime.strptime(str(idx), "%Y%m%d") date_str = date_obj.strftime("%Y-%m-%d") date_key = date_obj.strftime("%Y%m%d") # 只保留指定范围内的数据 if not (start_date <= date_key <= end_date): continue qfq_factor = float(qfq_df.loc[idx, symbol]) if symbol in qfq_df.columns else 1.0 hfq_factor = float(hfq_df.loc[idx, symbol]) if symbol in hfq_df.columns else 1.0 # 确保复权系数有效 if qfq_factor <= 0 or qfq_factor != qfq_factor: # 检查NaN qfq_factor = 1.0 if hfq_factor <= 0 or hfq_factor != hfq_factor: hfq_factor = 1.0 factors.append({ "trade_date": date_str, "qfq_factor": qfq_factor, "hfq_factor": hfq_factor }) info(f"Fetched {len(factors)} adjust factors from adapter for {symbol}") return factors except Exception as e: error(f"Failed to fetch adjust factors from adapter: {e}") return [] def list_symbols(self, req: SymbolListRequest) -> SymbolListData: """查询标的列表""" # 设置默认值 if req.page <= 0: req.page = 1 if req.size <= 0: req.size = 20 if req.size > 100: req.size = 100 symbols, total = self.repository.list_symbols(req) # 如果数据库没有数据,尝试从适配器获取 if not symbols: info("No symbols in DB, fetching from adapter...") symbols = self._fetch_symbols_from_adapter() if symbols: # 保存到数据库 self._save_symbols_to_db(symbols) # 重新查询 symbols, total = self.repository.list_symbols(req) return SymbolListData( total=total, page=req.page, size=req.size, items=symbols ) def _fetch_symbols_from_adapter(self) -> List: """从适配器获取股票列表""" try: adapter_service = AdapterService() # 确保适配器已连接 adapter = adapter_service.get_active_adapter("stock") if not adapter: # 从配置获取当前激活的适配器名称 from app.core.config import get_config config = get_config() active_source = config.sources.stock.active info(f"Connecting to configured adapter: {active_source}") asyncio.run(adapter_service._connect_adapter(active_source)) adapter = adapter_service.get_active_adapter("stock") if not adapter: error("No active adapter available") return [] # 异步获取数据 loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) symbols_info = loop.run_until_complete( adapter.fetch_symbols("stock") ) loop.close() # 转换为Symbol模型 from app.models import Symbol, SymbolType, Exchange symbols = [] for s in symbols_info: # 将字符串 exchange 转换为 Exchange 枚举 try: exchange_enum = Exchange(s.exchange) except ValueError: exchange_enum = Exchange.SH # 默认上海 symbols.append(Symbol( symbol_id=s.symbol_id, symbol_type=SymbolType.STOCK, exchange=exchange_enum, name=s.name, underlying=s.underlying, status="active" # 添加必需的 status 字段 )) info(f"Fetched {len(symbols)} symbols from adapter") return symbols except Exception as e: error(f"Failed to fetch symbols from adapter: {e}") return [] def _save_symbols_to_db(self, symbols: List) -> None: """保存股票列表到数据库""" try: from app.models import TradeCalData self.repository.save_symbols(symbols) info(f"Saved {len(symbols)} symbols to DB") except Exception as e: error(f"Failed to save symbols to DB: {e}") def batch_query_klines(self, req: BatchKLineRequest) -> BatchKLineData: """批量查询K线""" results = [] for symbol in req.symbols: single_req = KLineQueryRequest( symbol=symbol, start=req.start, end=req.end, freq=req.freq, adjust=req.adjust ) try: data = self.query_klines(single_req) results.append(BatchKLineResult( symbol=symbol, success=True, data=KLineSubData(count=data.count, items=data.items) )) except Exception as e: error(f"Batch query failed for {symbol}: {e}") results.append(BatchKLineResult( symbol=symbol, success=False, error=str(e) )) return BatchKLineData(results=results) def get_trading_dates(self, req: TradingDatesRequest) -> TradingDatesData: """获取交易日历""" # 从数据库获取 data = self.repository.get_trading_dates(req.start, req.end) # 如果数据库没有数据,从适配器获取 if not data.trading_dates: info(f"No trading dates in DB for {req.start}~{req.end}, fetching from adapter...") adapter_dates = self._fetch_trading_dates_from_adapter(req.start, req.end) if adapter_dates: # 保存到数据库 self._save_trading_dates_to_db(adapter_dates) # 重新查询 data = self.repository.get_trading_dates(req.start, req.end) return data def _fetch_trading_dates_from_adapter(self, start: str, end: str) -> List[str]: """从适配器获取交易日历""" try: adapter_service = AdapterService() # 确保适配器已连接 adapter = adapter_service.get_active_adapter("stock") if not adapter: # 从配置获取当前激活的适配器名称 from app.core.config import get_config config = get_config() active_source = config.sources.stock.active info(f"Connecting to configured adapter: {active_source}") asyncio.run(adapter_service._connect_adapter(active_source)) adapter = adapter_service.get_active_adapter("stock") if not adapter: error("No active adapter available") return [] # 异步获取数据 loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) calendar_data = loop.run_until_complete( adapter.fetch_trading_calendar("SH", start, end) ) loop.close() # 提取交易日日期 dates = [] for cal in calendar_data: if cal.is_trading_day: # 转换为 YYYYMMDD 格式 date_str = cal.date.strftime("%Y%m%d") dates.append(date_str) info(f"Fetched {len(dates)} trading dates from adapter") return dates except Exception as e: error(f"Failed to fetch trading dates from adapter: {e}") return [] def _save_trading_dates_to_db(self, dates: List[str]) -> None: """保存交易日历到数据库""" try: from app.repositories.models import StockTradingCalendar from datetime import datetime as dt for date_str in dates: # 检查是否已存在 existing = self.db.query(StockTradingCalendar).filter( StockTradingCalendar.trade_date == date_str ).first() if not existing: # 解析日期获取星期几 date_obj = dt.strptime(date_str, "%Y%m%d") week_day = date_obj.weekday() + 1 # 1=周一, 7=周日 new_record = StockTradingCalendar( trade_date=date_str, is_trading_day=True, week_day=week_day ) self.db.add(new_record) self.db.commit() info(f"Saved {len(dates)} trading dates to DB") except Exception as e: error(f"Failed to save trading dates to DB: {e}") self.db.rollback()