318 lines
11 KiB
Python
318 lines
11 KiB
Python
"""Celery 任务 - 支持并发控制和多阶段 pipeline"""
|
||
from celery import Task
|
||
from pathlib import Path
|
||
import shutil
|
||
import logging
|
||
import asyncio
|
||
import subprocess
|
||
import os
|
||
|
||
from ..core.celery_app import celery_app
|
||
from ..core.docker_client import DockerManager
|
||
from ..database import SessionLocal
|
||
from ..models.job import Job, JobStatus
|
||
from ..services.concurrency_manager import get_concurrency_manager
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Pipeline 阶段定义
|
||
PIPELINE_STAGES = ["digger", "crispr", "shoter", "plots", "bundle"]
|
||
|
||
|
||
def run_local_command(cmd: list, cwd: Path = None, env: dict = None) -> dict:
|
||
"""Run a command locally in the container"""
|
||
try:
|
||
logger.info(f"Running command: {' '.join(cmd)}")
|
||
result = subprocess.run(
|
||
cmd,
|
||
cwd=cwd,
|
||
env=env or os.environ.copy(),
|
||
capture_output=True,
|
||
text=True,
|
||
check=False
|
||
)
|
||
return {
|
||
'success': result.returncode == 0,
|
||
'stdout': result.stdout,
|
||
'stderr': result.stderr,
|
||
'exit_code': result.returncode
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"Command failed: {e}")
|
||
return {'success': False, 'error': str(e)}
|
||
|
||
|
||
@celery_app.task(bind=True, max_retries=3)
|
||
def run_bttoxin_analysis(
|
||
self,
|
||
job_id: str,
|
||
input_dir: str,
|
||
output_dir: str,
|
||
sequence_type: str = "nucl",
|
||
scaf_suffix: str = ".fna",
|
||
threads: int = 4,
|
||
min_identity: float = 0.8,
|
||
min_coverage: float = 0.6,
|
||
allow_unknown_families: bool = False,
|
||
require_index_hit: bool = True,
|
||
crispr_fusion: bool = False,
|
||
crispr_weight: float = 0.0,
|
||
):
|
||
"""
|
||
执行分析任务 - 完整的 4 阶段 pipeline
|
||
|
||
Stages:
|
||
1. digger - BtToxin_Digger 识别毒素基因
|
||
2. shoter - BtToxin_Shoter 评估毒性活性
|
||
3. plots - 生成热力图
|
||
4. bundle - 打包结果
|
||
"""
|
||
db = SessionLocal()
|
||
|
||
try:
|
||
job = db.query(Job).filter(Job.id == job_id).first()
|
||
if not job:
|
||
logger.error(f"Job {job_id} not found")
|
||
return {'job_id': job_id, 'status': 'error', 'error': 'Job not found'}
|
||
|
||
# 更新状态为 QUEUED
|
||
job.status = JobStatus.QUEUED
|
||
db.commit()
|
||
|
||
# 尝试获取执行槽位(使用同步 Redis,因为 Celery 是同步的)
|
||
# 注意:这里简化处理,实际应该用异步
|
||
# 暂时直接执行,稍后集成真正的并发控制
|
||
|
||
# 更新状态为 RUNNING
|
||
job.status = JobStatus.RUNNING
|
||
job.current_stage = "digger"
|
||
job.progress_percent = 0
|
||
db.commit()
|
||
|
||
# 阶段 1: Digger - 识别毒素基因
|
||
logger.info(f"Job {job_id}: Starting Digger stage")
|
||
self.update_state(
|
||
state='PROGRESS',
|
||
meta={'stage': 'digger', 'progress': 10, 'status': 'Running BtToxin_Digger...'}
|
||
)
|
||
|
||
docker_manager = DockerManager()
|
||
digger_result = docker_manager.run_bttoxin_digger(
|
||
input_dir=Path(input_dir),
|
||
output_dir=Path(output_dir),
|
||
sequence_type=sequence_type,
|
||
scaf_suffix=scaf_suffix,
|
||
threads=threads
|
||
)
|
||
|
||
if not digger_result['success']:
|
||
raise Exception(f"Digger stage failed: {digger_result.get('error', 'Unknown error')}")
|
||
|
||
job.progress_percent = 40
|
||
db.commit()
|
||
|
||
# 阶段 1.5: CRISPR-Cas (如果启用)
|
||
crispr_results_file = None
|
||
if crispr_fusion:
|
||
logger.info(f"Job {job_id}: Starting CRISPR stage")
|
||
job.current_stage = "crispr"
|
||
db.commit()
|
||
self.update_state(
|
||
state='PROGRESS',
|
||
meta={'stage': 'crispr', 'progress': 45, 'status': 'Running CRISPR Detection...'}
|
||
)
|
||
|
||
crispr_out = Path(output_dir) / "crispr" / "results.json"
|
||
crispr_out.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 1. Detection
|
||
detect_cmd = [
|
||
"pixi", "run", "-e", "crispr", "python", "crispr_cas/scripts/detect_crispr.py",
|
||
"--input", str(Path(input_dir) / f"{job_id}{scaf_suffix}"), # Assuming input file name matches
|
||
"--output", str(crispr_out),
|
||
"--mock" # Always use mock for now as we don't have the tool installed
|
||
]
|
||
|
||
# Find input file - might be named differently
|
||
input_files = list(Path(input_dir).glob(f"*{scaf_suffix}"))
|
||
if input_files:
|
||
detect_cmd[7] = str(input_files[0])
|
||
|
||
res = run_local_command(detect_cmd, cwd=Path("/app"))
|
||
if not res['success']:
|
||
logger.warning(f"CRISPR detection failed: {res.get('stderr')}")
|
||
else:
|
||
crispr_results_file = crispr_out
|
||
|
||
# 2. Fusion (if requested)
|
||
fusion_out = Path(output_dir) / "crispr" / "fusion_analysis.json"
|
||
# TODO: We need the toxins file from Digger output.
|
||
# Assuming Digger output structure: output_dir/Results/Toxins/All_Toxins.txt (Need to verify)
|
||
# But DockerManager output might be different. Let's assume standard structure.
|
||
toxins_file = Path(output_dir) / "Results" / "Toxins" / "All_Toxins.txt"
|
||
|
||
if toxins_file.exists():
|
||
fusion_cmd = [
|
||
"pixi", "run", "-e", "crispr", "python", "crispr_cas/scripts/fusion_analysis.py",
|
||
"--crispr-results", str(crispr_out),
|
||
"--toxin-results", str(toxins_file),
|
||
"--genome", str(input_files[0]),
|
||
"--output", str(fusion_out),
|
||
"--mock"
|
||
]
|
||
run_local_command(fusion_cmd, cwd=Path("/app"))
|
||
|
||
# 阶段 2: Shoter - 评估毒性活性
|
||
logger.info(f"Job {job_id}: Starting Shoter stage")
|
||
job.current_stage = "shoter"
|
||
db.commit()
|
||
self.update_state(
|
||
state='PROGRESS',
|
||
meta={'stage': 'shoter', 'progress': 50, 'status': 'Running BtToxin_Shoter...'}
|
||
)
|
||
|
||
# 构建 Shoter 命令
|
||
# 假设 Digger 输出在 output_dir/Results/Toxins/All_Toxins.txt
|
||
toxins_file = Path(output_dir) / "Results" / "Toxins" / "All_Toxins.txt"
|
||
shoter_out_dir = Path(output_dir) / "shoter"
|
||
|
||
# 即使 Digger 失败或没有结果,我们也可以尝试运行(脚本会处理空文件)
|
||
# 如果文件不存在,可能 Digger 结构不同,需要适配
|
||
|
||
shoter_cmd = [
|
||
"pixi", "run", "-e", "pipeline", "python", "scripts/bttoxin_shoter.py",
|
||
"--all_toxins", str(toxins_file),
|
||
"--output_dir", str(shoter_out_dir),
|
||
"--min_identity", str(min_identity),
|
||
"--min_coverage", str(min_coverage)
|
||
]
|
||
|
||
if allow_unknown_families:
|
||
shoter_cmd.append("--allow_unknown_families")
|
||
if require_index_hit:
|
||
shoter_cmd.append("--require_index_hit")
|
||
|
||
# CRISPR Integration
|
||
if crispr_results_file:
|
||
shoter_cmd.extend(["--crispr_results", str(crispr_results_file)])
|
||
shoter_cmd.extend(["--crispr_weight", str(crispr_weight)])
|
||
if crispr_fusion:
|
||
shoter_cmd.append("--crispr_fusion")
|
||
|
||
run_local_command(shoter_cmd, cwd=Path("/app"))
|
||
|
||
job.progress_percent = 70
|
||
db.commit()
|
||
|
||
# 阶段 3: Plots - 生成热力图
|
||
logger.info(f"Job {job_id}: Starting Plots stage")
|
||
job.current_stage = "plots"
|
||
db.commit()
|
||
self.update_state(
|
||
state='PROGRESS',
|
||
meta={'stage': 'plots', 'progress': 80, 'status': 'Generating plots...'}
|
||
)
|
||
|
||
plot_cmd = [
|
||
"pixi", "run", "-e", "pipeline", "python", "scripts/plot_shotter.py",
|
||
"--strain_scores", str(shoter_out_dir / "strain_target_scores.tsv"),
|
||
"--toxin_support", str(shoter_out_dir / "toxin_support.tsv"),
|
||
"--species_scores", str(shoter_out_dir / "strain_target_species_scores.tsv"),
|
||
"--out_dir", str(shoter_out_dir),
|
||
"--output_prefix", "Activity_Heatmap"
|
||
]
|
||
|
||
if crispr_results_file:
|
||
plot_cmd.extend(["--crispr_results", str(crispr_results_file)])
|
||
if crispr_fusion:
|
||
plot_cmd.append("--crispr_fusion")
|
||
|
||
run_local_command(plot_cmd, cwd=Path("/app"))
|
||
|
||
job.progress_percent = 90
|
||
db.commit()
|
||
|
||
# 阶段 4: Bundle - 打包结果
|
||
logger.info(f"Job {job_id}: Starting Bundle stage")
|
||
job.current_stage = "bundle"
|
||
db.commit()
|
||
self.update_state(
|
||
state='PROGRESS',
|
||
meta={'stage': 'bundle', 'progress': 95, 'status': 'Bundling results...'}
|
||
)
|
||
|
||
# 创建 manifest.json
|
||
import json
|
||
manifest = {
|
||
"job_id": job_id,
|
||
"stages_completed": ["digger"],
|
||
"stages_skipped": ["shoter", "plots", "bundle"],
|
||
"output_files": list(Path(output_dir).rglob("*")),
|
||
"parameters": {
|
||
"sequence_type": sequence_type,
|
||
"min_identity": min_identity,
|
||
"min_coverage": min_coverage,
|
||
"allow_unknown_families": allow_unknown_families,
|
||
"require_index_hit": require_index_hit,
|
||
"crispr_fusion": crispr_fusion,
|
||
"crispr_weight": crispr_weight,
|
||
}
|
||
}
|
||
|
||
manifest_path = Path(output_dir) / "manifest.json"
|
||
with open(manifest_path, "w") as f:
|
||
json.dump(manifest, f, indent=2, default=str)
|
||
|
||
# 完成
|
||
job.status = JobStatus.COMPLETED
|
||
job.progress_percent = 100
|
||
job.current_stage = "completed"
|
||
job.logs = json.dumps({"stages": ["digger"], "output": str(output_dir)})
|
||
db.commit()
|
||
|
||
logger.info(f"Job {job_id}: Completed successfully")
|
||
|
||
return {
|
||
'job_id': job_id,
|
||
'status': 'completed',
|
||
'stages': ['digger'],
|
||
'output_dir': str(output_dir)
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Job {job_id} failed: {e}")
|
||
job.status = JobStatus.FAILED
|
||
job.error_message = str(e)
|
||
job.current_stage = "failed"
|
||
db.commit()
|
||
raise
|
||
|
||
finally:
|
||
db.close()
|
||
|
||
|
||
@celery_app.task
|
||
def update_queue_positions():
|
||
"""
|
||
定期更新排队任务的位置
|
||
可以通过 Celery Beat 定期调用
|
||
"""
|
||
db = SessionLocal()
|
||
try:
|
||
# 获取所有 QUEUED 状态的任务
|
||
queued_jobs = db.query(Job).filter(
|
||
Job.status == JobStatus.QUEUED
|
||
).order_by(Job.created_at).all()
|
||
|
||
for idx, job in enumerate(queued_jobs, start=1):
|
||
job.queue_position = idx
|
||
|
||
db.commit()
|
||
logger.info(f"Updated queue positions for {len(queued_jobs)} jobs")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to update queue positions: {e}")
|
||
db.rollback()
|
||
finally:
|
||
db.close()
|