You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

562 lines
17 KiB

# -*- coding: utf-8 -*-
"""
===================================
股票分析接口
===================================
职责
1. 提供 POST /api/v1/analysis/analyze 触发分析接口
2. 提供 GET /api/v1/analysis/status/{task_id} 查询任务状态接口
3. 提供 GET /api/v1/analysis/tasks 获取任务列表接口
4. 提供 GET /api/v1/analysis/tasks/stream SSE 实时推送接口
特性
- 异步任务队列分析任务异步执行不阻塞请求
- 防重复提交相同股票代码正在分析时返回 409
- SSE 实时推送任务状态变化实时通知前端
"""
import asyncio
import json
import logging
from datetime import datetime
from typing import Optional, Union, Dict, Any
from fastapi import APIRouter, HTTPException, Depends, Query
from fastapi.responses import JSONResponse, StreamingResponse
from api.deps import get_config_dep
from api.v1.schemas.analysis import (
AnalyzeRequest,
AnalysisResultResponse,
TaskAccepted,
TaskStatus,
TaskInfo,
TaskListResponse,
DuplicateTaskErrorResponse,
)
from api.v1.schemas.common import ErrorResponse
from api.v1.schemas.history import (
AnalysisReport,
ReportMeta,
ReportSummary,
ReportStrategy,
ReportDetails,
)
from src.config import Config
from src.services.task_queue import (
get_task_queue,
DuplicateTaskError,
TaskStatus as TaskStatusEnum,
)
logger = logging.getLogger(__name__)
router = APIRouter()
# ============================================================
# POST /analyze - 触发股票分析
# ============================================================
@router.post(
"/analyze",
response_model=AnalysisResultResponse,
responses={
200: {"description": "分析完成(同步模式)", "model": AnalysisResultResponse},
202: {"description": "分析任务已接受(异步模式)", "model": TaskAccepted},
400: {"description": "请求参数错误", "model": ErrorResponse},
409: {"description": "股票正在分析中,拒绝重复提交", "model": DuplicateTaskErrorResponse},
500: {"description": "分析失败", "model": ErrorResponse},
},
summary="触发股票分析",
description="启动 AI 智能分析任务,支持同步和异步模式。异步模式下相同股票代码不允许重复提交。"
)
def trigger_analysis(
request: AnalyzeRequest,
config: Config = Depends(get_config_dep)
) -> Union[AnalysisResultResponse, JSONResponse]:
"""
触发股票分析
启动 AI 智能分析任务支持单只或多只股票批量分析
流程
1. 校验请求参数
2. 异步模式检查重复 -> 提交任务队列 -> 返回 202
3. 同步模式直接执行分析 -> 返回 200
Args:
request: 分析请求参数
config: 配置依赖
Returns:
AnalysisResultResponse: 分析结果同步模式
TaskAccepted: 任务已接受异步模式返回 202
Raises:
HTTPException: 400 - 请求参数错误
HTTPException: 409 - 股票正在分析中
HTTPException: 500 - 分析失败
"""
# 校验请求参数
stock_codes = []
if request.stock_code:
stock_codes.append(request.stock_code)
if request.stock_codes:
stock_codes.extend(request.stock_codes)
if not stock_codes:
raise HTTPException(
status_code=400,
detail={
"error": "validation_error",
"message": "必须提供 stock_code 或 stock_codes 参数"
}
)
# 去重
stock_codes = list(dict.fromkeys(stock_codes))
stock_code = stock_codes[0] # 当前只处理第一个
# 异步模式:使用任务队列
if request.async_mode:
return _handle_async_analysis(stock_code, request)
# 同步模式:直接执行分析
return _handle_sync_analysis(stock_code, request)
def _handle_async_analysis(
stock_code: str,
request: AnalyzeRequest
) -> JSONResponse:
"""
处理异步分析请求
提交任务到队列立即返回 202
如果股票正在分析中返回 409
"""
task_queue = get_task_queue()
try:
# 提交任务(如果重复会抛出 DuplicateTaskError
task_info = task_queue.submit_task(
stock_code=stock_code,
stock_name=None, # 名称在分析过程中获取
report_type=request.report_type,
force_refresh=request.force_refresh,
)
# 返回 202 Accepted
task_accepted = TaskAccepted(
task_id=task_info.task_id,
status="pending",
message=f"分析任务已加入队列: {stock_code}"
)
return JSONResponse(
status_code=202,
content=task_accepted.model_dump()
)
except DuplicateTaskError as e:
# 股票正在分析中,返回 409 Conflict
error_response = DuplicateTaskErrorResponse(
error="duplicate_task",
message=str(e),
stock_code=e.stock_code,
existing_task_id=e.existing_task_id,
)
return JSONResponse(
status_code=409,
content=error_response.model_dump()
)
def _handle_sync_analysis(
stock_code: str,
request: AnalyzeRequest
) -> AnalysisResultResponse:
"""
处理同步分析请求
直接执行分析等待完成后返回结果
"""
import uuid
from src.services.analysis_service import AnalysisService
query_id = uuid.uuid4().hex
try:
service = AnalysisService()
result = service.analyze_stock(
stock_code=stock_code,
report_type=request.report_type,
force_refresh=request.force_refresh,
query_id=query_id
)
if result is None:
raise HTTPException(
status_code=500,
detail={
"error": "analysis_failed",
"message": f"分析股票 {stock_code} 失败"
}
)
# 构建报告结构
report_data = result.get("report", {})
report = _build_analysis_report(
report_data, query_id, stock_code, result.get("stock_name")
)
return AnalysisResultResponse(
query_id=query_id,
stock_code=result.get("stock_code", stock_code),
stock_name=result.get("stock_name"),
report=report.model_dump() if report else None,
created_at=datetime.now().isoformat()
)
except HTTPException:
raise
except Exception as e:
logger.error(f"分析失败: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
"error": "internal_error",
"message": f"分析过程发生错误: {str(e)}"
}
)
# ============================================================
# GET /tasks - 获取任务列表
# ============================================================
@router.get(
"/tasks",
response_model=TaskListResponse,
responses={
200: {"description": "任务列表"},
},
summary="获取分析任务列表",
description="获取当前所有分析任务,可按状态筛选"
)
def get_task_list(
status: Optional[str] = Query(
None,
description="筛选状态pending, processing, completed, failed支持逗号分隔多个"
),
limit: int = Query(20, description="返回数量限制", ge=1, le=100),
) -> TaskListResponse:
"""
获取分析任务列表
Args:
status: 状态筛选可选
limit: 返回数量限制
Returns:
TaskListResponse: 任务列表响应
"""
task_queue = get_task_queue()
# 获取所有任务
all_tasks = task_queue.list_all_tasks(limit=limit)
# 状态筛选
if status:
status_list = [s.strip().lower() for s in status.split(",")]
all_tasks = [t for t in all_tasks if t.status.value in status_list]
# 统计信息
stats = task_queue.get_task_stats()
# 转换为 Schema
task_infos = [
TaskInfo(
task_id=t.task_id,
stock_code=t.stock_code,
stock_name=t.stock_name,
status=t.status.value,
progress=t.progress,
message=t.message,
report_type=t.report_type,
created_at=t.created_at.isoformat(),
started_at=t.started_at.isoformat() if t.started_at else None,
completed_at=t.completed_at.isoformat() if t.completed_at else None,
error=t.error,
)
for t in all_tasks
]
return TaskListResponse(
total=stats["total"],
pending=stats["pending"],
processing=stats["processing"],
tasks=task_infos,
)
# ============================================================
# GET /tasks/stream - SSE 实时推送
# ============================================================
@router.get(
"/tasks/stream",
responses={
200: {"description": "SSE 事件流", "content": {"text/event-stream": {}}},
},
summary="任务状态 SSE 流",
description="通过 Server-Sent Events 实时推送任务状态变化"
)
async def task_stream():
"""
SSE 任务状态流
事件类型
- connected: 连接成功
- task_created: 新任务创建
- task_started: 任务开始执行
- task_completed: 任务完成
- task_failed: 任务失败
- heartbeat: 心跳 30
Returns:
StreamingResponse: SSE 事件流
"""
async def event_generator():
task_queue = get_task_queue()
event_queue: asyncio.Queue = asyncio.Queue()
# 发送连接成功事件
yield _format_sse_event("connected", {"message": "Connected to task stream"})
# 发送当前进行中的任务
pending_tasks = task_queue.list_pending_tasks()
for task in pending_tasks:
yield _format_sse_event("task_created", task.to_dict())
# 订阅任务事件
task_queue.subscribe(event_queue)
try:
while True:
try:
# 等待事件,超时发送心跳
event = await asyncio.wait_for(event_queue.get(), timeout=30)
yield _format_sse_event(event["type"], event["data"])
except asyncio.TimeoutError:
# 心跳
yield _format_sse_event("heartbeat", {
"timestamp": datetime.now().isoformat()
})
except asyncio.CancelledError:
# 客户端断开连接
pass
finally:
task_queue.unsubscribe(event_queue)
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # 禁用 Nginx 缓冲
}
)
def _format_sse_event(event_type: str, data: Dict[str, Any]) -> str:
"""
格式化 SSE 事件
Args:
event_type: 事件类型
data: 事件数据
Returns:
SSE 格式字符串
"""
return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
# ============================================================
# GET /status/{task_id} - 查询单个任务状态
# ============================================================
@router.get(
"/status/{task_id}",
response_model=TaskStatus,
responses={
200: {"description": "任务状态"},
404: {"description": "任务不存在", "model": ErrorResponse},
},
summary="查询分析任务状态",
description="根据 task_id 查询单个任务的状态"
)
def get_analysis_status(task_id: str) -> TaskStatus:
"""
查询分析任务状态
优先从任务队列查询如果不存在则从数据库查询历史记录
Args:
task_id: 任务 ID
Returns:
TaskStatus: 任务状态信息
Raises:
HTTPException: 404 - 任务不存在
"""
# 1. 先从任务队列查询
task_queue = get_task_queue()
task = task_queue.get_task(task_id)
if task:
return TaskStatus(
task_id=task.task_id,
status=task.status.value,
progress=task.progress,
result=None, # 进行中的任务没有结果
error=task.error,
)
# 2. 从数据库查询已完成的记录
try:
from src.storage import DatabaseManager
db = DatabaseManager.get_instance()
records = db.get_analysis_history(query_id=task_id, limit=1)
if records:
record = records[0]
# Build report from DB record so completed tasks return real data
report_dict = AnalysisReport(
meta=ReportMeta(
query_id=task_id,
stock_code=record.code,
stock_name=record.name,
report_type=getattr(record, 'report_type', None),
created_at=record.created_at.isoformat() if record.created_at else None,
),
summary=ReportSummary(
sentiment_score=record.sentiment_score,
operation_advice=record.operation_advice,
trend_prediction=record.trend_prediction,
analysis_summary=record.analysis_summary,
),
strategy=ReportStrategy(
ideal_buy=str(getattr(record, 'ideal_buy', None)) if getattr(record, 'ideal_buy', None) is not None else None,
secondary_buy=str(getattr(record, 'secondary_buy', None)) if getattr(record, 'secondary_buy', None) is not None else None,
stop_loss=str(getattr(record, 'stop_loss', None)) if getattr(record, 'stop_loss', None) is not None else None,
take_profit=str(getattr(record, 'take_profit', None)) if getattr(record, 'take_profit', None) is not None else None,
),
).model_dump()
return TaskStatus(
task_id=task_id,
status="completed",
progress=100,
result=AnalysisResultResponse(
query_id=task_id,
stock_code=record.code,
stock_name=record.name,
report=report_dict,
created_at=record.created_at.isoformat() if record.created_at else datetime.now().isoformat()
),
error=None
)
except Exception as e:
logger.error(f"查询任务状态失败: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
"error": "internal_error",
"message": f"查询任务状态失败: {str(e)}"
}
)
# 3. 任务不存在
raise HTTPException(
status_code=404,
detail={
"error": "not_found",
"message": f"任务 {task_id} 不存在或已过期"
}
)
# ============================================================
# 辅助函数
# ============================================================
def _build_analysis_report(
report_data: Dict[str, Any],
query_id: str,
stock_code: str,
stock_name: Optional[str] = None
) -> AnalysisReport:
"""
构建符合 API 规范的分析报告
Args:
report_data: 原始报告数据
query_id: 查询 ID
stock_code: 股票代码
stock_name: 股票名称
Returns:
AnalysisReport: 结构化的分析报告
"""
meta_data = report_data.get("meta", {})
summary_data = report_data.get("summary", {})
strategy_data = report_data.get("strategy", {})
details_data = report_data.get("details", {})
meta = ReportMeta(
query_id=meta_data.get("query_id", query_id),
stock_code=meta_data.get("stock_code", stock_code),
stock_name=meta_data.get("stock_name", stock_name),
report_type=meta_data.get("report_type", "detailed"),
created_at=meta_data.get("created_at", datetime.now().isoformat()),
current_price=meta_data.get("current_price"),
change_pct=meta_data.get("change_pct"),
)
summary = ReportSummary(
analysis_summary=summary_data.get("analysis_summary"),
operation_advice=summary_data.get("operation_advice"),
trend_prediction=summary_data.get("trend_prediction"),
sentiment_score=summary_data.get("sentiment_score"),
sentiment_label=summary_data.get("sentiment_label")
)
strategy = None
if strategy_data:
strategy = ReportStrategy(
ideal_buy=strategy_data.get("ideal_buy"),
secondary_buy=strategy_data.get("secondary_buy"),
stop_loss=strategy_data.get("stop_loss"),
take_profit=strategy_data.get("take_profit")
)
details = None
if details_data:
details = ReportDetails(
news_content=details_data.get("news_summary") or details_data.get("news_content"),
raw_result=details_data,
context_snapshot=None
)
return AnalysisReport(
meta=meta,
summary=summary,
strategy=strategy,
details=details
)