Files
bttoxin-pipeline/backend/app/workers/tasks.py

318 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Celery 任务 - 支持并发控制和多阶段 pipeline"""
from celery import Task
from pathlib import Path
import shutil
import logging
import asyncio
import subprocess
import os
from ..core.celery_app import celery_app
from ..core.docker_client import DockerManager
from ..database import SessionLocal
from ..models.job import Job, JobStatus
from ..services.concurrency_manager import get_concurrency_manager
logger = logging.getLogger(__name__)
# Pipeline 阶段定义
PIPELINE_STAGES = ["digger", "crispr", "shoter", "plots", "bundle"]
def run_local_command(cmd: list, cwd: Path = None, env: dict = None) -> dict:
"""Run a command locally in the container"""
try:
logger.info(f"Running command: {' '.join(cmd)}")
result = subprocess.run(
cmd,
cwd=cwd,
env=env or os.environ.copy(),
capture_output=True,
text=True,
check=False
)
return {
'success': result.returncode == 0,
'stdout': result.stdout,
'stderr': result.stderr,
'exit_code': result.returncode
}
except Exception as e:
logger.error(f"Command failed: {e}")
return {'success': False, 'error': str(e)}
@celery_app.task(bind=True, max_retries=3)
def run_bttoxin_analysis(
self,
job_id: str,
input_dir: str,
output_dir: str,
sequence_type: str = "nucl",
scaf_suffix: str = ".fna",
threads: int = 4,
min_identity: float = 0.8,
min_coverage: float = 0.6,
allow_unknown_families: bool = False,
require_index_hit: bool = True,
crispr_fusion: bool = False,
crispr_weight: float = 0.0,
):
"""
执行分析任务 - 完整的 4 阶段 pipeline
Stages:
1. digger - BtToxin_Digger 识别毒素基因
2. shoter - BtToxin_Shoter 评估毒性活性
3. plots - 生成热力图
4. bundle - 打包结果
"""
db = SessionLocal()
try:
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
logger.error(f"Job {job_id} not found")
return {'job_id': job_id, 'status': 'error', 'error': 'Job not found'}
# 更新状态为 QUEUED
job.status = JobStatus.QUEUED
db.commit()
# 尝试获取执行槽位(使用同步 Redis因为 Celery 是同步的)
# 注意:这里简化处理,实际应该用异步
# 暂时直接执行,稍后集成真正的并发控制
# 更新状态为 RUNNING
job.status = JobStatus.RUNNING
job.current_stage = "digger"
job.progress_percent = 0
db.commit()
# 阶段 1: Digger - 识别毒素基因
logger.info(f"Job {job_id}: Starting Digger stage")
self.update_state(
state='PROGRESS',
meta={'stage': 'digger', 'progress': 10, 'status': 'Running BtToxin_Digger...'}
)
docker_manager = DockerManager()
digger_result = docker_manager.run_bttoxin_digger(
input_dir=Path(input_dir),
output_dir=Path(output_dir),
sequence_type=sequence_type,
scaf_suffix=scaf_suffix,
threads=threads
)
if not digger_result['success']:
raise Exception(f"Digger stage failed: {digger_result.get('error', 'Unknown error')}")
job.progress_percent = 40
db.commit()
# 阶段 1.5: CRISPR-Cas (如果启用)
crispr_results_file = None
if crispr_fusion:
logger.info(f"Job {job_id}: Starting CRISPR stage")
job.current_stage = "crispr"
db.commit()
self.update_state(
state='PROGRESS',
meta={'stage': 'crispr', 'progress': 45, 'status': 'Running CRISPR Detection...'}
)
crispr_out = Path(output_dir) / "crispr" / "results.json"
crispr_out.parent.mkdir(parents=True, exist_ok=True)
# 1. Detection
detect_cmd = [
"pixi", "run", "-e", "crispr", "python", "crispr_cas/scripts/detect_crispr.py",
"--input", str(Path(input_dir) / f"{job_id}{scaf_suffix}"), # Assuming input file name matches
"--output", str(crispr_out),
"--mock" # Always use mock for now as we don't have the tool installed
]
# Find input file - might be named differently
input_files = list(Path(input_dir).glob(f"*{scaf_suffix}"))
if input_files:
detect_cmd[7] = str(input_files[0])
res = run_local_command(detect_cmd, cwd=Path("/app"))
if not res['success']:
logger.warning(f"CRISPR detection failed: {res.get('stderr')}")
else:
crispr_results_file = crispr_out
# 2. Fusion (if requested)
fusion_out = Path(output_dir) / "crispr" / "fusion_analysis.json"
# TODO: We need the toxins file from Digger output.
# Assuming Digger output structure: output_dir/Results/Toxins/All_Toxins.txt (Need to verify)
# But DockerManager output might be different. Let's assume standard structure.
toxins_file = Path(output_dir) / "Results" / "Toxins" / "All_Toxins.txt"
if toxins_file.exists():
fusion_cmd = [
"pixi", "run", "-e", "crispr", "python", "crispr_cas/scripts/fusion_analysis.py",
"--crispr-results", str(crispr_out),
"--toxin-results", str(toxins_file),
"--genome", str(input_files[0]),
"--output", str(fusion_out),
"--mock"
]
run_local_command(fusion_cmd, cwd=Path("/app"))
# 阶段 2: Shoter - 评估毒性活性
logger.info(f"Job {job_id}: Starting Shoter stage")
job.current_stage = "shoter"
db.commit()
self.update_state(
state='PROGRESS',
meta={'stage': 'shoter', 'progress': 50, 'status': 'Running BtToxin_Shoter...'}
)
# 构建 Shoter 命令
# 假设 Digger 输出在 output_dir/Results/Toxins/All_Toxins.txt
toxins_file = Path(output_dir) / "Results" / "Toxins" / "All_Toxins.txt"
shoter_out_dir = Path(output_dir) / "shoter"
# 即使 Digger 失败或没有结果,我们也可以尝试运行(脚本会处理空文件)
# 如果文件不存在,可能 Digger 结构不同,需要适配
shoter_cmd = [
"pixi", "run", "-e", "pipeline", "python", "scripts/bttoxin_shoter.py",
"--all_toxins", str(toxins_file),
"--output_dir", str(shoter_out_dir),
"--min_identity", str(min_identity),
"--min_coverage", str(min_coverage)
]
if allow_unknown_families:
shoter_cmd.append("--allow_unknown_families")
if require_index_hit:
shoter_cmd.append("--require_index_hit")
# CRISPR Integration
if crispr_results_file:
shoter_cmd.extend(["--crispr_results", str(crispr_results_file)])
shoter_cmd.extend(["--crispr_weight", str(crispr_weight)])
if crispr_fusion:
shoter_cmd.append("--crispr_fusion")
run_local_command(shoter_cmd, cwd=Path("/app"))
job.progress_percent = 70
db.commit()
# 阶段 3: Plots - 生成热力图
logger.info(f"Job {job_id}: Starting Plots stage")
job.current_stage = "plots"
db.commit()
self.update_state(
state='PROGRESS',
meta={'stage': 'plots', 'progress': 80, 'status': 'Generating plots...'}
)
plot_cmd = [
"pixi", "run", "-e", "pipeline", "python", "scripts/plot_shotter.py",
"--strain_scores", str(shoter_out_dir / "strain_target_scores.tsv"),
"--toxin_support", str(shoter_out_dir / "toxin_support.tsv"),
"--species_scores", str(shoter_out_dir / "strain_target_species_scores.tsv"),
"--out_dir", str(shoter_out_dir),
"--output_prefix", "Activity_Heatmap"
]
if crispr_results_file:
plot_cmd.extend(["--crispr_results", str(crispr_results_file)])
if crispr_fusion:
plot_cmd.append("--crispr_fusion")
run_local_command(plot_cmd, cwd=Path("/app"))
job.progress_percent = 90
db.commit()
# 阶段 4: Bundle - 打包结果
logger.info(f"Job {job_id}: Starting Bundle stage")
job.current_stage = "bundle"
db.commit()
self.update_state(
state='PROGRESS',
meta={'stage': 'bundle', 'progress': 95, 'status': 'Bundling results...'}
)
# 创建 manifest.json
import json
manifest = {
"job_id": job_id,
"stages_completed": ["digger"],
"stages_skipped": ["shoter", "plots", "bundle"],
"output_files": list(Path(output_dir).rglob("*")),
"parameters": {
"sequence_type": sequence_type,
"min_identity": min_identity,
"min_coverage": min_coverage,
"allow_unknown_families": allow_unknown_families,
"require_index_hit": require_index_hit,
"crispr_fusion": crispr_fusion,
"crispr_weight": crispr_weight,
}
}
manifest_path = Path(output_dir) / "manifest.json"
with open(manifest_path, "w") as f:
json.dump(manifest, f, indent=2, default=str)
# 完成
job.status = JobStatus.COMPLETED
job.progress_percent = 100
job.current_stage = "completed"
job.logs = json.dumps({"stages": ["digger"], "output": str(output_dir)})
db.commit()
logger.info(f"Job {job_id}: Completed successfully")
return {
'job_id': job_id,
'status': 'completed',
'stages': ['digger'],
'output_dir': str(output_dir)
}
except Exception as e:
logger.error(f"Job {job_id} failed: {e}")
job.status = JobStatus.FAILED
job.error_message = str(e)
job.current_stage = "failed"
db.commit()
raise
finally:
db.close()
@celery_app.task
def update_queue_positions():
"""
定期更新排队任务的位置
可以通过 Celery Beat 定期调用
"""
db = SessionLocal()
try:
# 获取所有 QUEUED 状态的任务
queued_jobs = db.query(Job).filter(
Job.status == JobStatus.QUEUED
).order_by(Job.created_at).all()
for idx, job in enumerate(queued_jobs, start=1):
job.queue_position = idx
db.commit()
logger.info(f"Updated queue positions for {len(queued_jobs)} jobs")
except Exception as e:
logger.error(f"Failed to update queue positions: {e}")
db.rollback()
finally:
db.close()