You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

275 lines
9.7 KiB

"""股票业务服务 - 对应Go的internal/service/stock.go"""
import asyncio
from datetime import datetime, timedelta
from typing import List
from sqlalchemy.orm import Session
from app.models import (
KLineQueryRequest, KLineData, SymbolListRequest, SymbolListData,
BatchKLineRequest, BatchKLineData, BatchKLineResult, KLineSubData,
TradingDatesRequest, TradingDatesData, AdjustType, Frequency, KLineItem
)
from app.repositories import StockRepository
from app.services.adapter_service import AdapterService
from app.core.logger import error, info
class StockService:
"""股票业务服务"""
def __init__(self, db: Session):
self.repository = StockRepository(db)
self.db = db
def query_klines(self, req: KLineQueryRequest) -> KLineData:
"""查询K线数据"""
# 解析日期
try:
start = datetime.strptime(req.start, "%Y%m%d")
end = datetime.strptime(req.end, "%Y%m%d")
end = end + timedelta(days=1) - timedelta(seconds=1) # 包含结束日期全天
except ValueError as e:
raise ValueError(f"Invalid date format: {e}")
# 获取K线数据从数据库
items = self.repository.get_klines(
req.symbol,
req.freq,
start,
end,
req.adjust
)
# 如果数据库没有数据,尝试从适配器获取
if not items:
info(f"No data in DB for {req.symbol}, fetching from adapter...")
items = self._fetch_from_adapter(req.symbol, req.start, req.end, req.freq)
# 保存到数据库
if items:
self._save_klines_to_db(req.symbol, req.freq, items)
# 处理复权(简化实现,实际需要复权系数表)
if req.adjust != AdjustType.NONE:
items = self._apply_adjust(req.symbol, items, req.adjust)
return KLineData(
symbol=req.symbol,
freq=req.freq,
adjust=req.adjust,
count=len(items),
items=items
)
def _fetch_from_adapter(self, symbol: str, start: str, end: str, freq: Frequency) -> List[KLineItem]:
"""从适配器获取K线数据"""
try:
# 获取适配器服务
adapter_service = AdapterService()
# 确保适配器已连接
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
# 尝试连接 amazingdata
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(adapter_service._connect_adapter("amazingdata"))
loop.close()
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
error("No active adapter available")
return []
# 转换频率格式
freq_str = self._convert_freq_to_str(freq)
# 异步获取数据
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
klines = loop.run_until_complete(
adapter.fetch_klines(symbol, start, end, freq_str)
)
loop.close()
# 转换为KLineItem
items = []
for k in klines:
items.append(KLineItem(
symbol=symbol,
time=datetime.fromtimestamp(k.time),
open=k.open,
high=k.high,
low=k.low,
close=k.close,
volume=k.volume,
amount=k.amount,
trade_date=getattr(k, 'trade_date', None),
is_limit_up=getattr(k, 'is_limit_up', None),
is_limit_down=getattr(k, 'is_limit_down', None),
total_market_cap=getattr(k, 'total_market_cap', None),
float_market_cap=getattr(k, 'float_market_cap', None),
inst_holding_ratio=getattr(k, 'inst_holding_ratio', None),
trading_days=getattr(k, 'trading_days', None),
created_at=datetime.now()
))
info(f"Fetched {len(items)} klines from adapter for {symbol}")
return items
except Exception as e:
error(f"Failed to fetch from adapter: {e}")
return []
def _convert_freq_to_str(self, freq: Frequency) -> str:
"""转换频率枚举为字符串"""
mapping = {
Frequency.FREQ_1M: "1m",
Frequency.FREQ_5M: "5m",
Frequency.FREQ_15M: "15m",
Frequency.FREQ_30M: "30m",
Frequency.FREQ_60M: "60m",
Frequency.FREQ_1D: "1d",
Frequency.FREQ_1W: "1w",
Frequency.FREQ_1MONTH: "1month",
}
return mapping.get(freq, "1d")
def _save_klines_to_db(self, symbol: str, freq: Frequency, items: List[KLineItem]) -> None:
"""保存K线数据到数据库"""
try:
# 添加symbol属性
for item in items:
item.symbol = symbol
self.repository.save_klines(freq, items)
info(f"Saved {len(items)} klines to DB for {symbol}")
except Exception as e:
error(f"Failed to save klines to DB: {e}")
def _apply_adjust(
self,
symbol: str,
items: List,
adjust_type: AdjustType
) -> List:
"""应用复权计算TODO: 实现复权逻辑)"""
# 复权计算需要从数据库获取复权系数
# 这里简化处理,直接返回原始数据
return items
def list_symbols(self, req: SymbolListRequest) -> SymbolListData:
"""查询标的列表"""
# 设置默认值
if req.page <= 0:
req.page = 1
if req.size <= 0:
req.size = 20
if req.size > 100:
req.size = 100
symbols, total = self.repository.list_symbols(req)
# 如果数据库没有数据,尝试从适配器获取
if not symbols:
info("No symbols in DB, fetching from adapter...")
symbols = self._fetch_symbols_from_adapter()
if symbols:
# 保存到数据库
self._save_symbols_to_db(symbols)
# 重新查询
symbols, total = self.repository.list_symbols(req)
return SymbolListData(
total=total,
page=req.page,
size=req.size,
items=symbols
)
def _fetch_symbols_from_adapter(self) -> List:
"""从适配器获取股票列表"""
try:
adapter_service = AdapterService()
# 确保适配器已连接
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
asyncio.run(adapter_service._connect_adapter("amazingdata"))
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
error("No active adapter available")
return []
# 异步获取数据
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
symbols_info = loop.run_until_complete(
adapter.fetch_symbols("stock")
)
loop.close()
# 转换为Symbol模型
from app.models import Symbol, SymbolType
symbols = []
for s in symbols_info:
symbols.append(Symbol(
symbol_id=s.symbol_id,
symbol_type=SymbolType.STOCK,
exchange=s.exchange,
name=s.name,
underlying=s.underlying
))
info(f"Fetched {len(symbols)} symbols from adapter")
return symbols
except Exception as e:
error(f"Failed to fetch symbols from adapter: {e}")
return []
def _save_symbols_to_db(self, symbols: List) -> None:
"""保存股票列表到数据库"""
try:
from app.models import TradeCalData
self.repository.save_symbols(symbols)
info(f"Saved {len(symbols)} symbols to DB")
except Exception as e:
error(f"Failed to save symbols to DB: {e}")
def batch_query_klines(self, req: BatchKLineRequest) -> BatchKLineData:
"""批量查询K线"""
results = []
for symbol in req.symbols:
single_req = KLineQueryRequest(
symbol=symbol,
start=req.start,
end=req.end,
freq=req.freq,
adjust=req.adjust
)
try:
data = self.query_klines(single_req)
results.append(BatchKLineResult(
symbol=symbol,
success=True,
data=KLineSubData(count=data.count, items=data.items)
))
except Exception as e:
error(f"Batch query failed for {symbol}: {e}")
results.append(BatchKLineResult(
symbol=symbol,
success=False,
error=str(e)
))
return BatchKLineData(results=results)
def get_trading_dates(self, req: TradingDatesRequest) -> TradingDatesData:
"""获取交易日历"""
return self.repository.get_trading_dates(req.start, req.end)