187 lines
5.9 KiB
Python
187 lines
5.9 KiB
Python
#!/usr/bin/env python3
|
|
"""Augment a CSV with ECFP4 binary fingerprints and Tanimoto neighbor summaries."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import pathlib
|
|
from dataclasses import dataclass, field
|
|
from typing import Iterable, List, Optional, Sequence, Tuple
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
from rdkit import Chem, DataStructs
|
|
from rdkit.Chem import rdFingerprintGenerator
|
|
|
|
|
|
@dataclass
|
|
class ECFP4Generator:
|
|
"""Generate ECFP4 (Morgan radius 2) fingerprints as RDKit bit vectors."""
|
|
|
|
n_bits: int = 2048
|
|
radius: int = 2
|
|
include_chirality: bool = True
|
|
generator: rdFingerprintGenerator.MorganGenerator = field(init=False)
|
|
|
|
def __post_init__(self) -> None:
|
|
self.generator = rdFingerprintGenerator.GetMorganGenerator(
|
|
radius=self.radius,
|
|
fpSize=self.n_bits,
|
|
includeChirality=self.include_chirality,
|
|
)
|
|
|
|
def fingerprint(self, smiles: str) -> Optional[DataStructs.ExplicitBitVect]:
|
|
if not smiles:
|
|
return None
|
|
|
|
mol = Chem.MolFromSmiles(smiles)
|
|
if mol is None:
|
|
return None
|
|
|
|
try:
|
|
return self.generator.GetFingerprint(mol)
|
|
except Exception:
|
|
return None
|
|
|
|
def to_binary_string(self, fp: DataStructs.ExplicitBitVect) -> str:
|
|
arr = np.zeros((self.n_bits,), dtype=np.uint8)
|
|
DataStructs.ConvertToNumpyArray(fp, arr)
|
|
return ''.join(arr.astype(str))
|
|
|
|
|
|
def tanimoto_top_k(
|
|
fingerprints: Sequence[Optional[DataStructs.ExplicitBitVect]],
|
|
ids: Sequence[str],
|
|
top_k: int = 5,
|
|
) -> List[str]:
|
|
"""Return semicolon-delimited Tanimoto summaries for each fingerprint."""
|
|
valid_indices = [idx for idx, fp in enumerate(fingerprints) if fp is not None]
|
|
valid_fps = [fingerprints[idx] for idx in valid_indices]
|
|
|
|
index_lookup = {pos: original for pos, original in enumerate(valid_indices)}
|
|
summaries = [''] * len(fingerprints)
|
|
|
|
if not valid_fps:
|
|
return summaries
|
|
|
|
for pos, fp in enumerate(valid_fps):
|
|
sims = DataStructs.BulkTanimotoSimilarity(fp, valid_fps)
|
|
ranked: List[Tuple[int, float]] = []
|
|
for other_pos, score in enumerate(sims):
|
|
if other_pos == pos:
|
|
continue
|
|
ranked.append((other_pos, score))
|
|
|
|
ranked.sort(key=lambda item: item[1], reverse=True)
|
|
top_entries = []
|
|
for other_pos, score in ranked[:top_k]:
|
|
original_idx = index_lookup[other_pos]
|
|
top_entries.append(f"{ids[original_idx]}:{score:.4f}")
|
|
|
|
summaries[index_lookup[pos]] = ';'.join(top_entries)
|
|
|
|
return summaries
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
parser.add_argument("input_csv", type=pathlib.Path, help="Source CSV with SMILES")
|
|
parser.add_argument(
|
|
"--output",
|
|
type=pathlib.Path,
|
|
default=None,
|
|
help="Destination file (default: <input>_with_ecfp4.parquet)",
|
|
)
|
|
parser.add_argument(
|
|
"--smiles-column",
|
|
default="smiles",
|
|
help="Name of the column containing SMILES (default: smiles)",
|
|
)
|
|
parser.add_argument(
|
|
"--id-column",
|
|
default="generated_id",
|
|
help="Column used to label Tanimoto neighbors (default: generated_id)",
|
|
)
|
|
parser.add_argument(
|
|
"--top-k",
|
|
type=int,
|
|
default=5,
|
|
help="Number of nearest neighbors to report in Tanimoto summaries (default: 5)",
|
|
)
|
|
parser.add_argument(
|
|
"--format",
|
|
choices=("parquet", "csv", "auto"),
|
|
default="parquet",
|
|
help="Output format; defaults to parquet unless overridden or inferred from --output",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
|
|
if not args.input_csv.exists():
|
|
raise FileNotFoundError(f"Input file not found: {args.input_csv}")
|
|
|
|
def resolve_output_path() -> pathlib.Path:
|
|
if args.output is not None:
|
|
return args.output
|
|
|
|
suffix = ".parquet" if args.format in ("parquet", "auto") else ".csv"
|
|
return args.input_csv.with_name(f"{args.input_csv.stem}_with_ecfp4{suffix}")
|
|
|
|
def resolve_format(path: pathlib.Path) -> str:
|
|
suffix = path.suffix.lower()
|
|
if suffix in {".parquet", ".pq"}:
|
|
return "parquet"
|
|
if suffix == ".csv":
|
|
return "csv"
|
|
if args.format == "parquet":
|
|
return "parquet"
|
|
if args.format == "csv":
|
|
return "csv"
|
|
raise ValueError(
|
|
"无法根据输出文件推断格式,请为 --output 指定 .parquet/.csv 后缀或使用 --format",
|
|
)
|
|
|
|
output_path = resolve_output_path()
|
|
output_format = resolve_format(output_path)
|
|
|
|
if args.input_csv.resolve() == output_path.resolve():
|
|
raise ValueError("Output path must differ from input path to avoid overwriting input.")
|
|
|
|
df = pd.read_csv(args.input_csv)
|
|
if args.smiles_column not in df.columns:
|
|
raise ValueError(f"Column '{args.smiles_column}' not found in input data")
|
|
|
|
smiles_series = df[args.smiles_column].fillna('')
|
|
|
|
if args.id_column in df.columns:
|
|
ids = df[args.id_column].astype(str).tolist()
|
|
else:
|
|
ids = [f"D{idx:06d}" for idx in range(1, len(df) + 1)]
|
|
df[args.id_column] = ids
|
|
|
|
generator = ECFP4Generator()
|
|
fingerprints: List[Optional[DataStructs.ExplicitBitVect]] = []
|
|
binary_repr: List[str] = []
|
|
|
|
for smiles in smiles_series:
|
|
fp = generator.fingerprint(smiles)
|
|
fingerprints.append(fp)
|
|
binary_repr.append(generator.to_binary_string(fp) if fp is not None else '')
|
|
|
|
df['ecfp4_binary'] = binary_repr
|
|
df['tanimoto_top_neighbors'] = tanimoto_top_k(fingerprints, ids, top_k=args.top_k)
|
|
|
|
if output_format == "parquet":
|
|
df.to_parquet(output_path, index=False)
|
|
else:
|
|
df.to_csv(output_path, index=False)
|
|
|
|
print(f"Wrote augmented data with {len(df)} rows to {output_path} ({output_format})")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|