99 lines
2.7 KiB
Python
99 lines
2.7 KiB
Python
"""任务管理 API"""
|
|
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
|
|
from sqlalchemy.orm import Session
|
|
from typing import List
|
|
from pathlib import Path
|
|
import uuid
|
|
import shutil
|
|
|
|
from ...database import get_db
|
|
from ...models.job import Job, JobStatus
|
|
from ...schemas.job import JobResponse
|
|
from ...workers.tasks import run_bttoxin_analysis
|
|
from ...config import settings
|
|
|
|
router = APIRouter()
|
|
|
|
UPLOAD_DIR = Path(settings.UPLOAD_DIR)
|
|
RESULTS_DIR = Path(settings.RESULTS_DIR)
|
|
UPLOAD_DIR.mkdir(exist_ok=True)
|
|
RESULTS_DIR.mkdir(exist_ok=True)
|
|
|
|
@router.post("/create", response_model=JobResponse)
|
|
async def create_job(
|
|
files: List[UploadFile] = File(...),
|
|
sequence_type: str = "nucl",
|
|
scaf_suffix: str = ".fna",
|
|
threads: int = 4,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""创建新任务"""
|
|
job_id = str(uuid.uuid4())
|
|
|
|
job_input_dir = UPLOAD_DIR / job_id
|
|
job_output_dir = RESULTS_DIR / job_id
|
|
job_input_dir.mkdir(parents=True)
|
|
job_output_dir.mkdir(parents=True)
|
|
|
|
uploaded_files = []
|
|
for file in files:
|
|
file_path = job_input_dir / file.filename
|
|
with open(file_path, "wb") as buffer:
|
|
shutil.copyfileobj(file.file, buffer)
|
|
uploaded_files.append(file.filename)
|
|
|
|
job = Job(
|
|
id=job_id,
|
|
status=JobStatus.PENDING,
|
|
input_files=uploaded_files,
|
|
sequence_type=sequence_type,
|
|
scaf_suffix=scaf_suffix,
|
|
threads=threads
|
|
)
|
|
|
|
db.add(job)
|
|
db.commit()
|
|
db.refresh(job)
|
|
|
|
task = run_bttoxin_analysis.delay(
|
|
job_id=job_id,
|
|
input_dir=str(job_input_dir),
|
|
output_dir=str(job_output_dir),
|
|
sequence_type=sequence_type,
|
|
scaf_suffix=scaf_suffix,
|
|
threads=threads
|
|
)
|
|
|
|
job.celery_task_id = task.id
|
|
db.commit()
|
|
|
|
return job
|
|
|
|
@router.get("/{job_id}", response_model=JobResponse)
|
|
async def get_job(job_id: str, db: Session = Depends(get_db)):
|
|
"""获取任务详情"""
|
|
job = db.query(Job).filter(Job.id == job_id).first()
|
|
if not job:
|
|
raise HTTPException(status_code=404, detail="Job not found")
|
|
return job
|
|
|
|
@router.get("/{job_id}/progress")
|
|
async def get_job_progress(job_id: str, db: Session = Depends(get_db)):
|
|
"""获取任务进度"""
|
|
job = db.query(Job).filter(Job.id == job_id).first()
|
|
if not job:
|
|
raise HTTPException(status_code=404, detail="Job not found")
|
|
|
|
if job.celery_task_id:
|
|
from ...core.celery_app import celery_app
|
|
task = celery_app.AsyncResult(job.celery_task_id)
|
|
|
|
return {
|
|
'job_id': job_id,
|
|
'status': job.status,
|
|
'celery_state': task.state,
|
|
'progress': task.info if task.state == 'PROGRESS' else None
|
|
}
|
|
|
|
return {'job_id': job_id, 'status': job.status}
|