Files
bttoxin-pipeline/scripts/bttoxin_shoter.py
zly fe353fc0bc chore: 初始版本提交 - 简化架构 + 轮询改造
- 移除 Motia Streams 实时通信,改用 3 秒轮询
- 简化前端代码,移除冗余组件
- 简化后端架构,准备 FastAPI 重构
- 更新 pixi.toml 环境配置
- 保留 bttoxin_digger_v5_repro 作为参考文档

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-13 16:50:09 +08:00

706 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Bttoxin_Shoter v1
- Read BPPRC specificity CSV (positive-only) to build name/family -> target_order distributions
- Read BtToxin_Digger All_Toxins.txt to collect hits per strain
- Compute per-hit similarity weight and order contributions
- Combine contributions to strain-level potential activity scores per insect order
- Save per-hit and per-strain results (TSV + JSON) using dataclasses with save methods
Assumptions and notes are documented in docs/shotter_workflow.md
"""
from __future__ import annotations
import argparse
import json
import re
from dataclasses import dataclass, asdict, field
import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import pandas as pd
# -----------------------------
# Helpers for parsing families
# -----------------------------
FAMILY_PREFIXES = (
"Cry", "Cyt", "Vip", "Vpa", "Vpb", "Mpp", "Tpp", "Spp", "App",
"Mcf", "Mpf", "Pra", "Prb", "Txp", "Gpp", "Mtx", "Xpp",
)
FAMILY_RE = re.compile(r"(?i)(Cry|Cyt|Vip|Vpa|Vpb|Mpp|Tpp|Spp|App|Mcf|Mpf|Pra|Prb|Txp|Gpp|Mtx|Xpp)(\d+)([A-Za-z]*)")
def normalize_hit_id(hit_id: str) -> str:
"""Strip trailing markers like '-other', spaces, and keep core token.
Examples: 'Spp1Aa1' -> 'Spp1Aa1'; 'Bmp1-other' -> 'Bmp1'
"""
if not isinstance(hit_id, str):
return ""
x = hit_id.strip()
x = re.sub(r"-other$", "", x)
x = re.sub(r"\s+", "", x)
return x
def extract_family_keys(name: str) -> Tuple[Optional[str], Optional[str]]:
"""From a protein name like 'Cry1Ac1' or '5618-Cry1Ia' extract:
- family_key: e.g., 'Cry1'
- subfamily_key: e.g., 'Cry1A' (first subclass letter if present)
Returns (None, None) if no family prefix is found.
"""
if not isinstance(name, str) or not name:
return None, None
# Try direct match first
m = FAMILY_RE.search(name)
if not m:
return None, None
fam = m.group(1)
num = m.group(2)
sub = m.group(3) # may be ''
family_key = f"{fam}{num}"
subfamily_key = f"{fam}{num}{sub[:1]}" if sub else None
return family_key, subfamily_key
# ---------------------------------
# Specificity index from CSV (BPPRC)
# ---------------------------------
@dataclass
class SpecificityIndex:
# Distributions P(order | name) and P(order | family/subfamily)
name_to_orders: Dict[str, Dict[str, float]] = field(default_factory=dict)
subfam_to_orders: Dict[str, Dict[str, float]] = field(default_factory=dict)
fam_to_orders: Dict[str, Dict[str, float]] = field(default_factory=dict)
# Set of all observed target orders
all_orders: List[str] = field(default_factory=list)
# Distributions P(species | name) and P(species | family/subfamily)
name_to_species: Dict[str, Dict[str, float]] = field(default_factory=dict)
subfam_to_species: Dict[str, Dict[str, float]] = field(default_factory=dict)
fam_to_species: Dict[str, Dict[str, float]] = field(default_factory=dict)
# Set of all observed target species
all_species: List[str] = field(default_factory=list)
# Partner requirement heuristics at family level
partner_pairs: List[Tuple[str, str]] = field(default_factory=lambda: [
("Vip1", "Vip2"),
("Vpa", "Vpb"),
("BinA", "BinB"), # included for completeness if ever present
])
@staticmethod
def _potency_from_row(row: pd.Series) -> Optional[float]:
"""Map quantitative fields to a potency weight in [0,1]. Positive-only.
Rules:
- lc50: will be normalized later within unit-buckets; here return the numeric.
- percentage_mortality: map ranges to coarse weights.
- if both missing but activity is positive: default 0.55
"""
perc = str(row.get("percentage_mortality") or "").strip().lower()
lc50 = row.get("lc50")
if pd.notnull(lc50):
try:
return float(lc50) # raw, to be normalized later
except Exception:
pass
if perc:
# coarse mappings
if ">80%" in perc or "80-100%" in perc or "greater than 80%" in perc:
return 0.9
if "60-100%" in perc or "60-80%" in perc or "50-80%" in perc:
return 0.65
if "0-60%" in perc or "0-50%" in perc or "25%" in perc or "some" in perc or "stunting" in perc:
return 0.25
# default positive evidence
return 0.55
@classmethod
def from_csv(cls, csv_path: Path) -> "SpecificityIndex":
df = pd.read_csv(csv_path)
# Positive-only evidence
# many rows have activity 'Yes' and non_toxic empty; use only activity==Yes
df = df[df["activity"].astype(str).str.lower().eq("yes")] # keep positives
# Normalize units bucket for lc50 normalization
# unify ppm -> ug_per_g (diet context). Others stay in own unit bucket.
units = df["units"].astype(str).str.strip().str.lower()
ug_per_g_mask = units.isin(["µg/g", "ug/g", "μg/g", "ppm"]) # ppm ~ ug/g in diet
unit_bucket = units.where(~ug_per_g_mask, other="ug_per_g")
df = df.assign(_unit_bucket=unit_bucket)
# Compute potency: two cases: lc50 (to be inverted/normalized) vs categorical
df["_potency_raw"] = df.apply(cls._potency_from_row, axis=1)
# For rows with numeric raw potency (lc50), do within-bucket inverted quantile
for bucket, sub in df.groupby("_unit_bucket"):
# separate numeric vs categorical
is_num = pd.to_numeric(sub["_potency_raw"], errors="coerce").notnull()
num_idx = sub.index[is_num]
cat_idx = sub.index[~is_num]
# smaller LC50 => stronger potency
if len(num_idx) > 0:
vals = sub.loc[num_idx, "_potency_raw"].astype(float)
ranks = vals.rank(method="average", pct=True)
inv = 1.0 - ranks # 0..1
df.loc[num_idx, "_potency"] = inv.values
# categorical rows keep their 0-1 scores (already in _potency_raw)
if len(cat_idx) > 0:
df.loc[cat_idx, "_potency"] = sub.loc[cat_idx, "_potency_raw"].values
df["_potency"] = pd.to_numeric(df["_potency"], errors="coerce").fillna(0.55)
# Extract family/subfamily keys from name
# Name fields can be like 'Spp1Aa1', 'Cry1Ac1', or '5618-Cry1Ia', etc.
fam_keys: List[Optional[str]] = []
subfam_keys: List[Optional[str]] = []
for name in df["name"].astype(str).tolist():
fam, subfam = extract_family_keys(name)
fam_keys.append(fam)
subfam_keys.append(subfam)
df = df.assign(_family=fam_keys, _subfamily=subfam_keys)
# Aggregate per protein name (exact)
def agg_distribution(group: pd.DataFrame) -> Dict[str, float]:
# Sum potencies per target_order, then normalize
s = group.groupby("target_order")["_potency"].sum()
if s.sum() <= 0:
return {}
d = (s / s.sum()).to_dict()
return d
name_to_orders = (
df.groupby("name", dropna=False, group_keys=False)
.apply(agg_distribution, include_groups=False)
.to_dict()
)
# Aggregate per subfamily and family as back-off
subfam_to_orders = (
df.dropna(subset=["_subfamily"]).groupby("_subfamily", group_keys=False)
.apply(agg_distribution, include_groups=False)
.to_dict()
)
fam_to_orders = (
df.dropna(subset=["_family"]).groupby("_family", group_keys=False)
.apply(agg_distribution, include_groups=False)
.to_dict()
)
# Species distributions (optional if target_species present)
def agg_distribution_species(group: pd.DataFrame) -> Dict[str, float]:
s = group.groupby("target_species")["_potency"].sum()
if s.sum() <= 0:
return {}
d = (s / s.sum()).to_dict()
return d
name_to_species = (
df.groupby("name", dropna=False, group_keys=False)
.apply(agg_distribution_species, include_groups=False)
.to_dict()
)
subfam_to_species = (
df.dropna(subset=["_subfamily"]).groupby("_subfamily", group_keys=False)
.apply(agg_distribution_species, include_groups=False)
.to_dict()
)
fam_to_species = (
df.dropna(subset=["_family"]).groupby("_family", group_keys=False)
.apply(agg_distribution_species, include_groups=False)
.to_dict()
)
# collect all orders/species observed
all_orders = sorted({str(x) for x in df["target_order"].dropna().unique().tolist()})
all_species = sorted({str(x) for x in df["target_species"].dropna().unique().tolist()})
return cls(
name_to_orders=name_to_orders,
subfam_to_orders=subfam_to_orders,
fam_to_orders=fam_to_orders,
all_orders=all_orders,
name_to_species=name_to_species,
subfam_to_species=subfam_to_species,
fam_to_species=fam_to_species,
all_species=all_species,
)
def partner_needed_for_family(self, family_or_subfam: str) -> Optional[str]:
# Heuristic partner rule at family level
for a, b in self.partner_pairs:
if family_or_subfam.startswith(a):
return b
if family_or_subfam.startswith(b):
return a
return None
def orders_for_name_or_backoff(self, name_or_family: str) -> Dict[str, float]:
# Try exact name, then subfamily, then family
if name_or_family in self.name_to_orders:
return self.name_to_orders[name_or_family]
fam, subfam = extract_family_keys(name_or_family)
if subfam and subfam in self.subfam_to_orders:
return self.subfam_to_orders[subfam]
if fam and fam in self.fam_to_orders:
return self.fam_to_orders[fam]
return {}
def species_for_name_or_backoff(self, name_or_family: str) -> Dict[str, float]:
# Try exact name, then subfamily, then family
if name_or_family in self.name_to_species:
return self.name_to_species[name_or_family]
fam, subfam = extract_family_keys(name_or_family)
if subfam and subfam in self.subfam_to_species:
return self.subfam_to_species[subfam]
if fam and fam in self.fam_to_species:
return self.fam_to_species[fam]
return {}
# -----------------------------
# Dataclasses for outputs
# -----------------------------
@dataclass
class ToxinHit:
strain: str
protein_id: str
hit_id: str
identity: float # 0..1
aln_length: int
hit_length: int
coverage: float # 0..1
hmm: bool
family: str
name_key: str
partner_fulfilled: bool
weight: float
order_contribs: Dict[str, float]
top_order: str
top_score: float
def to_tsv_row(self, order_columns: List[str]) -> List[str]:
cols = [
self.strain,
self.protein_id,
self.hit_id,
f"{self.identity:.4f}",
str(self.aln_length),
str(self.hit_length),
f"{self.coverage:.4f}",
"YES" if self.hmm else "NO",
self.family or "unknown",
self.name_key or "",
"YES" if self.partner_fulfilled else "NO",
f"{self.weight:.4f}",
self.top_order or "",
f"{self.top_score:.6f}",
]
for o in order_columns:
v = self.order_contribs.get(o, 0.0)
cols.append(f"{v:.6f}")
return cols
@staticmethod
def save_list_tsv(tsv_path: Path, items: List["ToxinHit"], order_columns: List[str]) -> None:
tsv_path.parent.mkdir(parents=True, exist_ok=True)
header = [
"Strain", "Protein_id", "Hit_id", "Identity", "Aln_length", "Hit_length", "Coverage", "HMM",
"Family", "NameKey", "PartnerFulfilled", "Weight", "TopOrder", "TopScore",
] + order_columns
with tsv_path.open("w", encoding="utf-8") as f:
f.write("\t".join(header) + "\n")
for h in items:
f.write("\t".join(h.to_tsv_row(order_columns)) + "\n")
@dataclass
class StrainScores:
strain: str
order_scores: Dict[str, float]
top_order: str
top_score: float
def to_tsv_row(self, order_columns: List[str]) -> List[str]:
return [self.strain, self.top_order or "", f"{self.top_score:.6f}"] + [f"{self.order_scores.get(o, 0.0):.6f}" for o in order_columns]
@staticmethod
def save_list_tsv(tsv_path: Path, items: List["StrainScores"], order_columns: List[str]) -> None:
tsv_path.parent.mkdir(parents=True, exist_ok=True)
header = ["Strain", "TopOrder", "TopScore"] + order_columns
with tsv_path.open("w", encoding="utf-8") as f:
f.write("\t".join(header) + "\n")
for s in items:
f.write("\t".join(s.to_tsv_row(order_columns)) + "\n")
@staticmethod
def save_list_json(json_path: Path, items: List["StrainScores"]) -> None:
json_path.parent.mkdir(parents=True, exist_ok=True)
data = [asdict(s) for s in items]
with json_path.open("w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
@dataclass
class StrainSpeciesScores:
strain: str
species_scores: Dict[str, float]
top_species: str
top_species_score: float
def to_tsv_row(self, species_columns: List[str]) -> List[str]:
return [self.strain, self.top_species or "", f"{self.top_species_score:.6f}"] + [
f"{self.species_scores.get(sp, 0.0):.6f}" for sp in species_columns
]
@staticmethod
def save_list_tsv(tsv_path: Path, items: List["StrainSpeciesScores"], species_columns: List[str]) -> None:
tsv_path.parent.mkdir(parents=True, exist_ok=True)
header = ["Strain", "TopSpecies", "TopSpeciesScore"] + species_columns
with tsv_path.open("w", encoding="utf-8") as f:
f.write("\t".join(header) + "\n")
for s in items:
f.write("\t".join(s.to_tsv_row(species_columns)) + "\n")
@staticmethod
def save_list_json(json_path: Path, items: List["StrainSpeciesScores"]) -> None:
json_path.parent.mkdir(parents=True, exist_ok=True)
data = [asdict(s) for s in items]
with json_path.open("w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
# -----------------------------
# Parser and scoring
# -----------------------------
def compute_similarity_weight(identity: float, coverage: float, hmm: bool) -> float:
"""Weight formula (bounded [0,1]):
base(identity, coverage) × coverage, with HMM bonus up to 1.0
- If identity >= 0.78 and coverage >= 0.8: base = 1.0
- If 0.45 <= identity < 0.78: base = (identity-0.45)/(0.78-0.45)
- Else: base = 0
Final w = min(1.0, base * coverage + (0.1 if hmm else 0.0))
"""
base = 0.0
if identity >= 0.78 and coverage >= 0.8:
base = 1.0
elif 0.45 <= identity < 0.78:
base = (identity - 0.45) / (0.78 - 0.45)
base = max(0.0, min(1.0, base))
else:
base = 0.0
w = base * max(0.0, min(1.0, coverage))
if hmm:
w = min(1.0, w + 0.1)
return float(max(0.0, min(1.0, w)))
def parse_all_toxins(tsv_path: Path) -> pd.DataFrame:
df = pd.read_csv(tsv_path, sep="\t", dtype=str, engine="python")
# Coerce needed fields
for col in ["Identity", "Aln_length", "Hit_length"]:
df[col] = pd.to_numeric(df[col], errors="coerce")
df["Identity"] = df["Identity"].astype(float)
df["Aln_length"] = df["Aln_length"].fillna(0).astype(int)
df["Hit_length"] = df["Hit_length"].fillna(0).astype(int)
df["coverage"] = (df["Aln_length"].clip(lower=0) / df["Hit_length"].replace(0, pd.NA)).fillna(0.0)
df["identity01"] = (df["Identity"].fillna(0.0) / 100.0).clip(lower=0.0, upper=1.0)
df["HMM"] = df["HMM"].astype(str).str.upper().eq("YES")
# Normalize hit_id and family
df["Hit_id_norm"] = df["Hit_id"].astype(str).map(normalize_hit_id)
fams = []
for hit in df["Hit_id_norm"].tolist():
fam, _ = extract_family_keys(hit)
fams.append(fam or "unknown")
df["family_key"] = fams
return df
def partner_fulfilled_for_hit(hit_family: str, strain_df: pd.DataFrame, index: SpecificityIndex) -> bool:
partner = index.partner_needed_for_family(hit_family or "")
if not partner:
return True
# partner satisfied if any other hit in same strain startswith partner family and has weight >= 0.3
for _, row in strain_df.iterrows():
fam = row.get("family_key", "") or ""
if fam.startswith(partner):
w = compute_similarity_weight(
float(row.get("identity01", 0.0)), float(row.get("coverage", 0.0)), bool(row.get("HMM", False))
)
if w >= 0.3:
return True
return False
def score_strain(strain: str, sdf: pd.DataFrame, index: SpecificityIndex) -> Tuple[List[ToxinHit], StrainScores, Optional[StrainSpeciesScores]]:
# Include special buckets per requirements
order_set = sorted({*index.all_orders, "other", "unknown"})
per_hit: List[ToxinHit] = []
# Collect contributions per order by combining across hits
# We'll accumulate 1 - Π(1 - contrib)
one_minus = {o: 1.0 for o in order_set}
# Species accumulation if available
species_set = sorted(index.all_species) if index.all_species else []
sp_one_minus = {sp: 1.0 for sp in species_set}
for _, row in sdf.iterrows():
hit_id = str(row.get("Hit_id_norm", ""))
protein_id = str(row.get("Protein_id", ""))
identity01 = float(row.get("identity01", 0.0))
coverage = float(row.get("coverage", 0.0))
hmm = bool(row.get("HMM", False))
family = str(row.get("family_key", "")) or "unknown"
# Choose name backoff distribution
name_key = hit_id
order_dist = index.orders_for_name_or_backoff(name_key)
if not order_dist:
# backoff to family
order_dist = index.orders_for_name_or_backoff(family)
if not order_dist:
# Route by whether we have a parsable family
# - parsable family but no evidence -> 'other'
# - unparseable family (family == 'unknown') -> 'unknown'
if family != "unknown":
order_dist = {"other": 1.0}
else:
order_dist = {"unknown": 1.0}
# Compute weight with partner handling
w = compute_similarity_weight(identity01, coverage, hmm)
fulfilled = partner_fulfilled_for_hit(family, sdf, index)
if not fulfilled:
w *= 0.2 # partner penalty if required but not present
contribs: Dict[str, float] = {}
for o in order_set:
p = float(order_dist.get(o, 0.0))
c = w * p
if c > 0:
one_minus[o] *= (1.0 - c)
contribs[o] = c
# species contributions (only for known distributions; no other/unknown buckets)
if species_set:
sp_dist = index.species_for_name_or_backoff(name_key)
if not sp_dist:
sp_dist = index.species_for_name_or_backoff(family)
if sp_dist:
for sp in species_set:
psp = float(sp_dist.get(sp, 0.0))
csp = w * psp
if csp > 0:
sp_one_minus[sp] *= (1.0 - csp)
# top for this hit (allow other/unknown if that's all we have)
if contribs:
hit_top_order, hit_top_score = max(contribs.items(), key=lambda kv: kv[1])
else:
hit_top_order, hit_top_score = "", 0.0
per_hit.append(
ToxinHit(
strain=strain,
protein_id=protein_id,
hit_id=hit_id,
identity=identity01,
aln_length=int(row.get("Aln_length", 0)),
hit_length=int(row.get("Hit_length", 0)),
coverage=coverage,
hmm=hmm,
family=family,
name_key=name_key,
partner_fulfilled=fulfilled,
weight=w,
order_contribs=contribs,
top_order=hit_top_order,
top_score=float(hit_top_score),
)
)
order_scores = {o: (1.0 - one_minus[o]) for o in order_set}
# choose top order excluding unresolved buckets if possible
preferred = [o for o in index.all_orders if o in order_scores]
if not preferred:
preferred = [o for o in order_set if o not in ("other", "unknown")]
if preferred:
top_o = max(preferred, key=lambda o: order_scores.get(o, 0.0))
top_s = float(order_scores.get(top_o, 0.0))
else:
top_o, top_s = "", 0.0
strain_scores = StrainScores(strain=strain, order_scores=order_scores, top_order=top_o, top_score=top_s)
species_scores_obj: Optional[StrainSpeciesScores] = None
if species_set:
species_scores = {sp: (1.0 - sp_one_minus[sp]) for sp in species_set}
if species_scores:
top_sp = max(species_set, key=lambda sp: species_scores.get(sp, 0.0))
top_sp_score = float(species_scores.get(top_sp, 0.0))
else:
top_sp, top_sp_score = "", 0.0
species_scores_obj = StrainSpeciesScores(
strain=strain,
species_scores=species_scores,
top_species=top_sp,
top_species_score=top_sp_score,
)
return per_hit, strain_scores, species_scores_obj
# -----------------------------
# Saving utilities
# -----------------------------
def save_toxin_support(tsv_path: Path, hits: List[ToxinHit], order_columns: List[str]) -> None:
tsv_path.parent.mkdir(parents=True, exist_ok=True)
header = [
"Strain", "Protein_id", "Hit_id", "Identity", "Aln_length", "Hit_length", "Coverage", "HMM",
"Family", "NameKey", "PartnerFulfilled", "Weight", "TopOrder", "TopScore",
] + order_columns
with tsv_path.open("w", encoding="utf-8") as f:
f.write("\t".join(header) + "\n")
for h in hits:
f.write("\t".join(h.to_tsv_row(order_columns)) + "\n")
def save_strain_targets(tsv_path: Path, strains: List[StrainScores], order_columns: List[str]) -> None:
tsv_path.parent.mkdir(parents=True, exist_ok=True)
header = ["Strain", "TopOrder", "TopScore"] + order_columns
with tsv_path.open("w", encoding="utf-8") as f:
f.write("\t".join(header) + "\n")
for s in strains:
f.write("\t".join(s.to_tsv_row(order_columns)) + "\n")
def save_strain_json(json_path: Path, strains: List[StrainScores]) -> None:
json_path.parent.mkdir(parents=True, exist_ok=True)
data = [asdict(s) for s in strains]
with json_path.open("w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
# -----------------------------
# CLI
# -----------------------------
def _try_writable_dir(p: Path) -> bool:
try:
p.mkdir(parents=True, exist_ok=True)
except Exception:
return False
try:
probe = p / ".__shotter_write_test__"
with probe.open("w", encoding="utf-8") as f:
f.write("ok")
probe.unlink(missing_ok=True)
return True
except Exception:
return False
def resolve_output_dir(preferred: Path) -> Path:
# 1) try preferred
if _try_writable_dir(preferred):
return preferred
# 2) try a subfolder 'Shotter' under preferred
sub = preferred / "Shotter"
if _try_writable_dir(sub):
print(f"[Shotter] Output dir not writable: {preferred}. Falling back to: {sub}")
return sub
# 3) fallback to cwd/shotter_outputs
alt = Path.cwd() / "shotter_outputs"
if _try_writable_dir(alt):
print(f"[Shotter] Output dir not writable: {preferred}. Falling back to: {alt}")
return alt
# 4) last resort: preferred (will likely fail but we return it)
print(f"[Shotter] WARNING: could not find writable output directory. Using {preferred} (may fail).")
return preferred
def main():
ap = argparse.ArgumentParser(description="Bttoxin_Shoter: infer target orders from BtToxin_Digger outputs")
ap.add_argument("--toxicity_csv", type=Path, default=Path("Data/toxicity-data.csv"))
ap.add_argument("--all_toxins", type=Path, default=Path("tests/output/Results/Toxins/All_Toxins.txt"))
ap.add_argument("--output_dir", type=Path, default=Path("tests/output/Results/Toxins"))
# Filtering and thresholds
ap.add_argument("--allow_unknown_families", dest="allow_unknown_families", action="store_true")
ap.add_argument("--disallow_unknown_families", dest="allow_unknown_families", action="store_false")
ap.set_defaults(allow_unknown_families=True)
ap.add_argument("--require_index_hit", action="store_true", default=False,
help="Keep only hits that map to a known name/subfamily/family in the specificity index")
ap.add_argument("--min_identity", type=float, default=0.0, help="Minimum identity (0-1) to keep a hit")
ap.add_argument("--min_coverage", type=float, default=0.0, help="Minimum coverage (0-1) to keep a hit")
args = ap.parse_args()
index = SpecificityIndex.from_csv(args.toxicity_csv)
df = parse_all_toxins(args.all_toxins)
# thresholds
if args.min_identity > 0:
df = df[df["identity01"].astype(float) >= float(args.min_identity)]
if args.min_coverage > 0:
df = df[df["coverage"].astype(float) >= float(args.min_coverage)]
# unknown families handling
if not args.allow_unknown_families:
df = df[df["family_key"].astype(str) != "unknown"]
# require index hit mapping
if args.require_index_hit:
def _has_index_orders(row) -> bool:
name_key = str(row.get("Hit_id_norm", ""))
fam = str(row.get("family_key", ""))
d = index.orders_for_name_or_backoff(name_key)
if not d:
d = index.orders_for_name_or_backoff(fam)
return bool(d)
df = df[df.apply(_has_index_orders, axis=1)]
# Handle empty DataFrame - preserve columns and create empty outputs
if df.shape[0] == 0:
print("[Shotter] No hits passed filters, creating empty output files")
strains: List[str] = []
else:
strains = sorted(df["Strain"].astype(str).unique().tolist())
all_hits: List[ToxinHit] = []
all_strain_scores: List[StrainScores] = []
all_species_scores: List[StrainSpeciesScores] = []
for strain in strains:
sdf = df[df["Strain"].astype(str).eq(strain)].copy()
per_hit, sscore, sspecies = score_strain(strain, sdf, index)
all_hits.extend(per_hit)
all_strain_scores.append(sscore)
if sspecies is not None:
all_species_scores.append(sspecies)
# Always include the special buckets in outputs
order_columns = sorted({*index.all_orders, "other", "unknown"}) or ["unknown"]
species_columns = sorted(index.all_species)
# Resolve a writable output directory (avoid PermissionError on Docker-owned dirs)
out_dir = resolve_output_dir(args.output_dir)
# Save via dataclass methods
ToxinHit.save_list_tsv(out_dir / "toxin_support.tsv", all_hits, order_columns)
StrainScores.save_list_tsv(out_dir / "strain_target_scores.tsv", all_strain_scores, order_columns)
StrainScores.save_list_json(out_dir / "strain_scores.json", all_strain_scores)
if species_columns and all_species_scores:
StrainSpeciesScores.save_list_tsv(out_dir / "strain_target_species_scores.tsv", all_species_scores, species_columns)
StrainSpeciesScores.save_list_json(out_dir / "strain_species_scores.json", all_species_scores)
print(f"Saved: {out_dir / 'toxin_support.tsv'}")
print(f"Saved: {out_dir / 'strain_target_scores.tsv'}")
print(f"Saved: {out_dir / 'strain_scores.json'}")
if species_columns and all_species_scores:
print(f"Saved: {out_dir / 'strain_target_species_scores.tsv'}")
print(f"Saved: {out_dir / 'strain_species_scores.json'}")
if __name__ == "__main__":
main()