Files
bttoxin-pipeline/backend/app/api/routes/jobs.py
2025-10-13 21:05:00 +08:00

282 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""任务管理 API表单+文件上传,使用 Pydantic 校验)"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status
from sqlmodel import Session, select, col
import logging
from ...core.database import get_session
from ...models.job import Job, JobUpdate, JobStatus, Step, StepRead
from ...schemas.job import JobCreateRequest, JobCreateResponse, JobStatusResponse, FileUploadInfo
from ...services.workspace_service import WorkspaceManager
from ...workers.pipeline import orchestrate_pipeline
logger = logging.getLogger(__name__)
router = APIRouter()
workspace_mgr = WorkspaceManager()
@router.post(
"/",
response_model=JobCreateResponse,
status_code=status.HTTP_201_CREATED,
summary="创建新的分析任务(表单+文件)",
)
async def create_job(
files: List[UploadFile] = File(...),
name: str = Form(...),
description: Optional[str] = Form(None),
sequence_type: str = Form("nucl"),
scaf_suffix: Optional[str] = Form(None),
orfs_suffix: Optional[str] = Form(None),
prot_suffix: Optional[str] = Form(None),
platform: Optional[str] = Form(None),
reads1_suffix: Optional[str] = Form(None),
reads2_suffix: Optional[str] = Form(None),
genome_size: Optional[str] = Form(None),
suffix_len: Optional[int] = Form(None),
short1: Optional[str] = Form(None),
short2: Optional[str] = Form(None),
long: Optional[str] = Form(None),
threads: int = Form(4),
update_db: bool = Form(False),
assemble_only: bool = Form(False),
session: Session = Depends(get_session),
):
"""创建 BtToxin 分析任务:接收表单参数与文件,保存到工作区后触发 Celery。"""
if not files:
raise HTTPException(status_code=400, detail="至少需要上传一个文件")
try:
params = JobCreateRequest(
name=name,
description=description,
sequence_type=sequence_type,
scaf_suffix=scaf_suffix,
orfs_suffix=orfs_suffix,
prot_suffix=prot_suffix,
platform=platform,
reads1_suffix=reads1_suffix,
reads2_suffix=reads2_suffix,
genome_size=genome_size,
suffix_len=suffix_len,
short1=short1,
short2=short2,
long=long,
threads=threads,
update_db=update_db,
assemble_only=assemble_only,
)
except ValueError as e:
raise HTTPException(status_code=422, detail=f"参数验证失败: {e}")
job = Job(
name=params.name,
description=params.description,
sequence_type=params.sequence_type.value,
scaf_suffix=params.scaf_suffix or "",
threads=params.threads,
update_db=params.update_db,
status=JobStatus.PENDING,
)
session.add(job)
session.commit()
session.refresh(job)
workspace = workspace_mgr.create_workspace(job.id)
job.workspace_path = str(workspace["root"])
uploaded: List[FileUploadInfo] = []
for f in files:
f.file.seek(0, 2)
size = f.file.tell()
f.file.seek(0)
if size == 0:
continue
dst = workspace["inputs"] / f.filename
with dst.open("wb") as out:
out.write(await f.read())
uploaded.append(
FileUploadInfo(filename=f.filename, size=size, content_type=f.content_type, path=str(dst))
)
job.input_files = [u.model_dump() for u in uploaded]
session.add(job)
session.commit()
session.refresh(job)
cfg = {
"sequence_type": params.sequence_type.value,
"scaf_suffix": params.scaf_suffix,
"orfs_suffix": params.orfs_suffix,
"prot_suffix": params.prot_suffix,
"threads": params.threads,
"update_db": params.update_db,
"assemble_only": params.assemble_only,
"platform": params.platform.value if params.platform else None,
"reads1_suffix": params.reads1_suffix,
"reads2_suffix": params.reads2_suffix,
"suffix_len": params.suffix_len,
"genome_size": params.genome_size,
"short1": params.short1,
"short2": params.short2,
"long": params.long,
}
task = orchestrate_pipeline.delay(job.id, cfg)
job.celery_task_id = task.id
session.add(job)
session.commit()
session.refresh(job)
return JobCreateResponse(
job_id=job.id,
message="任务创建成功,正在执行分析",
uploaded_files=uploaded,
workspace_path=job.workspace_path,
celery_task_id=task.id,
)
@router.get("/", response_model=List[JobStatusResponse])
async def list_jobs(
skip: int = 0,
limit: int = 50,
status: Optional[JobStatus] = None,
session: Session = Depends(get_session),
):
"""获取任务列表"""
statement = select(Job)
if status:
statement = statement.where(Job.status == status)
statement = statement.offset(skip).limit(limit).order_by(col(Job.created_at).desc())
jobs = session.exec(statement).all()
return [
JobStatusResponse(
job_id=j.id,
name=j.name,
status=j.status.value if hasattr(j.status, "value") else j.status,
progress=j.progress,
current_step=j.current_step,
error_message=j.error_message,
created_at=j.created_at.isoformat(),
started_at=j.started_at.isoformat() if j.started_at else None,
completed_at=j.completed_at.isoformat() if j.completed_at else None,
)
for j in jobs
]
@router.get("/{job_id}", response_model=JobStatusResponse)
async def get_job(job_id: str, session: Session = Depends(get_session)):
"""获取任务详情"""
job = session.get(Job, job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
return JobStatusResponse(
job_id=job.id,
name=job.name,
status=job.status.value if hasattr(job.status, "value") else job.status,
progress=job.progress,
current_step=job.current_step,
error_message=job.error_message,
created_at=job.created_at.isoformat(),
started_at=job.started_at.isoformat() if job.started_at else None,
completed_at=job.completed_at.isoformat() if job.completed_at else None,
)
@router.patch("/{job_id}", response_model=JobRead)
async def update_job(
job_id: str, job_update: JobUpdate, session: Session = Depends(get_session)
):
"""更新任务"""
job = session.get(Job, job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
update_data = job_update.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(job, key, value)
session.add(job)
session.commit()
session.refresh(job)
return job
@router.delete("/{job_id}")
async def delete_job(job_id: str, session: Session = Depends(get_session)):
"""删除任务"""
job = session.get(Job, job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
workspace_mgr.cleanup_workspace(job_id, keep_results=False)
session.delete(job)
session.commit()
return {"message": "Job deleted successfully"}
@router.get("/{job_id}/steps", response_model=List[StepRead])
async def get_job_steps(job_id: str, session: Session = Depends(get_session)):
"""获取任务步骤列表"""
job = session.get(Job, job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
statement = select(Step).where(Step.job_id == job_id).order_by(Step.step_order)
steps = session.exec(statement).all()
return steps
@router.get("/{job_id}/progress")
async def get_job_progress(job_id: str, session: Session = Depends(get_session)):
"""获取任务进度"""
job = session.get(Job, job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
statement = select(Step).where(Step.job_id == job_id)
steps = session.exec(statement).all()
total_steps = len(steps)
completed_steps = sum(1 for s in steps if s.status == "completed")
failed_steps = sum(1 for s in steps if s.status == "failed")
return {
"job_id": job_id,
"status": job.status,
"progress": job.progress,
"current_step": job.current_step,
"total_steps": total_steps,
"completed_steps": completed_steps,
"failed_steps": failed_steps,
"steps": [
{"name": s.step_name, "status": s.status, "order": s.step_order}
for s in sorted(steps, key=lambda x: x.step_order)
],
}