""" 股票 K 线服务 v2.2 支持 8 周期 (1m~1month)、复权计算 """ import logging from datetime import datetime, timedelta from typing import List, Optional, Dict, Any from enum import Enum from sqlalchemy.orm import Session from app.models.kline import ( Frequency, AdjustType, StockKLineItem, StockKLineData, StockKLineQuery, StockSymbolInfo ) from app.repositories.kline.stock_repository import StockKLineRepository from app.services.cache_service import cache_service logger = logging.getLogger(__name__) # 支持的股票 K 线周期 STOCK_FREQUENCIES = [ Frequency.FREQ_1M, Frequency.FREQ_5M, Frequency.FREQ_15M, Frequency.FREQ_30M, Frequency.FREQ_1H, Frequency.FREQ_1D, Frequency.FREQ_1W, Frequency.FREQ_1MONTH, ] class StockKLineService: """股票 K 线服务""" def __init__(self, db: Session): self.repository = StockKLineRepository(db) self.db = db async def query_klines( self, symbol: str, freq: Frequency, start: datetime, end: datetime, adjust: AdjustType = AdjustType.NONE, use_cache: bool = True ) -> StockKLineData: """ 查询股票 K 线数据 Args: symbol: 股票代码 (如 000001.SZ) freq: K 线周期 start: 开始时间 end: 结束时间 adjust: 复权类型 (qfq/hfq/none) use_cache: 是否使用缓存 Returns: StockKLineData: K 线数据响应 Raises: ValueError: 参数验证失败 """ # 参数验证 self._validate_params(symbol, freq, start, end) # 验证周期是否支持 if freq not in STOCK_FREQUENCIES: raise ValueError(f"不支持的股票 K 线周期: {freq}") # 查询数据库 items = self.repository.get_klines(symbol, freq, start, end) # 如果没有数据,尝试从适配器获取 if not items: logger.info(f"数据库无 {symbol} 数据,尝试从数据源获取") items = await self._fetch_from_adapter(symbol, freq, start, end) # 保存到数据库 if items: self._save_klines_to_db(symbol, freq, items) # 应用复权计算 if adjust != AdjustType.NONE: items = self._apply_adjustment(symbol, items, adjust) # 获取股票名称 symbol_info = self.repository.get_symbol_info(symbol) return StockKLineData( symbol=symbol, name=symbol_info.name if symbol_info else "", freq=freq, adjust=adjust, count=len(items), items=items ) async def query_klines_batch( self, symbols: List[str], freq: Frequency, start: datetime, end: datetime, adjust: AdjustType = AdjustType.NONE, max_symbols: int = 100 ) -> Dict[str, StockKLineData]: """ 批量查询股票 K 线数据 Args: symbols: 股票代码列表 (最多 100 个) freq: K 线周期 start: 开始时间 end: 结束时间 adjust: 复权类型 max_symbols: 最大股票数量限制 Returns: Dict[str, StockKLineData]: 各股票的 K 线数据 """ # 参数验证 if len(symbols) > max_symbols: raise ValueError(f"批量查询最多支持 {max_symbols} 只股票,当前: {len(symbols)}") results = {} for symbol in symbols: try: data = await self.query_klines(symbol, freq, start, end, adjust) results[symbol] = data except Exception as e: logger.error(f"查询 {symbol} K 线失败: {e}") results[symbol] = StockKLineData( symbol=symbol, name="", freq=freq, adjust=adjust, count=0, items=[], error=str(e) ) return results async def _fetch_from_adapter( self, symbol: str, freq: Frequency, start: datetime, end: datetime ) -> List[StockKLineItem]: """ 从数据适配器获取 K 线数据 注意: 已修复 asyncio.new_event_loop() 问题, 现在直接使用当前事件循环的异步操作 """ try: # 导入适配器服务 from app.services.amazing_data_service import amazing_data_service # 转换周期格式 freq_map = { Frequency.FREQ_1M: "1m", Frequency.FREQ_5M: "5m", Frequency.FREQ_15M: "15m", Frequency.FREQ_30M: "30m", Frequency.FREQ_1H: "60m", Frequency.FREQ_1D: "1d", Frequency.FREQ_1W: "1w", Frequency.FREQ_1MONTH: "1month", } period = freq_map.get(freq, "1d") # 直接使用异步调用 (修复: 不再创建新的事件循环) items = await amazing_data_service.get_kline_data_async( symbol=symbol, period=period, start_date=start, end_date=end ) # 转换为 StockKLineItem 列表 result = [] for item in items: kline_item = StockKLineItem( symbol=symbol, time=datetime.fromisoformat(item["time"]) if isinstance(item["time"], str) else item["time"], open=float(item["open"]), high=float(item["high"]), low=float(item["low"]), close=float(item["close"]), volume=int(item["volume"]), amount=float(item.get("amount", 0)), trade_date=datetime.fromisoformat(item["time"]).date() if isinstance(item["time"], str) else item["time"].date(), is_limit_up=item.get("is_limit_up", False), is_limit_down=item.get("is_limit_down", False), total_market_cap=item.get("total_market_cap"), float_market_cap=item.get("float_market_cap"), ) result.append(kline_item) logger.info(f"从适配器获取 {symbol} {freq} K 线 {len(result)} 条") return result except Exception as e: logger.error(f"从适配器获取数据失败: {e}") return [] def _save_klines_to_db( self, symbol: str, freq: Frequency, items: List[StockKLineItem] ) -> None: """保存 K 线数据到数据库""" if not items: return try: self.repository.save_klines(freq, items) self.db.commit() logger.info(f"保存 {symbol} {freq} K 线 {len(items)} 条到数据库") except Exception as e: self.db.rollback() logger.error(f"保存 K 线数据失败: {e}") def _apply_adjustment( self, symbol: str, items: List[StockKLineItem], adjust: AdjustType ) -> List[StockKLineItem]: """ 应用复权计算 Args: symbol: 股票代码 items: K 线数据列表 adjust: 复权类型 (qfq/hfq) Returns: List[StockKLineItem]: 复权后的 K 线数据 """ if not items: return items # 获取复权因子 factors = self.repository.get_adjust_factors(symbol) if not factors: logger.warning(f"股票 {symbol} 无复权因子数据,返回原始数据") return items # 应用复权 adjusted_items = [] for item in items: # 找到对应的复权因子 factor = self._find_adjust_factor(item.trade_date, factors) if factor: adjusted_item = StockKLineItem( symbol=item.symbol, time=item.time, open=self._adjust_price(item.open, factor, adjust), high=self._adjust_price(item.high, factor, adjust), low=self._adjust_price(item.low, factor, adjust), close=self._adjust_price(item.close, factor, adjust), volume=item.volume, amount=item.amount, trade_date=item.trade_date, is_limit_up=item.is_limit_up, is_limit_down=item.is_limit_down, total_market_cap=item.total_market_cap, float_market_cap=item.float_market_cap, ) adjusted_items.append(adjusted_item) else: adjusted_items.append(item) return adjusted_items def _find_adjust_factor( self, trade_date: datetime.date, factors: List[Dict] ) -> Optional[float]: """找到对应日期的复权因子""" for factor in factors: if factor["adj_date"] <= trade_date: return factor["adj_factor"] return None def _adjust_price( self, price: float, factor: float, adjust: AdjustType ) -> float: """ 计算复权价格 前复权 (qfq): price * factor 后复权 (hfq): price / factor """ if adjust == AdjustType.QFQ: return round(price * factor, 4) elif adjust == AdjustType.HFQ: return round(price / factor, 4) return price def _validate_params( self, symbol: str, freq: Frequency, start: datetime, end: datetime ) -> None: """验证查询参数""" if not symbol: raise ValueError("股票代码不能为空") if start > end: raise ValueError("开始时间不能晚于结束时间") # 限制查询范围 (最多 1 年) if (end - start).days > 365: raise ValueError("查询范围不能超过 1 年") # 导出服务实例工厂 def get_stock_kline_service(db: Session) -> StockKLineService: """获取股票 K 线服务实例""" return StockKLineService(db)