feat(backend): add missing API endpoints, concurrency control, and queue management\n\n- Add /api/v1/tasks router for task management\n- Add DELETE endpoint for task deletion\n- Add GET /download endpoint for result bundling (tar.gz)\n- Add GET /queue endpoint for queue position queries\n- Create ConcurrencyManager service with Redis Semaphore (16 concurrent limit)\n- Add QUEUED status to JobStatus enum\n- Update Job model with queue_position, current_stage, progress_percent fields\n- Add scoring parameters (min_identity, min_coverage, etc.) to jobs API\n- Implement pipeline stages: digger -> shoter -> plots -> bundle\n- Add update_queue_positions Celery task for periodic queue updates\n- Clean up duplicate code in main.py\n\nCo-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
"""API v1 路由"""
|
||||
from . import jobs, upload, results, tasks
|
||||
|
||||
__all__ = ["jobs", "upload", "results", "tasks"]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""任务管理 API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from pathlib import Path
|
||||
import uuid
|
||||
import shutil
|
||||
@@ -9,7 +9,7 @@ import shutil
|
||||
from ...database import get_db
|
||||
from ...models.job import Job, JobStatus
|
||||
from ...schemas.job import JobResponse
|
||||
from ...workers.tasks import run_bttoxin_analysis
|
||||
from ...workers.tasks import run_bttoxin_analysis, update_queue_positions
|
||||
from ...config import settings
|
||||
|
||||
router = APIRouter()
|
||||
@@ -19,21 +19,55 @@ RESULTS_DIR = Path(settings.RESULTS_DIR)
|
||||
UPLOAD_DIR.mkdir(exist_ok=True)
|
||||
RESULTS_DIR.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
@router.post("/create", response_model=JobResponse)
|
||||
async def create_job(
|
||||
files: List[UploadFile] = File(...),
|
||||
sequence_type: str = "nucl",
|
||||
scaf_suffix: str = ".fna",
|
||||
threads: int = 4,
|
||||
sequence_type: str = Form("nucl"),
|
||||
scaf_suffix: str = Form(".fna"),
|
||||
threads: int = Form(4),
|
||||
min_identity: float = Form(0.8),
|
||||
min_coverage: float = Form(0.6),
|
||||
allow_unknown_families: bool = Form(False),
|
||||
require_index_hit: bool = Form(True),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""创建新任务"""
|
||||
"""
|
||||
创建新分析任务
|
||||
|
||||
Args:
|
||||
files: 上传的文件列表(单文件)
|
||||
sequence_type: 序列类型 (nucl=核酸, prot=蛋白)
|
||||
scaf_suffix: 文件后缀
|
||||
threads: 线程数
|
||||
min_identity: 最小相似度 (0-1)
|
||||
min_coverage: 最小覆盖度 (0-1)
|
||||
allow_unknown_families: 是否允许未知家族
|
||||
require_index_hit: 是否要求索引命中
|
||||
"""
|
||||
# 验证文件类型
|
||||
allowed_extensions = {".fna", ".fa", ".fasta", ".faa"}
|
||||
for file in files:
|
||||
ext = Path(file.filename).suffix.lower()
|
||||
if ext not in allowed_extensions:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid file extension: {ext}. Allowed: {', '.join(allowed_extensions)}"
|
||||
)
|
||||
|
||||
# 限制单文件上传
|
||||
if len(files) != 1:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Only one file allowed per task"
|
||||
)
|
||||
|
||||
job_id = str(uuid.uuid4())
|
||||
|
||||
job_input_dir = UPLOAD_DIR / job_id
|
||||
job_output_dir = RESULTS_DIR / job_id
|
||||
job_input_dir.mkdir(parents=True)
|
||||
job_output_dir.mkdir(parents=True)
|
||||
job_input_dir.mkdir(parents=True, exist_ok=True)
|
||||
job_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
uploaded_files = []
|
||||
for file in files:
|
||||
@@ -48,20 +82,29 @@ async def create_job(
|
||||
input_files=uploaded_files,
|
||||
sequence_type=sequence_type,
|
||||
scaf_suffix=scaf_suffix,
|
||||
threads=threads
|
||||
threads=threads,
|
||||
min_identity=int(min_identity * 100), # 存储为百分比
|
||||
min_coverage=int(min_coverage * 100),
|
||||
allow_unknown_families=int(allow_unknown_families),
|
||||
require_index_hit=int(require_index_hit),
|
||||
)
|
||||
|
||||
db.add(job)
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
|
||||
# 启动 Celery 任务
|
||||
task = run_bttoxin_analysis.delay(
|
||||
job_id=job_id,
|
||||
input_dir=str(job_input_dir),
|
||||
output_dir=str(job_output_dir),
|
||||
sequence_type=sequence_type,
|
||||
scaf_suffix=scaf_suffix,
|
||||
threads=threads
|
||||
threads=threads,
|
||||
min_identity=min_identity,
|
||||
min_coverage=min_coverage,
|
||||
allow_unknown_families=allow_unknown_families,
|
||||
require_index_hit=require_index_hit,
|
||||
)
|
||||
|
||||
job.celery_task_id = task.id
|
||||
@@ -69,6 +112,7 @@ async def create_job(
|
||||
|
||||
return job
|
||||
|
||||
|
||||
@router.get("/{job_id}", response_model=JobResponse)
|
||||
async def get_job(job_id: str, db: Session = Depends(get_db)):
|
||||
"""获取任务详情"""
|
||||
@@ -77,6 +121,7 @@ async def get_job(job_id: str, db: Session = Depends(get_db)):
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
return job
|
||||
|
||||
|
||||
@router.get("/{job_id}/progress")
|
||||
async def get_job_progress(job_id: str, db: Session = Depends(get_db)):
|
||||
"""获取任务进度"""
|
||||
@@ -84,15 +129,26 @@ async def get_job_progress(job_id: str, db: Session = Depends(get_db)):
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
result = {
|
||||
'job_id': job_id,
|
||||
'status': job.status.value if isinstance(job.status, JobStatus) else job.status,
|
||||
'current_stage': job.current_stage,
|
||||
'progress_percent': job.progress_percent,
|
||||
'queue_position': job.queue_position,
|
||||
}
|
||||
|
||||
if job.celery_task_id:
|
||||
from ...core.celery_app import celery_app
|
||||
task = celery_app.AsyncResult(job.celery_task_id)
|
||||
result['celery_state'] = task.state
|
||||
if task.state == 'PROGRESS':
|
||||
result['celery_info'] = task.info
|
||||
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'status': job.status,
|
||||
'celery_state': task.state,
|
||||
'progress': task.info if task.state == 'PROGRESS' else None
|
||||
}
|
||||
return result
|
||||
|
||||
return {'job_id': job_id, 'status': job.status}
|
||||
|
||||
@router.post("/update-queue-positions")
|
||||
async def trigger_queue_update(db: Session = Depends(get_db)):
|
||||
"""手动触发队列位置更新"""
|
||||
task = update_queue_positions.delay()
|
||||
return {"message": "Queue update triggered", "task_id": task.id}
|
||||
|
||||
@@ -1,8 +1,76 @@
|
||||
"""结果查询 API"""
|
||||
from fastapi import APIRouter
|
||||
"""结果下载 API"""
|
||||
from fastapi import APIRouter, HTTPException, Response
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from pathlib import Path
|
||||
import tarfile
|
||||
import io
|
||||
import shutil
|
||||
|
||||
from ...database import get_db
|
||||
from ...models.job import Job, JobStatus
|
||||
from ...config import settings
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/")
|
||||
async def results_info():
|
||||
return {"message": "Results endpoint"}
|
||||
RESULTS_DIR = Path(settings.RESULTS_DIR)
|
||||
|
||||
|
||||
@router.get("/{job_id}/download")
|
||||
async def download_results(job_id: str, db: Session = Depends(get_db)):
|
||||
"""下载任务结果(打包为 .tar.gz)"""
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
if job.status != JobStatus.COMPLETED:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Job not completed. Current status: {job.status}"
|
||||
)
|
||||
|
||||
job_output_dir = RESULTS_DIR / job_id
|
||||
if not job_output_dir.exists():
|
||||
raise HTTPException(status_code=404, detail="Results not found on disk")
|
||||
|
||||
# 创建 tar.gz 文件到内存
|
||||
tar_buffer = io.BytesIO()
|
||||
with tarfile.open(fileobj=tar_buffer, mode="w:gz") as tar:
|
||||
for file_path in job_output_dir.rglob("*"):
|
||||
if file_path.is_file():
|
||||
arcname = file_path.relative_to(job_output_dir)
|
||||
tar.add(file_path, arcname=arcname)
|
||||
|
||||
tar_buffer.seek(0)
|
||||
|
||||
return Response(
|
||||
content=tar_buffer.read(),
|
||||
media_type="application/gzip",
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=bttoxin_{job_id}.tar.gz"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{job_id}")
|
||||
async def delete_job(job_id: str, db: Session = Depends(get_db)):
|
||||
"""删除任务及其结果"""
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
# 删除磁盘上的文件
|
||||
job_input_dir = Path(settings.UPLOAD_DIR) / job_id
|
||||
job_output_dir = RESULTS_DIR / job_id
|
||||
|
||||
if job_input_dir.exists():
|
||||
shutil.rmtree(job_input_dir)
|
||||
|
||||
if job_output_dir.exists():
|
||||
shutil.rmtree(job_output_dir)
|
||||
|
||||
# 删除数据库记录
|
||||
db.delete(job)
|
||||
db.commit()
|
||||
|
||||
return {"message": f"Job {job_id} deleted successfully"}
|
||||
|
||||
70
backend/app/api/v1/tasks.py
Normal file
70
backend/app/api/v1/tasks.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""任务管理 API - 兼容 /api/v1/tasks 路径"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ...database import get_db
|
||||
from ...models.job import Job, JobStatus
|
||||
from ...schemas.job import JobResponse
|
||||
from ...config import settings
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class TaskCreateRequest(BaseModel):
|
||||
"""任务创建请求"""
|
||||
files: List[str] # 文件名列表
|
||||
sequence_type: str = "nucl"
|
||||
min_identity: float = 0.8
|
||||
min_coverage: float = 0.6
|
||||
allow_unknown_families: bool = False
|
||||
require_index_hit: bool = True
|
||||
|
||||
|
||||
class QueuePosition(BaseModel):
|
||||
"""队列位置信息"""
|
||||
position: int
|
||||
estimated_wait_minutes: int = None
|
||||
|
||||
|
||||
@router.post("/", response_model=JobResponse)
|
||||
async def create_task(request: TaskCreateRequest, db: Session = Depends(get_db)):
|
||||
"""创建新任务(兼容前端)"""
|
||||
# 暂时复用 jobs 逻辑
|
||||
# TODO: 实现完整的文件上传和处理
|
||||
raise HTTPException(status_code=501, detail="Use POST /api/v1/jobs/create for now")
|
||||
|
||||
|
||||
@router.get("/{task_id}", response_model=JobResponse)
|
||||
async def get_task(task_id: str, db: Session = Depends(get_db)):
|
||||
"""获取任务状态"""
|
||||
job = db.query(Job).filter(Job.id == task_id).first()
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return job
|
||||
|
||||
|
||||
@router.get("/{task_id}/queue")
|
||||
async def get_queue_position(task_id: str, db: Session = Depends(get_db)):
|
||||
"""获取排队位置"""
|
||||
job = db.query(Job).filter(Job.id == task_id).first()
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
if job.status not in [JobStatus.PENDING, JobStatus.QUEUED]:
|
||||
return {"position": 0, "message": "Task is not in queue"}
|
||||
|
||||
# 计算排队位置
|
||||
ahead_jobs = db.query(Job).filter(
|
||||
Job.id != task_id,
|
||||
Job.status.in_([JobStatus.PENDING, JobStatus.QUEUED]),
|
||||
Job.created_at < job.created_at
|
||||
).count()
|
||||
|
||||
position = ahead_jobs + 1
|
||||
# 假设每个任务约5分钟
|
||||
estimated_wait = position * 5
|
||||
|
||||
return QueuePosition(position=position, estimated_wait_minutes=estimated_wait)
|
||||
Reference in New Issue
Block a user