- Add filter to skip .zip and .tar.gz files when creating result archive - Update CRISPR feature with CASFinder dependencies (hmmer, blast, vmatch, etc.) - Add install-casfinder task for macsydata installation - Remove obsolete CRISPR test files Co-Authored-By: Claude <noreply@anthropic.com>
256 lines
8.9 KiB
Python
256 lines
8.9 KiB
Python
"""Celery 任务 - 支持并发控制和多阶段 pipeline"""
|
||
from celery import Task
|
||
from pathlib import Path
|
||
import shutil
|
||
import logging
|
||
import asyncio
|
||
import subprocess
|
||
import os
|
||
import zipfile
|
||
import json
|
||
|
||
from ..core.celery_app import celery_app
|
||
from ..core.tool_runner import ToolRunner
|
||
from ..database import SessionLocal
|
||
from ..models.job import Job, JobStatus
|
||
from ..services.concurrency_manager import get_concurrency_manager
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
def run_local_command(cmd: list, cwd: Path = None, env: dict = None) -> dict:
|
||
"""Run a command locally in the container"""
|
||
try:
|
||
logger.info(f"Running command: {' '.join(cmd)}")
|
||
result = subprocess.run(
|
||
cmd,
|
||
cwd=cwd,
|
||
env=env or os.environ.copy(),
|
||
capture_output=True,
|
||
text=True,
|
||
check=False
|
||
)
|
||
return {
|
||
'success': result.returncode == 0,
|
||
'stdout': result.stdout,
|
||
'stderr': result.stderr,
|
||
'exit_code': result.returncode
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"Command failed: {e}")
|
||
return {'success': False, 'error': str(e)}
|
||
|
||
|
||
@celery_app.task(bind=True, max_retries=3, name="backend.app.workers.tasks.run_bttoxin_analysis")
|
||
def run_bttoxin_analysis(
|
||
self,
|
||
job_id: str,
|
||
input_dir: str,
|
||
output_dir: str,
|
||
sequence_type: str = "nucl",
|
||
scaf_suffix: str = ".fna",
|
||
threads: int = 4,
|
||
min_identity: float = 0.8,
|
||
min_coverage: float = 0.6,
|
||
allow_unknown_families: bool = False,
|
||
require_index_hit: bool = True,
|
||
crispr_fusion: bool = False,
|
||
crispr_weight: float = 0.0,
|
||
lang: str = "zh"
|
||
):
|
||
"""
|
||
执行分析任务 - 使用 scripts/run_single_fna_pipeline.py 统一脚本
|
||
"""
|
||
db = SessionLocal()
|
||
job = None
|
||
|
||
try:
|
||
job = db.query(Job).filter(Job.id == job_id).first()
|
||
if not job:
|
||
logger.error(f"Job {job_id} not found")
|
||
return {'job_id': job_id, 'status': 'error', 'error': 'Job not found'}
|
||
|
||
# 更新状态为 RUNNING
|
||
job.status = JobStatus.RUNNING
|
||
job.current_stage = "running"
|
||
job.progress_percent = 0
|
||
db.commit()
|
||
|
||
# 准备路径
|
||
input_path = Path(input_dir)
|
||
output_path = Path(output_dir)
|
||
|
||
# 查找输入文件 (由于 input_dir 是上传目录,里面应该有一个文件)
|
||
input_files = list(input_path.glob(f"*{scaf_suffix}"))
|
||
if not input_files:
|
||
# 尝试查找任意文件
|
||
files = [f for f in input_path.iterdir() if f.is_file()]
|
||
if files:
|
||
input_file = files[0]
|
||
else:
|
||
raise FileNotFoundError(f"No input file found in {input_dir}")
|
||
else:
|
||
input_file = input_files[0]
|
||
|
||
logger.info(f"Job {job_id}: Starting pipeline for {input_file}")
|
||
|
||
# 构建 pipeline 命令
|
||
# 使用 pixi run -e pipeline 来执行脚本,确保环境一致
|
||
pipeline_cmd = [
|
||
"pixi", "run", "-e", "pipeline", "python", "scripts/run_single_fna_pipeline.py",
|
||
"--input", str(input_file),
|
||
"--out_root", str(output_path),
|
||
"--toxicity_csv", "Data/toxicity-data.csv",
|
||
"--min_identity", str(min_identity),
|
||
"--min_coverage", str(min_coverage),
|
||
"--threads", str(threads),
|
||
"--lang", lang
|
||
]
|
||
|
||
if not allow_unknown_families:
|
||
pipeline_cmd.append("--disallow_unknown_families")
|
||
|
||
if require_index_hit:
|
||
pipeline_cmd.append("--require_index_hit")
|
||
|
||
# 执行脚本
|
||
res = run_local_command(pipeline_cmd, cwd=Path("/app"))
|
||
|
||
if not res['success']:
|
||
error_msg = f"Pipeline execution failed (exit={res['exit_code']}): {res['stderr']}"
|
||
logger.error(error_msg)
|
||
raise Exception(error_msg)
|
||
|
||
logger.info(f"Job {job_id}: Pipeline script completed")
|
||
|
||
# 结果打包 (Zip)
|
||
logger.info(f"Job {job_id}: Creating zip bundle")
|
||
zip_path = output_path / f"pipeline_results_{job_id}.zip"
|
||
|
||
# 在创建新 ZIP 前,删除目录下任何现有的 zip/tar.gz 文件,防止递归打包
|
||
for existing_archive in output_path.glob("*.zip"):
|
||
try:
|
||
existing_archive.unlink()
|
||
except Exception:
|
||
pass
|
||
for existing_archive in output_path.glob("*.tar.gz"):
|
||
try:
|
||
existing_archive.unlink()
|
||
except Exception:
|
||
pass
|
||
|
||
# 定义映射关系:原始目录 -> 压缩包内展示名称
|
||
dir_mapping = {
|
||
"digger": "1_Toxin_Mining",
|
||
"shotter": "2_Toxicity_Scoring",
|
||
"logs": "Logs"
|
||
}
|
||
|
||
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
||
# 1. 添加输入文件 (放入 Input 目录)
|
||
zipf.write(input_file, arcname=f"Input/{input_file.name}")
|
||
|
||
# 2. 添加结果目录 (重命名)
|
||
for src_name, dest_name in dir_mapping.items():
|
||
src_path = output_path / src_name
|
||
if src_path.exists():
|
||
for root, dirs, files in os.walk(src_path):
|
||
for file in files:
|
||
file_path = Path(root) / file
|
||
# 排除压缩包自己(如果有意外情况)
|
||
if file_path == zip_path:
|
||
continue
|
||
|
||
# 防止嵌套打包:忽略可能存在的 zip 或 tar.gz 文件
|
||
if file.endswith('.zip') or file.endswith('.tar.gz'):
|
||
continue
|
||
|
||
# 计算相对路径,例如 digger/Results/foo.txt -> Results/foo.txt
|
||
rel_path = file_path.relative_to(src_path)
|
||
# 构造新的归档路径 -> 1_Toxin_Mining/Results/foo.txt
|
||
arcname = Path(dest_name) / rel_path
|
||
zipf.write(file_path, arcname=str(arcname))
|
||
|
||
# 删除原始结果目录
|
||
logger.info(f"Job {job_id}: Cleaning up intermediate files")
|
||
# 需要清理的原始目录名
|
||
dirs_to_clean = ["digger", "shotter", "context", "logs", "stage"]
|
||
for d in dirs_to_clean:
|
||
d_path = output_path / d
|
||
if d_path.exists():
|
||
shutil.rmtree(d_path)
|
||
|
||
# 删除 tar.gz (如果脚本生成了)
|
||
tar_gz = output_path / "pipeline_results.tar.gz"
|
||
if tar_gz.exists():
|
||
tar_gz.unlink()
|
||
|
||
# 验证 Zip 是否生成
|
||
if not zip_path.exists():
|
||
raise Exception("Failed to create result zip file")
|
||
|
||
# 重命名为标准下载名 (或者保持这样,由 API 决定下载名)
|
||
# 这里的 output_dir 就是 API 下载时寻找的地方
|
||
# downloadResult API 默认可能找 pipeline_results.tar.gz?
|
||
# 我们需要确保 frontend 下载链接改为 zip,并且后端 API 能找到这个文件
|
||
# 目前后端 API (backend/app/api/v1/results.py) 可能需要调整,或者我们把 zip 命名为 API 期望的名字?
|
||
# 假设 API 期望 output_dir 下有文件。
|
||
# 为了兼容,我们把 zip 命名为 pipeline_results.zip (通用)
|
||
# 但前端生成的下载链接是 pipeline_results_{id}.zip
|
||
|
||
# 完成
|
||
job.status = JobStatus.COMPLETED
|
||
job.progress_percent = 100
|
||
job.current_stage = "completed"
|
||
# 记录日志摘要
|
||
job.logs = res['stdout'][-2000:] if res['stdout'] else "No output"
|
||
db.commit()
|
||
|
||
logger.info(f"Job {job_id}: Completed successfully")
|
||
|
||
return {
|
||
'job_id': job_id,
|
||
'status': 'completed',
|
||
'output_dir': str(output_dir)
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Job {job_id} failed: {e}")
|
||
if job:
|
||
try:
|
||
job.status = JobStatus.FAILED
|
||
job.error_message = str(e)
|
||
job.current_stage = "failed"
|
||
db.commit()
|
||
except Exception as commit_error:
|
||
logger.error(f"Failed to update job status to FAILED: {commit_error}")
|
||
raise
|
||
|
||
finally:
|
||
db.close()
|
||
|
||
|
||
@celery_app.task
|
||
def update_queue_positions():
|
||
"""
|
||
定期更新排队任务的位置
|
||
可以通过 Celery Beat 定期调用
|
||
"""
|
||
db = SessionLocal()
|
||
try:
|
||
# 获取所有 QUEUED 状态的任务
|
||
queued_jobs = db.query(Job).filter(
|
||
Job.status == JobStatus.QUEUED
|
||
).order_by(Job.created_at).all()
|
||
|
||
for idx, job in enumerate(queued_jobs, start=1):
|
||
job.queue_position = idx
|
||
|
||
db.commit()
|
||
logger.info(f"Updated queue positions for {len(queued_jobs)} jobs")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to update queue positions: {e}")
|
||
db.rollback()
|
||
finally:
|
||
db.close()
|