Files
bttoxin-pipeline/backend/app/workers/tasks.py
zly 963215de2d Fix(pipeline): prevent nested zip packaging and update CRISPR dependencies
- Add filter to skip .zip and .tar.gz files when creating result archive
- Update CRISPR feature with CASFinder dependencies (hmmer, blast, vmatch, etc.)
- Add install-casfinder task for macsydata installation
- Remove obsolete CRISPR test files

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-28 20:06:41 +08:00

256 lines
8.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Celery 任务 - 支持并发控制和多阶段 pipeline"""
from celery import Task
from pathlib import Path
import shutil
import logging
import asyncio
import subprocess
import os
import zipfile
import json
from ..core.celery_app import celery_app
from ..core.tool_runner import ToolRunner
from ..database import SessionLocal
from ..models.job import Job, JobStatus
from ..services.concurrency_manager import get_concurrency_manager
logger = logging.getLogger(__name__)
def run_local_command(cmd: list, cwd: Path = None, env: dict = None) -> dict:
"""Run a command locally in the container"""
try:
logger.info(f"Running command: {' '.join(cmd)}")
result = subprocess.run(
cmd,
cwd=cwd,
env=env or os.environ.copy(),
capture_output=True,
text=True,
check=False
)
return {
'success': result.returncode == 0,
'stdout': result.stdout,
'stderr': result.stderr,
'exit_code': result.returncode
}
except Exception as e:
logger.error(f"Command failed: {e}")
return {'success': False, 'error': str(e)}
@celery_app.task(bind=True, max_retries=3, name="backend.app.workers.tasks.run_bttoxin_analysis")
def run_bttoxin_analysis(
self,
job_id: str,
input_dir: str,
output_dir: str,
sequence_type: str = "nucl",
scaf_suffix: str = ".fna",
threads: int = 4,
min_identity: float = 0.8,
min_coverage: float = 0.6,
allow_unknown_families: bool = False,
require_index_hit: bool = True,
crispr_fusion: bool = False,
crispr_weight: float = 0.0,
lang: str = "zh"
):
"""
执行分析任务 - 使用 scripts/run_single_fna_pipeline.py 统一脚本
"""
db = SessionLocal()
job = None
try:
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
logger.error(f"Job {job_id} not found")
return {'job_id': job_id, 'status': 'error', 'error': 'Job not found'}
# 更新状态为 RUNNING
job.status = JobStatus.RUNNING
job.current_stage = "running"
job.progress_percent = 0
db.commit()
# 准备路径
input_path = Path(input_dir)
output_path = Path(output_dir)
# 查找输入文件 (由于 input_dir 是上传目录,里面应该有一个文件)
input_files = list(input_path.glob(f"*{scaf_suffix}"))
if not input_files:
# 尝试查找任意文件
files = [f for f in input_path.iterdir() if f.is_file()]
if files:
input_file = files[0]
else:
raise FileNotFoundError(f"No input file found in {input_dir}")
else:
input_file = input_files[0]
logger.info(f"Job {job_id}: Starting pipeline for {input_file}")
# 构建 pipeline 命令
# 使用 pixi run -e pipeline 来执行脚本,确保环境一致
pipeline_cmd = [
"pixi", "run", "-e", "pipeline", "python", "scripts/run_single_fna_pipeline.py",
"--input", str(input_file),
"--out_root", str(output_path),
"--toxicity_csv", "Data/toxicity-data.csv",
"--min_identity", str(min_identity),
"--min_coverage", str(min_coverage),
"--threads", str(threads),
"--lang", lang
]
if not allow_unknown_families:
pipeline_cmd.append("--disallow_unknown_families")
if require_index_hit:
pipeline_cmd.append("--require_index_hit")
# 执行脚本
res = run_local_command(pipeline_cmd, cwd=Path("/app"))
if not res['success']:
error_msg = f"Pipeline execution failed (exit={res['exit_code']}): {res['stderr']}"
logger.error(error_msg)
raise Exception(error_msg)
logger.info(f"Job {job_id}: Pipeline script completed")
# 结果打包 (Zip)
logger.info(f"Job {job_id}: Creating zip bundle")
zip_path = output_path / f"pipeline_results_{job_id}.zip"
# 在创建新 ZIP 前,删除目录下任何现有的 zip/tar.gz 文件,防止递归打包
for existing_archive in output_path.glob("*.zip"):
try:
existing_archive.unlink()
except Exception:
pass
for existing_archive in output_path.glob("*.tar.gz"):
try:
existing_archive.unlink()
except Exception:
pass
# 定义映射关系:原始目录 -> 压缩包内展示名称
dir_mapping = {
"digger": "1_Toxin_Mining",
"shotter": "2_Toxicity_Scoring",
"logs": "Logs"
}
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
# 1. 添加输入文件 (放入 Input 目录)
zipf.write(input_file, arcname=f"Input/{input_file.name}")
# 2. 添加结果目录 (重命名)
for src_name, dest_name in dir_mapping.items():
src_path = output_path / src_name
if src_path.exists():
for root, dirs, files in os.walk(src_path):
for file in files:
file_path = Path(root) / file
# 排除压缩包自己(如果有意外情况)
if file_path == zip_path:
continue
# 防止嵌套打包:忽略可能存在的 zip 或 tar.gz 文件
if file.endswith('.zip') or file.endswith('.tar.gz'):
continue
# 计算相对路径,例如 digger/Results/foo.txt -> Results/foo.txt
rel_path = file_path.relative_to(src_path)
# 构造新的归档路径 -> 1_Toxin_Mining/Results/foo.txt
arcname = Path(dest_name) / rel_path
zipf.write(file_path, arcname=str(arcname))
# 删除原始结果目录
logger.info(f"Job {job_id}: Cleaning up intermediate files")
# 需要清理的原始目录名
dirs_to_clean = ["digger", "shotter", "context", "logs", "stage"]
for d in dirs_to_clean:
d_path = output_path / d
if d_path.exists():
shutil.rmtree(d_path)
# 删除 tar.gz (如果脚本生成了)
tar_gz = output_path / "pipeline_results.tar.gz"
if tar_gz.exists():
tar_gz.unlink()
# 验证 Zip 是否生成
if not zip_path.exists():
raise Exception("Failed to create result zip file")
# 重命名为标准下载名 (或者保持这样,由 API 决定下载名)
# 这里的 output_dir 就是 API 下载时寻找的地方
# downloadResult API 默认可能找 pipeline_results.tar.gz?
# 我们需要确保 frontend 下载链接改为 zip并且后端 API 能找到这个文件
# 目前后端 API (backend/app/api/v1/results.py) 可能需要调整,或者我们把 zip 命名为 API 期望的名字?
# 假设 API 期望 output_dir 下有文件。
# 为了兼容,我们把 zip 命名为 pipeline_results.zip (通用)
# 但前端生成的下载链接是 pipeline_results_{id}.zip
# 完成
job.status = JobStatus.COMPLETED
job.progress_percent = 100
job.current_stage = "completed"
# 记录日志摘要
job.logs = res['stdout'][-2000:] if res['stdout'] else "No output"
db.commit()
logger.info(f"Job {job_id}: Completed successfully")
return {
'job_id': job_id,
'status': 'completed',
'output_dir': str(output_dir)
}
except Exception as e:
logger.error(f"Job {job_id} failed: {e}")
if job:
try:
job.status = JobStatus.FAILED
job.error_message = str(e)
job.current_stage = "failed"
db.commit()
except Exception as commit_error:
logger.error(f"Failed to update job status to FAILED: {commit_error}")
raise
finally:
db.close()
@celery_app.task
def update_queue_positions():
"""
定期更新排队任务的位置
可以通过 Celery Beat 定期调用
"""
db = SessionLocal()
try:
# 获取所有 QUEUED 状态的任务
queued_jobs = db.query(Job).filter(
Job.status == JobStatus.QUEUED
).order_by(Job.created_at).all()
for idx, job in enumerate(queued_jobs, start=1):
job.queue_position = idx
db.commit()
logger.info(f"Updated queue positions for {len(queued_jobs)} jobs")
except Exception as e:
logger.error(f"Failed to update queue positions: {e}")
db.rollback()
finally:
db.close()