You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

98 lines
3.4 KiB

from typing import Dict, Optional, List
import json
from app.services.datasource.base import DataSourceBase
from app.services.datasource.tushare import TushareSource
from app.services.datasource.akshare import AkshareSource
from app.database import SessionLocal
from app.models import DataSourceConfig
class DataSourceManager:
"""数据源管理器:管理多个数据源的注册、切换和调用"""
_sources: Dict[str, DataSourceBase] = {}
_source_map = {
"tushare": TushareSource,
"akshare": AkshareSource,
# "ctp": CtpSource, # 后续扩展
}
@classmethod
def register(cls, name: str, source_class):
"""注册新的数据源类型"""
cls._source_map[name] = source_class
@classmethod
def get_source(cls, name: str) -> Optional[DataSourceBase]:
"""获取已初始化的数据源实例"""
return cls._sources.get(name)
@classmethod
def load_enabled_sources(cls):
"""从数据库加载启用的数据源"""
db = SessionLocal()
try:
configs = db.query(DataSourceConfig).filter(
DataSourceConfig.is_enabled == True
).order_by(DataSourceConfig.priority).all()
for cfg in configs:
if cfg.source_name in cls._source_map:
source_class = cls._source_map[cfg.source_name]
try:
config = json.loads(cfg.config_json) if cfg.config_json else {}
except json.JSONDecodeError:
config = {}
source = source_class(config)
cls._sources[cfg.source_name] = source
finally:
db.close()
@classmethod
def get_primary_source(cls) -> Optional[DataSourceBase]:
"""获取优先级最高的已启用数据源"""
if not cls._sources:
cls.load_enabled_sources()
# 按优先级排序
db = SessionLocal()
try:
primary_cfg = db.query(DataSourceConfig).filter(
DataSourceConfig.is_enabled == True
).order_by(DataSourceConfig.priority).first()
if primary_cfg and primary_cfg.source_name in cls._sources:
return cls._sources[primary_cfg.source_name]
return None
finally:
db.close()
@classmethod
def get_all_sources_status(cls) -> List[dict]:
"""获取所有数据源状态"""
db = SessionLocal()
try:
configs = db.query(DataSourceConfig).all()
result = []
for cfg in configs:
# 解析 config_json
try:
config_json = json.loads(cfg.config_json) if cfg.config_json else {}
except json.JSONDecodeError:
config_json = {}
status = {
"source_name": cfg.source_name,
"display_name": cfg.display_name,
"is_enabled": cfg.is_enabled,
"priority": cfg.priority,
"status": cfg.status,
"error_msg": cfg.error_msg,
"last_sync_time": cfg.last_sync_time,
"config_json": config_json,
}
result.append(status)
return result
finally:
db.close()