- Backend: Refactored tasks.py to directly invoke run_single_fna_pipeline.py for consistency. - Backend: Changed output format to ZIP and added auto-cleanup of intermediate files. - Backend: Fixed language parameter passing in API and tasks. - Frontend: Removed CRISPR Fusion UI elements from Submit and Monitor views. - Frontend: Implemented simulated progress bar for better UX. - Frontend: Restored One-click load button and added result file structure documentation. - Docker: Fixed critical Restarting loop by removing incorrect image directive in docker-compose.yml. - Docker: Optimized Dockerfile to correct .pixi environment path issues and prevent accidental deletion of frontend assets.
879 lines
34 KiB
Python
879 lines
34 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Bttoxin_Shoter v1
|
||
|
||
- Read BPPRC specificity CSV (positive-only) to build name/family -> target_order distributions
|
||
- Read BtToxin_Digger All_Toxins.txt to collect hits per strain
|
||
- Compute per-hit similarity weight and order contributions
|
||
- Combine contributions to strain-level potential activity scores per insect order
|
||
- Save per-hit and per-strain results (TSV + JSON) using dataclasses with save methods
|
||
|
||
Assumptions and notes are documented in docs/shotter_workflow.md
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import json
|
||
import re
|
||
import math
|
||
from dataclasses import dataclass, asdict, field
|
||
import os
|
||
from pathlib import Path
|
||
from typing import Dict, List, Optional, Tuple
|
||
|
||
import pandas as pd
|
||
|
||
|
||
# -----------------------------
|
||
# Helpers for parsing families
|
||
# -----------------------------
|
||
FAMILY_PREFIXES = (
|
||
"Cry", "Cyt", "Vip", "Vpa", "Vpb", "Mpp", "Tpp", "Spp", "App",
|
||
"Mcf", "Mpf", "Pra", "Prb", "Txp", "Gpp", "Mtx", "Xpp",
|
||
)
|
||
FAMILY_RE = re.compile(r"(?i)(Cry|Cyt|Vip|Vpa|Vpb|Mpp|Tpp|Spp|App|Mcf|Mpf|Pra|Prb|Txp|Gpp|Mtx|Xpp)(\d+)([A-Za-z]*)")
|
||
|
||
|
||
def normalize_hit_id(hit_id: str) -> str:
|
||
"""Strip trailing markers like '-other', spaces, and keep core token.
|
||
Examples: 'Spp1Aa1' -> 'Spp1Aa1'; 'Bmp1-other' -> 'Bmp1'
|
||
"""
|
||
if not isinstance(hit_id, str):
|
||
return ""
|
||
x = hit_id.strip()
|
||
x = re.sub(r"-other$", "", x)
|
||
x = re.sub(r"\s+", "", x)
|
||
return x
|
||
|
||
|
||
def extract_family_keys(name: str) -> Tuple[Optional[str], Optional[str]]:
|
||
"""From a protein name like 'Cry1Ac1' or '5618-Cry1Ia' extract:
|
||
- family_key: e.g., 'Cry1'
|
||
- subfamily_key: e.g., 'Cry1A' (first subclass letter if present)
|
||
Returns (None, None) if no family prefix is found.
|
||
"""
|
||
if not isinstance(name, str) or not name:
|
||
return None, None
|
||
# Try direct match first
|
||
m = FAMILY_RE.search(name)
|
||
if not m:
|
||
return None, None
|
||
fam = m.group(1)
|
||
num = m.group(2)
|
||
sub = m.group(3) # may be ''
|
||
family_key = f"{fam}{num}"
|
||
subfamily_key = f"{fam}{num}{sub[:1]}" if sub else None
|
||
return family_key, subfamily_key
|
||
|
||
|
||
# ---------------------------------
|
||
# Specificity index from CSV (BPPRC)
|
||
# ---------------------------------
|
||
@dataclass
|
||
class SpecificityIndex:
|
||
# Distributions P(order | name) and P(order | family/subfamily)
|
||
name_to_orders: Dict[str, Dict[str, float]] = field(default_factory=dict)
|
||
subfam_to_orders: Dict[str, Dict[str, float]] = field(default_factory=dict)
|
||
fam_to_orders: Dict[str, Dict[str, float]] = field(default_factory=dict)
|
||
# Set of all observed target orders
|
||
all_orders: List[str] = field(default_factory=list)
|
||
# Distributions P(species | name) and P(species | family/subfamily)
|
||
name_to_species: Dict[str, Dict[str, float]] = field(default_factory=dict)
|
||
subfam_to_species: Dict[str, Dict[str, float]] = field(default_factory=dict)
|
||
fam_to_species: Dict[str, Dict[str, float]] = field(default_factory=dict)
|
||
# Set of all observed target species
|
||
all_species: List[str] = field(default_factory=list)
|
||
|
||
# Partner requirement heuristics at family level
|
||
partner_pairs: List[Tuple[str, str]] = field(default_factory=lambda: [
|
||
("Vip1", "Vip2"),
|
||
("Vpa", "Vpb"),
|
||
("BinA", "BinB"), # included for completeness if ever present
|
||
])
|
||
|
||
@staticmethod
|
||
def _potency_from_row(row: pd.Series) -> Optional[float]:
|
||
"""Map quantitative fields to a potency weight in [0,1]. Positive-only.
|
||
Rules:
|
||
- lc50: will be normalized later within unit-buckets; here return the numeric.
|
||
- percentage_mortality: map ranges to coarse weights.
|
||
- if both missing but activity is positive: default 0.55
|
||
"""
|
||
perc = str(row.get("percentage_mortality") or "").strip().lower()
|
||
lc50 = row.get("lc50")
|
||
if pd.notnull(lc50):
|
||
try:
|
||
return float(lc50) # raw, to be normalized later
|
||
except Exception:
|
||
pass
|
||
if perc:
|
||
# coarse mappings
|
||
if ">80%" in perc or "80-100%" in perc or "greater than 80%" in perc:
|
||
return 0.9
|
||
if "60-100%" in perc or "60-80%" in perc or "50-80%" in perc:
|
||
return 0.65
|
||
if "0-60%" in perc or "0-50%" in perc or "25%" in perc or "some" in perc or "stunting" in perc:
|
||
return 0.25
|
||
# default positive evidence
|
||
return 0.55
|
||
|
||
@classmethod
|
||
def from_csv(cls, csv_path: Path) -> "SpecificityIndex":
|
||
df = pd.read_csv(csv_path)
|
||
# Positive-only evidence
|
||
# many rows have activity 'Yes' and non_toxic empty; use only activity==Yes
|
||
df = df[df["activity"].astype(str).str.lower().eq("yes")] # keep positives
|
||
|
||
# Normalize units bucket for lc50 normalization
|
||
# unify ppm -> ug_per_g (diet context). Others stay in own unit bucket.
|
||
units = df["units"].astype(str).str.strip().str.lower()
|
||
ug_per_g_mask = units.isin(["µg/g", "ug/g", "μg/g", "ppm"]) # ppm ~ ug/g in diet
|
||
unit_bucket = units.where(~ug_per_g_mask, other="ug_per_g")
|
||
df = df.assign(_unit_bucket=unit_bucket)
|
||
|
||
# Compute potency: two cases: lc50 (to be inverted/normalized) vs categorical
|
||
df["_potency_raw"] = df.apply(cls._potency_from_row, axis=1)
|
||
|
||
# For rows with numeric raw potency (lc50), do within-bucket inverted quantile
|
||
for bucket, sub in df.groupby("_unit_bucket"):
|
||
# separate numeric vs categorical
|
||
is_num = pd.to_numeric(sub["_potency_raw"], errors="coerce").notnull()
|
||
num_idx = sub.index[is_num]
|
||
cat_idx = sub.index[~is_num]
|
||
# smaller LC50 => stronger potency
|
||
if len(num_idx) > 0:
|
||
vals = sub.loc[num_idx, "_potency_raw"].astype(float)
|
||
ranks = vals.rank(method="average", pct=True)
|
||
inv = 1.0 - ranks # 0..1
|
||
df.loc[num_idx, "_potency"] = inv.values
|
||
# categorical rows keep their 0-1 scores (already in _potency_raw)
|
||
if len(cat_idx) > 0:
|
||
df.loc[cat_idx, "_potency"] = sub.loc[cat_idx, "_potency_raw"].values
|
||
df["_potency"] = pd.to_numeric(df["_potency"], errors="coerce").fillna(0.55)
|
||
|
||
# Extract family/subfamily keys from name
|
||
# Name fields can be like 'Spp1Aa1', 'Cry1Ac1', or '5618-Cry1Ia', etc.
|
||
fam_keys: List[Optional[str]] = []
|
||
subfam_keys: List[Optional[str]] = []
|
||
for name in df["name"].astype(str).tolist():
|
||
fam, subfam = extract_family_keys(name)
|
||
fam_keys.append(fam)
|
||
subfam_keys.append(subfam)
|
||
df = df.assign(_family=fam_keys, _subfamily=subfam_keys)
|
||
|
||
# Aggregate per protein name (exact)
|
||
def agg_distribution(group: pd.DataFrame) -> Dict[str, float]:
|
||
# Sum potencies per target_order, then normalize
|
||
s = group.groupby("target_order")["_potency"].sum()
|
||
if s.sum() <= 0:
|
||
return {}
|
||
d = (s / s.sum()).to_dict()
|
||
return d
|
||
|
||
name_to_orders = (
|
||
df.groupby("name", dropna=False, group_keys=False)
|
||
.apply(agg_distribution, include_groups=False)
|
||
.to_dict()
|
||
)
|
||
|
||
# Aggregate per subfamily and family as back-off
|
||
subfam_to_orders = (
|
||
df.dropna(subset=["_subfamily"]).groupby("_subfamily", group_keys=False)
|
||
.apply(agg_distribution, include_groups=False)
|
||
.to_dict()
|
||
)
|
||
fam_to_orders = (
|
||
df.dropna(subset=["_family"]).groupby("_family", group_keys=False)
|
||
.apply(agg_distribution, include_groups=False)
|
||
.to_dict()
|
||
)
|
||
|
||
# Species distributions (optional if target_species present)
|
||
def agg_distribution_species(group: pd.DataFrame) -> Dict[str, float]:
|
||
s = group.groupby("target_species")["_potency"].sum()
|
||
if s.sum() <= 0:
|
||
return {}
|
||
d = (s / s.sum()).to_dict()
|
||
return d
|
||
|
||
name_to_species = (
|
||
df.groupby("name", dropna=False, group_keys=False)
|
||
.apply(agg_distribution_species, include_groups=False)
|
||
.to_dict()
|
||
)
|
||
subfam_to_species = (
|
||
df.dropna(subset=["_subfamily"]).groupby("_subfamily", group_keys=False)
|
||
.apply(agg_distribution_species, include_groups=False)
|
||
.to_dict()
|
||
)
|
||
fam_to_species = (
|
||
df.dropna(subset=["_family"]).groupby("_family", group_keys=False)
|
||
.apply(agg_distribution_species, include_groups=False)
|
||
.to_dict()
|
||
)
|
||
|
||
# collect all orders/species observed
|
||
all_orders = sorted({str(x) for x in df["target_order"].dropna().unique().tolist()})
|
||
all_species = sorted({str(x) for x in df["target_species"].dropna().unique().tolist()})
|
||
|
||
return cls(
|
||
name_to_orders=name_to_orders,
|
||
subfam_to_orders=subfam_to_orders,
|
||
fam_to_orders=fam_to_orders,
|
||
all_orders=all_orders,
|
||
name_to_species=name_to_species,
|
||
subfam_to_species=subfam_to_species,
|
||
fam_to_species=fam_to_species,
|
||
all_species=all_species,
|
||
)
|
||
|
||
def partner_needed_for_family(self, family_or_subfam: str) -> Optional[str]:
|
||
# Heuristic partner rule at family level
|
||
for a, b in self.partner_pairs:
|
||
if family_or_subfam.startswith(a):
|
||
return b
|
||
if family_or_subfam.startswith(b):
|
||
return a
|
||
return None
|
||
|
||
def orders_for_name_or_backoff(self, name_or_family: str) -> Dict[str, float]:
|
||
# Try exact name, then subfamily, then family
|
||
if name_or_family in self.name_to_orders:
|
||
return self.name_to_orders[name_or_family]
|
||
fam, subfam = extract_family_keys(name_or_family)
|
||
if subfam and subfam in self.subfam_to_orders:
|
||
return self.subfam_to_orders[subfam]
|
||
if fam and fam in self.fam_to_orders:
|
||
return self.fam_to_orders[fam]
|
||
return {}
|
||
|
||
def species_for_name_or_backoff(self, name_or_family: str) -> Dict[str, float]:
|
||
# Try exact name, then subfamily, then family
|
||
if name_or_family in self.name_to_species:
|
||
return self.name_to_species[name_or_family]
|
||
fam, subfam = extract_family_keys(name_or_family)
|
||
if subfam and subfam in self.subfam_to_species:
|
||
return self.subfam_to_species[subfam]
|
||
if fam and fam in self.fam_to_species:
|
||
return self.fam_to_species[fam]
|
||
return {}
|
||
|
||
|
||
# -----------------------------
|
||
# Dataclasses for outputs
|
||
# -----------------------------
|
||
@dataclass
|
||
class ToxinHit:
|
||
strain: str
|
||
protein_id: str
|
||
hit_id: str
|
||
identity: float # 0..1
|
||
aln_length: int
|
||
hit_length: int
|
||
coverage: float # 0..1
|
||
hmm: bool
|
||
family: str
|
||
name_key: str
|
||
partner_fulfilled: bool
|
||
weight: float
|
||
order_contribs: Dict[str, float]
|
||
top_order: str
|
||
top_score: float
|
||
|
||
def to_tsv_row(self, order_columns: List[str]) -> List[str]:
|
||
cols = [
|
||
self.strain,
|
||
self.protein_id,
|
||
self.hit_id,
|
||
f"{self.identity:.4f}",
|
||
str(self.aln_length),
|
||
str(self.hit_length),
|
||
f"{self.coverage:.4f}",
|
||
"YES" if self.hmm else "NO",
|
||
self.family or "unknown",
|
||
self.name_key or "",
|
||
"YES" if self.partner_fulfilled else "NO",
|
||
f"{self.weight:.4f}",
|
||
self.top_order or "",
|
||
f"{self.top_score:.6f}",
|
||
]
|
||
for o in order_columns:
|
||
v = self.order_contribs.get(o, 0.0)
|
||
cols.append(f"{v:.6f}")
|
||
return cols
|
||
|
||
@staticmethod
|
||
def save_list_tsv(tsv_path: Path, items: List["ToxinHit"], order_columns: List[str]) -> None:
|
||
tsv_path.parent.mkdir(parents=True, exist_ok=True)
|
||
header = [
|
||
"Strain", "Protein_id", "Hit_id", "Identity", "Aln_length", "Hit_length", "Coverage", "HMM",
|
||
"Family", "NameKey", "PartnerFulfilled", "Weight", "TopOrder", "TopScore",
|
||
] + order_columns
|
||
with tsv_path.open("w", encoding="utf-8") as f:
|
||
f.write("\t".join(header) + "\n")
|
||
for h in items:
|
||
f.write("\t".join(h.to_tsv_row(order_columns)) + "\n")
|
||
|
||
|
||
@dataclass
|
||
class StrainScores:
|
||
strain: str
|
||
order_scores: Dict[str, float]
|
||
top_order: str
|
||
top_score: float
|
||
|
||
def to_tsv_row(self, order_columns: List[str]) -> List[str]:
|
||
return [self.strain, self.top_order or "", f"{self.top_score:.6f}"] + [f"{self.order_scores.get(o, 0.0):.6f}" for o in order_columns]
|
||
|
||
@staticmethod
|
||
def save_list_tsv(tsv_path: Path, items: List["StrainScores"], order_columns: List[str]) -> None:
|
||
tsv_path.parent.mkdir(parents=True, exist_ok=True)
|
||
header = ["Strain", "TopOrder", "TopScore"] + order_columns
|
||
with tsv_path.open("w", encoding="utf-8") as f:
|
||
f.write("\t".join(header) + "\n")
|
||
for s in items:
|
||
f.write("\t".join(s.to_tsv_row(order_columns)) + "\n")
|
||
|
||
@staticmethod
|
||
def save_list_json(json_path: Path, items: List["StrainScores"]) -> None:
|
||
json_path.parent.mkdir(parents=True, exist_ok=True)
|
||
data = [asdict(s) for s in items]
|
||
with json_path.open("w", encoding="utf-8") as f:
|
||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||
|
||
|
||
@dataclass
|
||
class StrainSpeciesScores:
|
||
strain: str
|
||
species_scores: Dict[str, float]
|
||
top_species: str
|
||
top_species_score: float
|
||
|
||
def to_tsv_row(self, species_columns: List[str]) -> List[str]:
|
||
return [self.strain, self.top_species or "", f"{self.top_species_score:.6f}"] + [
|
||
f"{self.species_scores.get(sp, 0.0):.6f}" for sp in species_columns
|
||
]
|
||
|
||
@staticmethod
|
||
def save_list_tsv(tsv_path: Path, items: List["StrainSpeciesScores"], species_columns: List[str]) -> None:
|
||
tsv_path.parent.mkdir(parents=True, exist_ok=True)
|
||
header = ["Strain", "TopSpecies", "TopSpeciesScore"] + species_columns
|
||
with tsv_path.open("w", encoding="utf-8") as f:
|
||
f.write("\t".join(header) + "\n")
|
||
for s in items:
|
||
f.write("\t".join(s.to_tsv_row(species_columns)) + "\n")
|
||
|
||
@staticmethod
|
||
def save_list_json(json_path: Path, items: List["StrainSpeciesScores"]) -> None:
|
||
json_path.parent.mkdir(parents=True, exist_ok=True)
|
||
data = [asdict(s) for s in items]
|
||
with json_path.open("w", encoding="utf-8") as f:
|
||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||
|
||
|
||
# -----------------------------
|
||
# Parser and scoring
|
||
# -----------------------------
|
||
|
||
def compute_similarity_weight(identity: float, coverage: float, hmm: bool) -> float:
|
||
"""Weight formula (bounded [0,1]):
|
||
base(identity, coverage) × coverage, with HMM bonus up to 1.0
|
||
- If identity >= 0.78 and coverage >= 0.8: base = 1.0
|
||
- If 0.45 <= identity < 0.78: base = (identity-0.45)/(0.78-0.45)
|
||
- Else: base = 0
|
||
Final w = min(1.0, base * coverage + (0.1 if hmm else 0.0))
|
||
"""
|
||
base = 0.0
|
||
if identity >= 0.78 and coverage >= 0.8:
|
||
base = 1.0
|
||
elif 0.45 <= identity < 0.78:
|
||
base = (identity - 0.45) / (0.78 - 0.45)
|
||
base = max(0.0, min(1.0, base))
|
||
else:
|
||
base = 0.0
|
||
w = base * max(0.0, min(1.0, coverage))
|
||
if hmm:
|
||
w = min(1.0, w + 0.1)
|
||
return float(max(0.0, min(1.0, w)))
|
||
|
||
|
||
def parse_all_toxins(tsv_path: Path) -> pd.DataFrame:
|
||
df = pd.read_csv(tsv_path, sep="\t", dtype=str, engine="python")
|
||
# Coerce needed fields
|
||
for col in ["Identity", "Aln_length", "Hit_length"]:
|
||
df[col] = pd.to_numeric(df[col], errors="coerce")
|
||
df["Identity"] = df["Identity"].astype(float)
|
||
df["Aln_length"] = df["Aln_length"].fillna(0).astype(int)
|
||
df["Hit_length"] = df["Hit_length"].fillna(0).astype(int)
|
||
df["coverage"] = (df["Aln_length"].clip(lower=0) / df["Hit_length"].replace(0, pd.NA)).fillna(0.0)
|
||
df["identity01"] = (df["Identity"].fillna(0.0) / 100.0).clip(lower=0.0, upper=1.0)
|
||
df["HMM"] = df["HMM"].astype(str).str.upper().eq("YES")
|
||
|
||
# Normalize hit_id and family
|
||
df["Hit_id_norm"] = df["Hit_id"].astype(str).map(normalize_hit_id)
|
||
|
||
fams = []
|
||
for hit in df["Hit_id_norm"].tolist():
|
||
fam, _ = extract_family_keys(hit)
|
||
fams.append(fam or "unknown")
|
||
df["family_key"] = fams
|
||
|
||
return df
|
||
|
||
|
||
def partner_fulfilled_for_hit(hit_family: str, strain_df: pd.DataFrame, index: SpecificityIndex) -> bool:
|
||
partner = index.partner_needed_for_family(hit_family or "")
|
||
if not partner:
|
||
return True
|
||
# partner satisfied if any other hit in same strain startswith partner family and has weight >= 0.3
|
||
for _, row in strain_df.iterrows():
|
||
fam = row.get("family_key", "") or ""
|
||
if fam.startswith(partner):
|
||
w = compute_similarity_weight(
|
||
float(row.get("identity01", 0.0)), float(row.get("coverage", 0.0)), bool(row.get("HMM", False))
|
||
)
|
||
if w >= 0.3:
|
||
return True
|
||
return False
|
||
|
||
|
||
def apply_logit_prior(
|
||
score: float,
|
||
bgc_data: Dict[str, int],
|
||
mobilome_data: Dict[str, int],
|
||
crispr_state: int,
|
||
betas: Dict[str, float]
|
||
) -> float:
|
||
"""
|
||
Apply Logit Prior adjustment: S_final = sigmoid(logit(S_tox) + Delta)
|
||
Delta = beta_z*b_z + beta_t*b_t + beta_a*b_a + beta_m*g(m) + beta_c*h(c)
|
||
"""
|
||
if score <= 0: return 0.0
|
||
if score >= 1: return 1.0
|
||
|
||
# epsilon to avoid inf
|
||
epsilon = 1e-6
|
||
p = max(epsilon, min(1.0 - epsilon, score))
|
||
logit_p = math.log(p / (1.0 - p))
|
||
|
||
# BGC
|
||
b_z = bgc_data.get("ZWA", 0)
|
||
b_t = bgc_data.get("Thu", 0)
|
||
b_a = bgc_data.get("TAA", 0)
|
||
|
||
# Mobilome: g(m) = ln(1+m)
|
||
m = mobilome_data.get("mobile_elements_count", 0)
|
||
g_m = math.log(1.0 + m)
|
||
|
||
# CRISPR: h(c) = 1 - c/2 (0->1, 1->0.5, 2->0)
|
||
# c in [0, 1, 2]
|
||
c = crispr_state
|
||
h_c = 1.0 - (c / 2.0)
|
||
|
||
delta = (betas["z"] * b_z) + \
|
||
(betas["t"] * b_t) + \
|
||
(betas["a"] * b_a) + \
|
||
(betas["m"] * g_m) + \
|
||
(betas["c"] * h_c)
|
||
|
||
final_logit = logit_p + delta
|
||
return 1.0 / (1.0 + math.exp(-final_logit))
|
||
|
||
|
||
def score_strain(
|
||
strain: str,
|
||
sdf: pd.DataFrame,
|
||
index: SpecificityIndex,
|
||
crispr_associations: Dict[str, Any] = None,
|
||
crispr_weight: float = 0.0,
|
||
context_data: Dict[str, Any] = None,
|
||
betas: Dict[str, float] = None
|
||
) -> Tuple[List[ToxinHit], StrainScores, Optional[StrainSpeciesScores]]:
|
||
# Include special buckets per requirements
|
||
order_set = sorted({*index.all_orders, "other", "unknown"})
|
||
per_hit: List[ToxinHit] = []
|
||
|
||
# Collect contributions per order by combining across hits
|
||
# We'll accumulate 1 - Π(1 - contrib)
|
||
one_minus = {o: 1.0 for o in order_set}
|
||
# Species accumulation if available
|
||
species_set = sorted(index.all_species) if index.all_species else []
|
||
sp_one_minus = {sp: 1.0 for sp in species_set}
|
||
|
||
for _, row in sdf.iterrows():
|
||
hit_id = str(row.get("Hit_id_norm", ""))
|
||
protein_id = str(row.get("Protein_id", ""))
|
||
identity01 = float(row.get("identity01", 0.0))
|
||
coverage = float(row.get("coverage", 0.0))
|
||
hmm = bool(row.get("HMM", False))
|
||
family = str(row.get("family_key", "")) or "unknown"
|
||
|
||
# Choose name backoff distribution
|
||
name_key = hit_id
|
||
order_dist = index.orders_for_name_or_backoff(name_key)
|
||
if not order_dist:
|
||
# backoff to family
|
||
order_dist = index.orders_for_name_or_backoff(family)
|
||
if not order_dist:
|
||
# Route by whether we have a parsable family
|
||
# - parsable family but no evidence -> 'other'
|
||
# - unparseable family (family == 'unknown') -> 'unknown'
|
||
if family != "unknown":
|
||
order_dist = {"other": 1.0}
|
||
else:
|
||
order_dist = {"unknown": 1.0}
|
||
|
||
# Compute weight with partner handling
|
||
w = compute_similarity_weight(identity01, coverage, hmm)
|
||
fulfilled = partner_fulfilled_for_hit(family, sdf, index)
|
||
if not fulfilled:
|
||
w *= 0.2 # partner penalty if required but not present
|
||
|
||
# CRISPR Boost (Hit-level)
|
||
if crispr_associations and crispr_weight > 0:
|
||
# Check if this toxin is associated with CRISPR
|
||
# keys in crispr_associations are toxin names
|
||
assoc = crispr_associations.get(hit_id) or crispr_associations.get(name_key)
|
||
if assoc:
|
||
w = min(1.0, w + crispr_weight)
|
||
|
||
contribs: Dict[str, float] = {}
|
||
for o in order_set:
|
||
p = float(order_dist.get(o, 0.0))
|
||
c = w * p
|
||
if c > 0:
|
||
one_minus[o] *= (1.0 - c)
|
||
contribs[o] = c
|
||
|
||
# species contributions (only for known distributions; no other/unknown buckets)
|
||
if species_set:
|
||
sp_dist = index.species_for_name_or_backoff(name_key)
|
||
if not sp_dist:
|
||
sp_dist = index.species_for_name_or_backoff(family)
|
||
if sp_dist:
|
||
for sp in species_set:
|
||
psp = float(sp_dist.get(sp, 0.0))
|
||
csp = w * psp
|
||
if csp > 0:
|
||
sp_one_minus[sp] *= (1.0 - csp)
|
||
|
||
# top for this hit (allow other/unknown if that's all we have)
|
||
if contribs:
|
||
hit_top_order, hit_top_score = max(contribs.items(), key=lambda kv: kv[1])
|
||
else:
|
||
hit_top_order, hit_top_score = "", 0.0
|
||
|
||
per_hit.append(
|
||
ToxinHit(
|
||
strain=strain,
|
||
protein_id=protein_id,
|
||
hit_id=hit_id,
|
||
identity=identity01,
|
||
aln_length=int(row.get("Aln_length", 0)),
|
||
hit_length=int(row.get("Hit_length", 0)),
|
||
coverage=coverage,
|
||
hmm=hmm,
|
||
family=family,
|
||
name_key=name_key,
|
||
partner_fulfilled=fulfilled,
|
||
weight=w,
|
||
order_contribs=contribs,
|
||
top_order=hit_top_order,
|
||
top_score=float(hit_top_score),
|
||
)
|
||
)
|
||
|
||
# 1. Calculate S_tox (Noisy-OR)
|
||
order_scores = {o: (1.0 - one_minus[o]) for o in order_set}
|
||
|
||
# 2. Apply Logit Prior if context data is present
|
||
if context_data and betas:
|
||
bgc = context_data.get("bgc", {})
|
||
mobi = context_data.get("mobilome", {})
|
||
crispr_st = context_data.get("crispr_state", 0)
|
||
|
||
for o in order_scores:
|
||
s_tox = order_scores[o]
|
||
if s_tox > 0:
|
||
s_final = apply_logit_prior(s_tox, bgc, mobi, crispr_st, betas)
|
||
order_scores[o] = s_final
|
||
|
||
# choose top order excluding unresolved buckets if possible
|
||
preferred = [o for o in index.all_orders if o in order_scores]
|
||
if not preferred:
|
||
preferred = [o for o in order_set if o not in ("other", "unknown")]
|
||
if preferred:
|
||
top_o = max(preferred, key=lambda o: order_scores.get(o, 0.0))
|
||
top_s = float(order_scores.get(top_o, 0.0))
|
||
else:
|
||
top_o, top_s = "", 0.0
|
||
strain_scores = StrainScores(strain=strain, order_scores=order_scores, top_order=top_o, top_score=top_s)
|
||
|
||
species_scores_obj: Optional[StrainSpeciesScores] = None
|
||
if species_set:
|
||
species_scores = {sp: (1.0 - sp_one_minus[sp]) for sp in species_set}
|
||
|
||
# Apply Logit Prior to Species scores too?
|
||
# The math doc focuses on "Strain x Order", but presumably it applies to species too if we follow the logic.
|
||
# "combine evidence... then prior". Let's apply it for consistency.
|
||
if context_data and betas:
|
||
bgc = context_data.get("bgc", {})
|
||
mobi = context_data.get("mobilome", {})
|
||
crispr_st = context_data.get("crispr_state", 0)
|
||
for sp in species_scores:
|
||
s_tox = species_scores[sp]
|
||
if s_tox > 0:
|
||
s_final = apply_logit_prior(s_tox, bgc, mobi, crispr_st, betas)
|
||
species_scores[sp] = s_final
|
||
|
||
if species_scores:
|
||
top_sp = max(species_set, key=lambda sp: species_scores.get(sp, 0.0))
|
||
top_sp_score = float(species_scores.get(top_sp, 0.0))
|
||
else:
|
||
top_sp, top_sp_score = "", 0.0
|
||
species_scores_obj = StrainSpeciesScores(
|
||
strain=strain,
|
||
species_scores=species_scores,
|
||
top_species=top_sp,
|
||
top_species_score=top_sp_score,
|
||
)
|
||
|
||
return per_hit, strain_scores, species_scores_obj
|
||
|
||
|
||
# -----------------------------
|
||
# Saving utilities
|
||
# -----------------------------
|
||
|
||
def save_toxin_support(tsv_path: Path, hits: List[ToxinHit], order_columns: List[str]) -> None:
|
||
tsv_path.parent.mkdir(parents=True, exist_ok=True)
|
||
header = [
|
||
"Strain", "Protein_id", "Hit_id", "Identity", "Aln_length", "Hit_length", "Coverage", "HMM",
|
||
"Family", "NameKey", "PartnerFulfilled", "Weight", "TopOrder", "TopScore",
|
||
] + order_columns
|
||
with tsv_path.open("w", encoding="utf-8") as f:
|
||
f.write("\t".join(header) + "\n")
|
||
for h in hits:
|
||
f.write("\t".join(h.to_tsv_row(order_columns)) + "\n")
|
||
|
||
|
||
def save_strain_targets(tsv_path: Path, strains: List[StrainScores], order_columns: List[str]) -> None:
|
||
tsv_path.parent.mkdir(parents=True, exist_ok=True)
|
||
header = ["Strain", "TopOrder", "TopScore"] + order_columns
|
||
with tsv_path.open("w", encoding="utf-8") as f:
|
||
f.write("\t".join(header) + "\n")
|
||
for s in strains:
|
||
f.write("\t".join(s.to_tsv_row(order_columns)) + "\n")
|
||
|
||
|
||
def save_strain_json(json_path: Path, strains: List[StrainScores]) -> None:
|
||
json_path.parent.mkdir(parents=True, exist_ok=True)
|
||
data = [asdict(s) for s in strains]
|
||
with json_path.open("w", encoding="utf-8") as f:
|
||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||
|
||
|
||
# -----------------------------
|
||
# CLI
|
||
# -----------------------------
|
||
|
||
def _try_writable_dir(p: Path) -> bool:
|
||
try:
|
||
p.mkdir(parents=True, exist_ok=True)
|
||
except Exception:
|
||
return False
|
||
try:
|
||
probe = p / ".__shotter_write_test__"
|
||
with probe.open("w", encoding="utf-8") as f:
|
||
f.write("ok")
|
||
probe.unlink(missing_ok=True)
|
||
return True
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
def resolve_output_dir(preferred: Path) -> Path:
|
||
# 1) try preferred
|
||
if _try_writable_dir(preferred):
|
||
return preferred
|
||
# 2) try a subfolder 'Shotter' under preferred
|
||
sub = preferred / "Shotter"
|
||
if _try_writable_dir(sub):
|
||
print(f"[Shotter] Output dir not writable: {preferred}. Falling back to: {sub}")
|
||
return sub
|
||
# 3) fallback to cwd/shotter_outputs
|
||
alt = Path.cwd() / "shotter_outputs"
|
||
if _try_writable_dir(alt):
|
||
print(f"[Shotter] Output dir not writable: {preferred}. Falling back to: {alt}")
|
||
return alt
|
||
# 4) last resort: preferred (will likely fail but we return it)
|
||
print(f"[Shotter] WARNING: could not find writable output directory. Using {preferred} (may fail).")
|
||
return preferred
|
||
|
||
def main():
|
||
ap = argparse.ArgumentParser(description="Bttoxin_Shoter: infer target orders from BtToxin_Digger outputs")
|
||
ap.add_argument("--toxicity_csv", type=Path, default=Path("Data/toxicity-data.csv"))
|
||
ap.add_argument("--all_toxins", type=Path, default=Path("tests/output/Results/Toxins/All_Toxins.txt"))
|
||
ap.add_argument("--output_dir", type=Path, default=Path("tests/output/Results/Toxins"))
|
||
# Filtering and thresholds
|
||
ap.add_argument("--allow_unknown_families", dest="allow_unknown_families", action="store_true")
|
||
ap.add_argument("--disallow_unknown_families", dest="allow_unknown_families", action="store_false")
|
||
ap.set_defaults(allow_unknown_families=True)
|
||
ap.add_argument("--require_index_hit", action="store_true", default=False,
|
||
help="Keep only hits that map to a known name/subfamily/family in the specificity index")
|
||
ap.add_argument("--min_identity", type=float, default=0.0, help="Minimum identity (0-1) to keep a hit")
|
||
ap.add_argument("--min_coverage", type=float, default=0.0, help="Minimum coverage (0-1) to keep a hit")
|
||
|
||
# CRISPR Integration
|
||
ap.add_argument("--crispr_results", type=Path, default=None, help="Path to CRISPR Fusion analysis results (JSON) for hit-level boost")
|
||
ap.add_argument("--crispr_fusion", action="store_true", help="Use fusion analysis results for stronger evidence")
|
||
ap.add_argument("--crispr_weight", type=float, default=0.1, help="Weight boost for CRISPR-associated toxins (0-1)")
|
||
|
||
# Genome Context Priors (Strain-level)
|
||
ap.add_argument("--context_bgc", type=Path, default=None, help="Path to BGC detection results (JSON)")
|
||
ap.add_argument("--context_mobilome", type=Path, default=None, help="Path to Mobilome analysis results (JSON)")
|
||
ap.add_argument("--context_crispr", type=Path, default=None, help="Path to CRISPR detection results (JSON) for strain-level prior")
|
||
|
||
# Prior Weights (Defaults from math doc)
|
||
ap.add_argument("--beta_z", type=float, default=0.5, help="Weight for ZWA BGC")
|
||
ap.add_argument("--beta_t", type=float, default=0.5, help="Weight for Thu BGC")
|
||
ap.add_argument("--beta_a", type=float, default=0.5, help="Weight for TAA BGC")
|
||
ap.add_argument("--beta_m", type=float, default=0.5, help="Weight for Mobilome")
|
||
ap.add_argument("--beta_c", type=float, default=0.5, help="Weight for CRISPR state")
|
||
|
||
args = ap.parse_args()
|
||
|
||
# Load CRISPR data if available
|
||
crispr_associations = {}
|
||
if args.crispr_results and args.crispr_results.exists():
|
||
try:
|
||
with open(args.crispr_results) as f:
|
||
cdata = json.load(f)
|
||
# If fusion analysis results (has 'associations')
|
||
if "associations" in cdata:
|
||
for assoc in cdata["associations"]:
|
||
toxin_name = assoc.get("toxin")
|
||
if toxin_name:
|
||
# Normalize name if possible or keep as is.
|
||
# Digger outputs might have variants, but fusion usually uses specific names.
|
||
crispr_associations[toxin_name] = assoc
|
||
print(f"[Shoter] Loaded {len(crispr_associations)} CRISPR associations")
|
||
except Exception as e:
|
||
print(f"[Shoter] Failed to load CRISPR results: {e}")
|
||
|
||
# Load Prior Data
|
||
context_data = {
|
||
"bgc": {},
|
||
"mobilome": {},
|
||
"crispr_state": 0
|
||
}
|
||
|
||
# BGC
|
||
if args.context_bgc and args.context_bgc.exists():
|
||
try:
|
||
with open(args.context_bgc) as f:
|
||
context_data["bgc"] = json.load(f)
|
||
print(f"[Shoter] Loaded BGC context: {context_data['bgc']}")
|
||
except Exception as e:
|
||
print(f"[Shoter] Failed to load BGC context: {e}")
|
||
|
||
# Mobilome
|
||
if args.context_mobilome and args.context_mobilome.exists():
|
||
try:
|
||
with open(args.context_mobilome) as f:
|
||
context_data["mobilome"] = json.load(f)
|
||
print(f"[Shoter] Loaded Mobilome context: {context_data['mobilome']}")
|
||
except Exception as e:
|
||
print(f"[Shoter] Failed to load Mobilome context: {e}")
|
||
|
||
# CRISPR State
|
||
if args.context_crispr and args.context_crispr.exists():
|
||
try:
|
||
with open(args.context_crispr) as f:
|
||
cdata = json.load(f)
|
||
context_data["crispr_state"] = cdata.get("crispr_state", 0)
|
||
print(f"[Shoter] Loaded CRISPR state: {context_data['crispr_state']}")
|
||
except Exception as e:
|
||
print(f"[Shoter] Failed to load CRISPR context: {e}")
|
||
|
||
betas = {
|
||
"z": args.beta_z,
|
||
"t": args.beta_t,
|
||
"a": args.beta_a,
|
||
"m": args.beta_m,
|
||
"c": args.beta_c
|
||
}
|
||
|
||
index = SpecificityIndex.from_csv(args.toxicity_csv)
|
||
|
||
df = parse_all_toxins(args.all_toxins)
|
||
# thresholds
|
||
if args.min_identity > 0:
|
||
df = df[df["identity01"].astype(float) >= float(args.min_identity)]
|
||
if args.min_coverage > 0:
|
||
df = df[df["coverage"].astype(float) >= float(args.min_coverage)]
|
||
# unknown families handling
|
||
if not args.allow_unknown_families:
|
||
df = df[df["family_key"].astype(str) != "unknown"]
|
||
# require index hit mapping
|
||
if args.require_index_hit:
|
||
def _has_index_orders(row) -> bool:
|
||
name_key = str(row.get("Hit_id_norm", ""))
|
||
fam = str(row.get("family_key", ""))
|
||
d = index.orders_for_name_or_backoff(name_key)
|
||
if not d:
|
||
d = index.orders_for_name_or_backoff(fam)
|
||
return bool(d)
|
||
df = df[df.apply(_has_index_orders, axis=1)]
|
||
|
||
# Handle empty DataFrame - preserve columns and create empty outputs
|
||
if df.shape[0] == 0:
|
||
print("[Shotter] No hits passed filters, creating empty output files")
|
||
strains: List[str] = []
|
||
else:
|
||
strains = sorted(df["Strain"].astype(str).unique().tolist())
|
||
|
||
all_hits: List[ToxinHit] = []
|
||
all_strain_scores: List[StrainScores] = []
|
||
all_species_scores: List[StrainSpeciesScores] = []
|
||
|
||
for strain in strains:
|
||
sdf = df[df["Strain"].astype(str).eq(strain)].copy()
|
||
per_hit, sscore, sspecies = score_strain(
|
||
strain, sdf, index,
|
||
crispr_associations=crispr_associations,
|
||
crispr_weight=args.crispr_weight if args.crispr_fusion else 0.0,
|
||
context_data=context_data,
|
||
betas=betas
|
||
)
|
||
all_hits.extend(per_hit)
|
||
all_strain_scores.append(sscore)
|
||
if sspecies is not None:
|
||
all_species_scores.append(sspecies)
|
||
|
||
# Always include the special buckets in outputs
|
||
order_columns = sorted({*index.all_orders, "other", "unknown"}) or ["unknown"]
|
||
species_columns = sorted(index.all_species)
|
||
|
||
# Resolve a writable output directory (avoid PermissionError on Docker-owned dirs)
|
||
out_dir = resolve_output_dir(args.output_dir)
|
||
|
||
# Save via dataclass methods
|
||
ToxinHit.save_list_tsv(out_dir / "toxin_support.tsv", all_hits, order_columns)
|
||
StrainScores.save_list_tsv(out_dir / "strain_target_scores.tsv", all_strain_scores, order_columns)
|
||
StrainScores.save_list_json(out_dir / "strain_scores.json", all_strain_scores)
|
||
if species_columns and all_species_scores:
|
||
StrainSpeciesScores.save_list_tsv(out_dir / "strain_target_species_scores.tsv", all_species_scores, species_columns)
|
||
StrainSpeciesScores.save_list_json(out_dir / "strain_species_scores.json", all_species_scores)
|
||
|
||
print(f"Saved: {out_dir / 'toxin_support.tsv'}")
|
||
print(f"Saved: {out_dir / 'strain_target_scores.tsv'}")
|
||
print(f"Saved: {out_dir / 'strain_scores.json'}")
|
||
if species_columns and all_species_scores:
|
||
print(f"Saved: {out_dir / 'strain_target_species_scores.tsv'}")
|
||
print(f"Saved: {out_dir / 'strain_species_scores.json'}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|