You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

266 lines
9.1 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""股票业务服务 - 对应Go的internal/service/stock.go"""
import asyncio
from datetime import datetime, timedelta
from typing import List
from sqlalchemy.orm import Session
from app.models import (
KLineQueryRequest, KLineData, SymbolListRequest, SymbolListData,
BatchKLineRequest, BatchKLineData, BatchKLineResult, KLineSubData,
TradingDatesRequest, TradingDatesData, AdjustType, Frequency, KLineItem
)
from app.repositories import StockRepository
from app.services.adapter_service import AdapterService
from app.core.logger import error, info
class StockService:
"""股票业务服务"""
def __init__(self, db: Session):
self.repository = StockRepository(db)
self.db = db
def query_klines(self, req: KLineQueryRequest) -> KLineData:
"""查询K线数据"""
# 解析日期
try:
start = datetime.strptime(req.start, "%Y%m%d")
end = datetime.strptime(req.end, "%Y%m%d")
end = end + timedelta(days=1) - timedelta(seconds=1) # 包含结束日期全天
except ValueError as e:
raise ValueError(f"Invalid date format: {e}")
# 获取K线数据从数据库
items = self.repository.get_klines(
req.symbol,
req.freq,
start,
end,
req.adjust
)
# 如果数据库没有数据,尝试从适配器获取
if not items:
info(f"No data in DB for {req.symbol}, fetching from adapter...")
items = self._fetch_from_adapter(req.symbol, req.start, req.end, req.freq)
# 保存到数据库
if items:
self._save_klines_to_db(req.symbol, req.freq, items)
# 处理复权(简化实现,实际需要复权系数表)
if req.adjust != AdjustType.NONE:
items = self._apply_adjust(req.symbol, items, req.adjust)
return KLineData(
symbol=req.symbol,
freq=req.freq,
adjust=req.adjust,
count=len(items),
items=items
)
def _fetch_from_adapter(self, symbol: str, start: str, end: str, freq: Frequency) -> List[KLineItem]:
"""从适配器获取K线数据"""
try:
# 获取适配器服务
adapter_service = AdapterService()
# 确保适配器已连接
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
# 尝试连接akshare
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(adapter_service._connect_adapter("akshare"))
loop.close()
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
error("No active adapter available")
return []
# 转换频率格式
freq_str = self._convert_freq_to_str(freq)
# 异步获取数据
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
klines = loop.run_until_complete(
adapter.fetch_klines(symbol, start, end, freq_str)
)
loop.close()
# 转换为KLineItem
items = []
for k in klines:
items.append(KLineItem(
time=datetime.fromtimestamp(k.time),
open=k.open,
high=k.high,
low=k.low,
close=k.close,
volume=k.volume,
amount=k.amount
))
info(f"Fetched {len(items)} klines from adapter for {symbol}")
return items
except Exception as e:
error(f"Failed to fetch from adapter: {e}")
return []
def _convert_freq_to_str(self, freq: Frequency) -> str:
"""转换频率枚举为字符串"""
mapping = {
Frequency.FREQ_1M: "1m",
Frequency.FREQ_5M: "5m",
Frequency.FREQ_15M: "15m",
Frequency.FREQ_30M: "30m",
Frequency.FREQ_60M: "60m",
Frequency.FREQ_1D: "1d",
Frequency.FREQ_1W: "1w",
Frequency.FREQ_1MONTH: "1month",
}
return mapping.get(freq, "1d")
def _save_klines_to_db(self, symbol: str, freq: Frequency, items: List[KLineItem]) -> None:
"""保存K线数据到数据库"""
try:
# 添加symbol属性
for item in items:
item.symbol = symbol
self.repository.save_klines(freq, items)
info(f"Saved {len(items)} klines to DB for {symbol}")
except Exception as e:
error(f"Failed to save klines to DB: {e}")
def _apply_adjust(
self,
symbol: str,
items: List,
adjust_type: AdjustType
) -> List:
"""应用复权计算TODO: 实现复权逻辑)"""
# 复权计算需要从数据库获取复权系数
# 这里简化处理,直接返回原始数据
return items
def list_symbols(self, req: SymbolListRequest) -> SymbolListData:
"""查询标的列表"""
# 设置默认值
if req.page <= 0:
req.page = 1
if req.size <= 0:
req.size = 20
if req.size > 100:
req.size = 100
symbols, total = self.repository.list_symbols(req)
# 如果数据库没有数据,尝试从适配器获取
if not symbols:
info("No symbols in DB, fetching from adapter...")
symbols = self._fetch_symbols_from_adapter()
if symbols:
# 保存到数据库
self._save_symbols_to_db(symbols)
# 重新查询
symbols, total = self.repository.list_symbols(req)
return SymbolListData(
total=total,
page=req.page,
size=req.size,
items=symbols
)
def _fetch_symbols_from_adapter(self) -> List:
"""从适配器获取股票列表"""
try:
adapter_service = AdapterService()
# 确保适配器已连接
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
asyncio.run(adapter_service._connect_adapter("akshare"))
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
error("No active adapter available")
return []
# 异步获取数据
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
symbols_info = loop.run_until_complete(
adapter.fetch_symbols("stock")
)
loop.close()
# 转换为Symbol模型
from app.models import Symbol, SymbolType
symbols = []
for s in symbols_info:
symbols.append(Symbol(
symbol_id=s.symbol_id,
symbol_type=SymbolType.STOCK,
exchange=s.exchange,
name=s.name,
underlying=s.underlying
))
info(f"Fetched {len(symbols)} symbols from adapter")
return symbols
except Exception as e:
error(f"Failed to fetch symbols from adapter: {e}")
return []
def _save_symbols_to_db(self, symbols: List) -> None:
"""保存股票列表到数据库"""
try:
from app.models import TradeCalData
self.repository.save_symbols(symbols)
info(f"Saved {len(symbols)} symbols to DB")
except Exception as e:
error(f"Failed to save symbols to DB: {e}")
def batch_query_klines(self, req: BatchKLineRequest) -> BatchKLineData:
"""批量查询K线"""
results = []
for symbol in req.symbols:
single_req = KLineQueryRequest(
symbol=symbol,
start=req.start,
end=req.end,
freq=req.freq,
adjust=req.adjust
)
try:
data = self.query_klines(single_req)
results.append(BatchKLineResult(
symbol=symbol,
success=True,
data=KLineSubData(count=data.count, items=data.items)
))
except Exception as e:
error(f"Batch query failed for {symbol}: {e}")
results.append(BatchKLineResult(
symbol=symbol,
success=False,
error=str(e)
))
return BatchKLineData(results=results)
def get_trading_dates(self, req: TradingDatesRequest) -> TradingDatesData:
"""获取交易日历"""
return self.repository.get_trading_dates(req.start, req.end)