You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

513 lines
20 KiB

"""股票业务服务 - 对应Go的internal/service/stock.go"""
import asyncio
from datetime import datetime, timedelta
from typing import List
from sqlalchemy.orm import Session
from app.models import (
KLineQueryRequest, KLineData, SymbolListRequest, SymbolListData,
BatchKLineRequest, BatchKLineData, BatchKLineResult, KLineSubData,
TradingDatesRequest, TradingDatesData, AdjustType, Frequency, KLineItem
)
from app.repositories import StockRepository
from app.services.adapter_service import AdapterService
from app.core.logger import error, info
class StockService:
"""股票业务服务"""
def __init__(self, db: Session):
self.repository = StockRepository(db)
self.db = db
def query_klines(self, req: KLineQueryRequest) -> KLineData:
"""查询K线数据"""
# 解析日期
try:
start = datetime.strptime(req.start, "%Y%m%d")
end = datetime.strptime(req.end, "%Y%m%d")
end = end + timedelta(days=1) - timedelta(seconds=1) # 包含结束日期全天
except ValueError as e:
raise ValueError(f"Invalid date format: {e}")
# 获取K线数据从数据库
items = self.repository.get_klines(
req.symbol,
req.freq,
start,
end,
req.adjust
)
# 如果数据库没有数据,尝试从适配器获取
if not items:
info(f"No data in DB for {req.symbol}, fetching from adapter...")
items = self._fetch_from_adapter(req.symbol, req.start, req.end, req.freq)
# 保存到数据库
if items:
self._save_klines_to_db(req.symbol, req.freq, items)
# 处理复权(简化实现,实际需要复权系数表)
if req.adjust != AdjustType.NONE:
items = self._apply_adjust(req.symbol, items, req.adjust)
return KLineData(
symbol=req.symbol,
freq=req.freq,
adjust=req.adjust,
count=len(items),
items=items
)
def _fetch_from_adapter(self, symbol: str, start: str, end: str, freq: Frequency) -> List[KLineItem]:
"""从适配器获取K线数据"""
try:
# 获取适配器服务
adapter_service = AdapterService()
# 确保适配器已连接
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
# 从配置获取当前激活的适配器名称
from app.core.config import get_config
config = get_config()
active_source = config.sources.stock.active
# 尝试连接配置的适配器
info(f"Connecting to configured adapter: {active_source}")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(adapter_service._connect_adapter(active_source))
loop.close()
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
error("No active adapter available")
return []
# 转换频率格式
freq_str = self._convert_freq_to_str(freq)
# 异步获取数据
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
klines = loop.run_until_complete(
adapter.fetch_klines(symbol, start, end, freq_str)
)
loop.close()
# 转换为KLineItem
items = []
for k in klines:
items.append(KLineItem(
symbol=symbol,
time=datetime.fromtimestamp(k.time),
open=k.open,
high=k.high,
low=k.low,
close=k.close,
volume=k.volume,
amount=k.amount,
trade_date=getattr(k, 'trade_date', None),
is_limit_up=getattr(k, 'is_limit_up', None),
is_limit_down=getattr(k, 'is_limit_down', None),
total_market_cap=getattr(k, 'total_market_cap', None),
float_market_cap=getattr(k, 'float_market_cap', None),
inst_holding_ratio=getattr(k, 'inst_holding_ratio', None),
trading_days=getattr(k, 'trading_days', None),
created_at=datetime.now()
))
info(f"Fetched {len(items)} klines from adapter for {symbol}")
return items
except Exception as e:
error(f"Failed to fetch from adapter: {e}")
return []
def _convert_freq_to_str(self, freq: Frequency) -> str:
"""转换频率枚举为字符串"""
mapping = {
Frequency.FREQ_1M: "1m",
Frequency.FREQ_5M: "5m",
Frequency.FREQ_15M: "15m",
Frequency.FREQ_30M: "30m",
Frequency.FREQ_60M: "60m",
Frequency.FREQ_1D: "1d",
Frequency.FREQ_1W: "1w",
Frequency.FREQ_1MONTH: "1month",
}
return mapping.get(freq, "1d")
def _save_klines_to_db(self, symbol: str, freq: Frequency, items: List[KLineItem]) -> None:
"""保存K线数据到数据库"""
try:
# 添加symbol属性
for item in items:
item.symbol = symbol
self.repository.save_klines(freq, items)
info(f"Saved {len(items)} klines to DB for {symbol}")
except Exception as e:
error(f"Failed to save klines to DB: {e}")
def _apply_adjust(
self,
symbol: str,
items: List[KLineItem],
adjust_type: AdjustType
) -> List[KLineItem]:
"""应用复权计算
复权原理:
- 前复权(qfq): 以最新价格为基准将历史价格按比例缩小
- 后复权(hfq): 以历史最早价格为基准将后续价格按比例放大
"""
if not items or adjust_type == AdjustType.NONE:
return items
try:
# 获取日期范围
start_date = items[0].time.strftime("%Y%m%d")
end_date = items[-1].time.strftime("%Y%m%d")
# 从数据库获取复权系数
factors = self.repository.get_adjust_factors(symbol, start_date, end_date)
# 如果没有复权系数,尝试从适配器获取
if not factors:
factors = self._fetch_adjust_factors_from_adapter(symbol, start_date, end_date)
if factors:
self.repository.save_adjust_factors(symbol, factors)
# 将复权系数转换为字典,方便查找
factor_map = {f["trade_date"]: f for f in factors}
# 应用复权
adjusted_items = []
for item in items:
# 获取交易日期
trade_date = getattr(item, 'trade_date', None)
if not trade_date and hasattr(item, 'time'):
trade_date = item.time.strftime("%Y-%m-%d")
factor = factor_map.get(trade_date, {"qfq_factor": 1.0, "hfq_factor": 1.0})
# 根据复权类型选择系数
if adjust_type == AdjustType.QFQ:
adj_factor = factor.get("qfq_factor", 1.0)
else: # HFQ
adj_factor = factor.get("hfq_factor", 1.0)
# 应用复权系数到价格字段
adjusted_item = KLineItem(
symbol=item.symbol,
time=item.time,
open=round(item.open * adj_factor, 4),
high=round(item.high * adj_factor, 4),
low=round(item.low * adj_factor, 4),
close=round(item.close * adj_factor, 4),
volume=item.volume,
amount=round(item.amount * adj_factor, 4) if item.amount else item.amount,
trade_date=getattr(item, 'trade_date', None),
is_limit_up=getattr(item, 'is_limit_up', None),
is_limit_down=getattr(item, 'is_limit_down', None),
total_market_cap=getattr(item, 'total_market_cap', None),
float_market_cap=getattr(item, 'float_market_cap', None),
inst_holding_ratio=getattr(item, 'inst_holding_ratio', None),
trading_days=getattr(item, 'trading_days', None),
adj_factor=adj_factor
)
adjusted_items.append(adjusted_item)
return adjusted_items
except Exception as e:
error(f"Failed to apply adjust factor for {symbol}: {e}")
# 出错时返回原始数据
return items
def _fetch_adjust_factors_from_adapter(
self,
symbol: str,
start_date: str,
end_date: str
) -> List[dict]:
"""从适配器获取复权系数"""
try:
adapter_service = AdapterService()
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
error("No active adapter available for fetching adjust factors")
return []
# 检查适配器是否支持获取复权因子
if not hasattr(adapter, 'get_adj_factor'):
return []
# 异步获取前复权因子
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
qfq_df = loop.run_until_complete(
adapter.get_adj_factor([symbol])
)
hfq_df = loop.run_until_complete(
adapter.get_backward_factor([symbol])
)
finally:
loop.close()
# 转换DataFrame为列表
factors = []
# 处理日期格式
for idx in qfq_df.index:
date_obj = idx if hasattr(idx, 'strftime') else datetime.strptime(str(idx), "%Y%m%d")
date_str = date_obj.strftime("%Y-%m-%d")
date_key = date_obj.strftime("%Y%m%d")
# 只保留指定范围内的数据
if not (start_date <= date_key <= end_date):
continue
qfq_factor = float(qfq_df.loc[idx, symbol]) if symbol in qfq_df.columns else 1.0
hfq_factor = float(hfq_df.loc[idx, symbol]) if symbol in hfq_df.columns else 1.0
# 确保复权系数有效
if qfq_factor <= 0 or qfq_factor != qfq_factor: # 检查NaN
qfq_factor = 1.0
if hfq_factor <= 0 or hfq_factor != hfq_factor:
hfq_factor = 1.0
factors.append({
"trade_date": date_str,
"qfq_factor": qfq_factor,
"hfq_factor": hfq_factor
})
info(f"Fetched {len(factors)} adjust factors from adapter for {symbol}")
return factors
except Exception as e:
error(f"Failed to fetch adjust factors from adapter: {e}")
return []
def list_symbols(self, req: SymbolListRequest) -> SymbolListData:
"""查询标的列表"""
# 设置默认值
if req.page <= 0:
req.page = 1
if req.size <= 0:
req.size = 20
if req.size > 100:
req.size = 100
symbols, total = self.repository.list_symbols(req)
# 如果数据库没有数据,尝试从适配器获取
if not symbols:
info("No symbols in DB, fetching from adapter...")
symbols = self._fetch_symbols_from_adapter()
if symbols:
# 保存到数据库
self._save_symbols_to_db(symbols)
# 重新查询
symbols, total = self.repository.list_symbols(req)
return SymbolListData(
total=total,
page=req.page,
size=req.size,
items=symbols
)
def _fetch_symbols_from_adapter(self) -> List:
"""从适配器获取股票列表"""
try:
adapter_service = AdapterService()
# 确保适配器已连接
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
# 从配置获取当前激活的适配器名称
from app.core.config import get_config
config = get_config()
active_source = config.sources.stock.active
info(f"Connecting to configured adapter: {active_source}")
asyncio.run(adapter_service._connect_adapter(active_source))
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
error("No active adapter available")
return []
# 异步获取数据
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
symbols_info = loop.run_until_complete(
adapter.fetch_symbols("stock")
)
loop.close()
# 转换为Symbol模型
from app.models import Symbol, SymbolType, Exchange
symbols = []
for s in symbols_info:
# 将字符串 exchange 转换为 Exchange 枚举
try:
exchange_enum = Exchange(s.exchange)
except ValueError:
exchange_enum = Exchange.SH # 默认上海
symbols.append(Symbol(
symbol_id=s.symbol_id,
symbol_type=SymbolType.STOCK,
exchange=exchange_enum,
name=s.name,
underlying=s.underlying,
status="active" # 添加必需的 status 字段
))
info(f"Fetched {len(symbols)} symbols from adapter")
return symbols
except Exception as e:
error(f"Failed to fetch symbols from adapter: {e}")
return []
def _save_symbols_to_db(self, symbols: List) -> None:
"""保存股票列表到数据库"""
try:
from app.models import TradeCalData
self.repository.save_symbols(symbols)
info(f"Saved {len(symbols)} symbols to DB")
except Exception as e:
error(f"Failed to save symbols to DB: {e}")
def batch_query_klines(self, req: BatchKLineRequest) -> BatchKLineData:
"""批量查询K线"""
results = []
for symbol in req.symbols:
single_req = KLineQueryRequest(
symbol=symbol,
start=req.start,
end=req.end,
freq=req.freq,
adjust=req.adjust
)
try:
data = self.query_klines(single_req)
results.append(BatchKLineResult(
symbol=symbol,
success=True,
data=KLineSubData(count=data.count, items=data.items)
))
except Exception as e:
error(f"Batch query failed for {symbol}: {e}")
results.append(BatchKLineResult(
symbol=symbol,
success=False,
error=str(e)
))
return BatchKLineData(results=results)
def get_trading_dates(self, req: TradingDatesRequest) -> TradingDatesData:
"""获取交易日历"""
# 从数据库获取
data = self.repository.get_trading_dates(req.start, req.end)
# 如果数据库没有数据,从适配器获取
if not data.trading_dates:
info(f"No trading dates in DB for {req.start}~{req.end}, fetching from adapter...")
adapter_dates = self._fetch_trading_dates_from_adapter(req.start, req.end)
if adapter_dates:
# 保存到数据库
self._save_trading_dates_to_db(adapter_dates)
# 重新查询
data = self.repository.get_trading_dates(req.start, req.end)
return data
def _fetch_trading_dates_from_adapter(self, start: str, end: str) -> List[str]:
"""从适配器获取交易日历"""
try:
adapter_service = AdapterService()
# 确保适配器已连接
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
# 从配置获取当前激活的适配器名称
from app.core.config import get_config
config = get_config()
active_source = config.sources.stock.active
info(f"Connecting to configured adapter: {active_source}")
asyncio.run(adapter_service._connect_adapter(active_source))
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
error("No active adapter available")
return []
# 异步获取数据
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
calendar_data = loop.run_until_complete(
adapter.fetch_trading_calendar("SH", start, end)
)
loop.close()
# 提取交易日日期
dates = []
for cal in calendar_data:
if cal.is_trading_day:
# 转换为 YYYYMMDD 格式
date_str = cal.date.strftime("%Y%m%d")
dates.append(date_str)
info(f"Fetched {len(dates)} trading dates from adapter")
return dates
except Exception as e:
error(f"Failed to fetch trading dates from adapter: {e}")
return []
def _save_trading_dates_to_db(self, dates: List[str]) -> None:
"""保存交易日历到数据库"""
try:
from app.repositories.models import StockTradingCalendar
from datetime import datetime as dt
for date_str in dates:
# 检查是否已存在
existing = self.db.query(StockTradingCalendar).filter(
StockTradingCalendar.trade_date == date_str
).first()
if not existing:
# 解析日期获取星期几
date_obj = dt.strptime(date_str, "%Y%m%d")
week_day = date_obj.weekday() + 1 # 1=周一, 7=周日
new_record = StockTradingCalendar(
trade_date=date_str,
is_trading_day=True,
week_day=week_day
)
self.db.add(new_record)
self.db.commit()
info(f"Saved {len(dates)} trading dates to DB")
except Exception as e:
error(f"Failed to save trading dates to DB: {e}")
self.db.rollback()