You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
163 lines
5.4 KiB
163 lines
5.4 KiB
"""
|
|
AmazingData 数据服务平台 - 批量操作 API
|
|
"""
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from sqlalchemy.orm import Session
|
|
from typing import Optional, List
|
|
from datetime import datetime
|
|
from backend.models.database import get_db
|
|
from backend.models.schemas import (
|
|
BaseResponse, BatchTaskRequest, BatchTaskStatus
|
|
)
|
|
from backend.models.tables import BatchTask, User
|
|
from backend.auth.dependencies import get_current_user
|
|
from backend.services.data_service import data_service
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.post("/execute", response_model=BaseResponse)
|
|
async def execute_batch_task(
|
|
request: BatchTaskRequest,
|
|
db: Session = Depends(get_db),
|
|
current_user: Optional[User] = Depends(get_current_user)
|
|
):
|
|
"""执行批量任务"""
|
|
import threading
|
|
|
|
# 创建任务记录
|
|
task = BatchTask(
|
|
task_type=request.task_type,
|
|
task_params={
|
|
"codes": request.codes,
|
|
"use_main_contract": request.use_main_contract,
|
|
"trading_days": request.trading_days,
|
|
"batch_size": request.batch_size
|
|
},
|
|
status="pending",
|
|
output_path=request.save_path,
|
|
created_by=current_user.username if current_user else "anonymous"
|
|
)
|
|
db.add(task)
|
|
db.commit()
|
|
db.refresh(task)
|
|
|
|
def batch_worker():
|
|
"""批量工作线程"""
|
|
try:
|
|
task.status = "running"
|
|
task.started_at = datetime.utcnow()
|
|
db.commit()
|
|
|
|
if request.task_type == "stock":
|
|
result = data_service.batch_get_stock_kline(
|
|
codes=request.codes,
|
|
trading_days=request.trading_days,
|
|
save_path=request.save_path,
|
|
batch_size=request.batch_size
|
|
)
|
|
elif request.task_type == "future":
|
|
result = data_service.batch_get_future_kline(
|
|
underlying_codes=request.codes,
|
|
use_main_contract=request.use_main_contract,
|
|
trading_days=request.trading_days,
|
|
save_path=request.save_path
|
|
)
|
|
else:
|
|
task.status = "error"
|
|
task.error_message = f"Unknown task type: {request.task_type}"
|
|
db.commit()
|
|
return
|
|
|
|
if "error" in result:
|
|
task.status = "error"
|
|
task.error_message = result["error"]
|
|
else:
|
|
task.status = "completed"
|
|
task.success_count = len(result)
|
|
task.completed_at = datetime.utcnow()
|
|
|
|
db.commit()
|
|
|
|
except Exception as e:
|
|
task.status = "error"
|
|
task.error_message = str(e)
|
|
db.commit()
|
|
|
|
thread = threading.Thread(target=batch_worker, daemon=True)
|
|
thread.start()
|
|
|
|
return BaseResponse(
|
|
data={
|
|
"task_id": task.id,
|
|
"status": "pending",
|
|
"message": "Batch task queued"
|
|
}
|
|
)
|
|
|
|
|
|
@router.get("/tasks", response_model=BaseResponse)
|
|
async def list_batch_tasks(
|
|
status: Optional[str] = None,
|
|
task_type: Optional[str] = None,
|
|
db: Session = Depends(get_db),
|
|
current_user: Optional[User] = Depends(get_current_user)
|
|
):
|
|
"""列出批量任务"""
|
|
query = db.query(BatchTask)
|
|
if status:
|
|
query = query.filter(BatchTask.status == status)
|
|
if task_type:
|
|
query = query.filter(BatchTask.task_type == task_type)
|
|
|
|
tasks = query.order_by(BatchTask.created_at.desc()).all()
|
|
|
|
return BaseResponse(data={
|
|
"tasks": [
|
|
{
|
|
"id": t.id,
|
|
"task_type": t.task_type,
|
|
"total_count": t.total_count,
|
|
"processed_count": t.processed_count,
|
|
"success_count": t.success_count,
|
|
"failed_count": t.failed_count,
|
|
"status": t.status,
|
|
"output_path": t.output_path,
|
|
"error_message": t.error_message,
|
|
"started_at": t.started_at.isoformat() if t.started_at else None,
|
|
"completed_at": t.completed_at.isoformat() if t.completed_at else None,
|
|
"created_at": t.created_at.isoformat() if t.created_at else None
|
|
}
|
|
for t in tasks
|
|
],
|
|
"total": len(tasks)
|
|
})
|
|
|
|
|
|
@router.get("/tasks/{task_id}", response_model=BaseResponse)
|
|
async def get_batch_task(
|
|
task_id: int,
|
|
db: Session = Depends(get_db),
|
|
current_user: Optional[User] = Depends(get_current_user)
|
|
):
|
|
"""获取批量任务详情"""
|
|
task = db.query(BatchTask).filter(BatchTask.id == task_id).first()
|
|
if not task:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
|
|
return BaseResponse(data={
|
|
"id": task.id,
|
|
"task_type": task.task_type,
|
|
"task_params": task.task_params,
|
|
"total_count": task.total_count,
|
|
"processed_count": task.processed_count,
|
|
"success_count": task.success_count,
|
|
"failed_count": task.failed_count,
|
|
"status": task.status,
|
|
"output_path": task.output_path,
|
|
"error_message": task.error_message,
|
|
"started_at": task.started_at.isoformat() if task.started_at else None,
|
|
"completed_at": task.completed_at.isoformat() if task.completed_at else None,
|
|
"created_at": task.created_at.isoformat() if task.created_at else None
|
|
}) |