- Add property-based tests for PixiRunner - Add HAN055.fna test data file - Update README with pixi installation and usage guide - Update .gitignore for pixi and test artifacts - Update CLI to remove Docker-related arguments
305 lines
12 KiB
Python
305 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""Bttoxin Pipeline API (pixi-based).
|
|
|
|
This module provides the API for running the BtToxin pipeline using pixi environments:
|
|
- digger environment: BtToxin_Digger with bioconda dependencies
|
|
- pipeline environment: Python analysis with pandas/matplotlib/seaborn
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
import shutil
|
|
import tarfile
|
|
from pathlib import Path
|
|
from types import SimpleNamespace
|
|
from typing import Dict, Any, Optional
|
|
import sys as _sys
|
|
|
|
# Ensure repo-relative imports for scripts when running from installed package
|
|
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
|
_SCRIPTS_DIR = _REPO_ROOT / "scripts"
|
|
for _p in (str(_SCRIPTS_DIR),):
|
|
if _p not in _sys.path:
|
|
_sys.path.append(_p)
|
|
|
|
# Import PixiRunner from scripts
|
|
from pixi_runner import PixiRunner # type: ignore
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _lazy_import_shoter():
|
|
try:
|
|
import bttoxin_shoter as shoter # type: ignore
|
|
return shoter
|
|
except Exception as e:
|
|
raise ImportError(
|
|
f"Failed to import bttoxin_shoter from {_SCRIPTS_DIR}. Ensure repo is present in the image.\n{e}"
|
|
)
|
|
|
|
|
|
def _lazy_import_plotter():
|
|
try:
|
|
import plot_shotter as plotter # type: ignore
|
|
return plotter
|
|
except Exception as e:
|
|
raise ImportError(
|
|
f"Failed to import plot_shotter from {_SCRIPTS_DIR}. Ensure repo is present in the image.\n{e}"
|
|
)
|
|
|
|
|
|
class BtToxinRunner:
|
|
"""Wrap BtToxin_Digger pixi invocation for a single FNA."""
|
|
|
|
def __init__(
|
|
self,
|
|
base_workdir: Optional[Path] = None,
|
|
bttoxin_db_dir: Optional[Path] = None,
|
|
) -> None:
|
|
if base_workdir is None:
|
|
base_workdir = _REPO_ROOT / "runs" / "bttoxin"
|
|
self.base_workdir = base_workdir
|
|
self.base_workdir.mkdir(parents=True, exist_ok=True)
|
|
self.bttoxin_db_dir = bttoxin_db_dir
|
|
self.runner = PixiRunner(pixi_project_dir=_REPO_ROOT, env_name="digger")
|
|
|
|
def _prepare_layout(self, fna_path: Path) -> tuple[Path, Path, Path, Path, str]:
|
|
if not fna_path.exists():
|
|
raise FileNotFoundError(f"FNA file not found: {fna_path}")
|
|
sample_name = fna_path.stem
|
|
run_root = self.base_workdir / sample_name
|
|
input_dir = run_root / "input"
|
|
digger_out = run_root / "output" / "digger"
|
|
log_dir = run_root / "logs"
|
|
for d in (input_dir, digger_out, log_dir):
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
target = input_dir / fna_path.name
|
|
if target.exists():
|
|
target.unlink()
|
|
try:
|
|
os.link(fna_path, target)
|
|
logger.info("Hard-linked FNA: %s → %s", fna_path, target)
|
|
except OSError:
|
|
shutil.copy2(fna_path, target)
|
|
logger.info("Copied FNA: %s → %s", fna_path, target)
|
|
return input_dir, digger_out, log_dir, run_root, sample_name
|
|
|
|
def run_single_fna(self, fna_path: Path | str, sequence_type: str = "nucl", threads: int = 4) -> Dict[str, Any]:
|
|
fna_path = Path(fna_path)
|
|
input_dir, digger_out, log_dir, run_root, sample_name = self._prepare_layout(fna_path)
|
|
logger.info("Start BtToxin_Digger: %s (sample=%s)", fna_path, sample_name)
|
|
result = self.runner.run_bttoxin_digger(
|
|
input_dir=input_dir,
|
|
output_dir=digger_out,
|
|
log_dir=log_dir,
|
|
sequence_type=sequence_type,
|
|
scaf_suffix=fna_path.suffix or ".fna",
|
|
threads=threads,
|
|
bttoxin_db_dir=self.bttoxin_db_dir,
|
|
)
|
|
toxins_dir = digger_out / "Results" / "Toxins"
|
|
files = {
|
|
"list": toxins_dir / f"{sample_name}.list",
|
|
"gbk": toxins_dir / f"{sample_name}.gbk",
|
|
"all_genes": toxins_dir / "Bt_all_genes.table",
|
|
"all_toxins": toxins_dir / "All_Toxins.txt",
|
|
}
|
|
ok = bool(result.get("success")) and files["all_toxins"].exists()
|
|
return {
|
|
"success": ok,
|
|
"sample": sample_name,
|
|
"run_root": run_root,
|
|
"input_dir": input_dir,
|
|
"digger_out": digger_out,
|
|
"log_dir": log_dir,
|
|
"toxins_dir": toxins_dir,
|
|
"files": files,
|
|
"raw_result": result,
|
|
}
|
|
|
|
|
|
class ShotterAPI:
|
|
"""Pure Python Shotter scoring and saving (no subprocess)."""
|
|
|
|
def score(
|
|
self,
|
|
toxicity_csv: Path,
|
|
all_toxins: Path,
|
|
out_dir: Path,
|
|
min_identity: float = 0.0,
|
|
min_coverage: float = 0.0,
|
|
allow_unknown_families: bool = True,
|
|
require_index_hit: bool = False,
|
|
) -> Dict[str, Any]:
|
|
shoter = _lazy_import_shoter()
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
index = shoter.SpecificityIndex.from_csv(toxicity_csv)
|
|
df = shoter.parse_all_toxins(all_toxins)
|
|
if min_identity > 0:
|
|
df = df[df["identity01"].astype(float) >= float(min_identity)]
|
|
if min_coverage > 0:
|
|
df = df[df["coverage"].astype(float) >= float(min_coverage)]
|
|
if not allow_unknown_families:
|
|
df = df[df["family_key"].astype(str) != "unknown"]
|
|
if require_index_hit:
|
|
def _has_index_orders(row) -> bool:
|
|
name_key = str(row.get("Hit_id_norm", ""))
|
|
fam = str(row.get("family_key", ""))
|
|
d = index.orders_for_name_or_backoff(name_key)
|
|
if not d:
|
|
d = index.orders_for_name_or_backoff(fam)
|
|
return bool(d)
|
|
df = df[df.apply(_has_index_orders, axis=1)]
|
|
strains = sorted(df["Strain"].astype(str).unique().tolist())
|
|
all_hits: list[shoter.ToxinHit] = []
|
|
all_strain_scores: list[shoter.StrainScores] = []
|
|
all_species_scores: list[shoter.StrainSpeciesScores] = []
|
|
for strain in strains:
|
|
sdf = df[df["Strain"].astype(str).eq(strain)].copy()
|
|
per_hit, sscore, sspecies = shoter.score_strain(strain, sdf, index)
|
|
all_hits.extend(per_hit)
|
|
all_strain_scores.append(sscore)
|
|
if sspecies is not None:
|
|
all_species_scores.append(sspecies)
|
|
order_columns = sorted({*index.all_orders, "other", "unknown"}) or ["unknown"]
|
|
species_columns = sorted(index.all_species)
|
|
shoter.ToxinHit.save_list_tsv(out_dir / "toxin_support.tsv", all_hits, order_columns)
|
|
shoter.StrainScores.save_list_tsv(out_dir / "strain_target_scores.tsv", all_strain_scores, order_columns)
|
|
shoter.StrainScores.save_list_json(out_dir / "strain_scores.json", all_strain_scores)
|
|
if species_columns and all_species_scores:
|
|
shoter.StrainSpeciesScores.save_list_tsv(out_dir / "strain_target_species_scores.tsv", all_species_scores, species_columns)
|
|
shoter.StrainSpeciesScores.save_list_json(out_dir / "strain_species_scores.json", all_species_scores)
|
|
return {
|
|
"orders": order_columns,
|
|
"species": species_columns,
|
|
"strain_scores": out_dir / "strain_target_scores.tsv",
|
|
"toxin_support": out_dir / "toxin_support.tsv",
|
|
"strain_scores_json": out_dir / "strain_scores.json",
|
|
"species_scores": out_dir / "strain_target_species_scores.tsv",
|
|
"species_scores_json": out_dir / "strain_species_scores.json",
|
|
}
|
|
|
|
|
|
class PlotAPI:
|
|
"""Plot heatmaps and write Markdown report (no subprocess)."""
|
|
|
|
def render(
|
|
self,
|
|
shotter_dir: Path,
|
|
lang: str = "zh",
|
|
merge_unresolved: bool = True,
|
|
per_hit_strain: Optional[str] = None,
|
|
cmap: str = "viridis",
|
|
vmin: float = 0.0,
|
|
vmax: float = 1.0,
|
|
) -> Dict[str, Any]:
|
|
plotter = _lazy_import_plotter()
|
|
strain_scores = shotter_dir / "strain_target_scores.tsv"
|
|
toxin_support = shotter_dir / "toxin_support.tsv"
|
|
species_scores = shotter_dir / "strain_target_species_scores.tsv"
|
|
out1 = shotter_dir / "strain_target_scores.png"
|
|
plotter.plot_strain_scores(strain_scores, out1, cmap, vmin, vmax, None, merge_unresolved)
|
|
out2 = None
|
|
if per_hit_strain and toxin_support.exists():
|
|
out2 = shotter_dir / f"per_hit_{per_hit_strain}.png"
|
|
plotter.plot_per_hit_for_strain(toxin_support, per_hit_strain, out2, cmap, vmin, vmax, None, merge_unresolved)
|
|
species_png = None
|
|
if species_scores.exists():
|
|
species_png = shotter_dir / "strain_target_species_scores.png"
|
|
plotter.plot_species_scores(species_scores, species_png, cmap, vmin, vmax, None)
|
|
args_ns = SimpleNamespace(
|
|
allow_unknown_families=None,
|
|
require_index_hit=None,
|
|
min_identity=None,
|
|
min_coverage=None,
|
|
lang=lang,
|
|
)
|
|
report_path = shotter_dir / "shotter_report_paper.md"
|
|
plotter.write_report_md(
|
|
out_path=report_path,
|
|
mode="paper",
|
|
lang=lang,
|
|
strain_scores_path=strain_scores,
|
|
toxin_support_path=toxin_support if toxin_support.exists() else None,
|
|
species_scores_path=species_scores if species_scores.exists() else None,
|
|
strain_heatmap_path=out1,
|
|
per_hit_heatmap_path=out2,
|
|
species_heatmap_path=species_png,
|
|
merge_unresolved=merge_unresolved,
|
|
args_namespace=args_ns,
|
|
)
|
|
return {
|
|
"strain_orders_png": out1,
|
|
"per_hit_png": out2,
|
|
"species_png": species_png,
|
|
"report_md": report_path,
|
|
}
|
|
|
|
|
|
class BtSingleFnaPipeline:
|
|
"""End-to-end single-FNA pipeline: Digger → Shotter → Plot → Bundle (pixi-based)."""
|
|
|
|
def __init__(
|
|
self,
|
|
base_workdir: Optional[Path] = None,
|
|
) -> None:
|
|
self.base_workdir = base_workdir
|
|
self.shotter = ShotterAPI()
|
|
self.plotter = PlotAPI()
|
|
|
|
def run(
|
|
self,
|
|
fna: Path | str,
|
|
toxicity_csv: Path | str = Path("Data/toxicity-data.csv"),
|
|
min_identity: float = 0.0,
|
|
min_coverage: float = 0.0,
|
|
allow_unknown_families: bool = True,
|
|
require_index_hit: bool = False,
|
|
lang: str = "zh",
|
|
threads: int = 4,
|
|
bttoxin_db_dir: Optional[Path] = None,
|
|
) -> Dict[str, Any]:
|
|
# Create digger runner with optional external database
|
|
digger = BtToxinRunner(base_workdir=self.base_workdir, bttoxin_db_dir=bttoxin_db_dir)
|
|
dig = digger.run_single_fna(fna_path=fna, sequence_type="nucl", threads=threads)
|
|
if not dig.get("success"):
|
|
return {"ok": False, "stage": "digger", "detail": dig}
|
|
run_root: Path = dig["run_root"]
|
|
shotter_dir = run_root / "output" / "shotter"
|
|
shot = self.shotter.score(
|
|
toxicity_csv=Path(toxicity_csv),
|
|
all_toxins=Path(dig["files"]["all_toxins"]),
|
|
out_dir=shotter_dir,
|
|
min_identity=min_identity,
|
|
min_coverage=min_coverage,
|
|
allow_unknown_families=allow_unknown_families,
|
|
require_index_hit=require_index_hit,
|
|
)
|
|
strain_for_plot = None
|
|
try:
|
|
import pandas as pd
|
|
df = pd.read_csv(shot["strain_scores"], sep="\t")
|
|
if len(df):
|
|
strain_for_plot = str(df.iloc[0]["Strain"])
|
|
except Exception:
|
|
pass
|
|
_ = self.plotter.render(
|
|
shotter_dir=shotter_dir,
|
|
lang=lang,
|
|
merge_unresolved=True,
|
|
per_hit_strain=strain_for_plot,
|
|
)
|
|
bundle = run_root / "pipeline_results.tar.gz"
|
|
with tarfile.open(bundle, "w:gz") as tar:
|
|
tar.add(run_root / "output" / "digger", arcname="digger")
|
|
tar.add(run_root / "output" / "shotter", arcname="shotter")
|
|
return {
|
|
"ok": True,
|
|
"run_root": str(run_root),
|
|
"digger_dir": str(run_root / "output" / "digger"),
|
|
"shotter_dir": str(shotter_dir),
|
|
"bundle": str(bundle),
|
|
"strain": strain_for_plot or "",
|
|
}
|