#!/usr/bin/env python3 """Add generated IDs and macrolactone ring size annotations to a CSV file.""" import argparse import pathlib import warnings from dataclasses import dataclass, field from typing import Dict, Iterable, List, Optional, Set, Tuple import pandas as pd from rdkit import Chem @dataclass class MacrolactoneRingDetector: """Detect macrolactone rings in SMILES strings via SMARTS patterns.""" min_size: int = 12 max_size: int = 20 patterns: Dict[int, Optional[Chem.Mol]] = field(init=False, default_factory=dict) def __post_init__(self) -> None: self.patterns = { size: Chem.MolFromSmarts(f"[r{size}]([#8][#6](=[#8]))") for size in range(self.min_size, self.max_size + 1) } def ring_sizes(self, smiles: str) -> List[int]: """Return a sorted list of macrolactone ring sizes present in the SMILES.""" if not smiles: return [] mol = Chem.MolFromSmiles(smiles) if mol is None: return [] ring_atoms = mol.GetRingInfo().AtomRings() if not ring_atoms: return [] matched_rings: Set[Tuple[int, ...]] = set() matched_sizes: Set[int] = set() for size, query in self.patterns.items(): if query is None: continue for match in mol.GetSubstructMatches(query, uniquify=True): ring = self._pick_ring(match, ring_atoms, size) if ring and ring not in matched_rings: matched_rings.add(ring) matched_sizes.add(size) sizes = sorted(matched_sizes) if len(sizes) > 1: warnings.warn( "Multiple macrolactone ring sizes detected", RuntimeWarning, stacklevel=2, ) print(f"Multiple ring sizes {sizes} for SMILES: {smiles}") return sizes @staticmethod def _pick_ring( match: Tuple[int, ...], rings: Iterable[Tuple[int, ...]], expected_size: int ) -> Optional[Tuple[int, ...]]: ring_atom = match[0] for ring in rings: if len(ring) == expected_size and ring_atom in ring: return tuple(sorted(ring)) return None def add_columns(df: pd.DataFrame, detector: MacrolactoneRingDetector) -> pd.DataFrame: result = df.copy() result["generated_id"] = [f"D{index:06d}" for index in range(1, len(result) + 1)] smiles_series = result["smiles"].fillna("") if "smiles" in result.columns else pd.Series( [""] * len(result), index=result.index ) def format_sizes(smiles: str) -> str: sizes = detector.ring_sizes(smiles) return ";".join(str(size) for size in sizes) if sizes else "" result["macrocycle_ring_sizes"] = smiles_series.apply(format_sizes) return result def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("input_csv", type=pathlib.Path, help="Path to the source CSV file") parser.add_argument( "--output", type=pathlib.Path, default=None, help="Destination for the augmented CSV (default: _with_macrocycles.csv)", ) return parser.parse_args() def main() -> None: args = parse_args() if not args.input_csv.exists(): raise FileNotFoundError(f"Input file not found: {args.input_csv}") output_path = args.output or args.input_csv.with_name( f"{args.input_csv.stem}_with_macrocycles.csv" ) if args.input_csv.resolve() == output_path.resolve(): raise ValueError("Output path must differ from input path to avoid overwriting.") df = pd.read_csv(args.input_csv) detector = MacrolactoneRingDetector() augmented = add_columns(df, detector) augmented.to_csv(output_path, index=False) print(f"Wrote augmented data with {len(augmented)} rows to {output_path}") if __name__ == "__main__": main()