Files
embedding_atlas/script/add_macrocycle_columns.py
2025-10-23 16:21:52 +08:00

128 lines
3.9 KiB
Python

#!/usr/bin/env python3
"""Add generated IDs and macrolactone ring size annotations to a CSV file."""
import argparse
import pathlib
import warnings
from dataclasses import dataclass, field
from typing import Dict, Iterable, List, Optional, Set, Tuple
import pandas as pd
from rdkit import Chem
@dataclass
class MacrolactoneRingDetector:
"""Detect macrolactone rings in SMILES strings via SMARTS patterns."""
min_size: int = 12
max_size: int = 20
patterns: Dict[int, Optional[Chem.Mol]] = field(init=False, default_factory=dict)
def __post_init__(self) -> None:
self.patterns = {
size: Chem.MolFromSmarts(f"[r{size}]([#8][#6](=[#8]))")
for size in range(self.min_size, self.max_size + 1)
}
def ring_sizes(self, smiles: str) -> List[int]:
"""Return a sorted list of macrolactone ring sizes present in the SMILES."""
if not smiles:
return []
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return []
ring_atoms = mol.GetRingInfo().AtomRings()
if not ring_atoms:
return []
matched_rings: Set[Tuple[int, ...]] = set()
matched_sizes: Set[int] = set()
for size, query in self.patterns.items():
if query is None:
continue
for match in mol.GetSubstructMatches(query, uniquify=True):
ring = self._pick_ring(match, ring_atoms, size)
if ring and ring not in matched_rings:
matched_rings.add(ring)
matched_sizes.add(size)
sizes = sorted(matched_sizes)
if len(sizes) > 1:
warnings.warn(
"Multiple macrolactone ring sizes detected",
RuntimeWarning,
stacklevel=2,
)
print(f"Multiple ring sizes {sizes} for SMILES: {smiles}")
return sizes
@staticmethod
def _pick_ring(
match: Tuple[int, ...], rings: Iterable[Tuple[int, ...]], expected_size: int
) -> Optional[Tuple[int, ...]]:
ring_atom = match[0]
for ring in rings:
if len(ring) == expected_size and ring_atom in ring:
return tuple(sorted(ring))
return None
def add_columns(df: pd.DataFrame, detector: MacrolactoneRingDetector) -> pd.DataFrame:
result = df.copy()
result["generated_id"] = [f"D{index:06d}" for index in range(1, len(result) + 1)]
smiles_series = result["smiles"].fillna("") if "smiles" in result.columns else pd.Series(
[""] * len(result), index=result.index
)
def format_sizes(smiles: str) -> str:
sizes = detector.ring_sizes(smiles)
return ";".join(str(size) for size in sizes) if sizes else ""
result["macrocycle_ring_sizes"] = smiles_series.apply(format_sizes)
return result
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("input_csv", type=pathlib.Path, help="Path to the source CSV file")
parser.add_argument(
"--output",
type=pathlib.Path,
default=None,
help="Destination for the augmented CSV (default: <input>_with_macrocycles.csv)",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
if not args.input_csv.exists():
raise FileNotFoundError(f"Input file not found: {args.input_csv}")
output_path = args.output or args.input_csv.with_name(
f"{args.input_csv.stem}_with_macrocycles.csv"
)
if args.input_csv.resolve() == output_path.resolve():
raise ValueError("Output path must differ from input path to avoid overwriting.")
df = pd.read_csv(args.input_csv)
detector = MacrolactoneRingDetector()
augmented = add_columns(df, detector)
augmented.to_csv(output_path, index=False)
print(f"Wrote augmented data with {len(augmented)} rows to {output_path}")
if __name__ == "__main__":
main()