""" 定时任务接口 - 创建/启动/停止/删除/列表 """ import logging from typing import Optional from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session from app.database import get_db from app.models import ScheduledTask from app.schemas import ( CreateTaskRequest, TaskInfo, TaskListResponse, ) from app.services.cache import ( create_task, list_tasks, get_task, disable_task, enable_task, delete_task, ) from app.services.scheduler import ( add_job, remove_job, is_job_running, get_all_jobs, ) logger = logging.getLogger(__name__) router = APIRouter(prefix="/tasks", tags=["定时任务"]) @router.post("", response_model=TaskInfo) def create_new_task(req: CreateTaskRequest, db: Session = Depends(get_db)): """ 创建并启动一个定时采集任务。 输入品种合约和轮询时长,自动开始定时获取数据。 """ # 将periods数组转为逗号分隔的字符串 periods_str = ",".join(req.periods) if isinstance(req.periods, list) else req.periods task = create_task( db=db, symbol=req.symbol, data_type=req.data_type, periods=periods_str, interval_seconds=req.interval_seconds, task_type=req.task_type, run_time=req.run_time, ) # 注册到调度器 job_id = add_job(task.id, task.interval_seconds, task.task_type, task.run_time) task.job_id = job_id db.commit() db.refresh(task) return _to_task_info(task) @router.get("", response_model=TaskListResponse) def list_all_tasks(db: Session = Depends(get_db)): """列出所有定时任务(未完成的)""" tasks = db.query(ScheduledTask).filter( ScheduledTask.is_finished == False ).order_by(ScheduledTask.created_at.desc()).all() job_status = get_all_jobs() task_infos = [] for t in tasks: job_id = f"task_{t.id}" job_info = job_status.get(job_id) task_infos.append(_to_task_info(t, job_info)) return TaskListResponse(tasks=task_infos, total=len(task_infos)) @router.get("/history", response_model=TaskListResponse) def list_finished_tasks(db: Session = Depends(get_db)): """列出已完成的历史任务""" tasks = db.query(ScheduledTask).filter( ScheduledTask.is_finished == True ).order_by(ScheduledTask.updated_at.desc()).all() task_infos = [] for t in tasks: task_infos.append(_to_task_info(t, None)) return TaskListResponse(tasks=task_infos, total=len(task_infos)) @router.post("/{task_id}/rerun", response_model=TaskInfo) def rerun_task(task_id: int, db: Session = Depends(get_db)): """重新执行已完成的任务""" task = get_task(db, task_id) if not task: raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在") if not task.is_finished: raise HTTPException(status_code=400, detail=f"任务 {task_id} 尚未完成,无法重新执行") # 重置任务状态 task.is_finished = False task.enabled = True task.last_run = None task.last_status = None db.commit() db.refresh(task) # 重新注册到调度器 job_id = add_job(task.id, task.interval_seconds, task.task_type, task.run_time) task.job_id = job_id db.commit() db.refresh(task) return _to_task_info(task) @router.post("/{task_id}/stop", response_model=TaskInfo) def stop_task(task_id: int, db: Session = Depends(get_db)): """停止定时任务(从调度器移除,但保留配置)""" task = get_task(db, task_id) if not task: raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在") remove_job(task_id) task = disable_task(db, task_id) return _to_task_info(task) @router.post("/{task_id}/start", response_model=TaskInfo) def start_task(task_id: int, db: Session = Depends(get_db)): """重新启动已停止的定时任务""" task = get_task(db, task_id) if not task: raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在") enable_task(db, task_id) add_job(task.id, task.interval_seconds, task.task_type, task.run_time) db.refresh(task) return _to_task_info(task) @router.delete("/{task_id}") def delete_existing_task(task_id: int, db: Session = Depends(get_db)): """删除定时任务(同时从调度器移除)""" task = get_task(db, task_id) if not task: raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在") remove_job(task_id) delete_task(db, task_id) return {"message": f"任务 {task_id} 已删除"} @router.post("/{task_id}/update-interval", response_model=TaskInfo) def update_interval( task_id: int, interval_seconds: int, db: Session = Depends(get_db), ): """更新任务的轮询间隔""" task = get_task(db, task_id) if not task: raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在") task.interval_seconds = interval_seconds task.updated_at = task.updated_at.__class__.now() db.commit() db.refresh(task) # 如果任务正在运行,更新调度器 if task.enabled and is_job_running(task_id): remove_job(task_id) add_job(task.id, task.interval_seconds, task.task_type, task.run_time) return _to_task_info(task) def _to_task_info(task, job_info: Optional[dict] = None) -> TaskInfo: """ORM -> Pydantic""" next_run = None if job_info and job_info.get("next_run_time"): next_run = job_info["next_run_time"] return TaskInfo( id=task.id, symbol=task.symbol, data_type=task.data_type, periods=task.periods.split(",") if task.periods else [], interval_seconds=task.interval_seconds, task_type=task.task_type if hasattr(task, 'task_type') else 'interval', enabled=task.enabled, running=is_job_running(task.id), last_run=task.last_run.isoformat() if task.last_run else None, last_status=task.last_status, next_run=next_run, created_at=task.created_at.isoformat(), updated_at=task.updated_at.isoformat(), )