#!/usr/bin/env python3 """ Bttoxin_Shoter v1 - Read BPPRC specificity CSV (positive-only) to build name/family -> target_order distributions - Read BtToxin_Digger All_Toxins.txt to collect hits per strain - Compute per-hit similarity weight and order contributions - Combine contributions to strain-level potential activity scores per insect order - Save per-hit and per-strain results (TSV + JSON) using dataclasses with save methods Assumptions and notes are documented in docs/shotter_workflow.md """ from __future__ import annotations import argparse import json import re import math from dataclasses import dataclass, asdict, field import os from pathlib import Path from typing import Dict, List, Optional, Tuple import pandas as pd # ----------------------------- # Helpers for parsing families # ----------------------------- FAMILY_PREFIXES = ( "Cry", "Cyt", "Vip", "Vpa", "Vpb", "Mpp", "Tpp", "Spp", "App", "Mcf", "Mpf", "Pra", "Prb", "Txp", "Gpp", "Mtx", "Xpp", ) FAMILY_RE = re.compile(r"(?i)(Cry|Cyt|Vip|Vpa|Vpb|Mpp|Tpp|Spp|App|Mcf|Mpf|Pra|Prb|Txp|Gpp|Mtx|Xpp)(\d+)([A-Za-z]*)") def normalize_hit_id(hit_id: str) -> str: """Strip trailing markers like '-other', spaces, and keep core token. Examples: 'Spp1Aa1' -> 'Spp1Aa1'; 'Bmp1-other' -> 'Bmp1' """ if not isinstance(hit_id, str): return "" x = hit_id.strip() x = re.sub(r"-other$", "", x) x = re.sub(r"\s+", "", x) return x def extract_family_keys(name: str) -> Tuple[Optional[str], Optional[str]]: """From a protein name like 'Cry1Ac1' or '5618-Cry1Ia' extract: - family_key: e.g., 'Cry1' - subfamily_key: e.g., 'Cry1A' (first subclass letter if present) Returns (None, None) if no family prefix is found. """ if not isinstance(name, str) or not name: return None, None # Try direct match first m = FAMILY_RE.search(name) if not m: return None, None fam = m.group(1) num = m.group(2) sub = m.group(3) # may be '' family_key = f"{fam}{num}" subfamily_key = f"{fam}{num}{sub[:1]}" if sub else None return family_key, subfamily_key # --------------------------------- # Specificity index from CSV (BPPRC) # --------------------------------- @dataclass class SpecificityIndex: # Distributions P(order | name) and P(order | family/subfamily) name_to_orders: Dict[str, Dict[str, float]] = field(default_factory=dict) subfam_to_orders: Dict[str, Dict[str, float]] = field(default_factory=dict) fam_to_orders: Dict[str, Dict[str, float]] = field(default_factory=dict) # Set of all observed target orders all_orders: List[str] = field(default_factory=list) # Distributions P(species | name) and P(species | family/subfamily) name_to_species: Dict[str, Dict[str, float]] = field(default_factory=dict) subfam_to_species: Dict[str, Dict[str, float]] = field(default_factory=dict) fam_to_species: Dict[str, Dict[str, float]] = field(default_factory=dict) # Set of all observed target species all_species: List[str] = field(default_factory=list) # Partner requirement heuristics at family level partner_pairs: List[Tuple[str, str]] = field(default_factory=lambda: [ ("Vip1", "Vip2"), ("Vpa", "Vpb"), ("BinA", "BinB"), # included for completeness if ever present ]) @staticmethod def _potency_from_row(row: pd.Series) -> Optional[float]: """Map quantitative fields to a potency weight in [0,1]. Positive-only. Rules: - lc50: will be normalized later within unit-buckets; here return the numeric. - percentage_mortality: map ranges to coarse weights. - if both missing but activity is positive: default 0.55 """ perc = str(row.get("percentage_mortality") or "").strip().lower() lc50 = row.get("lc50") if pd.notnull(lc50): try: return float(lc50) # raw, to be normalized later except Exception: pass if perc: # coarse mappings if ">80%" in perc or "80-100%" in perc or "greater than 80%" in perc: return 0.9 if "60-100%" in perc or "60-80%" in perc or "50-80%" in perc: return 0.65 if "0-60%" in perc or "0-50%" in perc or "25%" in perc or "some" in perc or "stunting" in perc: return 0.25 # default positive evidence return 0.55 @classmethod def from_csv(cls, csv_path: Path) -> "SpecificityIndex": df = pd.read_csv(csv_path) # Positive-only evidence # many rows have activity 'Yes' and non_toxic empty; use only activity==Yes df = df[df["activity"].astype(str).str.lower().eq("yes")] # keep positives # Normalize units bucket for lc50 normalization # unify ppm -> ug_per_g (diet context). Others stay in own unit bucket. units = df["units"].astype(str).str.strip().str.lower() ug_per_g_mask = units.isin(["µg/g", "ug/g", "μg/g", "ppm"]) # ppm ~ ug/g in diet unit_bucket = units.where(~ug_per_g_mask, other="ug_per_g") df = df.assign(_unit_bucket=unit_bucket) # Compute potency: two cases: lc50 (to be inverted/normalized) vs categorical df["_potency_raw"] = df.apply(cls._potency_from_row, axis=1) # For rows with numeric raw potency (lc50), do within-bucket inverted quantile for bucket, sub in df.groupby("_unit_bucket"): # separate numeric vs categorical is_num = pd.to_numeric(sub["_potency_raw"], errors="coerce").notnull() num_idx = sub.index[is_num] cat_idx = sub.index[~is_num] # smaller LC50 => stronger potency if len(num_idx) > 0: vals = sub.loc[num_idx, "_potency_raw"].astype(float) ranks = vals.rank(method="average", pct=True) inv = 1.0 - ranks # 0..1 df.loc[num_idx, "_potency"] = inv.values # categorical rows keep their 0-1 scores (already in _potency_raw) if len(cat_idx) > 0: df.loc[cat_idx, "_potency"] = sub.loc[cat_idx, "_potency_raw"].values df["_potency"] = pd.to_numeric(df["_potency"], errors="coerce").fillna(0.55) # Extract family/subfamily keys from name # Name fields can be like 'Spp1Aa1', 'Cry1Ac1', or '5618-Cry1Ia', etc. fam_keys: List[Optional[str]] = [] subfam_keys: List[Optional[str]] = [] for name in df["name"].astype(str).tolist(): fam, subfam = extract_family_keys(name) fam_keys.append(fam) subfam_keys.append(subfam) df = df.assign(_family=fam_keys, _subfamily=subfam_keys) # Aggregate per protein name (exact) def agg_distribution(group: pd.DataFrame) -> Dict[str, float]: # Sum potencies per target_order, then normalize s = group.groupby("target_order")["_potency"].sum() if s.sum() <= 0: return {} d = (s / s.sum()).to_dict() return d name_to_orders = ( df.groupby("name", dropna=False, group_keys=False) .apply(agg_distribution, include_groups=False) .to_dict() ) # Aggregate per subfamily and family as back-off subfam_to_orders = ( df.dropna(subset=["_subfamily"]).groupby("_subfamily", group_keys=False) .apply(agg_distribution, include_groups=False) .to_dict() ) fam_to_orders = ( df.dropna(subset=["_family"]).groupby("_family", group_keys=False) .apply(agg_distribution, include_groups=False) .to_dict() ) # Species distributions (optional if target_species present) def agg_distribution_species(group: pd.DataFrame) -> Dict[str, float]: s = group.groupby("target_species")["_potency"].sum() if s.sum() <= 0: return {} d = (s / s.sum()).to_dict() return d name_to_species = ( df.groupby("name", dropna=False, group_keys=False) .apply(agg_distribution_species, include_groups=False) .to_dict() ) subfam_to_species = ( df.dropna(subset=["_subfamily"]).groupby("_subfamily", group_keys=False) .apply(agg_distribution_species, include_groups=False) .to_dict() ) fam_to_species = ( df.dropna(subset=["_family"]).groupby("_family", group_keys=False) .apply(agg_distribution_species, include_groups=False) .to_dict() ) # collect all orders/species observed all_orders = sorted({str(x) for x in df["target_order"].dropna().unique().tolist()}) all_species = sorted({str(x) for x in df["target_species"].dropna().unique().tolist()}) return cls( name_to_orders=name_to_orders, subfam_to_orders=subfam_to_orders, fam_to_orders=fam_to_orders, all_orders=all_orders, name_to_species=name_to_species, subfam_to_species=subfam_to_species, fam_to_species=fam_to_species, all_species=all_species, ) def partner_needed_for_family(self, family_or_subfam: str) -> Optional[str]: # Heuristic partner rule at family level for a, b in self.partner_pairs: if family_or_subfam.startswith(a): return b if family_or_subfam.startswith(b): return a return None def orders_for_name_or_backoff(self, name_or_family: str) -> Dict[str, float]: # Try exact name, then subfamily, then family if name_or_family in self.name_to_orders: return self.name_to_orders[name_or_family] fam, subfam = extract_family_keys(name_or_family) if subfam and subfam in self.subfam_to_orders: return self.subfam_to_orders[subfam] if fam and fam in self.fam_to_orders: return self.fam_to_orders[fam] return {} def species_for_name_or_backoff(self, name_or_family: str) -> Dict[str, float]: # Try exact name, then subfamily, then family if name_or_family in self.name_to_species: return self.name_to_species[name_or_family] fam, subfam = extract_family_keys(name_or_family) if subfam and subfam in self.subfam_to_species: return self.subfam_to_species[subfam] if fam and fam in self.fam_to_species: return self.fam_to_species[fam] return {} # ----------------------------- # Dataclasses for outputs # ----------------------------- @dataclass class ToxinHit: strain: str protein_id: str hit_id: str identity: float # 0..1 aln_length: int hit_length: int coverage: float # 0..1 hmm: bool family: str name_key: str partner_fulfilled: bool weight: float order_contribs: Dict[str, float] top_order: str top_score: float def to_tsv_row(self, order_columns: List[str]) -> List[str]: cols = [ self.strain, self.protein_id, self.hit_id, f"{self.identity:.4f}", str(self.aln_length), str(self.hit_length), f"{self.coverage:.4f}", "YES" if self.hmm else "NO", self.family or "unknown", self.name_key or "", "YES" if self.partner_fulfilled else "NO", f"{self.weight:.4f}", self.top_order or "", f"{self.top_score:.6f}", ] for o in order_columns: v = self.order_contribs.get(o, 0.0) cols.append(f"{v:.6f}") return cols @staticmethod def save_list_tsv(tsv_path: Path, items: List["ToxinHit"], order_columns: List[str]) -> None: tsv_path.parent.mkdir(parents=True, exist_ok=True) header = [ "Strain", "Protein_id", "Hit_id", "Identity", "Aln_length", "Hit_length", "Coverage", "HMM", "Family", "NameKey", "PartnerFulfilled", "Weight", "TopOrder", "TopScore", ] + order_columns with tsv_path.open("w", encoding="utf-8") as f: f.write("\t".join(header) + "\n") for h in items: f.write("\t".join(h.to_tsv_row(order_columns)) + "\n") @dataclass class StrainScores: strain: str order_scores: Dict[str, float] top_order: str top_score: float def to_tsv_row(self, order_columns: List[str]) -> List[str]: return [self.strain, self.top_order or "", f"{self.top_score:.6f}"] + [f"{self.order_scores.get(o, 0.0):.6f}" for o in order_columns] @staticmethod def save_list_tsv(tsv_path: Path, items: List["StrainScores"], order_columns: List[str]) -> None: tsv_path.parent.mkdir(parents=True, exist_ok=True) header = ["Strain", "TopOrder", "TopScore"] + order_columns with tsv_path.open("w", encoding="utf-8") as f: f.write("\t".join(header) + "\n") for s in items: f.write("\t".join(s.to_tsv_row(order_columns)) + "\n") @staticmethod def save_list_json(json_path: Path, items: List["StrainScores"]) -> None: json_path.parent.mkdir(parents=True, exist_ok=True) data = [asdict(s) for s in items] with json_path.open("w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) @dataclass class StrainSpeciesScores: strain: str species_scores: Dict[str, float] top_species: str top_species_score: float def to_tsv_row(self, species_columns: List[str]) -> List[str]: return [self.strain, self.top_species or "", f"{self.top_species_score:.6f}"] + [ f"{self.species_scores.get(sp, 0.0):.6f}" for sp in species_columns ] @staticmethod def save_list_tsv(tsv_path: Path, items: List["StrainSpeciesScores"], species_columns: List[str]) -> None: tsv_path.parent.mkdir(parents=True, exist_ok=True) header = ["Strain", "TopSpecies", "TopSpeciesScore"] + species_columns with tsv_path.open("w", encoding="utf-8") as f: f.write("\t".join(header) + "\n") for s in items: f.write("\t".join(s.to_tsv_row(species_columns)) + "\n") @staticmethod def save_list_json(json_path: Path, items: List["StrainSpeciesScores"]) -> None: json_path.parent.mkdir(parents=True, exist_ok=True) data = [asdict(s) for s in items] with json_path.open("w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) # ----------------------------- # Parser and scoring # ----------------------------- def compute_similarity_weight(identity: float, coverage: float, hmm: bool) -> float: """Weight formula (bounded [0,1]): base(identity, coverage) × coverage, with HMM bonus up to 1.0 - If identity >= 0.78 and coverage >= 0.8: base = 1.0 - If 0.45 <= identity < 0.78: base = (identity-0.45)/(0.78-0.45) - Else: base = 0 Final w = min(1.0, base * coverage + (0.1 if hmm else 0.0)) """ base = 0.0 if identity >= 0.78 and coverage >= 0.8: base = 1.0 elif 0.45 <= identity < 0.78: base = (identity - 0.45) / (0.78 - 0.45) base = max(0.0, min(1.0, base)) else: base = 0.0 w = base * max(0.0, min(1.0, coverage)) if hmm: w = min(1.0, w + 0.1) return float(max(0.0, min(1.0, w))) def parse_all_toxins(tsv_path: Path) -> pd.DataFrame: df = pd.read_csv(tsv_path, sep="\t", dtype=str, engine="python") # Coerce needed fields for col in ["Identity", "Aln_length", "Hit_length"]: df[col] = pd.to_numeric(df[col], errors="coerce") df["Identity"] = df["Identity"].astype(float) df["Aln_length"] = df["Aln_length"].fillna(0).astype(int) df["Hit_length"] = df["Hit_length"].fillna(0).astype(int) df["coverage"] = (df["Aln_length"].clip(lower=0) / df["Hit_length"].replace(0, pd.NA)).fillna(0.0) df["identity01"] = (df["Identity"].fillna(0.0) / 100.0).clip(lower=0.0, upper=1.0) df["HMM"] = df["HMM"].astype(str).str.upper().eq("YES") # Normalize hit_id and family df["Hit_id_norm"] = df["Hit_id"].astype(str).map(normalize_hit_id) fams = [] for hit in df["Hit_id_norm"].tolist(): fam, _ = extract_family_keys(hit) fams.append(fam or "unknown") df["family_key"] = fams return df def partner_fulfilled_for_hit(hit_family: str, strain_df: pd.DataFrame, index: SpecificityIndex) -> bool: partner = index.partner_needed_for_family(hit_family or "") if not partner: return True # partner satisfied if any other hit in same strain startswith partner family and has weight >= 0.3 for _, row in strain_df.iterrows(): fam = row.get("family_key", "") or "" if fam.startswith(partner): w = compute_similarity_weight( float(row.get("identity01", 0.0)), float(row.get("coverage", 0.0)), bool(row.get("HMM", False)) ) if w >= 0.3: return True return False def apply_logit_prior( score: float, bgc_data: Dict[str, int], mobilome_data: Dict[str, int], crispr_state: int, betas: Dict[str, float] ) -> float: """ Apply Logit Prior adjustment: S_final = sigmoid(logit(S_tox) + Delta) Delta = beta_z*b_z + beta_t*b_t + beta_a*b_a + beta_m*g(m) + beta_c*h(c) """ if score <= 0: return 0.0 if score >= 1: return 1.0 # epsilon to avoid inf epsilon = 1e-6 p = max(epsilon, min(1.0 - epsilon, score)) logit_p = math.log(p / (1.0 - p)) # BGC b_z = bgc_data.get("ZWA", 0) b_t = bgc_data.get("Thu", 0) b_a = bgc_data.get("TAA", 0) # Mobilome: g(m) = ln(1+m) m = mobilome_data.get("mobile_elements_count", 0) g_m = math.log(1.0 + m) # CRISPR: h(c) = 1 - c/2 (0->1, 1->0.5, 2->0) # c in [0, 1, 2] c = crispr_state h_c = 1.0 - (c / 2.0) delta = (betas["z"] * b_z) + \ (betas["t"] * b_t) + \ (betas["a"] * b_a) + \ (betas["m"] * g_m) + \ (betas["c"] * h_c) final_logit = logit_p + delta return 1.0 / (1.0 + math.exp(-final_logit)) def score_strain( strain: str, sdf: pd.DataFrame, index: SpecificityIndex, crispr_associations: Dict[str, Any] = None, crispr_weight: float = 0.0, context_data: Dict[str, Any] = None, betas: Dict[str, float] = None ) -> Tuple[List[ToxinHit], StrainScores, Optional[StrainSpeciesScores]]: # Include special buckets per requirements order_set = sorted({*index.all_orders, "other", "unknown"}) per_hit: List[ToxinHit] = [] # Collect contributions per order by combining across hits # We'll accumulate 1 - Π(1 - contrib) one_minus = {o: 1.0 for o in order_set} # Species accumulation if available species_set = sorted(index.all_species) if index.all_species else [] sp_one_minus = {sp: 1.0 for sp in species_set} for _, row in sdf.iterrows(): hit_id = str(row.get("Hit_id_norm", "")) protein_id = str(row.get("Protein_id", "")) identity01 = float(row.get("identity01", 0.0)) coverage = float(row.get("coverage", 0.0)) hmm = bool(row.get("HMM", False)) family = str(row.get("family_key", "")) or "unknown" # Choose name backoff distribution name_key = hit_id order_dist = index.orders_for_name_or_backoff(name_key) if not order_dist: # backoff to family order_dist = index.orders_for_name_or_backoff(family) if not order_dist: # Route by whether we have a parsable family # - parsable family but no evidence -> 'other' # - unparseable family (family == 'unknown') -> 'unknown' if family != "unknown": order_dist = {"other": 1.0} else: order_dist = {"unknown": 1.0} # Compute weight with partner handling w = compute_similarity_weight(identity01, coverage, hmm) fulfilled = partner_fulfilled_for_hit(family, sdf, index) if not fulfilled: w *= 0.2 # partner penalty if required but not present # CRISPR Boost (Hit-level) if crispr_associations and crispr_weight > 0: # Check if this toxin is associated with CRISPR # keys in crispr_associations are toxin names assoc = crispr_associations.get(hit_id) or crispr_associations.get(name_key) if assoc: w = min(1.0, w + crispr_weight) contribs: Dict[str, float] = {} for o in order_set: p = float(order_dist.get(o, 0.0)) c = w * p if c > 0: one_minus[o] *= (1.0 - c) contribs[o] = c # species contributions (only for known distributions; no other/unknown buckets) if species_set: sp_dist = index.species_for_name_or_backoff(name_key) if not sp_dist: sp_dist = index.species_for_name_or_backoff(family) if sp_dist: for sp in species_set: psp = float(sp_dist.get(sp, 0.0)) csp = w * psp if csp > 0: sp_one_minus[sp] *= (1.0 - csp) # top for this hit (allow other/unknown if that's all we have) if contribs: hit_top_order, hit_top_score = max(contribs.items(), key=lambda kv: kv[1]) else: hit_top_order, hit_top_score = "", 0.0 per_hit.append( ToxinHit( strain=strain, protein_id=protein_id, hit_id=hit_id, identity=identity01, aln_length=int(row.get("Aln_length", 0)), hit_length=int(row.get("Hit_length", 0)), coverage=coverage, hmm=hmm, family=family, name_key=name_key, partner_fulfilled=fulfilled, weight=w, order_contribs=contribs, top_order=hit_top_order, top_score=float(hit_top_score), ) ) # 1. Calculate S_tox (Noisy-OR) order_scores = {o: (1.0 - one_minus[o]) for o in order_set} # 2. Apply Logit Prior if context data is present if context_data and betas: bgc = context_data.get("bgc", {}) mobi = context_data.get("mobilome", {}) crispr_st = context_data.get("crispr_state", 0) for o in order_scores: s_tox = order_scores[o] if s_tox > 0: s_final = apply_logit_prior(s_tox, bgc, mobi, crispr_st, betas) order_scores[o] = s_final # choose top order excluding unresolved buckets if possible preferred = [o for o in index.all_orders if o in order_scores] if not preferred: preferred = [o for o in order_set if o not in ("other", "unknown")] if preferred: top_o = max(preferred, key=lambda o: order_scores.get(o, 0.0)) top_s = float(order_scores.get(top_o, 0.0)) else: top_o, top_s = "", 0.0 strain_scores = StrainScores(strain=strain, order_scores=order_scores, top_order=top_o, top_score=top_s) species_scores_obj: Optional[StrainSpeciesScores] = None if species_set: species_scores = {sp: (1.0 - sp_one_minus[sp]) for sp in species_set} # Apply Logit Prior to Species scores too? # The math doc focuses on "Strain x Order", but presumably it applies to species too if we follow the logic. # "combine evidence... then prior". Let's apply it for consistency. if context_data and betas: bgc = context_data.get("bgc", {}) mobi = context_data.get("mobilome", {}) crispr_st = context_data.get("crispr_state", 0) for sp in species_scores: s_tox = species_scores[sp] if s_tox > 0: s_final = apply_logit_prior(s_tox, bgc, mobi, crispr_st, betas) species_scores[sp] = s_final if species_scores: top_sp = max(species_set, key=lambda sp: species_scores.get(sp, 0.0)) top_sp_score = float(species_scores.get(top_sp, 0.0)) else: top_sp, top_sp_score = "", 0.0 species_scores_obj = StrainSpeciesScores( strain=strain, species_scores=species_scores, top_species=top_sp, top_species_score=top_sp_score, ) return per_hit, strain_scores, species_scores_obj # ----------------------------- # Saving utilities # ----------------------------- def save_toxin_support(tsv_path: Path, hits: List[ToxinHit], order_columns: List[str]) -> None: tsv_path.parent.mkdir(parents=True, exist_ok=True) header = [ "Strain", "Protein_id", "Hit_id", "Identity", "Aln_length", "Hit_length", "Coverage", "HMM", "Family", "NameKey", "PartnerFulfilled", "Weight", "TopOrder", "TopScore", ] + order_columns with tsv_path.open("w", encoding="utf-8") as f: f.write("\t".join(header) + "\n") for h in hits: f.write("\t".join(h.to_tsv_row(order_columns)) + "\n") def save_strain_targets(tsv_path: Path, strains: List[StrainScores], order_columns: List[str]) -> None: tsv_path.parent.mkdir(parents=True, exist_ok=True) header = ["Strain", "TopOrder", "TopScore"] + order_columns with tsv_path.open("w", encoding="utf-8") as f: f.write("\t".join(header) + "\n") for s in strains: f.write("\t".join(s.to_tsv_row(order_columns)) + "\n") def save_strain_json(json_path: Path, strains: List[StrainScores]) -> None: json_path.parent.mkdir(parents=True, exist_ok=True) data = [asdict(s) for s in strains] with json_path.open("w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) # ----------------------------- # CLI # ----------------------------- def _try_writable_dir(p: Path) -> bool: try: p.mkdir(parents=True, exist_ok=True) except Exception: return False try: probe = p / ".__shotter_write_test__" with probe.open("w", encoding="utf-8") as f: f.write("ok") probe.unlink(missing_ok=True) return True except Exception: return False def resolve_output_dir(preferred: Path) -> Path: # 1) try preferred if _try_writable_dir(preferred): return preferred # 2) try a subfolder 'Shotter' under preferred sub = preferred / "Shotter" if _try_writable_dir(sub): print(f"[Shotter] Output dir not writable: {preferred}. Falling back to: {sub}") return sub # 3) fallback to cwd/shotter_outputs alt = Path.cwd() / "shotter_outputs" if _try_writable_dir(alt): print(f"[Shotter] Output dir not writable: {preferred}. Falling back to: {alt}") return alt # 4) last resort: preferred (will likely fail but we return it) print(f"[Shotter] WARNING: could not find writable output directory. Using {preferred} (may fail).") return preferred def main(): ap = argparse.ArgumentParser(description="Bttoxin_Shoter: infer target orders from BtToxin_Digger outputs") ap.add_argument("--toxicity_csv", type=Path, default=Path("Data/toxicity-data.csv")) ap.add_argument("--all_toxins", type=Path, default=Path("tests/output/Results/Toxins/All_Toxins.txt")) ap.add_argument("--output_dir", type=Path, default=Path("tests/output/Results/Toxins")) # Filtering and thresholds ap.add_argument("--allow_unknown_families", dest="allow_unknown_families", action="store_true") ap.add_argument("--disallow_unknown_families", dest="allow_unknown_families", action="store_false") ap.set_defaults(allow_unknown_families=True) ap.add_argument("--require_index_hit", action="store_true", default=False, help="Keep only hits that map to a known name/subfamily/family in the specificity index") ap.add_argument("--min_identity", type=float, default=0.0, help="Minimum identity (0-1) to keep a hit") ap.add_argument("--min_coverage", type=float, default=0.0, help="Minimum coverage (0-1) to keep a hit") # CRISPR Integration ap.add_argument("--crispr_results", type=Path, default=None, help="Path to CRISPR Fusion analysis results (JSON) for hit-level boost") ap.add_argument("--crispr_fusion", action="store_true", help="Use fusion analysis results for stronger evidence") ap.add_argument("--crispr_weight", type=float, default=0.1, help="Weight boost for CRISPR-associated toxins (0-1)") # Genome Context Priors (Strain-level) ap.add_argument("--context_bgc", type=Path, default=None, help="Path to BGC detection results (JSON)") ap.add_argument("--context_mobilome", type=Path, default=None, help="Path to Mobilome analysis results (JSON)") ap.add_argument("--context_crispr", type=Path, default=None, help="Path to CRISPR detection results (JSON) for strain-level prior") # Prior Weights (Defaults from math doc) ap.add_argument("--beta_z", type=float, default=0.5, help="Weight for ZWA BGC") ap.add_argument("--beta_t", type=float, default=0.5, help="Weight for Thu BGC") ap.add_argument("--beta_a", type=float, default=0.5, help="Weight for TAA BGC") ap.add_argument("--beta_m", type=float, default=0.5, help="Weight for Mobilome") ap.add_argument("--beta_c", type=float, default=0.5, help="Weight for CRISPR state") args = ap.parse_args() # Load CRISPR data if available crispr_associations = {} if args.crispr_results and args.crispr_results.exists(): try: with open(args.crispr_results) as f: cdata = json.load(f) # If fusion analysis results (has 'associations') if "associations" in cdata: for assoc in cdata["associations"]: toxin_name = assoc.get("toxin") if toxin_name: # Normalize name if possible or keep as is. # Digger outputs might have variants, but fusion usually uses specific names. crispr_associations[toxin_name] = assoc print(f"[Shoter] Loaded {len(crispr_associations)} CRISPR associations") except Exception as e: print(f"[Shoter] Failed to load CRISPR results: {e}") # Load Prior Data context_data = { "bgc": {}, "mobilome": {}, "crispr_state": 0 } # BGC if args.context_bgc and args.context_bgc.exists(): try: with open(args.context_bgc) as f: context_data["bgc"] = json.load(f) print(f"[Shoter] Loaded BGC context: {context_data['bgc']}") except Exception as e: print(f"[Shoter] Failed to load BGC context: {e}") # Mobilome if args.context_mobilome and args.context_mobilome.exists(): try: with open(args.context_mobilome) as f: context_data["mobilome"] = json.load(f) print(f"[Shoter] Loaded Mobilome context: {context_data['mobilome']}") except Exception as e: print(f"[Shoter] Failed to load Mobilome context: {e}") # CRISPR State if args.context_crispr and args.context_crispr.exists(): try: with open(args.context_crispr) as f: cdata = json.load(f) context_data["crispr_state"] = cdata.get("crispr_state", 0) print(f"[Shoter] Loaded CRISPR state: {context_data['crispr_state']}") except Exception as e: print(f"[Shoter] Failed to load CRISPR context: {e}") betas = { "z": args.beta_z, "t": args.beta_t, "a": args.beta_a, "m": args.beta_m, "c": args.beta_c } index = SpecificityIndex.from_csv(args.toxicity_csv) df = parse_all_toxins(args.all_toxins) # thresholds if args.min_identity > 0: df = df[df["identity01"].astype(float) >= float(args.min_identity)] if args.min_coverage > 0: df = df[df["coverage"].astype(float) >= float(args.min_coverage)] # unknown families handling if not args.allow_unknown_families: df = df[df["family_key"].astype(str) != "unknown"] # require index hit mapping if args.require_index_hit: def _has_index_orders(row) -> bool: name_key = str(row.get("Hit_id_norm", "")) fam = str(row.get("family_key", "")) d = index.orders_for_name_or_backoff(name_key) if not d: d = index.orders_for_name_or_backoff(fam) return bool(d) df = df[df.apply(_has_index_orders, axis=1)] # Handle empty DataFrame - preserve columns and create empty outputs if df.shape[0] == 0: print("[Shotter] No hits passed filters, creating empty output files") strains: List[str] = [] else: strains = sorted(df["Strain"].astype(str).unique().tolist()) all_hits: List[ToxinHit] = [] all_strain_scores: List[StrainScores] = [] all_species_scores: List[StrainSpeciesScores] = [] for strain in strains: sdf = df[df["Strain"].astype(str).eq(strain)].copy() per_hit, sscore, sspecies = score_strain( strain, sdf, index, crispr_associations=crispr_associations, crispr_weight=args.crispr_weight if args.crispr_fusion else 0.0, context_data=context_data, betas=betas ) all_hits.extend(per_hit) all_strain_scores.append(sscore) if sspecies is not None: all_species_scores.append(sspecies) # Always include the special buckets in outputs order_columns = sorted({*index.all_orders, "other", "unknown"}) or ["unknown"] species_columns = sorted(index.all_species) # Resolve a writable output directory (avoid PermissionError on Docker-owned dirs) out_dir = resolve_output_dir(args.output_dir) # Save via dataclass methods ToxinHit.save_list_tsv(out_dir / "toxin_support.tsv", all_hits, order_columns) StrainScores.save_list_tsv(out_dir / "strain_target_scores.tsv", all_strain_scores, order_columns) StrainScores.save_list_json(out_dir / "strain_scores.json", all_strain_scores) if species_columns and all_species_scores: StrainSpeciesScores.save_list_tsv(out_dir / "strain_target_species_scores.tsv", all_species_scores, species_columns) StrainSpeciesScores.save_list_json(out_dir / "strain_species_scores.json", all_species_scores) print(f"Saved: {out_dir / 'toxin_support.tsv'}") print(f"Saved: {out_dir / 'strain_target_scores.tsv'}") print(f"Saved: {out_dir / 'strain_scores.json'}") if species_columns and all_species_scores: print(f"Saved: {out_dir / 'strain_target_species_scores.tsv'}") print(f"Saved: {out_dir / 'strain_species_scores.json'}") if __name__ == "__main__": main()