重构项目结构并更新README.md
1. 重构目录结构: - 创建src/visualization模块用于存放可视化相关功能 - 移动script/visualize_csv_comparison.py到src/visualization/comparison.py - 创建src/visualization/__init__.py导出主要函数 - 整理script目录,按功能分类存放脚本文件 2. 更新README.md: - 添加CSV文件比较可视化部分 - 提供Python API和命令行使用方法说明 - 描述功能特点和使用示例 3. 更新模块引用: - 修正comparison.py中的模块引用路径 - 更新命令行帮助信息中的使用示例
This commit is contained in:
186
script/data_processing/add_ecfp4_tanimoto.py
Normal file
186
script/data_processing/add_ecfp4_tanimoto.py
Normal file
@@ -0,0 +1,186 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Augment a CSV with ECFP4 binary fingerprints and Tanimoto neighbor summaries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import pathlib
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterable, List, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from rdkit import Chem, DataStructs
|
||||
from rdkit.Chem import rdFingerprintGenerator
|
||||
|
||||
|
||||
@dataclass
|
||||
class ECFP4Generator:
|
||||
"""Generate ECFP4 (Morgan radius 2) fingerprints as RDKit bit vectors."""
|
||||
|
||||
n_bits: int = 2048
|
||||
radius: int = 2
|
||||
include_chirality: bool = True
|
||||
generator: rdFingerprintGenerator.MorganGenerator = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.generator = rdFingerprintGenerator.GetMorganGenerator(
|
||||
radius=self.radius,
|
||||
fpSize=self.n_bits,
|
||||
includeChirality=self.include_chirality,
|
||||
)
|
||||
|
||||
def fingerprint(self, smiles: str) -> Optional[DataStructs.ExplicitBitVect]:
|
||||
if not smiles:
|
||||
return None
|
||||
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
if mol is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return self.generator.GetFingerprint(mol)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def to_binary_string(self, fp: DataStructs.ExplicitBitVect) -> str:
|
||||
arr = np.zeros((self.n_bits,), dtype=np.uint8)
|
||||
DataStructs.ConvertToNumpyArray(fp, arr)
|
||||
return ''.join(arr.astype(str))
|
||||
|
||||
|
||||
def tanimoto_top_k(
|
||||
fingerprints: Sequence[Optional[DataStructs.ExplicitBitVect]],
|
||||
ids: Sequence[str],
|
||||
top_k: int = 5,
|
||||
) -> List[str]:
|
||||
"""Return semicolon-delimited Tanimoto summaries for each fingerprint."""
|
||||
valid_indices = [idx for idx, fp in enumerate(fingerprints) if fp is not None]
|
||||
valid_fps = [fingerprints[idx] for idx in valid_indices]
|
||||
|
||||
index_lookup = {pos: original for pos, original in enumerate(valid_indices)}
|
||||
summaries = [''] * len(fingerprints)
|
||||
|
||||
if not valid_fps:
|
||||
return summaries
|
||||
|
||||
for pos, fp in enumerate(valid_fps):
|
||||
sims = DataStructs.BulkTanimotoSimilarity(fp, valid_fps)
|
||||
ranked: List[Tuple[int, float]] = []
|
||||
for other_pos, score in enumerate(sims):
|
||||
if other_pos == pos:
|
||||
continue
|
||||
ranked.append((other_pos, score))
|
||||
|
||||
ranked.sort(key=lambda item: item[1], reverse=True)
|
||||
top_entries = []
|
||||
for other_pos, score in ranked[:top_k]:
|
||||
original_idx = index_lookup[other_pos]
|
||||
top_entries.append(f"{ids[original_idx]}:{score:.4f}")
|
||||
|
||||
summaries[index_lookup[pos]] = ';'.join(top_entries)
|
||||
|
||||
return summaries
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("input_csv", type=pathlib.Path, help="Source CSV with SMILES")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=pathlib.Path,
|
||||
default=None,
|
||||
help="Destination file (default: <input>_with_ecfp4.parquet)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--smiles-column",
|
||||
default="smiles",
|
||||
help="Name of the column containing SMILES (default: smiles)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--id-column",
|
||||
default="generated_id",
|
||||
help="Column used to label Tanimoto neighbors (default: generated_id)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of nearest neighbors to report in Tanimoto summaries (default: 5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--format",
|
||||
choices=("parquet", "csv", "auto"),
|
||||
default="parquet",
|
||||
help="Output format; defaults to parquet unless overridden or inferred from --output",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
if not args.input_csv.exists():
|
||||
raise FileNotFoundError(f"Input file not found: {args.input_csv}")
|
||||
|
||||
def resolve_output_path() -> pathlib.Path:
|
||||
if args.output is not None:
|
||||
return args.output
|
||||
|
||||
suffix = ".parquet" if args.format in ("parquet", "auto") else ".csv"
|
||||
return args.input_csv.with_name(f"{args.input_csv.stem}_with_ecfp4{suffix}")
|
||||
|
||||
def resolve_format(path: pathlib.Path) -> str:
|
||||
suffix = path.suffix.lower()
|
||||
if suffix in {".parquet", ".pq"}:
|
||||
return "parquet"
|
||||
if suffix == ".csv":
|
||||
return "csv"
|
||||
if args.format == "parquet":
|
||||
return "parquet"
|
||||
if args.format == "csv":
|
||||
return "csv"
|
||||
raise ValueError(
|
||||
"无法根据输出文件推断格式,请为 --output 指定 .parquet/.csv 后缀或使用 --format",
|
||||
)
|
||||
|
||||
output_path = resolve_output_path()
|
||||
output_format = resolve_format(output_path)
|
||||
|
||||
if args.input_csv.resolve() == output_path.resolve():
|
||||
raise ValueError("Output path must differ from input path to avoid overwriting input.")
|
||||
|
||||
df = pd.read_csv(args.input_csv)
|
||||
if args.smiles_column not in df.columns:
|
||||
raise ValueError(f"Column '{args.smiles_column}' not found in input data")
|
||||
|
||||
smiles_series = df[args.smiles_column].fillna('')
|
||||
|
||||
if args.id_column in df.columns:
|
||||
ids = df[args.id_column].astype(str).tolist()
|
||||
else:
|
||||
ids = [f"D{idx:06d}" for idx in range(1, len(df) + 1)]
|
||||
df[args.id_column] = ids
|
||||
|
||||
generator = ECFP4Generator()
|
||||
fingerprints: List[Optional[DataStructs.ExplicitBitVect]] = []
|
||||
binary_repr: List[str] = []
|
||||
|
||||
for smiles in smiles_series:
|
||||
fp = generator.fingerprint(smiles)
|
||||
fingerprints.append(fp)
|
||||
binary_repr.append(generator.to_binary_string(fp) if fp is not None else '')
|
||||
|
||||
df['ecfp4_binary'] = binary_repr
|
||||
df['tanimoto_top_neighbors'] = tanimoto_top_k(fingerprints, ids, top_k=args.top_k)
|
||||
|
||||
if output_format == "parquet":
|
||||
df.to_parquet(output_path, index=False)
|
||||
else:
|
||||
df.to_csv(output_path, index=False)
|
||||
|
||||
print(f"Wrote augmented data with {len(df)} rows to {output_path} ({output_format})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user