You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

285 lines
11 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""适配器管理服务 - 对应Go的internal/service/adapter.go"""
import asyncio
from datetime import datetime
from typing import Dict, List, Optional, Callable
from threading import RLock
from app.models import (
AdapterListData, AdapterInfo, AdapterStatus,
AdapterToggleRequest, AdapterConfigUpdateRequest
)
from app.adapters import DataSourceAdapter, AmazingDataAdapter
from app.core.logger import info, error
class AdapterService:
"""适配器管理服务(单例模式)
确保整个应用中只有一个 AdapterService 实例,
避免重复创建导致适配器连接状态丢失。
"""
_instance: Optional['AdapterService'] = None
_lock = RLock()
def __new__(cls) -> 'AdapterService':
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
# 避免重复初始化
if self._initialized:
return
with self._lock:
if self._initialized:
return
self.lock = RLock()
# 已注册的适配器工厂
self.factories: Dict[str, Callable[[], DataSourceAdapter]] = {}
# 适配器配置
self.configs: Dict[str, dict] = {}
# 当前激活的适配器实例
self.active_adapters: Dict[str, DataSourceAdapter] = {}
# 适配器元数据
self.metadata: Dict[str, dict] = {}
# 注册内置适配器
self._register_builtin_adapters()
self._initialized = True
def _register_builtin_adapters(self):
"""注册内置适配器"""
# 注册星耀数智(AmazingData)适配器
self.register_adapter("amazingdata", lambda: AmazingDataAdapter())
# 设置星耀数智元数据
self.metadata["amazingdata"] = {
"name": "amazingdata",
"type": "sdk",
"version": "1.0.30",
"description": "银河证券星耀数智量化平台需要SDK和账号",
"updated_at": datetime.now()
}
# 星耀数智默认配置(需要账号)
self.configs["amazingdata"] = {
"enabled": False,
"config": {
"username": "",
"password": "",
"host": "",
"port": 8080,
"local_path": "./amazing_data_cache/",
"use_local_cache": True
}
}
def get_adapter_list(self) -> AdapterListData:
"""获取适配器列表"""
with self.lock:
adapters = []
for name, meta in self.metadata.items():
cfg = self.configs.get(name, {"enabled": False, "config": {}})
# 确定状态
if not cfg["enabled"]:
status = AdapterStatus.DISABLED
elif name in self.active_adapters:
status = AdapterStatus.ACTIVE
else:
status = AdapterStatus.STANDBY
adapters.append(AdapterInfo(
name=meta["name"],
type=meta["type"],
version=meta["version"],
description=meta["description"],
status=status,
config=cfg.get("config", {}),
updated_at=meta.get("updated_at", datetime.now())
))
return AdapterListData(adapters=adapters)
def toggle_adapter(self, req: AdapterToggleRequest) -> None:
"""启用/禁用适配器"""
with self.lock:
if req.name not in self.configs:
raise ValueError(f"Adapter not found: {req.name}")
old_enabled = self.configs[req.name]["enabled"]
self.configs[req.name]["enabled"] = req.enable
# 如果禁用,关闭适配器连接
if not req.enable and req.name in self.active_adapters:
adapter = self.active_adapters.pop(req.name)
asyncio.create_task(adapter.close())
# 如果启用且之前未启用,建立连接
if req.enable and not old_enabled and req.name not in self.active_adapters:
# 使用 create_task 启动异步连接
asyncio.create_task(self._connect_adapter(req.name))
# 更新元数据
if req.name in self.metadata:
self.metadata[req.name]["updated_at"] = datetime.now()
def update_adapter_config(self, req: AdapterConfigUpdateRequest) -> None:
"""更新适配器配置"""
with self.lock:
if req.name not in self.configs:
raise ValueError(f"Adapter not found: {req.name}")
# 更新配置
self.configs[req.name]["config"].update(req.config)
# 如果适配器已激活,重新连接
if req.name in self.active_adapters:
adapter = self.active_adapters.pop(req.name)
asyncio.create_task(adapter.close())
# 如果启用状态,重新连接
if self.configs[req.name]["enabled"]:
asyncio.create_task(self._connect_adapter(req.name))
# 更新元数据
if req.name in self.metadata:
self.metadata[req.name]["updated_at"] = datetime.now()
def get_active_adapter(self, asset_class: str) -> Optional[DataSourceAdapter]:
"""获取当前激活的适配器
Args:
asset_class: 资产类别,如 'stock', 'futures'
"""
from app.core.config import get_config
with self.lock:
# 从配置获取当前激活的适配器名称
self.config = get_config()
print(f"Getting active adapter for asset class,config is : {self.config}")
if asset_class == "stock":
active_name = self.config.sources.stock.active
elif asset_class == "futures":
active_name = self.config.sources.futures.active
else:
active_name = "custom" # 默认
print(f"Using adapter: {active_name}")
# 返回已激活的适配器实例
if active_name in self.active_adapters:
print(f"Adapter {active_name} already connected, reusing...")
return self.active_adapters[active_name]
# 如果适配器已注册但未连接,尝试连接
# 检查适配器是否已注册(在 factories 中)
if active_name in self.factories:
print(f"Adapter {active_name} registered but not connected, connecting...")
# 注意:这里不直接调用异步方法,让调用方处理连接
# 返回 None 表示需要连接,由服务层调用 _connect_adapter
# 如果没有激活的适配器返回None
return None
def get_available_adapters(self) -> List[str]:
"""获取所有可用的适配器名称"""
with self.lock:
names = []
for name, meta in self.metadata.items():
if name in self.factories:
names.append(f"{name}|{meta['description']}")
return names
def register_adapter(self, name: str, factory: Callable[[], DataSourceAdapter]):
"""注册适配器"""
with self.lock:
self.factories[name] = factory
async def _connect_adapter(self, name: str):
"""连接适配器"""
from app.core.config import get_config
with self.lock:
if name not in self.factories:
raise ValueError(f"Adapter factory not found: {name}")
if name not in self.configs:
raise ValueError(f"Adapter config not found: {name}")
# 如果已经连接,先关闭
if name in self.active_adapters:
old_adapter = self.active_adapters.pop(name)
try:
await old_adapter.close()
except Exception as e:
error(f"Error closing old adapter {name}: {e}")
factory = self.factories[name]
# 从 config.json 获取最新配置(与文件同步)
file_config = get_config()
print(f"Using file config: {file_config}, adapter name: {name}")
# 尝试从配置文件中获取适配器配置
adapter_config = None
# 1. 首先检查 stock 配置
if name in file_config.sources.stock.list:
source_info = file_config.sources.stock.list[name]
adapter_config = dict(source_info.config) if source_info else {}
print(f"Using stock config for {name}: {adapter_config}")
# 2. 然后检查 futures 配置
elif name in file_config.sources.futures.list:
source_info = file_config.sources.futures.list[name]
adapter_config = dict(source_info.config) if source_info else {}
print(f"Using futures config for {name}: {adapter_config}")
# 3. 使用默认配置
else:
adapter_config = self.configs[name].get("config", {})
print(f"Using default config for {name}: {adapter_config}")
# 处理 port 为字符串的情况
if "port" in adapter_config and isinstance(adapter_config["port"], str):
adapter_config["port"] = int(adapter_config["port"]) if adapter_config["port"].strip() else 8600
cfg = {"enabled": self.configs[name].get("enabled", False), "config": adapter_config}
try:
info(f"243 Connecting to adapter: {name},config: {cfg}")
adapter = factory()
await adapter.connect(cfg["config"])
with self.lock:
self.active_adapters[name] = adapter
info(f"Adapter {name} connected successfully")
except Exception as e:
error(f"Failed to connect adapter {name}: {e}")
# 重置启用状态
with self.lock:
if name in self.configs:
self.configs[name]["enabled"] = False
async def health_check(self, name: str) -> bool:
"""适配器健康检查"""
with self.lock:
if name not in self.active_adapters:
return False
adapter = self.active_adapters[name]
return await adapter.health_check()