feat(backend): add missing API endpoints, concurrency control, and queue management\n\n- Add /api/v1/tasks router for task management\n- Add DELETE endpoint for task deletion\n- Add GET /download endpoint for result bundling (tar.gz)\n- Add GET /queue endpoint for queue position queries\n- Create ConcurrencyManager service with Redis Semaphore (16 concurrent limit)\n- Add QUEUED status to JobStatus enum\n- Update Job model with queue_position, current_stage, progress_percent fields\n- Add scoring parameters (min_identity, min_coverage, etc.) to jobs API\n- Implement pipeline stages: digger -> shoter -> plots -> bundle\n- Add update_queue_positions Celery task for periodic queue updates\n- Clean up duplicate code in main.py\n\nCo-Authored-By: Claude <noreply@anthropic.com>

This commit is contained in:
zly
2026-01-13 23:41:15 +08:00
parent 1df699b338
commit d4f0e27af8
8 changed files with 517 additions and 272 deletions

View File

@@ -1,7 +1,7 @@
"""任务管理 API"""
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form
from sqlalchemy.orm import Session
from typing import List
from typing import List, Optional
from pathlib import Path
import uuid
import shutil
@@ -9,7 +9,7 @@ import shutil
from ...database import get_db
from ...models.job import Job, JobStatus
from ...schemas.job import JobResponse
from ...workers.tasks import run_bttoxin_analysis
from ...workers.tasks import run_bttoxin_analysis, update_queue_positions
from ...config import settings
router = APIRouter()
@@ -19,21 +19,55 @@ RESULTS_DIR = Path(settings.RESULTS_DIR)
UPLOAD_DIR.mkdir(exist_ok=True)
RESULTS_DIR.mkdir(exist_ok=True)
@router.post("/create", response_model=JobResponse)
async def create_job(
files: List[UploadFile] = File(...),
sequence_type: str = "nucl",
scaf_suffix: str = ".fna",
threads: int = 4,
sequence_type: str = Form("nucl"),
scaf_suffix: str = Form(".fna"),
threads: int = Form(4),
min_identity: float = Form(0.8),
min_coverage: float = Form(0.6),
allow_unknown_families: bool = Form(False),
require_index_hit: bool = Form(True),
db: Session = Depends(get_db)
):
"""创建新任务"""
"""
创建新分析任务
Args:
files: 上传的文件列表(单文件)
sequence_type: 序列类型 (nucl=核酸, prot=蛋白)
scaf_suffix: 文件后缀
threads: 线程数
min_identity: 最小相似度 (0-1)
min_coverage: 最小覆盖度 (0-1)
allow_unknown_families: 是否允许未知家族
require_index_hit: 是否要求索引命中
"""
# 验证文件类型
allowed_extensions = {".fna", ".fa", ".fasta", ".faa"}
for file in files:
ext = Path(file.filename).suffix.lower()
if ext not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"Invalid file extension: {ext}. Allowed: {', '.join(allowed_extensions)}"
)
# 限制单文件上传
if len(files) != 1:
raise HTTPException(
status_code=400,
detail="Only one file allowed per task"
)
job_id = str(uuid.uuid4())
job_input_dir = UPLOAD_DIR / job_id
job_output_dir = RESULTS_DIR / job_id
job_input_dir.mkdir(parents=True)
job_output_dir.mkdir(parents=True)
job_input_dir.mkdir(parents=True, exist_ok=True)
job_output_dir.mkdir(parents=True, exist_ok=True)
uploaded_files = []
for file in files:
@@ -48,20 +82,29 @@ async def create_job(
input_files=uploaded_files,
sequence_type=sequence_type,
scaf_suffix=scaf_suffix,
threads=threads
threads=threads,
min_identity=int(min_identity * 100), # 存储为百分比
min_coverage=int(min_coverage * 100),
allow_unknown_families=int(allow_unknown_families),
require_index_hit=int(require_index_hit),
)
db.add(job)
db.commit()
db.refresh(job)
# 启动 Celery 任务
task = run_bttoxin_analysis.delay(
job_id=job_id,
input_dir=str(job_input_dir),
output_dir=str(job_output_dir),
sequence_type=sequence_type,
scaf_suffix=scaf_suffix,
threads=threads
threads=threads,
min_identity=min_identity,
min_coverage=min_coverage,
allow_unknown_families=allow_unknown_families,
require_index_hit=require_index_hit,
)
job.celery_task_id = task.id
@@ -69,6 +112,7 @@ async def create_job(
return job
@router.get("/{job_id}", response_model=JobResponse)
async def get_job(job_id: str, db: Session = Depends(get_db)):
"""获取任务详情"""
@@ -77,6 +121,7 @@ async def get_job(job_id: str, db: Session = Depends(get_db)):
raise HTTPException(status_code=404, detail="Job not found")
return job
@router.get("/{job_id}/progress")
async def get_job_progress(job_id: str, db: Session = Depends(get_db)):
"""获取任务进度"""
@@ -84,15 +129,26 @@ async def get_job_progress(job_id: str, db: Session = Depends(get_db)):
if not job:
raise HTTPException(status_code=404, detail="Job not found")
result = {
'job_id': job_id,
'status': job.status.value if isinstance(job.status, JobStatus) else job.status,
'current_stage': job.current_stage,
'progress_percent': job.progress_percent,
'queue_position': job.queue_position,
}
if job.celery_task_id:
from ...core.celery_app import celery_app
task = celery_app.AsyncResult(job.celery_task_id)
result['celery_state'] = task.state
if task.state == 'PROGRESS':
result['celery_info'] = task.info
return {
'job_id': job_id,
'status': job.status,
'celery_state': task.state,
'progress': task.info if task.state == 'PROGRESS' else None
}
return result
return {'job_id': job_id, 'status': job.status}
@router.post("/update-queue-positions")
async def trigger_queue_update(db: Session = Depends(get_db)):
"""手动触发队列位置更新"""
task = update_queue_positions.delay()
return {"message": "Queue update triggered", "task_id": task.id}