#!/usr/bin/env python3 """Augment a CSV with ECFP4 binary fingerprints and Tanimoto neighbor summaries.""" from __future__ import annotations import argparse import pathlib from dataclasses import dataclass, field from typing import Iterable, List, Optional, Sequence, Tuple import numpy as np import pandas as pd from rdkit import Chem, DataStructs from rdkit.Chem import rdFingerprintGenerator @dataclass class ECFP4Generator: """Generate ECFP4 (Morgan radius 2) fingerprints as RDKit bit vectors.""" n_bits: int = 2048 radius: int = 2 include_chirality: bool = True generator: rdFingerprintGenerator.MorganGenerator = field(init=False) def __post_init__(self) -> None: self.generator = rdFingerprintGenerator.GetMorganGenerator( radius=self.radius, fpSize=self.n_bits, includeChirality=self.include_chirality, ) def fingerprint(self, smiles: str) -> Optional[DataStructs.ExplicitBitVect]: if not smiles: return None mol = Chem.MolFromSmiles(smiles) if mol is None: return None try: return self.generator.GetFingerprint(mol) except Exception: return None def to_binary_string(self, fp: DataStructs.ExplicitBitVect) -> str: arr = np.zeros((self.n_bits,), dtype=np.uint8) DataStructs.ConvertToNumpyArray(fp, arr) return ''.join(arr.astype(str)) def tanimoto_top_k( fingerprints: Sequence[Optional[DataStructs.ExplicitBitVect]], ids: Sequence[str], top_k: int = 5, ) -> List[str]: """Return semicolon-delimited Tanimoto summaries for each fingerprint.""" valid_indices = [idx for idx, fp in enumerate(fingerprints) if fp is not None] valid_fps = [fingerprints[idx] for idx in valid_indices] index_lookup = {pos: original for pos, original in enumerate(valid_indices)} summaries = [''] * len(fingerprints) if not valid_fps: return summaries for pos, fp in enumerate(valid_fps): sims = DataStructs.BulkTanimotoSimilarity(fp, valid_fps) ranked: List[Tuple[int, float]] = [] for other_pos, score in enumerate(sims): if other_pos == pos: continue ranked.append((other_pos, score)) ranked.sort(key=lambda item: item[1], reverse=True) top_entries = [] for other_pos, score in ranked[:top_k]: original_idx = index_lookup[other_pos] top_entries.append(f"{ids[original_idx]}:{score:.4f}") summaries[index_lookup[pos]] = ';'.join(top_entries) return summaries def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("input_csv", type=pathlib.Path, help="Source CSV with SMILES") parser.add_argument( "--output", type=pathlib.Path, default=None, help="Destination file (default: _with_ecfp4.parquet)", ) parser.add_argument( "--smiles-column", default="smiles", help="Name of the column containing SMILES (default: smiles)", ) parser.add_argument( "--id-column", default="generated_id", help="Column used to label Tanimoto neighbors (default: generated_id)", ) parser.add_argument( "--top-k", type=int, default=5, help="Number of nearest neighbors to report in Tanimoto summaries (default: 5)", ) parser.add_argument( "--format", choices=("parquet", "csv", "auto"), default="parquet", help="Output format; defaults to parquet unless overridden or inferred from --output", ) return parser.parse_args() def main() -> None: args = parse_args() if not args.input_csv.exists(): raise FileNotFoundError(f"Input file not found: {args.input_csv}") def resolve_output_path() -> pathlib.Path: if args.output is not None: return args.output suffix = ".parquet" if args.format in ("parquet", "auto") else ".csv" return args.input_csv.with_name(f"{args.input_csv.stem}_with_ecfp4{suffix}") def resolve_format(path: pathlib.Path) -> str: suffix = path.suffix.lower() if suffix in {".parquet", ".pq"}: return "parquet" if suffix == ".csv": return "csv" if args.format == "parquet": return "parquet" if args.format == "csv": return "csv" raise ValueError( "无法根据输出文件推断格式,请为 --output 指定 .parquet/.csv 后缀或使用 --format", ) output_path = resolve_output_path() output_format = resolve_format(output_path) if args.input_csv.resolve() == output_path.resolve(): raise ValueError("Output path must differ from input path to avoid overwriting input.") df = pd.read_csv(args.input_csv) if args.smiles_column not in df.columns: raise ValueError(f"Column '{args.smiles_column}' not found in input data") smiles_series = df[args.smiles_column].fillna('') if args.id_column in df.columns: ids = df[args.id_column].astype(str).tolist() else: ids = [f"D{idx:06d}" for idx in range(1, len(df) + 1)] df[args.id_column] = ids generator = ECFP4Generator() fingerprints: List[Optional[DataStructs.ExplicitBitVect]] = [] binary_repr: List[str] = [] for smiles in smiles_series: fp = generator.fingerprint(smiles) fingerprints.append(fp) binary_repr.append(generator.to_binary_string(fp) if fp is not None else '') df['ecfp4_binary'] = binary_repr df['tanimoto_top_neighbors'] = tanimoto_top_k(fingerprints, ids, top_k=args.top_k) if output_format == "parquet": df.to_parquet(output_path, index=False) else: df.to_csv(output_path, index=False) print(f"Wrote augmented data with {len(df)} rows to {output_path} ({output_format})") if __name__ == "__main__": main()