You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

363 lines
13 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""股票业务服务 - 对应Go的internal/service/stock.go"""
import asyncio
from datetime import datetime, timedelta
from typing import List
from sqlalchemy.orm import Session
from app.models import (
KLineQueryRequest, KLineData, SymbolListRequest, SymbolListData,
BatchKLineRequest, BatchKLineData, BatchKLineResult, KLineSubData,
TradingDatesRequest, TradingDatesData, AdjustType, Frequency, KLineItem
)
from app.repositories import StockRepository
from app.services.adapter_service import AdapterService
from app.core.logger import error, info
class StockService:
"""股票业务服务"""
def __init__(self, db: Session):
self.repository = StockRepository(db)
self.db = db
def query_klines(self, req: KLineQueryRequest) -> KLineData:
"""查询K线数据"""
# 解析日期
try:
start = datetime.strptime(req.start, "%Y%m%d")
end = datetime.strptime(req.end, "%Y%m%d")
end = end + timedelta(days=1) - timedelta(seconds=1) # 包含结束日期全天
except ValueError as e:
raise ValueError(f"Invalid date format: {e}")
# 获取K线数据从数据库
items = self.repository.get_klines(
req.symbol,
req.freq,
start,
end,
req.adjust
)
# 如果数据库没有数据,尝试从适配器获取
if not items:
info(f"No data in DB for {req.symbol}, fetching from adapter...")
items = self._fetch_from_adapter(req.symbol, req.start, req.end, req.freq)
# 保存到数据库
if items:
self._save_klines_to_db(req.symbol, req.freq, items)
# 处理复权(简化实现,实际需要复权系数表)
if req.adjust != AdjustType.NONE:
items = self._apply_adjust(req.symbol, items, req.adjust)
return KLineData(
symbol=req.symbol,
freq=req.freq,
adjust=req.adjust,
count=len(items),
items=items
)
def _fetch_from_adapter(self, symbol: str, start: str, end: str, freq: Frequency) -> List[KLineItem]:
"""从适配器获取K线数据"""
try:
# 获取适配器服务
adapter_service = AdapterService()
# 确保适配器已连接
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
# 尝试连接 amazingdata
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(adapter_service._connect_adapter("amazingdata"))
loop.close()
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
error("No active adapter available")
return []
# 转换频率格式
freq_str = self._convert_freq_to_str(freq)
# 异步获取数据
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
klines = loop.run_until_complete(
adapter.fetch_klines(symbol, start, end, freq_str)
)
loop.close()
# 转换为KLineItem
items = []
for k in klines:
items.append(KLineItem(
symbol=symbol,
time=datetime.fromtimestamp(k.time),
open=k.open,
high=k.high,
low=k.low,
close=k.close,
volume=k.volume,
amount=k.amount,
trade_date=getattr(k, 'trade_date', None),
is_limit_up=getattr(k, 'is_limit_up', None),
is_limit_down=getattr(k, 'is_limit_down', None),
total_market_cap=getattr(k, 'total_market_cap', None),
float_market_cap=getattr(k, 'float_market_cap', None),
inst_holding_ratio=getattr(k, 'inst_holding_ratio', None),
trading_days=getattr(k, 'trading_days', None),
created_at=datetime.now()
))
info(f"Fetched {len(items)} klines from adapter for {symbol}")
return items
except Exception as e:
error(f"Failed to fetch from adapter: {e}")
return []
def _convert_freq_to_str(self, freq: Frequency) -> str:
"""转换频率枚举为字符串"""
mapping = {
Frequency.FREQ_1M: "1m",
Frequency.FREQ_5M: "5m",
Frequency.FREQ_15M: "15m",
Frequency.FREQ_30M: "30m",
Frequency.FREQ_60M: "60m",
Frequency.FREQ_1D: "1d",
Frequency.FREQ_1W: "1w",
Frequency.FREQ_1MONTH: "1month",
}
return mapping.get(freq, "1d")
def _save_klines_to_db(self, symbol: str, freq: Frequency, items: List[KLineItem]) -> None:
"""保存K线数据到数据库"""
try:
# 添加symbol属性
for item in items:
item.symbol = symbol
self.repository.save_klines(freq, items)
info(f"Saved {len(items)} klines to DB for {symbol}")
except Exception as e:
error(f"Failed to save klines to DB: {e}")
def _apply_adjust(
self,
symbol: str,
items: List,
adjust_type: AdjustType
) -> List:
"""应用复权计算TODO: 实现复权逻辑)"""
# 复权计算需要从数据库获取复权系数
# 这里简化处理,直接返回原始数据
return items
def list_symbols(self, req: SymbolListRequest) -> SymbolListData:
"""查询标的列表"""
# 设置默认值
if req.page <= 0:
req.page = 1
if req.size <= 0:
req.size = 20
if req.size > 100:
req.size = 100
symbols, total = self.repository.list_symbols(req)
# 如果数据库没有数据,尝试从适配器获取
if not symbols:
info("No symbols in DB, fetching from adapter...")
symbols = self._fetch_symbols_from_adapter()
if symbols:
# 保存到数据库
self._save_symbols_to_db(symbols)
# 重新查询
symbols, total = self.repository.list_symbols(req)
return SymbolListData(
total=total,
page=req.page,
size=req.size,
items=symbols
)
def _fetch_symbols_from_adapter(self) -> List:
"""从适配器获取股票列表"""
try:
adapter_service = AdapterService()
# 确保适配器已连接
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
asyncio.run(adapter_service._connect_adapter("amazingdata"))
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
error("No active adapter available")
return []
# 异步获取数据
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
symbols_info = loop.run_until_complete(
adapter.fetch_symbols("stock")
)
loop.close()
# 转换为Symbol模型
from app.models import Symbol, SymbolType, Exchange
symbols = []
for s in symbols_info:
# 将字符串 exchange 转换为 Exchange 枚举
try:
exchange_enum = Exchange(s.exchange)
except ValueError:
exchange_enum = Exchange.SH # 默认上海
symbols.append(Symbol(
symbol_id=s.symbol_id,
symbol_type=SymbolType.STOCK,
exchange=exchange_enum,
name=s.name,
underlying=s.underlying,
status="active" # 添加必需的 status 字段
))
info(f"Fetched {len(symbols)} symbols from adapter")
return symbols
except Exception as e:
error(f"Failed to fetch symbols from adapter: {e}")
return []
def _save_symbols_to_db(self, symbols: List) -> None:
"""保存股票列表到数据库"""
try:
from app.models import TradeCalData
self.repository.save_symbols(symbols)
info(f"Saved {len(symbols)} symbols to DB")
except Exception as e:
error(f"Failed to save symbols to DB: {e}")
def batch_query_klines(self, req: BatchKLineRequest) -> BatchKLineData:
"""批量查询K线"""
results = []
for symbol in req.symbols:
single_req = KLineQueryRequest(
symbol=symbol,
start=req.start,
end=req.end,
freq=req.freq,
adjust=req.adjust
)
try:
data = self.query_klines(single_req)
results.append(BatchKLineResult(
symbol=symbol,
success=True,
data=KLineSubData(count=data.count, items=data.items)
))
except Exception as e:
error(f"Batch query failed for {symbol}: {e}")
results.append(BatchKLineResult(
symbol=symbol,
success=False,
error=str(e)
))
return BatchKLineData(results=results)
def get_trading_dates(self, req: TradingDatesRequest) -> TradingDatesData:
"""获取交易日历"""
# 从数据库获取
data = self.repository.get_trading_dates(req.start, req.end)
# 如果数据库没有数据,从适配器获取
if not data.trading_dates:
info(f"No trading dates in DB for {req.start}~{req.end}, fetching from adapter...")
adapter_dates = self._fetch_trading_dates_from_adapter(req.start, req.end)
if adapter_dates:
# 保存到数据库
self._save_trading_dates_to_db(adapter_dates)
# 重新查询
data = self.repository.get_trading_dates(req.start, req.end)
return data
def _fetch_trading_dates_from_adapter(self, start: str, end: str) -> List[str]:
"""从适配器获取交易日历"""
try:
adapter_service = AdapterService()
# 确保适配器已连接
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
asyncio.run(adapter_service._connect_adapter("amazingdata"))
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
error("No active adapter available")
return []
# 异步获取数据
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
calendar_data = loop.run_until_complete(
adapter.fetch_trading_calendar("SH", start, end)
)
loop.close()
# 提取交易日日期
dates = []
for cal in calendar_data:
if cal.is_trading_day:
# 转换为 YYYYMMDD 格式
date_str = cal.date.strftime("%Y%m%d")
dates.append(date_str)
info(f"Fetched {len(dates)} trading dates from adapter")
return dates
except Exception as e:
error(f"Failed to fetch trading dates from adapter: {e}")
return []
def _save_trading_dates_to_db(self, dates: List[str]) -> None:
"""保存交易日历到数据库"""
try:
from app.repositories.models import StockTradingCalendar
from datetime import datetime as dt
for date_str in dates:
# 检查是否已存在
existing = self.db.query(StockTradingCalendar).filter(
StockTradingCalendar.trade_date == date_str
).first()
if not existing:
# 解析日期获取星期几
date_obj = dt.strptime(date_str, "%Y%m%d")
week_day = date_obj.weekday() + 1 # 1=周一, 7=周日
new_record = StockTradingCalendar(
trade_date=date_str,
is_trading_day=True,
week_day=week_day
)
self.db.add(new_record)
self.db.commit()
info(f"Saved {len(dates)} trading dates to DB")
except Exception as e:
error(f"Failed to save trading dates to DB: {e}")
self.db.rollback()