You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
178 lines
5.3 KiB
178 lines
5.3 KiB
|
2 weeks ago
|
"""
|
||
|
|
AI模型配置接口 - 管理AI分析模型的配置
|
||
|
|
"""
|
||
|
|
import json
|
||
|
|
import logging
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
from fastapi import APIRouter, Depends, HTTPException
|
||
|
|
from pydantic import BaseModel
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
router = APIRouter(prefix="/ai-config", tags=["AI模型配置"])
|
||
|
|
|
||
|
|
|
||
|
|
class AIModelConfig(BaseModel):
|
||
|
|
"""AI模型配置"""
|
||
|
|
model_name: str
|
||
|
|
api_key: str
|
||
|
|
api_base: str = "https://api.openai.com/v1"
|
||
|
|
model_id: str = "gpt-4"
|
||
|
|
temperature: float = 0.7
|
||
|
|
max_tokens: int = 2000
|
||
|
|
enabled: bool = True
|
||
|
|
|
||
|
|
|
||
|
|
class AIConfigResponse(BaseModel):
|
||
|
|
"""AI配置响应"""
|
||
|
|
success: bool
|
||
|
|
data: Optional[dict] = None
|
||
|
|
message: str = ""
|
||
|
|
|
||
|
|
|
||
|
|
class SaveAIConfigRequest(BaseModel):
|
||
|
|
"""保存AI配置请求"""
|
||
|
|
models: list = []
|
||
|
|
active_model: Optional[str] = None
|
||
|
|
analysis_settings: Optional[dict] = None
|
||
|
|
|
||
|
|
|
||
|
|
CONFIG_DIR = Path(__file__).resolve().parent.parent.parent / "config"
|
||
|
|
AI_CONFIG_FILE = CONFIG_DIR / "ai_config.json"
|
||
|
|
|
||
|
|
|
||
|
|
def _ensure_config_dir():
|
||
|
|
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
|
||
|
|
def _load_ai_config() -> dict:
|
||
|
|
"""加载AI配置"""
|
||
|
|
_ensure_config_dir()
|
||
|
|
if not AI_CONFIG_FILE.exists():
|
||
|
|
return {
|
||
|
|
"models": [],
|
||
|
|
"active_model": None,
|
||
|
|
"analysis_settings": {
|
||
|
|
"enable_technical_analysis": True,
|
||
|
|
"enable_fundamental_analysis": False,
|
||
|
|
"enable_sentiment_analysis": False,
|
||
|
|
"risk_tolerance": "medium",
|
||
|
|
"max_position_pct": 10
|
||
|
|
}
|
||
|
|
}
|
||
|
|
with open(AI_CONFIG_FILE, "r", encoding="utf-8") as f:
|
||
|
|
return json.load(f)
|
||
|
|
|
||
|
|
|
||
|
|
def _save_ai_config(config: dict):
|
||
|
|
"""保存AI配置"""
|
||
|
|
_ensure_config_dir()
|
||
|
|
with open(AI_CONFIG_FILE, "w", encoding="utf-8") as f:
|
||
|
|
json.dump(config, f, ensure_ascii=False, indent=4)
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("", response_model=AIConfigResponse)
|
||
|
|
def get_ai_config():
|
||
|
|
"""获取当前AI模型配置"""
|
||
|
|
try:
|
||
|
|
config = _load_ai_config()
|
||
|
|
return {"success": True, "data": config}
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"加载AI配置失败: {e}")
|
||
|
|
return {"success": False, "message": str(e)}
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("", response_model=AIConfigResponse)
|
||
|
|
def save_ai_config(config: SaveAIConfigRequest):
|
||
|
|
"""保存AI模型配置"""
|
||
|
|
try:
|
||
|
|
config_dict = {
|
||
|
|
"models": config.models,
|
||
|
|
"active_model": config.active_model,
|
||
|
|
"analysis_settings": config.analysis_settings or {}
|
||
|
|
}
|
||
|
|
_save_ai_config(config_dict)
|
||
|
|
return {"success": True, "message": "AI配置保存成功"}
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"保存AI配置失败: {e}")
|
||
|
|
return {"success": False, "message": str(e)}
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/test", response_model=AIConfigResponse)
|
||
|
|
def test_ai_connection(model_config: AIModelConfig):
|
||
|
|
"""测试AI模型连接"""
|
||
|
|
try:
|
||
|
|
import httpx
|
||
|
|
|
||
|
|
headers = {
|
||
|
|
"Authorization": f"Bearer {model_config.api_key}",
|
||
|
|
"Content-Type": "application/json"
|
||
|
|
}
|
||
|
|
|
||
|
|
data = {
|
||
|
|
"model": model_config.model_id,
|
||
|
|
"messages": [{"role": "user", "content": "Hello"}],
|
||
|
|
"max_tokens": 10
|
||
|
|
}
|
||
|
|
|
||
|
|
with httpx.Client(timeout=30) as client:
|
||
|
|
response = client.post(
|
||
|
|
f"{model_config.api_base}/chat/completions",
|
||
|
|
headers=headers,
|
||
|
|
json=data
|
||
|
|
)
|
||
|
|
|
||
|
|
if response.status_code == 200:
|
||
|
|
return {"success": True, "message": "连接测试成功"}
|
||
|
|
else:
|
||
|
|
return {"success": False, "message": f"连接失败: {response.status_code} - {response.text}"}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"AI连接测试失败: {e}")
|
||
|
|
return {"success": False, "message": f"连接测试失败: {str(e)}"}
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/providers")
|
||
|
|
def get_ai_providers():
|
||
|
|
"""获取支持的AI提供商列表"""
|
||
|
|
providers = [
|
||
|
|
{
|
||
|
|
"id": "openai",
|
||
|
|
"name": "OpenAI",
|
||
|
|
"api_base": "https://api.openai.com/v1",
|
||
|
|
"models": ["gpt-4o", "gpt-4-turbo", "gpt-3.5-turbo"]
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"id": "anthropic",
|
||
|
|
"name": "Anthropic Claude",
|
||
|
|
"api_base": "https://api.anthropic.com/v1",
|
||
|
|
"models": ["claude-3-opus", "claude-3-sonnet", "claude-3-haiku"]
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"id": "google",
|
||
|
|
"name": "Google Gemini",
|
||
|
|
"api_base": "https://generativelanguage.googleapis.com/v1beta",
|
||
|
|
"models": ["gemini-pro", "gemini-pro-vision"]
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"id": "aliyun",
|
||
|
|
"name": "阿里云通义千问",
|
||
|
|
"api_base": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||
|
|
"models": ["qwen-max", "qwen-plus", "qwen-turbo"]
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"id": "baidu",
|
||
|
|
"name": "百度文心一言",
|
||
|
|
"api_base": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop",
|
||
|
|
"models": ["ernie-4.0", "ernie-3.5", "ernie-speed"]
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"id": "zhipu",
|
||
|
|
"name": "智谱清言",
|
||
|
|
"api_base": "https://open.bigmodel.cn/api/paas/v4",
|
||
|
|
"models": ["glm-4", "glm-3-turbo"]
|
||
|
|
}
|
||
|
|
]
|
||
|
|
return {"success": True, "data": providers}
|