You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

268 lines
10 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""适配器管理服务 - 对应Go的internal/service/adapter.go"""
import asyncio
from datetime import datetime
from typing import Dict, List, Optional, Callable
from threading import RLock
from app.models import (
AdapterListData, AdapterInfo, AdapterStatus,
AdapterToggleRequest, AdapterConfigUpdateRequest
)
from app.adapters import DataSourceAdapter, AKShareAdapter, AmazingDataAdapter
from app.core.logger import info, error
class AdapterService:
"""适配器管理服务"""
def __init__(self):
self.lock = RLock()
# 已注册的适配器工厂
self.factories: Dict[str, Callable[[], DataSourceAdapter]] = {}
# 适配器配置
self.configs: Dict[str, dict] = {}
# 当前激活的适配器实例
self.active_adapters: Dict[str, DataSourceAdapter] = {}
# 适配器元数据
self.metadata: Dict[str, dict] = {}
# 注册内置适配器
self._register_builtin_adapters()
def _register_builtin_adapters(self):
"""注册内置适配器"""
# 注册AKShare适配器
self.register_adapter("akshare", lambda: AKShareAdapter())
# 设置AKShare元数据
self.metadata["akshare"] = {
"name": "akshare",
"type": "http",
"version": "1.0.0",
"description": "AKShare 开源金融数据接口无需Token",
"updated_at": datetime.now()
}
# AKShare默认配置无需token
self.configs["akshare"] = {
"enabled": True,
"config": {
"timeout": 30
}
}
# 注册星耀数智(AmazingData)适配器
self.register_adapter("amazingdata", lambda: AmazingDataAdapter())
# 设置星耀数智元数据
self.metadata["amazingdata"] = {
"name": "amazingdata",
"type": "sdk",
"version": "1.0.30",
"description": "银河证券星耀数智量化平台需要SDK和账号",
"updated_at": datetime.now()
}
# 星耀数智默认配置(需要账号)
self.configs["amazingdata"] = {
"enabled": False,
"config": {
"username": "",
"password": "",
"host": "",
"port": 8080,
"local_path": "./amazing_data_cache/",
"use_local_cache": True
}
}
def get_adapter_list(self) -> AdapterListData:
"""获取适配器列表"""
with self.lock:
adapters = []
for name, meta in self.metadata.items():
cfg = self.configs.get(name, {"enabled": False, "config": {}})
# 确定状态
if not cfg["enabled"]:
status = AdapterStatus.DISABLED
elif name in self.active_adapters:
status = AdapterStatus.ACTIVE
else:
status = AdapterStatus.STANDBY
adapters.append(AdapterInfo(
name=meta["name"],
type=meta["type"],
version=meta["version"],
description=meta["description"],
status=status,
config=cfg.get("config", {}),
updated_at=meta.get("updated_at", datetime.now())
))
return AdapterListData(adapters=adapters)
def toggle_adapter(self, req: AdapterToggleRequest) -> None:
"""启用/禁用适配器"""
with self.lock:
if req.name not in self.configs:
raise ValueError(f"Adapter not found: {req.name}")
old_enabled = self.configs[req.name]["enabled"]
self.configs[req.name]["enabled"] = req.enable
# 如果禁用,关闭适配器连接
if not req.enable and req.name in self.active_adapters:
adapter = self.active_adapters.pop(req.name)
asyncio.create_task(adapter.close())
# 如果启用且之前未启用,建立连接
if req.enable and not old_enabled and req.name not in self.active_adapters:
# 使用 create_task 启动异步连接
asyncio.create_task(self._connect_adapter(req.name))
# 更新元数据
if req.name in self.metadata:
self.metadata[req.name]["updated_at"] = datetime.now()
def update_adapter_config(self, req: AdapterConfigUpdateRequest) -> None:
"""更新适配器配置"""
with self.lock:
if req.name not in self.configs:
raise ValueError(f"Adapter not found: {req.name}")
# 更新配置
self.configs[req.name]["config"].update(req.config)
# 如果适配器已激活,重新连接
if req.name in self.active_adapters:
adapter = self.active_adapters.pop(req.name)
asyncio.create_task(adapter.close())
# 如果启用状态,重新连接
if self.configs[req.name]["enabled"]:
asyncio.create_task(self._connect_adapter(req.name))
# 更新元数据
if req.name in self.metadata:
self.metadata[req.name]["updated_at"] = datetime.now()
def get_active_adapter(self, asset_class: str) -> Optional[DataSourceAdapter]:
"""获取当前激活的适配器
Args:
asset_class: 资产类别,如 'stock', 'futures'
"""
from app.core.config import get_config
with self.lock:
# 从配置获取当前激活的适配器名称
self.config = get_config()
print(f"Getting active adapter for asset class,config is : {self.config}")
if asset_class == "stock":
active_name = self.config.sources.stock.active
elif asset_class == "futures":
active_name = self.config.sources.futures.active
else:
active_name = "akshare" # 默认
print(f"Using adapter: {active_name}")
# 返回已激活的适配器实例
if active_name in self.active_adapters:
return self.active_adapters[active_name]
# 如果配置为启用但未激活,尝试连接
if active_name in self.configs and self.configs[active_name].get("enabled"):
# 启动异步连接(不等待)
asyncio.create_task(self._connect_adapter(active_name))
# 如果没有激活的适配器返回None
return None
def get_available_adapters(self) -> List[str]:
"""获取所有可用的适配器名称"""
with self.lock:
names = []
for name, meta in self.metadata.items():
if name in self.factories:
names.append(f"{name}|{meta['description']}")
return names
def register_adapter(self, name: str, factory: Callable[[], DataSourceAdapter]):
"""注册适配器"""
with self.lock:
self.factories[name] = factory
async def _connect_adapter(self, name: str):
name = "amazingdata" # 强制使用 amazingdata 适配器进行测试
"""连接适配器"""
from app.core.config import get_config
with self.lock:
if name not in self.factories:
raise ValueError(f"Adapter factory not found: {name}")
if name not in self.configs:
raise ValueError(f"Adapter config not found: {name}")
# 如果已经连接,先关闭
if name in self.active_adapters:
old_adapter = self.active_adapters.pop(name)
try:
await old_adapter.close()
except Exception as e:
error(f"Error closing old adapter {name}: {e}")
factory = self.factories[name]
# 从 config.json 获取最新配置(与文件同步)
file_config = get_config()
print(f"226 Using file config: {file_config}, adapter name: {name}")
if name == "akshare" and False: # 暂时不使用 akshare 适配器
source_info = file_config.sources.stock.list.get("akshare")
adapter_config = dict(source_info.config) if source_info else {}
adapter_config["timeout"] = int(adapter_config.get("timeout", 30))
elif name == "amazingdata":
# 优先使用 stock 下的 amazingdata 配置
source_info = file_config.sources.stock.list["amazingdata"]
adapter_config = dict(source_info.config) if source_info else {}
print(f"Using amazingdata config: {adapter_config}")
# 处理 port 为字符串的情况
if "port" in adapter_config and isinstance(adapter_config["port"], str):
adapter_config["port"] = int(adapter_config["port"]) if adapter_config["port"].strip() else 8600
else:
adapter_config = self.configs[name].get("config", {})
cfg = {"enabled": self.configs[name].get("enabled", False), "config": adapter_config}
try:
info(f"243 Connecting to adapter: {name},config: {cfg}")
adapter = factory()
await adapter.connect(cfg["config"])
with self.lock:
self.active_adapters[name] = adapter
info(f"Adapter {name} connected successfully")
except Exception as e:
error(f"Failed to connect adapter {name}: {e}")
# 重置启用状态
with self.lock:
if name in self.configs:
self.configs[name]["enabled"] = False
async def health_check(self, name: str) -> bool:
"""适配器健康检查"""
with self.lock:
if name not in self.active_adapters:
return False
adapter = self.active_adapters[name]
return await adapter.health_check()