163 lines
5.2 KiB
Python
163 lines
5.2 KiB
Python
"""任务管理 API"""
|
|
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form
|
|
from sqlalchemy.orm import Session
|
|
from typing import List, Optional
|
|
from pathlib import Path
|
|
import uuid
|
|
import shutil
|
|
|
|
from ...database import get_db
|
|
from ...models.job import Job, JobStatus
|
|
from ...schemas.job import JobResponse
|
|
from ...workers.tasks import run_bttoxin_analysis, update_queue_positions
|
|
from ...config import settings
|
|
from ...core.i18n import I18n, get_i18n
|
|
|
|
router = APIRouter()
|
|
|
|
UPLOAD_DIR = Path(settings.UPLOAD_DIR)
|
|
RESULTS_DIR = Path(settings.RESULTS_DIR)
|
|
UPLOAD_DIR.mkdir(exist_ok=True)
|
|
RESULTS_DIR.mkdir(exist_ok=True)
|
|
|
|
|
|
@router.post("/create", response_model=JobResponse)
|
|
async def create_job(
|
|
files: List[UploadFile] = File(...),
|
|
sequence_type: str = Form("nucl"),
|
|
scaf_suffix: str = Form(".fna"),
|
|
threads: int = Form(4),
|
|
min_identity: float = Form(0.8),
|
|
min_coverage: float = Form(0.6),
|
|
allow_unknown_families: bool = Form(False),
|
|
require_index_hit: bool = Form(True),
|
|
crispr_fusion: bool = Form(False),
|
|
crispr_weight: float = Form(0.0),
|
|
db: Session = Depends(get_db),
|
|
i18n: I18n = Depends(get_i18n)
|
|
):
|
|
"""
|
|
创建新分析任务
|
|
|
|
Args:
|
|
files: 上传的文件列表(单文件)
|
|
sequence_type: 序列类型 (nucl=核酸, prot=蛋白)
|
|
scaf_suffix: 文件后缀
|
|
threads: 线程数
|
|
min_identity: 最小相似度 (0-1)
|
|
min_coverage: 最小覆盖度 (0-1)
|
|
allow_unknown_families: 是否允许未知家族
|
|
require_index_hit: 是否要求索引命中
|
|
"""
|
|
# 验证文件类型
|
|
allowed_extensions = {".fna", ".fa", ".fasta", ".faa"}
|
|
for file in files:
|
|
ext = Path(file.filename).suffix.lower()
|
|
if ext not in allowed_extensions:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=i18n.t("invalid_extension", ext=ext, allowed=', '.join(allowed_extensions))
|
|
)
|
|
|
|
# 限制单文件上传
|
|
if len(files) != 1:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=i18n.t("single_file_only")
|
|
)
|
|
|
|
job_id = str(uuid.uuid4())
|
|
|
|
job_input_dir = UPLOAD_DIR / job_id
|
|
job_output_dir = RESULTS_DIR / job_id
|
|
job_input_dir.mkdir(parents=True, exist_ok=True)
|
|
job_output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
uploaded_files = []
|
|
for file in files:
|
|
file_path = job_input_dir / file.filename
|
|
with open(file_path, "wb") as buffer:
|
|
shutil.copyfileobj(file.file, buffer)
|
|
uploaded_files.append(file.filename)
|
|
|
|
job = Job(
|
|
id=job_id,
|
|
status=JobStatus.PENDING,
|
|
input_files=uploaded_files,
|
|
sequence_type=sequence_type,
|
|
scaf_suffix=scaf_suffix,
|
|
threads=threads,
|
|
min_identity=int(min_identity * 100), # 存储为百分比
|
|
min_coverage=int(min_coverage * 100),
|
|
allow_unknown_families=int(allow_unknown_families),
|
|
require_index_hit=int(require_index_hit),
|
|
crispr_fusion=int(crispr_fusion),
|
|
crispr_weight=int(crispr_weight * 100),
|
|
)
|
|
|
|
db.add(job)
|
|
db.commit()
|
|
db.refresh(job)
|
|
|
|
# 启动 Celery 任务
|
|
task = run_bttoxin_analysis.delay(
|
|
job_id=job_id,
|
|
input_dir=str(job_input_dir),
|
|
output_dir=str(job_output_dir),
|
|
sequence_type=sequence_type,
|
|
scaf_suffix=scaf_suffix,
|
|
threads=threads,
|
|
min_identity=min_identity,
|
|
min_coverage=min_coverage,
|
|
allow_unknown_families=allow_unknown_families,
|
|
require_index_hit=require_index_hit,
|
|
crispr_fusion=crispr_fusion,
|
|
crispr_weight=crispr_weight,
|
|
)
|
|
|
|
job.celery_task_id = task.id
|
|
db.commit()
|
|
|
|
return job
|
|
|
|
|
|
@router.get("/{job_id}", response_model=JobResponse)
|
|
async def get_job(job_id: str, db: Session = Depends(get_db), i18n: I18n = Depends(get_i18n)):
|
|
"""获取任务详情"""
|
|
job = db.query(Job).filter(Job.id == job_id).first()
|
|
if not job:
|
|
raise HTTPException(status_code=404, detail=i18n.t("job_not_found"))
|
|
return job
|
|
|
|
|
|
@router.get("/{job_id}/progress")
|
|
async def get_job_progress(job_id: str, db: Session = Depends(get_db), i18n: I18n = Depends(get_i18n)):
|
|
"""获取任务进度"""
|
|
job = db.query(Job).filter(Job.id == job_id).first()
|
|
if not job:
|
|
raise HTTPException(status_code=404, detail=i18n.t("job_not_found"))
|
|
|
|
result = {
|
|
'job_id': job_id,
|
|
'status': job.status.value if isinstance(job.status, JobStatus) else job.status,
|
|
'current_stage': job.current_stage,
|
|
'progress_percent': job.progress_percent,
|
|
'queue_position': job.queue_position,
|
|
}
|
|
|
|
if job.celery_task_id:
|
|
from ...core.celery_app import celery_app
|
|
task = celery_app.AsyncResult(job.celery_task_id)
|
|
result['celery_state'] = task.state
|
|
if task.state == 'PROGRESS':
|
|
result['celery_info'] = task.info
|
|
|
|
return result
|
|
|
|
|
|
@router.post("/update-queue-positions")
|
|
async def trigger_queue_update(db: Session = Depends(get_db), i18n: I18n = Depends(get_i18n)):
|
|
"""手动触发队列位置更新"""
|
|
task = update_queue_positions.delay()
|
|
return {"message": i18n.t("queue_update_triggered"), "task_id": task.id}
|