You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

163 lines
5.4 KiB

"""
AmazingData 数据服务平台 - 批量操作 API
"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import Optional, List
from datetime import datetime
from backend.models.database import get_db
from backend.models.schemas import (
BaseResponse, BatchTaskRequest, BatchTaskStatus
)
from backend.models.tables import BatchTask, User
from backend.auth.dependencies import get_current_user
from backend.services.data_service import data_service
router = APIRouter()
@router.post("/execute", response_model=BaseResponse)
async def execute_batch_task(
request: BatchTaskRequest,
db: Session = Depends(get_db),
current_user: Optional[User] = Depends(get_current_user)
):
"""执行批量任务"""
import threading
# 创建任务记录
task = BatchTask(
task_type=request.task_type,
task_params={
"codes": request.codes,
"use_main_contract": request.use_main_contract,
"trading_days": request.trading_days,
"batch_size": request.batch_size
},
status="pending",
output_path=request.save_path,
created_by=current_user.username if current_user else "anonymous"
)
db.add(task)
db.commit()
db.refresh(task)
def batch_worker():
"""批量工作线程"""
try:
task.status = "running"
task.started_at = datetime.utcnow()
db.commit()
if request.task_type == "stock":
result = data_service.batch_get_stock_kline(
codes=request.codes,
trading_days=request.trading_days,
save_path=request.save_path,
batch_size=request.batch_size
)
elif request.task_type == "future":
result = data_service.batch_get_future_kline(
underlying_codes=request.codes,
use_main_contract=request.use_main_contract,
trading_days=request.trading_days,
save_path=request.save_path
)
else:
task.status = "error"
task.error_message = f"Unknown task type: {request.task_type}"
db.commit()
return
if "error" in result:
task.status = "error"
task.error_message = result["error"]
else:
task.status = "completed"
task.success_count = len(result)
task.completed_at = datetime.utcnow()
db.commit()
except Exception as e:
task.status = "error"
task.error_message = str(e)
db.commit()
thread = threading.Thread(target=batch_worker, daemon=True)
thread.start()
return BaseResponse(
data={
"task_id": task.id,
"status": "pending",
"message": "Batch task queued"
}
)
@router.get("/tasks", response_model=BaseResponse)
async def list_batch_tasks(
status: Optional[str] = None,
task_type: Optional[str] = None,
db: Session = Depends(get_db),
current_user: Optional[User] = Depends(get_current_user)
):
"""列出批量任务"""
query = db.query(BatchTask)
if status:
query = query.filter(BatchTask.status == status)
if task_type:
query = query.filter(BatchTask.task_type == task_type)
tasks = query.order_by(BatchTask.created_at.desc()).all()
return BaseResponse(data={
"tasks": [
{
"id": t.id,
"task_type": t.task_type,
"total_count": t.total_count,
"processed_count": t.processed_count,
"success_count": t.success_count,
"failed_count": t.failed_count,
"status": t.status,
"output_path": t.output_path,
"error_message": t.error_message,
"started_at": t.started_at.isoformat() if t.started_at else None,
"completed_at": t.completed_at.isoformat() if t.completed_at else None,
"created_at": t.created_at.isoformat() if t.created_at else None
}
for t in tasks
],
"total": len(tasks)
})
@router.get("/tasks/{task_id}", response_model=BaseResponse)
async def get_batch_task(
task_id: int,
db: Session = Depends(get_db),
current_user: Optional[User] = Depends(get_current_user)
):
"""获取批量任务详情"""
task = db.query(BatchTask).filter(BatchTask.id == task_id).first()
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return BaseResponse(data={
"id": task.id,
"task_type": task.task_type,
"task_params": task.task_params,
"total_count": task.total_count,
"processed_count": task.processed_count,
"success_count": task.success_count,
"failed_count": task.failed_count,
"status": task.status,
"output_path": task.output_path,
"error_message": task.error_message,
"started_at": task.started_at.isoformat() if task.started_at else None,
"completed_at": task.completed_at.isoformat() if task.completed_at else None,
"created_at": task.created_at.isoformat() if task.created_at else None
})