282 lines
8.6 KiB
Python
282 lines
8.6 KiB
Python
"""任务管理 API(表单+文件上传,使用 Pydantic 校验)"""
|
||
from typing import List, Optional
|
||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status
|
||
from sqlmodel import Session, select, col
|
||
import logging
|
||
|
||
from ...core.database import get_session
|
||
from ...models.job import Job, JobUpdate, JobStatus, Step, StepRead
|
||
from ...schemas.job import JobCreateRequest, JobCreateResponse, JobStatusResponse, FileUploadInfo
|
||
from ...services.workspace_service import WorkspaceManager
|
||
from ...workers.pipeline import orchestrate_pipeline
|
||
|
||
|
||
logger = logging.getLogger(__name__)
|
||
router = APIRouter()
|
||
workspace_mgr = WorkspaceManager()
|
||
|
||
|
||
@router.post(
|
||
"/",
|
||
response_model=JobCreateResponse,
|
||
status_code=status.HTTP_201_CREATED,
|
||
summary="创建新的分析任务(表单+文件)",
|
||
)
|
||
async def create_job(
|
||
files: List[UploadFile] = File(...),
|
||
name: str = Form(...),
|
||
description: Optional[str] = Form(None),
|
||
sequence_type: str = Form("nucl"),
|
||
scaf_suffix: Optional[str] = Form(None),
|
||
orfs_suffix: Optional[str] = Form(None),
|
||
prot_suffix: Optional[str] = Form(None),
|
||
platform: Optional[str] = Form(None),
|
||
reads1_suffix: Optional[str] = Form(None),
|
||
reads2_suffix: Optional[str] = Form(None),
|
||
genome_size: Optional[str] = Form(None),
|
||
suffix_len: Optional[int] = Form(None),
|
||
short1: Optional[str] = Form(None),
|
||
short2: Optional[str] = Form(None),
|
||
long: Optional[str] = Form(None),
|
||
threads: int = Form(4),
|
||
update_db: bool = Form(False),
|
||
assemble_only: bool = Form(False),
|
||
session: Session = Depends(get_session),
|
||
):
|
||
"""创建 BtToxin 分析任务:接收表单参数与文件,保存到工作区后触发 Celery。"""
|
||
|
||
if not files:
|
||
raise HTTPException(status_code=400, detail="至少需要上传一个文件")
|
||
|
||
try:
|
||
params = JobCreateRequest(
|
||
name=name,
|
||
description=description,
|
||
sequence_type=sequence_type,
|
||
scaf_suffix=scaf_suffix,
|
||
orfs_suffix=orfs_suffix,
|
||
prot_suffix=prot_suffix,
|
||
platform=platform,
|
||
reads1_suffix=reads1_suffix,
|
||
reads2_suffix=reads2_suffix,
|
||
genome_size=genome_size,
|
||
suffix_len=suffix_len,
|
||
short1=short1,
|
||
short2=short2,
|
||
long=long,
|
||
threads=threads,
|
||
update_db=update_db,
|
||
assemble_only=assemble_only,
|
||
)
|
||
except ValueError as e:
|
||
raise HTTPException(status_code=422, detail=f"参数验证失败: {e}")
|
||
|
||
job = Job(
|
||
name=params.name,
|
||
description=params.description,
|
||
sequence_type=params.sequence_type.value,
|
||
scaf_suffix=params.scaf_suffix or "",
|
||
threads=params.threads,
|
||
update_db=params.update_db,
|
||
status=JobStatus.PENDING,
|
||
)
|
||
session.add(job)
|
||
session.commit()
|
||
session.refresh(job)
|
||
|
||
workspace = workspace_mgr.create_workspace(job.id)
|
||
job.workspace_path = str(workspace["root"])
|
||
|
||
uploaded: List[FileUploadInfo] = []
|
||
for f in files:
|
||
f.file.seek(0, 2)
|
||
size = f.file.tell()
|
||
f.file.seek(0)
|
||
if size == 0:
|
||
continue
|
||
dst = workspace["inputs"] / f.filename
|
||
with dst.open("wb") as out:
|
||
out.write(await f.read())
|
||
uploaded.append(
|
||
FileUploadInfo(filename=f.filename, size=size, content_type=f.content_type, path=str(dst))
|
||
)
|
||
|
||
job.input_files = [u.model_dump() for u in uploaded]
|
||
session.add(job)
|
||
session.commit()
|
||
session.refresh(job)
|
||
|
||
cfg = {
|
||
"sequence_type": params.sequence_type.value,
|
||
"scaf_suffix": params.scaf_suffix,
|
||
"orfs_suffix": params.orfs_suffix,
|
||
"prot_suffix": params.prot_suffix,
|
||
"threads": params.threads,
|
||
"update_db": params.update_db,
|
||
"assemble_only": params.assemble_only,
|
||
"platform": params.platform.value if params.platform else None,
|
||
"reads1_suffix": params.reads1_suffix,
|
||
"reads2_suffix": params.reads2_suffix,
|
||
"suffix_len": params.suffix_len,
|
||
"genome_size": params.genome_size,
|
||
"short1": params.short1,
|
||
"short2": params.short2,
|
||
"long": params.long,
|
||
}
|
||
|
||
task = orchestrate_pipeline.delay(job.id, cfg)
|
||
job.celery_task_id = task.id
|
||
session.add(job)
|
||
session.commit()
|
||
session.refresh(job)
|
||
|
||
return JobCreateResponse(
|
||
job_id=job.id,
|
||
message="任务创建成功,正在执行分析",
|
||
uploaded_files=uploaded,
|
||
workspace_path=job.workspace_path,
|
||
celery_task_id=task.id,
|
||
)
|
||
|
||
|
||
@router.get("/", response_model=List[JobStatusResponse])
|
||
async def list_jobs(
|
||
skip: int = 0,
|
||
limit: int = 50,
|
||
status: Optional[JobStatus] = None,
|
||
session: Session = Depends(get_session),
|
||
):
|
||
"""获取任务列表"""
|
||
|
||
statement = select(Job)
|
||
|
||
if status:
|
||
statement = statement.where(Job.status == status)
|
||
|
||
statement = statement.offset(skip).limit(limit).order_by(col(Job.created_at).desc())
|
||
|
||
jobs = session.exec(statement).all()
|
||
return [
|
||
JobStatusResponse(
|
||
job_id=j.id,
|
||
name=j.name,
|
||
status=j.status.value if hasattr(j.status, "value") else j.status,
|
||
progress=j.progress,
|
||
current_step=j.current_step,
|
||
error_message=j.error_message,
|
||
created_at=j.created_at.isoformat(),
|
||
started_at=j.started_at.isoformat() if j.started_at else None,
|
||
completed_at=j.completed_at.isoformat() if j.completed_at else None,
|
||
)
|
||
for j in jobs
|
||
]
|
||
|
||
|
||
@router.get("/{job_id}", response_model=JobStatusResponse)
|
||
async def get_job(job_id: str, session: Session = Depends(get_session)):
|
||
"""获取任务详情"""
|
||
|
||
job = session.get(Job, job_id)
|
||
|
||
if not job:
|
||
raise HTTPException(status_code=404, detail="Job not found")
|
||
|
||
return JobStatusResponse(
|
||
job_id=job.id,
|
||
name=job.name,
|
||
status=job.status.value if hasattr(job.status, "value") else job.status,
|
||
progress=job.progress,
|
||
current_step=job.current_step,
|
||
error_message=job.error_message,
|
||
created_at=job.created_at.isoformat(),
|
||
started_at=job.started_at.isoformat() if job.started_at else None,
|
||
completed_at=job.completed_at.isoformat() if job.completed_at else None,
|
||
)
|
||
|
||
|
||
@router.patch("/{job_id}", response_model=JobRead)
|
||
async def update_job(
|
||
job_id: str, job_update: JobUpdate, session: Session = Depends(get_session)
|
||
):
|
||
"""更新任务"""
|
||
|
||
job = session.get(Job, job_id)
|
||
|
||
if not job:
|
||
raise HTTPException(status_code=404, detail="Job not found")
|
||
|
||
update_data = job_update.model_dump(exclude_unset=True)
|
||
for key, value in update_data.items():
|
||
setattr(job, key, value)
|
||
|
||
session.add(job)
|
||
session.commit()
|
||
session.refresh(job)
|
||
|
||
return job
|
||
|
||
|
||
@router.delete("/{job_id}")
|
||
async def delete_job(job_id: str, session: Session = Depends(get_session)):
|
||
"""删除任务"""
|
||
|
||
job = session.get(Job, job_id)
|
||
|
||
if not job:
|
||
raise HTTPException(status_code=404, detail="Job not found")
|
||
|
||
workspace_mgr.cleanup_workspace(job_id, keep_results=False)
|
||
|
||
session.delete(job)
|
||
session.commit()
|
||
|
||
return {"message": "Job deleted successfully"}
|
||
|
||
|
||
@router.get("/{job_id}/steps", response_model=List[StepRead])
|
||
async def get_job_steps(job_id: str, session: Session = Depends(get_session)):
|
||
"""获取任务步骤列表"""
|
||
|
||
job = session.get(Job, job_id)
|
||
|
||
if not job:
|
||
raise HTTPException(status_code=404, detail="Job not found")
|
||
|
||
statement = select(Step).where(Step.job_id == job_id).order_by(Step.step_order)
|
||
steps = session.exec(statement).all()
|
||
|
||
return steps
|
||
|
||
|
||
@router.get("/{job_id}/progress")
|
||
async def get_job_progress(job_id: str, session: Session = Depends(get_session)):
|
||
"""获取任务进度"""
|
||
|
||
job = session.get(Job, job_id)
|
||
|
||
if not job:
|
||
raise HTTPException(status_code=404, detail="Job not found")
|
||
|
||
statement = select(Step).where(Step.job_id == job_id)
|
||
steps = session.exec(statement).all()
|
||
|
||
total_steps = len(steps)
|
||
completed_steps = sum(1 for s in steps if s.status == "completed")
|
||
failed_steps = sum(1 for s in steps if s.status == "failed")
|
||
|
||
return {
|
||
"job_id": job_id,
|
||
"status": job.status,
|
||
"progress": job.progress,
|
||
"current_step": job.current_step,
|
||
"total_steps": total_steps,
|
||
"completed_steps": completed_steps,
|
||
"failed_steps": failed_steps,
|
||
"steps": [
|
||
{"name": s.step_name, "status": s.status, "order": s.step_order}
|
||
for s in sorted(steps, key=lambda x: x.step_order)
|
||
],
|
||
}
|
||
|
||
|