"""任务管理 API""" from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form from sqlalchemy.orm import Session from typing import List, Optional from pathlib import Path import uuid import shutil from ...database import get_db from ...models.job import Job, JobStatus from ...schemas.job import JobResponse from ...workers.tasks import run_bttoxin_analysis, update_queue_positions from ...config import settings router = APIRouter() UPLOAD_DIR = Path(settings.UPLOAD_DIR) RESULTS_DIR = Path(settings.RESULTS_DIR) UPLOAD_DIR.mkdir(exist_ok=True) RESULTS_DIR.mkdir(exist_ok=True) @router.post("/create", response_model=JobResponse) async def create_job( files: List[UploadFile] = File(...), sequence_type: str = Form("nucl"), scaf_suffix: str = Form(".fna"), threads: int = Form(4), min_identity: float = Form(0.8), min_coverage: float = Form(0.6), allow_unknown_families: bool = Form(False), require_index_hit: bool = Form(True), db: Session = Depends(get_db) ): """ 创建新分析任务 Args: files: 上传的文件列表(单文件) sequence_type: 序列类型 (nucl=核酸, prot=蛋白) scaf_suffix: 文件后缀 threads: 线程数 min_identity: 最小相似度 (0-1) min_coverage: 最小覆盖度 (0-1) allow_unknown_families: 是否允许未知家族 require_index_hit: 是否要求索引命中 """ # 验证文件类型 allowed_extensions = {".fna", ".fa", ".fasta", ".faa"} for file in files: ext = Path(file.filename).suffix.lower() if ext not in allowed_extensions: raise HTTPException( status_code=400, detail=f"Invalid file extension: {ext}. Allowed: {', '.join(allowed_extensions)}" ) # 限制单文件上传 if len(files) != 1: raise HTTPException( status_code=400, detail="Only one file allowed per task" ) job_id = str(uuid.uuid4()) job_input_dir = UPLOAD_DIR / job_id job_output_dir = RESULTS_DIR / job_id job_input_dir.mkdir(parents=True, exist_ok=True) job_output_dir.mkdir(parents=True, exist_ok=True) uploaded_files = [] for file in files: file_path = job_input_dir / file.filename with open(file_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) uploaded_files.append(file.filename) job = Job( id=job_id, status=JobStatus.PENDING, input_files=uploaded_files, sequence_type=sequence_type, scaf_suffix=scaf_suffix, threads=threads, min_identity=int(min_identity * 100), # 存储为百分比 min_coverage=int(min_coverage * 100), allow_unknown_families=int(allow_unknown_families), require_index_hit=int(require_index_hit), ) db.add(job) db.commit() db.refresh(job) # 启动 Celery 任务 task = run_bttoxin_analysis.delay( job_id=job_id, input_dir=str(job_input_dir), output_dir=str(job_output_dir), sequence_type=sequence_type, scaf_suffix=scaf_suffix, threads=threads, min_identity=min_identity, min_coverage=min_coverage, allow_unknown_families=allow_unknown_families, require_index_hit=require_index_hit, ) job.celery_task_id = task.id db.commit() return job @router.get("/{job_id}", response_model=JobResponse) async def get_job(job_id: str, db: Session = Depends(get_db)): """获取任务详情""" job = db.query(Job).filter(Job.id == job_id).first() if not job: raise HTTPException(status_code=404, detail="Job not found") return job @router.get("/{job_id}/progress") async def get_job_progress(job_id: str, db: Session = Depends(get_db)): """获取任务进度""" job = db.query(Job).filter(Job.id == job_id).first() if not job: raise HTTPException(status_code=404, detail="Job not found") result = { 'job_id': job_id, 'status': job.status.value if isinstance(job.status, JobStatus) else job.status, 'current_stage': job.current_stage, 'progress_percent': job.progress_percent, 'queue_position': job.queue_position, } if job.celery_task_id: from ...core.celery_app import celery_app task = celery_app.AsyncResult(job.celery_task_id) result['celery_state'] = task.state if task.state == 'PROGRESS': result['celery_info'] = task.info return result @router.post("/update-queue-positions") async def trigger_queue_update(db: Session = Depends(get_db)): """手动触发队列位置更新""" task = update_queue_positions.delay() return {"message": "Queue update triggered", "task_id": task.id}