You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

513 lines
20 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""股票业务服务 - 对应Go的internal/service/stock.go"""
import asyncio
from datetime import datetime, timedelta
from typing import List
from sqlalchemy.orm import Session
from app.models import (
KLineQueryRequest, KLineData, SymbolListRequest, SymbolListData,
BatchKLineRequest, BatchKLineData, BatchKLineResult, KLineSubData,
TradingDatesRequest, TradingDatesData, AdjustType, Frequency, KLineItem
)
from app.repositories import StockRepository
from app.services.adapter_service import AdapterService
from app.core.logger import error, info
class StockService:
"""股票业务服务"""
def __init__(self, db: Session):
self.repository = StockRepository(db)
self.db = db
def query_klines(self, req: KLineQueryRequest) -> KLineData:
"""查询K线数据"""
# 解析日期
try:
start = datetime.strptime(req.start, "%Y%m%d")
end = datetime.strptime(req.end, "%Y%m%d")
end = end + timedelta(days=1) - timedelta(seconds=1) # 包含结束日期全天
except ValueError as e:
raise ValueError(f"Invalid date format: {e}")
# 获取K线数据从数据库
items = self.repository.get_klines(
req.symbol,
req.freq,
start,
end,
req.adjust
)
# 如果数据库没有数据,尝试从适配器获取
if not items:
info(f"No data in DB for {req.symbol}, fetching from adapter...")
items = self._fetch_from_adapter(req.symbol, req.start, req.end, req.freq)
# 保存到数据库
if items:
self._save_klines_to_db(req.symbol, req.freq, items)
# 处理复权(简化实现,实际需要复权系数表)
if req.adjust != AdjustType.NONE:
items = self._apply_adjust(req.symbol, items, req.adjust)
return KLineData(
symbol=req.symbol,
freq=req.freq,
adjust=req.adjust,
count=len(items),
items=items
)
def _fetch_from_adapter(self, symbol: str, start: str, end: str, freq: Frequency) -> List[KLineItem]:
"""从适配器获取K线数据"""
try:
# 获取适配器服务
adapter_service = AdapterService()
# 确保适配器已连接
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
# 从配置获取当前激活的适配器名称
from app.core.config import get_config
config = get_config()
active_source = config.sources.stock.active
# 尝试连接配置的适配器
info(f"Connecting to configured adapter: {active_source}")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(adapter_service._connect_adapter(active_source))
loop.close()
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
error("No active adapter available")
return []
# 转换频率格式
freq_str = self._convert_freq_to_str(freq)
# 异步获取数据
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
klines = loop.run_until_complete(
adapter.fetch_klines(symbol, start, end, freq_str)
)
loop.close()
# 转换为KLineItem
items = []
for k in klines:
items.append(KLineItem(
symbol=symbol,
time=datetime.fromtimestamp(k.time),
open=k.open,
high=k.high,
low=k.low,
close=k.close,
volume=k.volume,
amount=k.amount,
trade_date=getattr(k, 'trade_date', None),
is_limit_up=getattr(k, 'is_limit_up', None),
is_limit_down=getattr(k, 'is_limit_down', None),
total_market_cap=getattr(k, 'total_market_cap', None),
float_market_cap=getattr(k, 'float_market_cap', None),
inst_holding_ratio=getattr(k, 'inst_holding_ratio', None),
trading_days=getattr(k, 'trading_days', None),
created_at=datetime.now()
))
info(f"Fetched {len(items)} klines from adapter for {symbol}")
return items
except Exception as e:
error(f"Failed to fetch from adapter: {e}")
return []
def _convert_freq_to_str(self, freq: Frequency) -> str:
"""转换频率枚举为字符串"""
mapping = {
Frequency.FREQ_1M: "1m",
Frequency.FREQ_5M: "5m",
Frequency.FREQ_15M: "15m",
Frequency.FREQ_30M: "30m",
Frequency.FREQ_60M: "60m",
Frequency.FREQ_1D: "1d",
Frequency.FREQ_1W: "1w",
Frequency.FREQ_1MONTH: "1month",
}
return mapping.get(freq, "1d")
def _save_klines_to_db(self, symbol: str, freq: Frequency, items: List[KLineItem]) -> None:
"""保存K线数据到数据库"""
try:
# 添加symbol属性
for item in items:
item.symbol = symbol
self.repository.save_klines(freq, items)
info(f"Saved {len(items)} klines to DB for {symbol}")
except Exception as e:
error(f"Failed to save klines to DB: {e}")
def _apply_adjust(
self,
symbol: str,
items: List[KLineItem],
adjust_type: AdjustType
) -> List[KLineItem]:
"""应用复权计算
复权原理:
- 前复权(qfq): 以最新价格为基准,将历史价格按比例缩小
- 后复权(hfq): 以历史最早价格为基准,将后续价格按比例放大
"""
if not items or adjust_type == AdjustType.NONE:
return items
try:
# 获取日期范围
start_date = items[0].time.strftime("%Y%m%d")
end_date = items[-1].time.strftime("%Y%m%d")
# 从数据库获取复权系数
factors = self.repository.get_adjust_factors(symbol, start_date, end_date)
# 如果没有复权系数,尝试从适配器获取
if not factors:
factors = self._fetch_adjust_factors_from_adapter(symbol, start_date, end_date)
if factors:
self.repository.save_adjust_factors(symbol, factors)
# 将复权系数转换为字典,方便查找
factor_map = {f["trade_date"]: f for f in factors}
# 应用复权
adjusted_items = []
for item in items:
# 获取交易日期
trade_date = getattr(item, 'trade_date', None)
if not trade_date and hasattr(item, 'time'):
trade_date = item.time.strftime("%Y-%m-%d")
factor = factor_map.get(trade_date, {"qfq_factor": 1.0, "hfq_factor": 1.0})
# 根据复权类型选择系数
if adjust_type == AdjustType.QFQ:
adj_factor = factor.get("qfq_factor", 1.0)
else: # HFQ
adj_factor = factor.get("hfq_factor", 1.0)
# 应用复权系数到价格字段
adjusted_item = KLineItem(
symbol=item.symbol,
time=item.time,
open=round(item.open * adj_factor, 4),
high=round(item.high * adj_factor, 4),
low=round(item.low * adj_factor, 4),
close=round(item.close * adj_factor, 4),
volume=item.volume,
amount=round(item.amount * adj_factor, 4) if item.amount else item.amount,
trade_date=getattr(item, 'trade_date', None),
is_limit_up=getattr(item, 'is_limit_up', None),
is_limit_down=getattr(item, 'is_limit_down', None),
total_market_cap=getattr(item, 'total_market_cap', None),
float_market_cap=getattr(item, 'float_market_cap', None),
inst_holding_ratio=getattr(item, 'inst_holding_ratio', None),
trading_days=getattr(item, 'trading_days', None),
adj_factor=adj_factor
)
adjusted_items.append(adjusted_item)
return adjusted_items
except Exception as e:
error(f"Failed to apply adjust factor for {symbol}: {e}")
# 出错时返回原始数据
return items
def _fetch_adjust_factors_from_adapter(
self,
symbol: str,
start_date: str,
end_date: str
) -> List[dict]:
"""从适配器获取复权系数"""
try:
adapter_service = AdapterService()
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
error("No active adapter available for fetching adjust factors")
return []
# 检查适配器是否支持获取复权因子
if not hasattr(adapter, 'get_adj_factor'):
return []
# 异步获取前复权因子
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
qfq_df = loop.run_until_complete(
adapter.get_adj_factor([symbol])
)
hfq_df = loop.run_until_complete(
adapter.get_backward_factor([symbol])
)
finally:
loop.close()
# 转换DataFrame为列表
factors = []
# 处理日期格式
for idx in qfq_df.index:
date_obj = idx if hasattr(idx, 'strftime') else datetime.strptime(str(idx), "%Y%m%d")
date_str = date_obj.strftime("%Y-%m-%d")
date_key = date_obj.strftime("%Y%m%d")
# 只保留指定范围内的数据
if not (start_date <= date_key <= end_date):
continue
qfq_factor = float(qfq_df.loc[idx, symbol]) if symbol in qfq_df.columns else 1.0
hfq_factor = float(hfq_df.loc[idx, symbol]) if symbol in hfq_df.columns else 1.0
# 确保复权系数有效
if qfq_factor <= 0 or qfq_factor != qfq_factor: # 检查NaN
qfq_factor = 1.0
if hfq_factor <= 0 or hfq_factor != hfq_factor:
hfq_factor = 1.0
factors.append({
"trade_date": date_str,
"qfq_factor": qfq_factor,
"hfq_factor": hfq_factor
})
info(f"Fetched {len(factors)} adjust factors from adapter for {symbol}")
return factors
except Exception as e:
error(f"Failed to fetch adjust factors from adapter: {e}")
return []
def list_symbols(self, req: SymbolListRequest) -> SymbolListData:
"""查询标的列表"""
# 设置默认值
if req.page <= 0:
req.page = 1
if req.size <= 0:
req.size = 20
if req.size > 100:
req.size = 100
symbols, total = self.repository.list_symbols(req)
# 如果数据库没有数据,尝试从适配器获取
if not symbols:
info("No symbols in DB, fetching from adapter...")
symbols = self._fetch_symbols_from_adapter()
if symbols:
# 保存到数据库
self._save_symbols_to_db(symbols)
# 重新查询
symbols, total = self.repository.list_symbols(req)
return SymbolListData(
total=total,
page=req.page,
size=req.size,
items=symbols
)
def _fetch_symbols_from_adapter(self) -> List:
"""从适配器获取股票列表"""
try:
adapter_service = AdapterService()
# 确保适配器已连接
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
# 从配置获取当前激活的适配器名称
from app.core.config import get_config
config = get_config()
active_source = config.sources.stock.active
info(f"Connecting to configured adapter: {active_source}")
asyncio.run(adapter_service._connect_adapter(active_source))
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
error("No active adapter available")
return []
# 异步获取数据
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
symbols_info = loop.run_until_complete(
adapter.fetch_symbols("stock")
)
loop.close()
# 转换为Symbol模型
from app.models import Symbol, SymbolType, Exchange
symbols = []
for s in symbols_info:
# 将字符串 exchange 转换为 Exchange 枚举
try:
exchange_enum = Exchange(s.exchange)
except ValueError:
exchange_enum = Exchange.SH # 默认上海
symbols.append(Symbol(
symbol_id=s.symbol_id,
symbol_type=SymbolType.STOCK,
exchange=exchange_enum,
name=s.name,
underlying=s.underlying,
status="active" # 添加必需的 status 字段
))
info(f"Fetched {len(symbols)} symbols from adapter")
return symbols
except Exception as e:
error(f"Failed to fetch symbols from adapter: {e}")
return []
def _save_symbols_to_db(self, symbols: List) -> None:
"""保存股票列表到数据库"""
try:
from app.models import TradeCalData
self.repository.save_symbols(symbols)
info(f"Saved {len(symbols)} symbols to DB")
except Exception as e:
error(f"Failed to save symbols to DB: {e}")
def batch_query_klines(self, req: BatchKLineRequest) -> BatchKLineData:
"""批量查询K线"""
results = []
for symbol in req.symbols:
single_req = KLineQueryRequest(
symbol=symbol,
start=req.start,
end=req.end,
freq=req.freq,
adjust=req.adjust
)
try:
data = self.query_klines(single_req)
results.append(BatchKLineResult(
symbol=symbol,
success=True,
data=KLineSubData(count=data.count, items=data.items)
))
except Exception as e:
error(f"Batch query failed for {symbol}: {e}")
results.append(BatchKLineResult(
symbol=symbol,
success=False,
error=str(e)
))
return BatchKLineData(results=results)
def get_trading_dates(self, req: TradingDatesRequest) -> TradingDatesData:
"""获取交易日历"""
# 从数据库获取
data = self.repository.get_trading_dates(req.start, req.end)
# 如果数据库没有数据,从适配器获取
if not data.trading_dates:
info(f"No trading dates in DB for {req.start}~{req.end}, fetching from adapter...")
adapter_dates = self._fetch_trading_dates_from_adapter(req.start, req.end)
if adapter_dates:
# 保存到数据库
self._save_trading_dates_to_db(adapter_dates)
# 重新查询
data = self.repository.get_trading_dates(req.start, req.end)
return data
def _fetch_trading_dates_from_adapter(self, start: str, end: str) -> List[str]:
"""从适配器获取交易日历"""
try:
adapter_service = AdapterService()
# 确保适配器已连接
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
# 从配置获取当前激活的适配器名称
from app.core.config import get_config
config = get_config()
active_source = config.sources.stock.active
info(f"Connecting to configured adapter: {active_source}")
asyncio.run(adapter_service._connect_adapter(active_source))
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
error("No active adapter available")
return []
# 异步获取数据
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
calendar_data = loop.run_until_complete(
adapter.fetch_trading_calendar("SH", start, end)
)
loop.close()
# 提取交易日日期
dates = []
for cal in calendar_data:
if cal.is_trading_day:
# 转换为 YYYYMMDD 格式
date_str = cal.date.strftime("%Y%m%d")
dates.append(date_str)
info(f"Fetched {len(dates)} trading dates from adapter")
return dates
except Exception as e:
error(f"Failed to fetch trading dates from adapter: {e}")
return []
def _save_trading_dates_to_db(self, dates: List[str]) -> None:
"""保存交易日历到数据库"""
try:
from app.repositories.models import StockTradingCalendar
from datetime import datetime as dt
for date_str in dates:
# 检查是否已存在
existing = self.db.query(StockTradingCalendar).filter(
StockTradingCalendar.trade_date == date_str
).first()
if not existing:
# 解析日期获取星期几
date_obj = dt.strptime(date_str, "%Y%m%d")
week_day = date_obj.weekday() + 1 # 1=周一, 7=周日
new_record = StockTradingCalendar(
trade_date=date_str,
is_trading_day=True,
week_day=week_day
)
self.db.add(new_record)
self.db.commit()
info(f"Saved {len(dates)} trading dates to DB")
except Exception as e:
error(f"Failed to save trading dates to DB: {e}")
self.db.rollback()