You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

243 lines
9.4 KiB

"""适配器管理服务 - 对应Go的internal/service/adapter.go"""
import asyncio
from datetime import datetime
from typing import Dict, List, Optional, Callable
from threading import RLock
from app.models import (
AdapterListData, AdapterInfo, AdapterStatus,
AdapterToggleRequest, AdapterConfigUpdateRequest
)
from app.adapters import DataSourceAdapter, AmazingDataAdapter
from app.core.logger import info, error
class AdapterService:
"""适配器管理服务"""
def __init__(self):
self.lock = RLock()
# 已注册的适配器工厂
self.factories: Dict[str, Callable[[], DataSourceAdapter]] = {}
# 适配器配置
self.configs: Dict[str, dict] = {}
# 当前激活的适配器实例
self.active_adapters: Dict[str, DataSourceAdapter] = {}
# 适配器元数据
self.metadata: Dict[str, dict] = {}
# 注册内置适配器
self._register_builtin_adapters()
def _register_builtin_adapters(self):
"""注册内置适配器"""
# 注册星耀数智(AmazingData)适配器
self.register_adapter("amazingdata", lambda: AmazingDataAdapter())
# 设置星耀数智元数据
self.metadata["amazingdata"] = {
"name": "amazingdata",
"type": "sdk",
"version": "1.0.30",
"description": "银河证券星耀数智量化平台需要SDK和账号",
"updated_at": datetime.now()
}
# 星耀数智默认配置(需要账号)
self.configs["amazingdata"] = {
"enabled": False,
"config": {
"username": "",
"password": "",
"host": "",
"port": 8080,
"local_path": "./amazing_data_cache/",
"use_local_cache": True
}
}
def get_adapter_list(self) -> AdapterListData:
"""获取适配器列表"""
with self.lock:
adapters = []
for name, meta in self.metadata.items():
cfg = self.configs.get(name, {"enabled": False, "config": {}})
# 确定状态
if not cfg["enabled"]:
status = AdapterStatus.DISABLED
elif name in self.active_adapters:
status = AdapterStatus.ACTIVE
else:
status = AdapterStatus.STANDBY
adapters.append(AdapterInfo(
name=meta["name"],
type=meta["type"],
version=meta["version"],
description=meta["description"],
status=status,
config=cfg.get("config", {}),
updated_at=meta.get("updated_at", datetime.now())
))
return AdapterListData(adapters=adapters)
def toggle_adapter(self, req: AdapterToggleRequest) -> None:
"""启用/禁用适配器"""
with self.lock:
if req.name not in self.configs:
raise ValueError(f"Adapter not found: {req.name}")
old_enabled = self.configs[req.name]["enabled"]
self.configs[req.name]["enabled"] = req.enable
# 如果禁用,关闭适配器连接
if not req.enable and req.name in self.active_adapters:
adapter = self.active_adapters.pop(req.name)
asyncio.create_task(adapter.close())
# 如果启用且之前未启用,建立连接
if req.enable and not old_enabled and req.name not in self.active_adapters:
# 使用 create_task 启动异步连接
asyncio.create_task(self._connect_adapter(req.name))
# 更新元数据
if req.name in self.metadata:
self.metadata[req.name]["updated_at"] = datetime.now()
def update_adapter_config(self, req: AdapterConfigUpdateRequest) -> None:
"""更新适配器配置"""
with self.lock:
if req.name not in self.configs:
raise ValueError(f"Adapter not found: {req.name}")
# 更新配置
self.configs[req.name]["config"].update(req.config)
# 如果适配器已激活,重新连接
if req.name in self.active_adapters:
adapter = self.active_adapters.pop(req.name)
asyncio.create_task(adapter.close())
# 如果启用状态,重新连接
if self.configs[req.name]["enabled"]:
asyncio.create_task(self._connect_adapter(req.name))
# 更新元数据
if req.name in self.metadata:
self.metadata[req.name]["updated_at"] = datetime.now()
def get_active_adapter(self, asset_class: str) -> Optional[DataSourceAdapter]:
"""获取当前激活的适配器
Args:
asset_class: 资产类别 'stock', 'futures'
"""
from app.core.config import get_config
with self.lock:
# 从配置获取当前激活的适配器名称
self.config = get_config()
print(f"Getting active adapter for asset class,config is : {self.config}")
if asset_class == "stock":
active_name = self.config.sources.stock.active
elif asset_class == "futures":
active_name = self.config.sources.futures.active
else:
active_name = "custom" # 默认
print(f"Using adapter: {active_name}")
# 返回已激活的适配器实例
if active_name in self.active_adapters:
return self.active_adapters[active_name]
# 如果配置为启用但未激活,尝试连接
if active_name in self.configs and self.configs[active_name].get("enabled"):
# 启动异步连接(不等待)
asyncio.create_task(self._connect_adapter(active_name))
# 如果没有激活的适配器返回None
return None
def get_available_adapters(self) -> List[str]:
"""获取所有可用的适配器名称"""
with self.lock:
names = []
for name, meta in self.metadata.items():
if name in self.factories:
names.append(f"{name}|{meta['description']}")
return names
def register_adapter(self, name: str, factory: Callable[[], DataSourceAdapter]):
"""注册适配器"""
with self.lock:
self.factories[name] = factory
async def _connect_adapter(self, name: str):
"""连接适配器"""
from app.core.config import get_config
with self.lock:
if name not in self.factories:
raise ValueError(f"Adapter factory not found: {name}")
if name not in self.configs:
raise ValueError(f"Adapter config not found: {name}")
# 如果已经连接,先关闭
if name in self.active_adapters:
old_adapter = self.active_adapters.pop(name)
try:
await old_adapter.close()
except Exception as e:
error(f"Error closing old adapter {name}: {e}")
factory = self.factories[name]
# 从 config.json 获取最新配置(与文件同步)
file_config = get_config()
print(f"Using file config: {file_config}, adapter name: {name}")
if name == "amazingdata":
# 优先使用 stock 下的 amazingdata 配置
source_info = file_config.sources.stock.list["amazingdata"]
adapter_config = dict(source_info.config) if source_info else {}
print(f"Using amazingdata config: {adapter_config}")
# 处理 port 为字符串的情况
if "port" in adapter_config and isinstance(adapter_config["port"], str):
adapter_config["port"] = int(adapter_config["port"]) if adapter_config["port"].strip() else 8600
else:
adapter_config = self.configs[name].get("config", {})
cfg = {"enabled": self.configs[name].get("enabled", False), "config": adapter_config}
try:
info(f"243 Connecting to adapter: {name},config: {cfg}")
adapter = factory()
await adapter.connect(cfg["config"])
with self.lock:
self.active_adapters[name] = adapter
info(f"Adapter {name} connected successfully")
except Exception as e:
error(f"Failed to connect adapter {name}: {e}")
# 重置启用状态
with self.lock:
if name in self.configs:
self.configs[name]["enabled"] = False
async def health_check(self, name: str) -> bool:
"""适配器健康检查"""
with self.lock:
if name not in self.active_adapters:
return False
adapter = self.active_adapters[name]
return await adapter.health_check()