Files
embedding_atlas/script/data_processing/add_ecfp4_tanimoto.py
lingyuzeng bbf1746046 重构项目结构并更新README.md
1. 重构目录结构:
   - 创建src/visualization模块用于存放可视化相关功能
   - 移动script/visualize_csv_comparison.py到src/visualization/comparison.py
   - 创建src/visualization/__init__.py导出主要函数
   - 整理script目录,按功能分类存放脚本文件

2. 更新README.md:
   - 添加CSV文件比较可视化部分
   - 提供Python API和命令行使用方法说明
   - 描述功能特点和使用示例

3. 更新模块引用:
   - 修正comparison.py中的模块引用路径
   - 更新命令行帮助信息中的使用示例
2025-10-23 17:55:36 +08:00

187 lines
5.9 KiB
Python

#!/usr/bin/env python3
"""Augment a CSV with ECFP4 binary fingerprints and Tanimoto neighbor summaries."""
from __future__ import annotations
import argparse
import pathlib
from dataclasses import dataclass, field
from typing import Iterable, List, Optional, Sequence, Tuple
import numpy as np
import pandas as pd
from rdkit import Chem, DataStructs
from rdkit.Chem import rdFingerprintGenerator
@dataclass
class ECFP4Generator:
"""Generate ECFP4 (Morgan radius 2) fingerprints as RDKit bit vectors."""
n_bits: int = 2048
radius: int = 2
include_chirality: bool = True
generator: rdFingerprintGenerator.MorganGenerator = field(init=False)
def __post_init__(self) -> None:
self.generator = rdFingerprintGenerator.GetMorganGenerator(
radius=self.radius,
fpSize=self.n_bits,
includeChirality=self.include_chirality,
)
def fingerprint(self, smiles: str) -> Optional[DataStructs.ExplicitBitVect]:
if not smiles:
return None
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
try:
return self.generator.GetFingerprint(mol)
except Exception:
return None
def to_binary_string(self, fp: DataStructs.ExplicitBitVect) -> str:
arr = np.zeros((self.n_bits,), dtype=np.uint8)
DataStructs.ConvertToNumpyArray(fp, arr)
return ''.join(arr.astype(str))
def tanimoto_top_k(
fingerprints: Sequence[Optional[DataStructs.ExplicitBitVect]],
ids: Sequence[str],
top_k: int = 5,
) -> List[str]:
"""Return semicolon-delimited Tanimoto summaries for each fingerprint."""
valid_indices = [idx for idx, fp in enumerate(fingerprints) if fp is not None]
valid_fps = [fingerprints[idx] for idx in valid_indices]
index_lookup = {pos: original for pos, original in enumerate(valid_indices)}
summaries = [''] * len(fingerprints)
if not valid_fps:
return summaries
for pos, fp in enumerate(valid_fps):
sims = DataStructs.BulkTanimotoSimilarity(fp, valid_fps)
ranked: List[Tuple[int, float]] = []
for other_pos, score in enumerate(sims):
if other_pos == pos:
continue
ranked.append((other_pos, score))
ranked.sort(key=lambda item: item[1], reverse=True)
top_entries = []
for other_pos, score in ranked[:top_k]:
original_idx = index_lookup[other_pos]
top_entries.append(f"{ids[original_idx]}:{score:.4f}")
summaries[index_lookup[pos]] = ';'.join(top_entries)
return summaries
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("input_csv", type=pathlib.Path, help="Source CSV with SMILES")
parser.add_argument(
"--output",
type=pathlib.Path,
default=None,
help="Destination file (default: <input>_with_ecfp4.parquet)",
)
parser.add_argument(
"--smiles-column",
default="smiles",
help="Name of the column containing SMILES (default: smiles)",
)
parser.add_argument(
"--id-column",
default="generated_id",
help="Column used to label Tanimoto neighbors (default: generated_id)",
)
parser.add_argument(
"--top-k",
type=int,
default=5,
help="Number of nearest neighbors to report in Tanimoto summaries (default: 5)",
)
parser.add_argument(
"--format",
choices=("parquet", "csv", "auto"),
default="parquet",
help="Output format; defaults to parquet unless overridden or inferred from --output",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
if not args.input_csv.exists():
raise FileNotFoundError(f"Input file not found: {args.input_csv}")
def resolve_output_path() -> pathlib.Path:
if args.output is not None:
return args.output
suffix = ".parquet" if args.format in ("parquet", "auto") else ".csv"
return args.input_csv.with_name(f"{args.input_csv.stem}_with_ecfp4{suffix}")
def resolve_format(path: pathlib.Path) -> str:
suffix = path.suffix.lower()
if suffix in {".parquet", ".pq"}:
return "parquet"
if suffix == ".csv":
return "csv"
if args.format == "parquet":
return "parquet"
if args.format == "csv":
return "csv"
raise ValueError(
"无法根据输出文件推断格式,请为 --output 指定 .parquet/.csv 后缀或使用 --format",
)
output_path = resolve_output_path()
output_format = resolve_format(output_path)
if args.input_csv.resolve() == output_path.resolve():
raise ValueError("Output path must differ from input path to avoid overwriting input.")
df = pd.read_csv(args.input_csv)
if args.smiles_column not in df.columns:
raise ValueError(f"Column '{args.smiles_column}' not found in input data")
smiles_series = df[args.smiles_column].fillna('')
if args.id_column in df.columns:
ids = df[args.id_column].astype(str).tolist()
else:
ids = [f"D{idx:06d}" for idx in range(1, len(df) + 1)]
df[args.id_column] = ids
generator = ECFP4Generator()
fingerprints: List[Optional[DataStructs.ExplicitBitVect]] = []
binary_repr: List[str] = []
for smiles in smiles_series:
fp = generator.fingerprint(smiles)
fingerprints.append(fp)
binary_repr.append(generator.to_binary_string(fp) if fp is not None else '')
df['ecfp4_binary'] = binary_repr
df['tanimoto_top_neighbors'] = tanimoto_top_k(fingerprints, ids, top_k=args.top_k)
if output_format == "parquet":
df.to_parquet(output_path, index=False)
else:
df.to_csv(output_path, index=False)
print(f"Wrote augmented data with {len(df)} rows to {output_path} ({output_format})")
if __name__ == "__main__":
main()