You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
68 lines
2.4 KiB
68 lines
2.4 KiB
|
2 months ago
|
from typing import Dict, Optional
|
||
|
|
from loguru import logger
|
||
|
|
|
||
|
|
from adapters.base import BaseDataAdapter
|
||
|
|
from adapters.eastmoney import EastmoneyAdapter
|
||
|
|
from adapters.ths import THSAdapter
|
||
|
|
from adapters.xueqiu import XueqiuAdapter
|
||
|
|
from adapters.tencent import TencentAdapter
|
||
|
|
from models import DataSourceType, DataSource
|
||
|
|
|
||
|
|
|
||
|
|
class DataAdapterFactory:
|
||
|
|
_adapters: Dict[DataSourceType, BaseDataAdapter] = {}
|
||
|
|
_default_adapter: Optional[BaseDataAdapter] = None
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def get_adapter(cls, source: DataSourceType = DataSourceType.EASTMONEY) -> BaseDataAdapter:
|
||
|
|
if source not in cls._adapters:
|
||
|
|
if source == DataSourceType.EASTMONEY:
|
||
|
|
cls._adapters[source] = EastmoneyAdapter()
|
||
|
|
elif source == DataSourceType.THS:
|
||
|
|
cls._adapters[source] = THSAdapter()
|
||
|
|
elif source == DataSourceType.XUEQIU:
|
||
|
|
cls._adapters[source] = XueqiuAdapter()
|
||
|
|
elif source == DataSourceType.TENCENT:
|
||
|
|
cls._adapters[source] = TencentAdapter()
|
||
|
|
else:
|
||
|
|
cls._adapters[source] = EastmoneyAdapter()
|
||
|
|
|
||
|
|
return cls._adapters[source]
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def get_default_adapter(cls) -> BaseDataAdapter:
|
||
|
|
if cls._default_adapter is None:
|
||
|
|
cls._default_adapter = cls.get_adapter(DataSourceType.EASTMONEY)
|
||
|
|
return cls._default_adapter
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
async def get_available_sources(cls) -> list:
|
||
|
|
sources = []
|
||
|
|
for source_type in DataSourceType:
|
||
|
|
adapter = cls.get_adapter(source_type)
|
||
|
|
is_available = await adapter.is_available()
|
||
|
|
sources.append(DataSource(
|
||
|
|
id=source_type.value,
|
||
|
|
name=adapter.name,
|
||
|
|
icon=cls._get_icon(source_type),
|
||
|
|
is_available=is_available,
|
||
|
|
))
|
||
|
|
return sources
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def _get_icon(cls, source: DataSourceType) -> str:
|
||
|
|
icons = {
|
||
|
|
DataSourceType.EASTMONEY: "TrendingUp",
|
||
|
|
DataSourceType.THS: "BarChart3",
|
||
|
|
DataSourceType.XUEQIU: "Activity",
|
||
|
|
DataSourceType.TENCENT: "LineChart",
|
||
|
|
}
|
||
|
|
return icons.get(source, "Database")
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
async def close_all(cls):
|
||
|
|
for adapter in cls._adapters.values():
|
||
|
|
if hasattr(adapter, 'close'):
|
||
|
|
await adapter.close()
|
||
|
|
cls._adapters.clear()
|
||
|
|
cls._default_adapter = None
|