You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
98 lines
3.4 KiB
98 lines
3.4 KiB
from typing import Dict, Optional, List
|
|
import json
|
|
from app.services.datasource.base import DataSourceBase
|
|
from app.services.datasource.tushare import TushareSource
|
|
from app.services.datasource.akshare import AkshareSource
|
|
from app.database import SessionLocal
|
|
from app.models import DataSourceConfig
|
|
|
|
|
|
class DataSourceManager:
|
|
"""数据源管理器:管理多个数据源的注册、切换和调用"""
|
|
|
|
_sources: Dict[str, DataSourceBase] = {}
|
|
_source_map = {
|
|
"tushare": TushareSource,
|
|
"akshare": AkshareSource,
|
|
# "ctp": CtpSource, # 后续扩展
|
|
}
|
|
|
|
@classmethod
|
|
def register(cls, name: str, source_class):
|
|
"""注册新的数据源类型"""
|
|
cls._source_map[name] = source_class
|
|
|
|
@classmethod
|
|
def get_source(cls, name: str) -> Optional[DataSourceBase]:
|
|
"""获取已初始化的数据源实例"""
|
|
return cls._sources.get(name)
|
|
|
|
@classmethod
|
|
def load_enabled_sources(cls):
|
|
"""从数据库加载启用的数据源"""
|
|
db = SessionLocal()
|
|
try:
|
|
configs = db.query(DataSourceConfig).filter(
|
|
DataSourceConfig.is_enabled == True
|
|
).order_by(DataSourceConfig.priority).all()
|
|
|
|
for cfg in configs:
|
|
if cfg.source_name in cls._source_map:
|
|
source_class = cls._source_map[cfg.source_name]
|
|
try:
|
|
config = json.loads(cfg.config_json) if cfg.config_json else {}
|
|
except json.JSONDecodeError:
|
|
config = {}
|
|
source = source_class(config)
|
|
cls._sources[cfg.source_name] = source
|
|
finally:
|
|
db.close()
|
|
|
|
@classmethod
|
|
def get_primary_source(cls) -> Optional[DataSourceBase]:
|
|
"""获取优先级最高的已启用数据源"""
|
|
if not cls._sources:
|
|
cls.load_enabled_sources()
|
|
|
|
# 按优先级排序
|
|
db = SessionLocal()
|
|
try:
|
|
primary_cfg = db.query(DataSourceConfig).filter(
|
|
DataSourceConfig.is_enabled == True
|
|
).order_by(DataSourceConfig.priority).first()
|
|
|
|
if primary_cfg and primary_cfg.source_name in cls._sources:
|
|
return cls._sources[primary_cfg.source_name]
|
|
return None
|
|
finally:
|
|
db.close()
|
|
|
|
@classmethod
|
|
def get_all_sources_status(cls) -> List[dict]:
|
|
"""获取所有数据源状态"""
|
|
db = SessionLocal()
|
|
try:
|
|
configs = db.query(DataSourceConfig).all()
|
|
result = []
|
|
for cfg in configs:
|
|
# 解析 config_json
|
|
try:
|
|
config_json = json.loads(cfg.config_json) if cfg.config_json else {}
|
|
except json.JSONDecodeError:
|
|
config_json = {}
|
|
|
|
status = {
|
|
"source_name": cfg.source_name,
|
|
"display_name": cfg.display_name,
|
|
"is_enabled": cfg.is_enabled,
|
|
"priority": cfg.priority,
|
|
"status": cfg.status,
|
|
"error_msg": cfg.error_msg,
|
|
"last_sync_time": cfg.last_sync_time,
|
|
"config_json": config_json,
|
|
}
|
|
result.append(status)
|
|
return result
|
|
finally:
|
|
db.close()
|