"""股票业务服务 - 对应Go的internal/service/stock.go""" import asyncio from datetime import datetime, timedelta from typing import List from sqlalchemy.orm import Session from app.models import ( KLineQueryRequest, KLineData, SymbolListRequest, SymbolListData, BatchKLineRequest, BatchKLineData, BatchKLineResult, KLineSubData, TradingDatesRequest, TradingDatesData, AdjustType, Frequency, KLineItem ) from app.repositories import StockRepository from app.services.adapter_service import AdapterService from app.core.logger import error, info class StockService: """股票业务服务""" def __init__(self, db: Session): self.repository = StockRepository(db) self.db = db def query_klines(self, req: KLineQueryRequest) -> KLineData: """查询K线数据""" # 解析日期 try: start = datetime.strptime(req.start, "%Y%m%d") end = datetime.strptime(req.end, "%Y%m%d") end = end + timedelta(days=1) - timedelta(seconds=1) # 包含结束日期全天 except ValueError as e: raise ValueError(f"Invalid date format: {e}") # 获取K线数据(从数据库) items = self.repository.get_klines( req.symbol, req.freq, start, end, req.adjust ) # 如果数据库没有数据,尝试从适配器获取 if not items: info(f"No data in DB for {req.symbol}, fetching from adapter...") items = self._fetch_from_adapter(req.symbol, req.start, req.end, req.freq) # 保存到数据库 if items: self._save_klines_to_db(req.symbol, req.freq, items) # 处理复权(简化实现,实际需要复权系数表) if req.adjust != AdjustType.NONE: items = self._apply_adjust(req.symbol, items, req.adjust) return KLineData( symbol=req.symbol, freq=req.freq, adjust=req.adjust, count=len(items), items=items ) def _fetch_from_adapter(self, symbol: str, start: str, end: str, freq: Frequency) -> List[KLineItem]: """从适配器获取K线数据""" try: # 获取适配器服务 adapter_service = AdapterService() # 确保适配器已连接 adapter = adapter_service.get_active_adapter("stock") if not adapter: # 尝试连接 amazingdata loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(adapter_service._connect_adapter("amazingdata")) loop.close() adapter = adapter_service.get_active_adapter("stock") if not adapter: error("No active adapter available") return [] # 转换频率格式 freq_str = self._convert_freq_to_str(freq) # 异步获取数据 loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) klines = loop.run_until_complete( adapter.fetch_klines(symbol, start, end, freq_str) ) loop.close() # 转换为KLineItem items = [] for k in klines: items.append(KLineItem( symbol=symbol, time=datetime.fromtimestamp(k.time), open=k.open, high=k.high, low=k.low, close=k.close, volume=k.volume, amount=k.amount, trade_date=getattr(k, 'trade_date', None), is_limit_up=getattr(k, 'is_limit_up', None), is_limit_down=getattr(k, 'is_limit_down', None), total_market_cap=getattr(k, 'total_market_cap', None), float_market_cap=getattr(k, 'float_market_cap', None), inst_holding_ratio=getattr(k, 'inst_holding_ratio', None), trading_days=getattr(k, 'trading_days', None), created_at=datetime.now() )) info(f"Fetched {len(items)} klines from adapter for {symbol}") return items except Exception as e: error(f"Failed to fetch from adapter: {e}") return [] def _convert_freq_to_str(self, freq: Frequency) -> str: """转换频率枚举为字符串""" mapping = { Frequency.FREQ_1M: "1m", Frequency.FREQ_5M: "5m", Frequency.FREQ_15M: "15m", Frequency.FREQ_30M: "30m", Frequency.FREQ_60M: "60m", Frequency.FREQ_1D: "1d", Frequency.FREQ_1W: "1w", Frequency.FREQ_1MONTH: "1month", } return mapping.get(freq, "1d") def _save_klines_to_db(self, symbol: str, freq: Frequency, items: List[KLineItem]) -> None: """保存K线数据到数据库""" try: # 添加symbol属性 for item in items: item.symbol = symbol self.repository.save_klines(freq, items) info(f"Saved {len(items)} klines to DB for {symbol}") except Exception as e: error(f"Failed to save klines to DB: {e}") def _apply_adjust( self, symbol: str, items: List, adjust_type: AdjustType ) -> List: """应用复权计算(TODO: 实现复权逻辑)""" # 复权计算需要从数据库获取复权系数 # 这里简化处理,直接返回原始数据 return items def list_symbols(self, req: SymbolListRequest) -> SymbolListData: """查询标的列表""" # 设置默认值 if req.page <= 0: req.page = 1 if req.size <= 0: req.size = 20 if req.size > 100: req.size = 100 symbols, total = self.repository.list_symbols(req) # 如果数据库没有数据,尝试从适配器获取 if not symbols: info("No symbols in DB, fetching from adapter...") symbols = self._fetch_symbols_from_adapter() if symbols: # 保存到数据库 self._save_symbols_to_db(symbols) # 重新查询 symbols, total = self.repository.list_symbols(req) return SymbolListData( total=total, page=req.page, size=req.size, items=symbols ) def _fetch_symbols_from_adapter(self) -> List: """从适配器获取股票列表""" try: adapter_service = AdapterService() # 确保适配器已连接 adapter = adapter_service.get_active_adapter("stock") if not adapter: asyncio.run(adapter_service._connect_adapter("amazingdata")) adapter = adapter_service.get_active_adapter("stock") if not adapter: error("No active adapter available") return [] # 异步获取数据 loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) symbols_info = loop.run_until_complete( adapter.fetch_symbols("stock") ) loop.close() # 转换为Symbol模型 from app.models import Symbol, SymbolType, Exchange symbols = [] for s in symbols_info: # 将字符串 exchange 转换为 Exchange 枚举 try: exchange_enum = Exchange(s.exchange) except ValueError: exchange_enum = Exchange.SH # 默认上海 symbols.append(Symbol( symbol_id=s.symbol_id, symbol_type=SymbolType.STOCK, exchange=exchange_enum, name=s.name, underlying=s.underlying, status="active" # 添加必需的 status 字段 )) info(f"Fetched {len(symbols)} symbols from adapter") return symbols except Exception as e: error(f"Failed to fetch symbols from adapter: {e}") return [] def _save_symbols_to_db(self, symbols: List) -> None: """保存股票列表到数据库""" try: from app.models import TradeCalData self.repository.save_symbols(symbols) info(f"Saved {len(symbols)} symbols to DB") except Exception as e: error(f"Failed to save symbols to DB: {e}") def batch_query_klines(self, req: BatchKLineRequest) -> BatchKLineData: """批量查询K线""" results = [] for symbol in req.symbols: single_req = KLineQueryRequest( symbol=symbol, start=req.start, end=req.end, freq=req.freq, adjust=req.adjust ) try: data = self.query_klines(single_req) results.append(BatchKLineResult( symbol=symbol, success=True, data=KLineSubData(count=data.count, items=data.items) )) except Exception as e: error(f"Batch query failed for {symbol}: {e}") results.append(BatchKLineResult( symbol=symbol, success=False, error=str(e) )) return BatchKLineData(results=results) def get_trading_dates(self, req: TradingDatesRequest) -> TradingDatesData: """获取交易日历""" # 从数据库获取 data = self.repository.get_trading_dates(req.start, req.end) # 如果数据库没有数据,从适配器获取 if not data.trading_dates: info(f"No trading dates in DB for {req.start}~{req.end}, fetching from adapter...") adapter_dates = self._fetch_trading_dates_from_adapter(req.start, req.end) if adapter_dates: # 保存到数据库 self._save_trading_dates_to_db(adapter_dates) # 重新查询 data = self.repository.get_trading_dates(req.start, req.end) return data def _fetch_trading_dates_from_adapter(self, start: str, end: str) -> List[str]: """从适配器获取交易日历""" try: adapter_service = AdapterService() # 确保适配器已连接 adapter = adapter_service.get_active_adapter("stock") if not adapter: asyncio.run(adapter_service._connect_adapter("amazingdata")) adapter = adapter_service.get_active_adapter("stock") if not adapter: error("No active adapter available") return [] # 异步获取数据 loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) calendar_data = loop.run_until_complete( adapter.fetch_trading_calendar("SH", start, end) ) loop.close() # 提取交易日日期 dates = [] for cal in calendar_data: if cal.is_trading_day: # 转换为 YYYYMMDD 格式 date_str = cal.date.strftime("%Y%m%d") dates.append(date_str) info(f"Fetched {len(dates)} trading dates from adapter") return dates except Exception as e: error(f"Failed to fetch trading dates from adapter: {e}") return [] def _save_trading_dates_to_db(self, dates: List[str]) -> None: """保存交易日历到数据库""" try: from app.repositories.models import StockTradingCalendar from datetime import datetime as dt for date_str in dates: # 检查是否已存在 existing = self.db.query(StockTradingCalendar).filter( StockTradingCalendar.trade_date == date_str ).first() if not existing: # 解析日期获取星期几 date_obj = dt.strptime(date_str, "%Y%m%d") week_day = date_obj.weekday() + 1 # 1=周一, 7=周日 new_record = StockTradingCalendar( trade_date=date_str, is_trading_day=True, week_day=week_day ) self.db.add(new_record) self.db.commit() info(f"Saved {len(dates)} trading dates to DB") except Exception as e: error(f"Failed to save trading dates to DB: {e}") self.db.rollback()