Files
bttoxin-pipeline/backend/app/api/v1/jobs.py

163 lines
5.2 KiB
Python

"""任务管理 API"""
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form
from sqlalchemy.orm import Session
from typing import List, Optional
from pathlib import Path
import uuid
import shutil
from ...database import get_db
from ...models.job import Job, JobStatus
from ...schemas.job import JobResponse
from ...workers.tasks import run_bttoxin_analysis, update_queue_positions
from ...config import settings
from ...core.i18n import I18n, get_i18n
router = APIRouter()
UPLOAD_DIR = Path(settings.UPLOAD_DIR)
RESULTS_DIR = Path(settings.RESULTS_DIR)
UPLOAD_DIR.mkdir(exist_ok=True)
RESULTS_DIR.mkdir(exist_ok=True)
@router.post("/create", response_model=JobResponse)
async def create_job(
files: List[UploadFile] = File(...),
sequence_type: str = Form("nucl"),
scaf_suffix: str = Form(".fna"),
threads: int = Form(4),
min_identity: float = Form(0.8),
min_coverage: float = Form(0.6),
allow_unknown_families: bool = Form(False),
require_index_hit: bool = Form(True),
crispr_fusion: bool = Form(False),
crispr_weight: float = Form(0.0),
db: Session = Depends(get_db),
i18n: I18n = Depends(get_i18n)
):
"""
创建新分析任务
Args:
files: 上传的文件列表(单文件)
sequence_type: 序列类型 (nucl=核酸, prot=蛋白)
scaf_suffix: 文件后缀
threads: 线程数
min_identity: 最小相似度 (0-1)
min_coverage: 最小覆盖度 (0-1)
allow_unknown_families: 是否允许未知家族
require_index_hit: 是否要求索引命中
"""
# 验证文件类型
allowed_extensions = {".fna", ".fa", ".fasta", ".faa"}
for file in files:
ext = Path(file.filename).suffix.lower()
if ext not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=i18n.t("invalid_extension", ext=ext, allowed=', '.join(allowed_extensions))
)
# 限制单文件上传
if len(files) != 1:
raise HTTPException(
status_code=400,
detail=i18n.t("single_file_only")
)
job_id = str(uuid.uuid4())
job_input_dir = UPLOAD_DIR / job_id
job_output_dir = RESULTS_DIR / job_id
job_input_dir.mkdir(parents=True, exist_ok=True)
job_output_dir.mkdir(parents=True, exist_ok=True)
uploaded_files = []
for file in files:
file_path = job_input_dir / file.filename
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
uploaded_files.append(file.filename)
job = Job(
id=job_id,
status=JobStatus.PENDING,
input_files=uploaded_files,
sequence_type=sequence_type,
scaf_suffix=scaf_suffix,
threads=threads,
min_identity=int(min_identity * 100), # 存储为百分比
min_coverage=int(min_coverage * 100),
allow_unknown_families=int(allow_unknown_families),
require_index_hit=int(require_index_hit),
crispr_fusion=int(crispr_fusion),
crispr_weight=int(crispr_weight * 100),
)
db.add(job)
db.commit()
db.refresh(job)
# 启动 Celery 任务
task = run_bttoxin_analysis.delay(
job_id=job_id,
input_dir=str(job_input_dir),
output_dir=str(job_output_dir),
sequence_type=sequence_type,
scaf_suffix=scaf_suffix,
threads=threads,
min_identity=min_identity,
min_coverage=min_coverage,
allow_unknown_families=allow_unknown_families,
require_index_hit=require_index_hit,
crispr_fusion=crispr_fusion,
crispr_weight=crispr_weight,
)
job.celery_task_id = task.id
db.commit()
return job
@router.get("/{job_id}", response_model=JobResponse)
async def get_job(job_id: str, db: Session = Depends(get_db), i18n: I18n = Depends(get_i18n)):
"""获取任务详情"""
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
raise HTTPException(status_code=404, detail=i18n.t("job_not_found"))
return job
@router.get("/{job_id}/progress")
async def get_job_progress(job_id: str, db: Session = Depends(get_db), i18n: I18n = Depends(get_i18n)):
"""获取任务进度"""
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
raise HTTPException(status_code=404, detail=i18n.t("job_not_found"))
result = {
'job_id': job_id,
'status': job.status.value if isinstance(job.status, JobStatus) else job.status,
'current_stage': job.current_stage,
'progress_percent': job.progress_percent,
'queue_position': job.queue_position,
}
if job.celery_task_id:
from ...core.celery_app import celery_app
task = celery_app.AsyncResult(job.celery_task_id)
result['celery_state'] = task.state
if task.state == 'PROGRESS':
result['celery_info'] = task.info
return result
@router.post("/update-queue-positions")
async def trigger_queue_update(db: Session = Depends(get_db), i18n: I18n = Depends(get_i18n)):
"""手动触发队列位置更新"""
task = update_queue_positions.delay()
return {"message": i18n.t("queue_update_triggered"), "task_id": task.id}