You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

208 lines
6.1 KiB

"""
定时任务接口 - 创建/启动/停止/删除/列表
"""
import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app.database import get_db
from app.models import ScheduledTask
from app.schemas import (
CreateTaskRequest,
TaskInfo,
TaskListResponse,
)
from app.services.cache import (
create_task,
list_tasks,
get_task,
disable_task,
enable_task,
delete_task,
)
from app.services.scheduler import (
add_job,
remove_job,
is_job_running,
get_all_jobs,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/tasks", tags=["定时任务"])
@router.post("", response_model=TaskInfo)
def create_new_task(req: CreateTaskRequest, db: Session = Depends(get_db)):
"""
创建并启动一个定时采集任务。
输入品种合约和轮询时长,自动开始定时获取数据。
"""
# 将periods数组转为逗号分隔的字符串
periods_str = ",".join(req.periods) if isinstance(req.periods, list) else req.periods
task = create_task(
db=db,
symbol=req.symbol,
data_type=req.data_type,
periods=periods_str,
interval_seconds=req.interval_seconds,
task_type=req.task_type,
run_time=req.run_time,
)
# 注册到调度器
job_id = add_job(task.id, task.interval_seconds, task.task_type, task.run_time)
task.job_id = job_id
db.commit()
db.refresh(task)
return _to_task_info(task)
@router.get("", response_model=TaskListResponse)
def list_all_tasks(db: Session = Depends(get_db)):
"""列出所有定时任务(未完成的)"""
tasks = db.query(ScheduledTask).filter(
ScheduledTask.is_finished == False
).order_by(ScheduledTask.created_at.desc()).all()
job_status = get_all_jobs()
task_infos = []
for t in tasks:
job_id = f"task_{t.id}"
job_info = job_status.get(job_id)
task_infos.append(_to_task_info(t, job_info))
return TaskListResponse(tasks=task_infos, total=len(task_infos))
@router.get("/history", response_model=TaskListResponse)
def list_finished_tasks(db: Session = Depends(get_db)):
"""列出已完成的历史任务"""
tasks = db.query(ScheduledTask).filter(
ScheduledTask.is_finished == True
).order_by(ScheduledTask.updated_at.desc()).all()
task_infos = []
for t in tasks:
task_infos.append(_to_task_info(t, None))
return TaskListResponse(tasks=task_infos, total=len(task_infos))
@router.post("/{task_id}/rerun", response_model=TaskInfo)
def rerun_task(task_id: int, db: Session = Depends(get_db)):
"""重新执行已完成的任务"""
task = get_task(db, task_id)
if not task:
raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在")
if not task.is_finished:
raise HTTPException(status_code=400, detail=f"任务 {task_id} 尚未完成,无法重新执行")
# 重置任务状态
task.is_finished = False
task.enabled = True
task.last_run = None
task.last_status = None
db.commit()
db.refresh(task)
# 重新注册到调度器
job_id = add_job(task.id, task.interval_seconds, task.task_type, task.run_time)
task.job_id = job_id
db.commit()
db.refresh(task)
return _to_task_info(task)
@router.post("/{task_id}/stop", response_model=TaskInfo)
def stop_task(task_id: int, db: Session = Depends(get_db)):
"""停止定时任务(从调度器移除,但保留配置)"""
task = get_task(db, task_id)
if not task:
raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在")
remove_job(task_id)
task = disable_task(db, task_id)
return _to_task_info(task)
@router.post("/{task_id}/start", response_model=TaskInfo)
def start_task(task_id: int, db: Session = Depends(get_db)):
"""重新启动已停止的定时任务"""
task = get_task(db, task_id)
if not task:
raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在")
enable_task(db, task_id)
add_job(task.id, task.interval_seconds, task.task_type, task.run_time)
db.refresh(task)
return _to_task_info(task)
@router.delete("/{task_id}")
def delete_existing_task(task_id: int, db: Session = Depends(get_db)):
"""删除定时任务(同时从调度器移除)"""
task = get_task(db, task_id)
if not task:
raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在")
remove_job(task_id)
delete_task(db, task_id)
return {"message": f"任务 {task_id} 已删除"}
@router.post("/{task_id}/update-interval", response_model=TaskInfo)
def update_interval(
task_id: int,
interval_seconds: int,
db: Session = Depends(get_db),
):
"""更新任务的轮询间隔"""
task = get_task(db, task_id)
if not task:
raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在")
task.interval_seconds = interval_seconds
task.updated_at = task.updated_at.__class__.now()
db.commit()
db.refresh(task)
# 如果任务正在运行,更新调度器
if task.enabled and is_job_running(task_id):
remove_job(task_id)
add_job(task.id, task.interval_seconds, task.task_type, task.run_time)
return _to_task_info(task)
def _to_task_info(task, job_info: Optional[dict] = None) -> TaskInfo:
"""ORM -> Pydantic"""
next_run = None
if job_info and job_info.get("next_run_time"):
next_run = job_info["next_run_time"]
return TaskInfo(
id=task.id,
symbol=task.symbol,
data_type=task.data_type,
periods=task.periods.split(",") if task.periods else [],
interval_seconds=task.interval_seconds,
task_type=task.task_type if hasattr(task, 'task_type') else 'interval',
enabled=task.enabled,
running=is_job_running(task.id),
last_run=task.last_run.isoformat() if task.last_run else None,
last_status=task.last_status,
next_run=next_run,
created_at=task.created_at.isoformat(),
updated_at=task.updated_at.isoformat(),
)