"""Celery 任务 - 支持并发控制和多阶段 pipeline""" from celery import Task from pathlib import Path import shutil import logging import asyncio import subprocess import os import zipfile import json from ..core.celery_app import celery_app from ..core.tool_runner import ToolRunner from ..database import SessionLocal from ..models.job import Job, JobStatus from ..services.concurrency_manager import get_concurrency_manager logger = logging.getLogger(__name__) def run_local_command(cmd: list, cwd: Path = None, env: dict = None) -> dict: """Run a command locally in the container""" try: logger.info(f"Running command: {' '.join(cmd)}") result = subprocess.run( cmd, cwd=cwd, env=env or os.environ.copy(), capture_output=True, text=True, check=False ) return { 'success': result.returncode == 0, 'stdout': result.stdout, 'stderr': result.stderr, 'exit_code': result.returncode } except Exception as e: logger.error(f"Command failed: {e}") return {'success': False, 'error': str(e)} @celery_app.task(bind=True, max_retries=3, name="backend.app.workers.tasks.run_bttoxin_analysis") def run_bttoxin_analysis( self, job_id: str, input_dir: str, output_dir: str, sequence_type: str = "nucl", scaf_suffix: str = ".fna", threads: int = 4, min_identity: float = 0.8, min_coverage: float = 0.6, allow_unknown_families: bool = False, require_index_hit: bool = True, crispr_fusion: bool = False, crispr_weight: float = 0.0, lang: str = "zh" ): """ 执行分析任务 - 使用 scripts/run_single_fna_pipeline.py 统一脚本 """ db = SessionLocal() job = None try: job = db.query(Job).filter(Job.id == job_id).first() if not job: logger.error(f"Job {job_id} not found") return {'job_id': job_id, 'status': 'error', 'error': 'Job not found'} # 更新状态为 RUNNING job.status = JobStatus.RUNNING job.current_stage = "running" job.progress_percent = 0 db.commit() # 准备路径 input_path = Path(input_dir) output_path = Path(output_dir) # 查找输入文件 (由于 input_dir 是上传目录,里面应该有一个文件) input_files = list(input_path.glob(f"*{scaf_suffix}")) if not input_files: # 尝试查找任意文件 files = [f for f in input_path.iterdir() if f.is_file()] if files: input_file = files[0] else: raise FileNotFoundError(f"No input file found in {input_dir}") else: input_file = input_files[0] logger.info(f"Job {job_id}: Starting pipeline for {input_file}") # 构建 pipeline 命令 # 使用 pixi run -e pipeline 来执行脚本,确保环境一致 pipeline_cmd = [ "pixi", "run", "-e", "pipeline", "python", "scripts/run_single_fna_pipeline.py", "--input", str(input_file), "--out_root", str(output_path), "--toxicity_csv", "Data/toxicity-data.csv", "--min_identity", str(min_identity), "--min_coverage", str(min_coverage), "--threads", str(threads), "--lang", lang ] if not allow_unknown_families: pipeline_cmd.append("--disallow_unknown_families") if require_index_hit: pipeline_cmd.append("--require_index_hit") # 执行脚本 res = run_local_command(pipeline_cmd, cwd=Path("/app")) if not res['success']: error_msg = f"Pipeline execution failed (exit={res['exit_code']}): {res['stderr']}" logger.error(error_msg) raise Exception(error_msg) logger.info(f"Job {job_id}: Pipeline script completed") # 结果打包 (Zip) logger.info(f"Job {job_id}: Creating zip bundle") zip_path = output_path / f"pipeline_results_{job_id}.zip" # 在创建新 ZIP 前,删除目录下任何现有的 zip/tar.gz 文件,防止递归打包 for existing_archive in output_path.glob("*.zip"): try: existing_archive.unlink() except Exception: pass for existing_archive in output_path.glob("*.tar.gz"): try: existing_archive.unlink() except Exception: pass # 定义映射关系:原始目录 -> 压缩包内展示名称 dir_mapping = { "digger": "1_Toxin_Mining", "shotter": "2_Toxicity_Scoring", "logs": "Logs" } with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: # 1. 添加输入文件 (放入 Input 目录) zipf.write(input_file, arcname=f"Input/{input_file.name}") # 2. 添加结果目录 (重命名) for src_name, dest_name in dir_mapping.items(): src_path = output_path / src_name if src_path.exists(): for root, dirs, files in os.walk(src_path): for file in files: file_path = Path(root) / file # 排除压缩包自己(如果有意外情况) if file_path == zip_path: continue # 防止嵌套打包:忽略可能存在的 zip 或 tar.gz 文件 if file.endswith('.zip') or file.endswith('.tar.gz'): continue # 计算相对路径,例如 digger/Results/foo.txt -> Results/foo.txt rel_path = file_path.relative_to(src_path) # 构造新的归档路径 -> 1_Toxin_Mining/Results/foo.txt arcname = Path(dest_name) / rel_path zipf.write(file_path, arcname=str(arcname)) # 删除原始结果目录 logger.info(f"Job {job_id}: Cleaning up intermediate files") # 需要清理的原始目录名 dirs_to_clean = ["digger", "shotter", "context", "logs", "stage"] for d in dirs_to_clean: d_path = output_path / d if d_path.exists(): shutil.rmtree(d_path) # 删除 tar.gz (如果脚本生成了) tar_gz = output_path / "pipeline_results.tar.gz" if tar_gz.exists(): tar_gz.unlink() # 验证 Zip 是否生成 if not zip_path.exists(): raise Exception("Failed to create result zip file") # 重命名为标准下载名 (或者保持这样,由 API 决定下载名) # 这里的 output_dir 就是 API 下载时寻找的地方 # downloadResult API 默认可能找 pipeline_results.tar.gz? # 我们需要确保 frontend 下载链接改为 zip,并且后端 API 能找到这个文件 # 目前后端 API (backend/app/api/v1/results.py) 可能需要调整,或者我们把 zip 命名为 API 期望的名字? # 假设 API 期望 output_dir 下有文件。 # 为了兼容,我们把 zip 命名为 pipeline_results.zip (通用) # 但前端生成的下载链接是 pipeline_results_{id}.zip # 完成 job.status = JobStatus.COMPLETED job.progress_percent = 100 job.current_stage = "completed" # 记录日志摘要 job.logs = res['stdout'][-2000:] if res['stdout'] else "No output" db.commit() logger.info(f"Job {job_id}: Completed successfully") return { 'job_id': job_id, 'status': 'completed', 'output_dir': str(output_dir) } except Exception as e: logger.error(f"Job {job_id} failed: {e}") if job: try: job.status = JobStatus.FAILED job.error_message = str(e) job.current_stage = "failed" db.commit() except Exception as commit_error: logger.error(f"Failed to update job status to FAILED: {commit_error}") raise finally: db.close() @celery_app.task def update_queue_positions(): """ 定期更新排队任务的位置 可以通过 Celery Beat 定期调用 """ db = SessionLocal() try: # 获取所有 QUEUED 状态的任务 queued_jobs = db.query(Job).filter( Job.status == JobStatus.QUEUED ).order_by(Job.created_at).all() for idx, job in enumerate(queued_jobs, start=1): job.queue_position = idx db.commit() logger.info(f"Updated queue positions for {len(queued_jobs)} jobs") except Exception as e: logger.error(f"Failed to update queue positions: {e}") db.rollback() finally: db.close()