Files
bttoxin-pipeline/backend/app/workers/tasks.py

204 lines
6.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Celery 任务 - 支持并发控制和多阶段 pipeline"""
from celery import Task
from pathlib import Path
import shutil
import logging
import asyncio
from ..core.celery_app import celery_app
from ..core.docker_client import DockerManager
from ..database import SessionLocal
from ..models.job import Job, JobStatus
from ..services.concurrency_manager import get_concurrency_manager
logger = logging.getLogger(__name__)
# Pipeline 阶段定义
PIPELINE_STAGES = ["digger", "shoter", "plots", "bundle"]
@celery_app.task(bind=True, max_retries=3)
def run_bttoxin_analysis(
self,
job_id: str,
input_dir: str,
output_dir: str,
sequence_type: str = "nucl",
scaf_suffix: str = ".fna",
threads: int = 4,
min_identity: float = 0.8,
min_coverage: float = 0.6,
allow_unknown_families: bool = False,
require_index_hit: bool = True,
crispr_fusion: bool = False,
crispr_weight: float = 0.0,
):
"""
执行分析任务 - 完整的 4 阶段 pipeline
Stages:
1. digger - BtToxin_Digger 识别毒素基因
2. shoter - BtToxin_Shoter 评估毒性活性
3. plots - 生成热力图
4. bundle - 打包结果
"""
db = SessionLocal()
try:
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
logger.error(f"Job {job_id} not found")
return {'job_id': job_id, 'status': 'error', 'error': 'Job not found'}
# 更新状态为 QUEUED
job.status = JobStatus.QUEUED
db.commit()
# 尝试获取执行槽位(使用同步 Redis因为 Celery 是同步的)
# 注意:这里简化处理,实际应该用异步
# 暂时直接执行,稍后集成真正的并发控制
# 更新状态为 RUNNING
job.status = JobStatus.RUNNING
job.current_stage = "digger"
job.progress_percent = 0
db.commit()
# 阶段 1: Digger - 识别毒素基因
logger.info(f"Job {job_id}: Starting Digger stage")
self.update_state(
state='PROGRESS',
meta={'stage': 'digger', 'progress': 10, 'status': 'Running BtToxin_Digger...'}
)
docker_manager = DockerManager()
digger_result = docker_manager.run_bttoxin_digger(
input_dir=Path(input_dir),
output_dir=Path(output_dir),
sequence_type=sequence_type,
scaf_suffix=scaf_suffix,
threads=threads
)
if not digger_result['success']:
raise Exception(f"Digger stage failed: {digger_result.get('error', 'Unknown error')}")
job.progress_percent = 40
db.commit()
# 阶段 2: Shoter - 评估毒性活性
logger.info(f"Job {job_id}: Starting Shoter stage")
job.current_stage = "shoter"
db.commit()
self.update_state(
state='PROGRESS',
meta={'stage': 'shoter', 'progress': 50, 'status': 'Running BtToxin_Shoter...'}
)
# TODO: 实现 Shoter 调用
# shoter_result = run_shoter_pipeline(...)
# 暂时跳过
logger.info(f"Job {job_id}: Shoter stage not implemented yet, skipping")
job.progress_percent = 70
db.commit()
# 阶段 3: Plots - 生成热力图
logger.info(f"Job {job_id}: Starting Plots stage")
job.current_stage = "plots"
db.commit()
self.update_state(
state='PROGRESS',
meta={'stage': 'plots', 'progress': 80, 'status': 'Generating plots...'}
)
# TODO: 实现 Plots 生成
logger.info(f"Job {job_id}: Plots stage not implemented yet, skipping")
job.progress_percent = 90
db.commit()
# 阶段 4: Bundle - 打包结果
logger.info(f"Job {job_id}: Starting Bundle stage")
job.current_stage = "bundle"
db.commit()
self.update_state(
state='PROGRESS',
meta={'stage': 'bundle', 'progress': 95, 'status': 'Bundling results...'}
)
# 创建 manifest.json
import json
manifest = {
"job_id": job_id,
"stages_completed": ["digger"],
"stages_skipped": ["shoter", "plots", "bundle"],
"output_files": list(Path(output_dir).rglob("*")),
"parameters": {
"sequence_type": sequence_type,
"min_identity": min_identity,
"min_coverage": min_coverage,
"allow_unknown_families": allow_unknown_families,
"require_index_hit": require_index_hit,
"crispr_fusion": crispr_fusion,
"crispr_weight": crispr_weight,
}
}
manifest_path = Path(output_dir) / "manifest.json"
with open(manifest_path, "w") as f:
json.dump(manifest, f, indent=2, default=str)
# 完成
job.status = JobStatus.COMPLETED
job.progress_percent = 100
job.current_stage = "completed"
job.logs = json.dumps({"stages": ["digger"], "output": str(output_dir)})
db.commit()
logger.info(f"Job {job_id}: Completed successfully")
return {
'job_id': job_id,
'status': 'completed',
'stages': ['digger'],
'output_dir': str(output_dir)
}
except Exception as e:
logger.error(f"Job {job_id} failed: {e}")
job.status = JobStatus.FAILED
job.error_message = str(e)
job.current_stage = "failed"
db.commit()
raise
finally:
db.close()
@celery_app.task
def update_queue_positions():
"""
定期更新排队任务的位置
可以通过 Celery Beat 定期调用
"""
db = SessionLocal()
try:
# 获取所有 QUEUED 状态的任务
queued_jobs = db.query(Job).filter(
Job.status == JobStatus.QUEUED
).order_by(Job.created_at).all()
for idx, job in enumerate(queued_jobs, start=1):
job.queue_position = idx
db.commit()
logger.info(f"Updated queue positions for {len(queued_jobs)} jobs")
except Exception as e:
logger.error(f"Failed to update queue positions: {e}")
db.rollback()
finally:
db.close()