You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

121 lines
3.6 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
基础数据服务
"""
from typing import List, Optional, Dict
from datetime import date
from sqlalchemy.orm import Session
import pandas as pd
from app.services.sdk_manager import sdk_manager
from app.models.config import SDKConfig
class BaseDataService:
"""基础数据服务"""
def __init__(self, db: Session):
self.db = db
def _get_adapter(self):
"""获取SDK适配器使用连接管理器"""
return sdk_manager.get_default_connection()
def get_code_list(self, security_type: str) -> List[str]:
"""
获取代码列表
Args:
security_type: 证券类型
- EXTRA_STOCK_A: 沪深A股
- EXTRA_FUTURE: 期货
- EXTRA_ETF: ETF
- EXTRA_INDEX_A: 指数
Returns:
代码列表
"""
adapter = self._get_adapter()
if not adapter:
raise RuntimeError("SDK连接失败请先测试连接")
return adapter.get_code_list(security_type)
def get_code_info(self, security_type: str) -> pd.DataFrame:
"""获取证券信息"""
adapter = self._get_adapter()
if not adapter:
raise RuntimeError("SDK连接失败请先测试连接")
return adapter.get_code_info(security_type)
def get_trading_calendar(
self,
market: str,
start_date: date = None,
end_date: date = None
) -> List[date]:
"""
获取交易日历
Args:
market: 市场代码 (SH, SZ, CFE)
start_date: 开始日期
end_date: 结束日期
Returns:
交易日列表
"""
adapter = self._get_adapter()
if not adapter:
raise RuntimeError("SDK连接失败请先测试连接")
calendar_ints = adapter.get_trading_calendar(market)
# 转换为date对象
from app.utils.date_utils import int_to_date
dates = [int_to_date(d) for d in calendar_ints]
# 过滤日期范围
if start_date:
dates = [d for d in dates if d >= start_date]
if end_date:
dates = [d for d in dates if d <= end_date]
return dates
def get_adj_factor(self, codes: List[str]) -> pd.DataFrame:
"""获取复权因子"""
adapter = self._get_adapter()
if not adapter:
raise RuntimeError("SDK连接失败请先测试连接")
return adapter.get_adj_factor(codes)
def get_backward_factor(self, codes: List[str]) -> pd.DataFrame:
"""获取后复权因子"""
adapter = self._get_adapter()
if not adapter:
raise RuntimeError("SDK连接失败请先测试连接")
return adapter.get_backward_factor(codes)
def get_security_type(self, code: str) -> str:
"""
根据代码判断证券类型
Args:
code: 证券代码
Returns:
证券类型 (stock, future, index, etf, unknown)
"""
if code.endswith(".CFE"):
return "future"
elif code.endswith((".SH", ".SZ", ".BJ")):
# 根据代码前缀判断
prefix = code[:3]
if prefix in ["000", "001", "002", "003", "300", "301", "600", "601", "603", "605", "688"]:
return "stock"
elif prefix in ["510", "511", "512", "513", "515", "516", "518", "560", "563", "588"]:
return "etf"
elif prefix in ["000", "880"]:
return "index"
return "unknown"