You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

338 lines
11 KiB

"""
股票 K 线服务 v2.2
支持 8 周期 (1m~1month)复权计算
"""
import logging
from datetime import datetime, timedelta
from typing import List, Optional, Dict, Any
from enum import Enum
from sqlalchemy.orm import Session
from app.models.kline import (
Frequency, AdjustType, StockKLineItem, StockKLineData,
StockKLineQuery, StockSymbolInfo
)
from app.repositories.kline.stock_repository import StockKLineRepository
from app.services.cache_service import cache_service
logger = logging.getLogger(__name__)
# 支持的股票 K 线周期
STOCK_FREQUENCIES = [
Frequency.FREQ_1M,
Frequency.FREQ_5M,
Frequency.FREQ_15M,
Frequency.FREQ_30M,
Frequency.FREQ_1H,
Frequency.FREQ_1D,
Frequency.FREQ_1W,
Frequency.FREQ_1MONTH,
]
class StockKLineService:
"""股票 K 线服务"""
def __init__(self, db: Session):
self.repository = StockKLineRepository(db)
self.db = db
async def query_klines(
self,
symbol: str,
freq: Frequency,
start: datetime,
end: datetime,
adjust: AdjustType = AdjustType.NONE,
use_cache: bool = True
) -> StockKLineData:
"""
查询股票 K 线数据
Args:
symbol: 股票代码 ( 000001.SZ)
freq: K 线周期
start: 开始时间
end: 结束时间
adjust: 复权类型 (qfq/hfq/none)
use_cache: 是否使用缓存
Returns:
StockKLineData: K 线数据响应
Raises:
ValueError: 参数验证失败
"""
# 参数验证
self._validate_params(symbol, freq, start, end)
# 验证周期是否支持
if freq not in STOCK_FREQUENCIES:
raise ValueError(f"不支持的股票 K 线周期: {freq}")
# 查询数据库
items = self.repository.get_klines(symbol, freq, start, end)
# 如果没有数据,尝试从适配器获取
if not items:
logger.info(f"数据库无 {symbol} 数据,尝试从数据源获取")
items = await self._fetch_from_adapter(symbol, freq, start, end)
# 保存到数据库
if items:
self._save_klines_to_db(symbol, freq, items)
# 应用复权计算
if adjust != AdjustType.NONE:
items = self._apply_adjustment(symbol, items, adjust)
# 获取股票名称
symbol_info = self.repository.get_symbol_info(symbol)
return StockKLineData(
symbol=symbol,
name=symbol_info.name if symbol_info else "",
freq=freq,
adjust=adjust,
count=len(items),
items=items
)
async def query_klines_batch(
self,
symbols: List[str],
freq: Frequency,
start: datetime,
end: datetime,
adjust: AdjustType = AdjustType.NONE,
max_symbols: int = 100
) -> Dict[str, StockKLineData]:
"""
批量查询股票 K 线数据
Args:
symbols: 股票代码列表 (最多 100 )
freq: K 线周期
start: 开始时间
end: 结束时间
adjust: 复权类型
max_symbols: 最大股票数量限制
Returns:
Dict[str, StockKLineData]: 各股票的 K 线数据
"""
# 参数验证
if len(symbols) > max_symbols:
raise ValueError(f"批量查询最多支持 {max_symbols} 只股票,当前: {len(symbols)}")
results = {}
for symbol in symbols:
try:
data = await self.query_klines(symbol, freq, start, end, adjust)
results[symbol] = data
except Exception as e:
logger.error(f"查询 {symbol} K 线失败: {e}")
results[symbol] = StockKLineData(
symbol=symbol,
name="",
freq=freq,
adjust=adjust,
count=0,
items=[],
error=str(e)
)
return results
async def _fetch_from_adapter(
self,
symbol: str,
freq: Frequency,
start: datetime,
end: datetime
) -> List[StockKLineItem]:
"""
从数据适配器获取 K 线数据
注意: 已修复 asyncio.new_event_loop() 问题
现在直接使用当前事件循环的异步操作
"""
try:
# 导入适配器服务
from app.services.amazing_data_service import amazing_data_service
# 转换周期格式
freq_map = {
Frequency.FREQ_1M: "1m",
Frequency.FREQ_5M: "5m",
Frequency.FREQ_15M: "15m",
Frequency.FREQ_30M: "30m",
Frequency.FREQ_1H: "60m",
Frequency.FREQ_1D: "1d",
Frequency.FREQ_1W: "1w",
Frequency.FREQ_1MONTH: "1month",
}
period = freq_map.get(freq, "1d")
# 直接使用异步调用 (修复: 不再创建新的事件循环)
items = await amazing_data_service.get_kline_data_async(
symbol=symbol,
period=period,
start_date=start,
end_date=end
)
# 转换为 StockKLineItem 列表
result = []
for item in items:
kline_item = StockKLineItem(
symbol=symbol,
time=datetime.fromisoformat(item["time"]) if isinstance(item["time"], str) else item["time"],
open=float(item["open"]),
high=float(item["high"]),
low=float(item["low"]),
close=float(item["close"]),
volume=int(item["volume"]),
amount=float(item.get("amount", 0)),
trade_date=datetime.fromisoformat(item["time"]).date() if isinstance(item["time"], str) else item["time"].date(),
is_limit_up=item.get("is_limit_up", False),
is_limit_down=item.get("is_limit_down", False),
total_market_cap=item.get("total_market_cap"),
float_market_cap=item.get("float_market_cap"),
)
result.append(kline_item)
logger.info(f"从适配器获取 {symbol} {freq} K 线 {len(result)}")
return result
except Exception as e:
logger.error(f"从适配器获取数据失败: {e}")
return []
def _save_klines_to_db(
self,
symbol: str,
freq: Frequency,
items: List[StockKLineItem]
) -> None:
"""保存 K 线数据到数据库"""
if not items:
return
try:
self.repository.save_klines(freq, items)
self.db.commit()
logger.info(f"保存 {symbol} {freq} K 线 {len(items)} 条到数据库")
except Exception as e:
self.db.rollback()
logger.error(f"保存 K 线数据失败: {e}")
def _apply_adjustment(
self,
symbol: str,
items: List[StockKLineItem],
adjust: AdjustType
) -> List[StockKLineItem]:
"""
应用复权计算
Args:
symbol: 股票代码
items: K 线数据列表
adjust: 复权类型 (qfq/hfq)
Returns:
List[StockKLineItem]: 复权后的 K 线数据
"""
if not items:
return items
# 获取复权因子
factors = self.repository.get_adjust_factors(symbol)
if not factors:
logger.warning(f"股票 {symbol} 无复权因子数据,返回原始数据")
return items
# 应用复权
adjusted_items = []
for item in items:
# 找到对应的复权因子
factor = self._find_adjust_factor(item.trade_date, factors)
if factor:
adjusted_item = StockKLineItem(
symbol=item.symbol,
time=item.time,
open=self._adjust_price(item.open, factor, adjust),
high=self._adjust_price(item.high, factor, adjust),
low=self._adjust_price(item.low, factor, adjust),
close=self._adjust_price(item.close, factor, adjust),
volume=item.volume,
amount=item.amount,
trade_date=item.trade_date,
is_limit_up=item.is_limit_up,
is_limit_down=item.is_limit_down,
total_market_cap=item.total_market_cap,
float_market_cap=item.float_market_cap,
)
adjusted_items.append(adjusted_item)
else:
adjusted_items.append(item)
return adjusted_items
def _find_adjust_factor(
self,
trade_date: datetime.date,
factors: List[Dict]
) -> Optional[float]:
"""找到对应日期的复权因子"""
for factor in factors:
if factor["adj_date"] <= trade_date:
return factor["adj_factor"]
return None
def _adjust_price(
self,
price: float,
factor: float,
adjust: AdjustType
) -> float:
"""
计算复权价格
前复权 (qfq): price * factor
后复权 (hfq): price / factor
"""
if adjust == AdjustType.QFQ:
return round(price * factor, 4)
elif adjust == AdjustType.HFQ:
return round(price / factor, 4)
return price
def _validate_params(
self,
symbol: str,
freq: Frequency,
start: datetime,
end: datetime
) -> None:
"""验证查询参数"""
if not symbol:
raise ValueError("股票代码不能为空")
if start > end:
raise ValueError("开始时间不能晚于结束时间")
# 限制查询范围 (最多 1 年)
if (end - start).days > 365:
raise ValueError("查询范围不能超过 1 年")
# 导出服务实例工厂
def get_stock_kline_service(db: Session) -> StockKLineService:
"""获取股票 K 线服务实例"""
return StockKLineService(db)