update
This commit is contained in:
186
script/add_ecfp4_tanimoto.py
Normal file
186
script/add_ecfp4_tanimoto.py
Normal file
@@ -0,0 +1,186 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Augment a CSV with ECFP4 binary fingerprints and Tanimoto neighbor summaries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import pathlib
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterable, List, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from rdkit import Chem, DataStructs
|
||||
from rdkit.Chem import rdFingerprintGenerator
|
||||
|
||||
|
||||
@dataclass
|
||||
class ECFP4Generator:
|
||||
"""Generate ECFP4 (Morgan radius 2) fingerprints as RDKit bit vectors."""
|
||||
|
||||
n_bits: int = 2048
|
||||
radius: int = 2
|
||||
include_chirality: bool = True
|
||||
generator: rdFingerprintGenerator.MorganGenerator = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.generator = rdFingerprintGenerator.GetMorganGenerator(
|
||||
radius=self.radius,
|
||||
fpSize=self.n_bits,
|
||||
includeChirality=self.include_chirality,
|
||||
)
|
||||
|
||||
def fingerprint(self, smiles: str) -> Optional[DataStructs.ExplicitBitVect]:
|
||||
if not smiles:
|
||||
return None
|
||||
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
if mol is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return self.generator.GetFingerprint(mol)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def to_binary_string(self, fp: DataStructs.ExplicitBitVect) -> str:
|
||||
arr = np.zeros((self.n_bits,), dtype=np.uint8)
|
||||
DataStructs.ConvertToNumpyArray(fp, arr)
|
||||
return ''.join(arr.astype(str))
|
||||
|
||||
|
||||
def tanimoto_top_k(
|
||||
fingerprints: Sequence[Optional[DataStructs.ExplicitBitVect]],
|
||||
ids: Sequence[str],
|
||||
top_k: int = 5,
|
||||
) -> List[str]:
|
||||
"""Return semicolon-delimited Tanimoto summaries for each fingerprint."""
|
||||
valid_indices = [idx for idx, fp in enumerate(fingerprints) if fp is not None]
|
||||
valid_fps = [fingerprints[idx] for idx in valid_indices]
|
||||
|
||||
index_lookup = {pos: original for pos, original in enumerate(valid_indices)}
|
||||
summaries = [''] * len(fingerprints)
|
||||
|
||||
if not valid_fps:
|
||||
return summaries
|
||||
|
||||
for pos, fp in enumerate(valid_fps):
|
||||
sims = DataStructs.BulkTanimotoSimilarity(fp, valid_fps)
|
||||
ranked: List[Tuple[int, float]] = []
|
||||
for other_pos, score in enumerate(sims):
|
||||
if other_pos == pos:
|
||||
continue
|
||||
ranked.append((other_pos, score))
|
||||
|
||||
ranked.sort(key=lambda item: item[1], reverse=True)
|
||||
top_entries = []
|
||||
for other_pos, score in ranked[:top_k]:
|
||||
original_idx = index_lookup[other_pos]
|
||||
top_entries.append(f"{ids[original_idx]}:{score:.4f}")
|
||||
|
||||
summaries[index_lookup[pos]] = ';'.join(top_entries)
|
||||
|
||||
return summaries
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("input_csv", type=pathlib.Path, help="Source CSV with SMILES")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=pathlib.Path,
|
||||
default=None,
|
||||
help="Destination file (default: <input>_with_ecfp4.parquet)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--smiles-column",
|
||||
default="smiles",
|
||||
help="Name of the column containing SMILES (default: smiles)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--id-column",
|
||||
default="generated_id",
|
||||
help="Column used to label Tanimoto neighbors (default: generated_id)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of nearest neighbors to report in Tanimoto summaries (default: 5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--format",
|
||||
choices=("parquet", "csv", "auto"),
|
||||
default="parquet",
|
||||
help="Output format; defaults to parquet unless overridden or inferred from --output",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
if not args.input_csv.exists():
|
||||
raise FileNotFoundError(f"Input file not found: {args.input_csv}")
|
||||
|
||||
def resolve_output_path() -> pathlib.Path:
|
||||
if args.output is not None:
|
||||
return args.output
|
||||
|
||||
suffix = ".parquet" if args.format in ("parquet", "auto") else ".csv"
|
||||
return args.input_csv.with_name(f"{args.input_csv.stem}_with_ecfp4{suffix}")
|
||||
|
||||
def resolve_format(path: pathlib.Path) -> str:
|
||||
suffix = path.suffix.lower()
|
||||
if suffix in {".parquet", ".pq"}:
|
||||
return "parquet"
|
||||
if suffix == ".csv":
|
||||
return "csv"
|
||||
if args.format == "parquet":
|
||||
return "parquet"
|
||||
if args.format == "csv":
|
||||
return "csv"
|
||||
raise ValueError(
|
||||
"无法根据输出文件推断格式,请为 --output 指定 .parquet/.csv 后缀或使用 --format",
|
||||
)
|
||||
|
||||
output_path = resolve_output_path()
|
||||
output_format = resolve_format(output_path)
|
||||
|
||||
if args.input_csv.resolve() == output_path.resolve():
|
||||
raise ValueError("Output path must differ from input path to avoid overwriting input.")
|
||||
|
||||
df = pd.read_csv(args.input_csv)
|
||||
if args.smiles_column not in df.columns:
|
||||
raise ValueError(f"Column '{args.smiles_column}' not found in input data")
|
||||
|
||||
smiles_series = df[args.smiles_column].fillna('')
|
||||
|
||||
if args.id_column in df.columns:
|
||||
ids = df[args.id_column].astype(str).tolist()
|
||||
else:
|
||||
ids = [f"D{idx:06d}" for idx in range(1, len(df) + 1)]
|
||||
df[args.id_column] = ids
|
||||
|
||||
generator = ECFP4Generator()
|
||||
fingerprints: List[Optional[DataStructs.ExplicitBitVect]] = []
|
||||
binary_repr: List[str] = []
|
||||
|
||||
for smiles in smiles_series:
|
||||
fp = generator.fingerprint(smiles)
|
||||
fingerprints.append(fp)
|
||||
binary_repr.append(generator.to_binary_string(fp) if fp is not None else '')
|
||||
|
||||
df['ecfp4_binary'] = binary_repr
|
||||
df['tanimoto_top_neighbors'] = tanimoto_top_k(fingerprints, ids, top_k=args.top_k)
|
||||
|
||||
if output_format == "parquet":
|
||||
df.to_parquet(output_path, index=False)
|
||||
else:
|
||||
df.to_csv(output_path, index=False)
|
||||
|
||||
print(f"Wrote augmented data with {len(df)} rows to {output_path} ({output_format})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
127
script/add_macrocycle_columns.py
Normal file
127
script/add_macrocycle_columns.py
Normal file
@@ -0,0 +1,127 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Add generated IDs and macrolactone ring size annotations to a CSV file."""
|
||||
|
||||
import argparse
|
||||
import pathlib
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import pandas as pd
|
||||
from rdkit import Chem
|
||||
|
||||
|
||||
@dataclass
|
||||
class MacrolactoneRingDetector:
|
||||
"""Detect macrolactone rings in SMILES strings via SMARTS patterns."""
|
||||
|
||||
min_size: int = 12
|
||||
max_size: int = 20
|
||||
patterns: Dict[int, Optional[Chem.Mol]] = field(init=False, default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.patterns = {
|
||||
size: Chem.MolFromSmarts(f"[r{size}]([#8][#6](=[#8]))")
|
||||
for size in range(self.min_size, self.max_size + 1)
|
||||
}
|
||||
|
||||
def ring_sizes(self, smiles: str) -> List[int]:
|
||||
"""Return a sorted list of macrolactone ring sizes present in the SMILES."""
|
||||
if not smiles:
|
||||
return []
|
||||
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
if mol is None:
|
||||
return []
|
||||
|
||||
ring_atoms = mol.GetRingInfo().AtomRings()
|
||||
if not ring_atoms:
|
||||
return []
|
||||
|
||||
matched_rings: Set[Tuple[int, ...]] = set()
|
||||
matched_sizes: Set[int] = set()
|
||||
|
||||
for size, query in self.patterns.items():
|
||||
if query is None:
|
||||
continue
|
||||
|
||||
for match in mol.GetSubstructMatches(query, uniquify=True):
|
||||
ring = self._pick_ring(match, ring_atoms, size)
|
||||
if ring and ring not in matched_rings:
|
||||
matched_rings.add(ring)
|
||||
matched_sizes.add(size)
|
||||
|
||||
sizes = sorted(matched_sizes)
|
||||
|
||||
if len(sizes) > 1:
|
||||
warnings.warn(
|
||||
"Multiple macrolactone ring sizes detected",
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
print(f"Multiple ring sizes {sizes} for SMILES: {smiles}")
|
||||
|
||||
return sizes
|
||||
|
||||
@staticmethod
|
||||
def _pick_ring(
|
||||
match: Tuple[int, ...], rings: Iterable[Tuple[int, ...]], expected_size: int
|
||||
) -> Optional[Tuple[int, ...]]:
|
||||
ring_atom = match[0]
|
||||
for ring in rings:
|
||||
if len(ring) == expected_size and ring_atom in ring:
|
||||
return tuple(sorted(ring))
|
||||
return None
|
||||
|
||||
|
||||
def add_columns(df: pd.DataFrame, detector: MacrolactoneRingDetector) -> pd.DataFrame:
|
||||
result = df.copy()
|
||||
result["generated_id"] = [f"D{index:06d}" for index in range(1, len(result) + 1)]
|
||||
|
||||
smiles_series = result["smiles"].fillna("") if "smiles" in result.columns else pd.Series(
|
||||
[""] * len(result), index=result.index
|
||||
)
|
||||
|
||||
def format_sizes(smiles: str) -> str:
|
||||
sizes = detector.ring_sizes(smiles)
|
||||
return ";".join(str(size) for size in sizes) if sizes else ""
|
||||
|
||||
result["macrocycle_ring_sizes"] = smiles_series.apply(format_sizes)
|
||||
return result
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("input_csv", type=pathlib.Path, help="Path to the source CSV file")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=pathlib.Path,
|
||||
default=None,
|
||||
help="Destination for the augmented CSV (default: <input>_with_macrocycles.csv)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
if not args.input_csv.exists():
|
||||
raise FileNotFoundError(f"Input file not found: {args.input_csv}")
|
||||
|
||||
output_path = args.output or args.input_csv.with_name(
|
||||
f"{args.input_csv.stem}_with_macrocycles.csv"
|
||||
)
|
||||
|
||||
if args.input_csv.resolve() == output_path.resolve():
|
||||
raise ValueError("Output path must differ from input path to avoid overwriting.")
|
||||
|
||||
df = pd.read_csv(args.input_csv)
|
||||
detector = MacrolactoneRingDetector()
|
||||
augmented = add_columns(df, detector)
|
||||
augmented.to_csv(output_path, index=False)
|
||||
|
||||
print(f"Wrote augmented data with {len(augmented)} rows to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
439
script/ecfp4_umap_embedding_optimized.py
Normal file
439
script/ecfp4_umap_embedding_optimized.py
Normal file
@@ -0,0 +1,439 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Optimized ECFP4 Fingerprinting with UMAP Visualization for Macrolactone Molecules
|
||||
|
||||
This script processes SMILES data to:
|
||||
1. Generate ECFP4 fingerprints using RDKit
|
||||
2. Detect ring numbers in macrolactone molecules using SMARTS patterns
|
||||
3. Generate unique IDs for molecules without existing IDs
|
||||
4. Perform UMAP dimensionality reduction with Tanimoto distance
|
||||
5. Prepare data for embedding-atlas visualization
|
||||
|
||||
Optimized for large datasets with progress tracking and memory efficiency.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import subprocess
|
||||
from typing import Optional, List
|
||||
|
||||
# RDKit imports
|
||||
from rdkit import Chem
|
||||
from rdkit.Chem import rdMolDescriptors, DataStructs
|
||||
from rdkit.Chem.MolStandardize import rdMolStandardize
|
||||
|
||||
# Data processing
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
# UMAP and visualization
|
||||
import umap
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Suppress warnings
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
# Progress bar
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
HAS_TQDM = True
|
||||
except ImportError:
|
||||
HAS_TQDM = False
|
||||
|
||||
class MacrolactoneProcessor:
|
||||
"""Process macrolactone molecules for embedding visualization."""
|
||||
|
||||
def __init__(self, n_bits: int = 2048, radius: int = 2, chirality: bool = True):
|
||||
"""
|
||||
Initialize processor with ECFP4 parameters.
|
||||
|
||||
Args:
|
||||
n_bits: Number of fingerprint bits (default: 2048)
|
||||
radius: Morgan fingerprint radius (default: 2 for ECFP4)
|
||||
chirality: Include chirality information (default: True)
|
||||
"""
|
||||
self.n_bits = n_bits
|
||||
self.radius = radius
|
||||
self.chirality = chirality
|
||||
|
||||
# Standardizer for molecule preprocessing
|
||||
self.standardizer = rdMolStandardize.MetalDisconnector()
|
||||
|
||||
# SMARTS patterns for different ring sizes (12-20 membered rings)
|
||||
self.ring_smarts = {
|
||||
12: '[r12][#8][#6](=[#8])', # 12-membered ring with lactone
|
||||
13: '[r13][#8][#6](=[#8])', # 13-membered ring with lactone
|
||||
14: '[r14][#8][#6](=[#8])', # 14-membered ring with lactone
|
||||
15: '[r15][#8][#6](=[#8])', # 15-membered ring with lactone
|
||||
16: '[r16][#8][#6](=[#8])', # 16-membered ring with lactone
|
||||
17: '[r17][#8][#6](=[#8])', # 17-membered ring with lactone
|
||||
18: '[r18][#8][#6](=[#8])', # 18-membered ring with lactone
|
||||
19: '[r19][#8][#6](=[#8])', # 19-membered ring with lactone
|
||||
20: '[r20][#8][#6](=[#8])', # 20-membered ring with lactone
|
||||
}
|
||||
|
||||
def standardize_molecule(self, mol: Chem.Mol) -> Optional[Chem.Mol]:
|
||||
"""Standardize molecule using RDKit standardization."""
|
||||
try:
|
||||
# Remove metals
|
||||
mol = self.standardizer.Disconnect(mol)
|
||||
# Normalize
|
||||
mol = rdMolStandardize.Normalize(mol)
|
||||
# Remove fragments
|
||||
mol = rdMolStandardize.FragmentParent(mol)
|
||||
# Neutralize charges
|
||||
mol = rdMolStandardize.ChargeParent(mol)
|
||||
return mol
|
||||
except:
|
||||
return None
|
||||
|
||||
def ecfp4_fingerprint(self, smiles: str) -> Optional[np.ndarray]:
|
||||
"""Generate ECFP4 fingerprint from SMILES string using newer RDKit API."""
|
||||
try:
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
if mol is None:
|
||||
return None
|
||||
|
||||
# Standardize molecule
|
||||
mol = self.standardize_molecule(mol)
|
||||
if mol is None:
|
||||
return None
|
||||
|
||||
# Generate Morgan fingerprint using the newer API to avoid deprecation warnings
|
||||
from rdkit.Chem import rdFingerprintGenerator
|
||||
generator = rdFingerprintGenerator.GetMorganGenerator(
|
||||
radius=self.radius,
|
||||
fpSize=self.n_bits,
|
||||
includeChirality=self.chirality
|
||||
)
|
||||
bv = generator.GetFingerprint(mol)
|
||||
|
||||
# Convert to numpy array
|
||||
arr = np.zeros((self.n_bits,), dtype=np.uint8)
|
||||
DataStructs.ConvertToNumpyArray(bv, arr)
|
||||
return arr
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing SMILES {smiles[:50]}...: {e}")
|
||||
return None
|
||||
|
||||
def detect_ring_number(self, smiles: str) -> int:
|
||||
"""Detect the ring number in macrolactone molecule using SMARTS patterns."""
|
||||
try:
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
if mol is None:
|
||||
return 0
|
||||
|
||||
# Check each ring size pattern
|
||||
for ring_size, smarts in self.ring_smarts.items():
|
||||
query = Chem.MolFromSmarts(smarts)
|
||||
if query:
|
||||
matches = mol.GetSubstructMatches(query)
|
||||
if matches:
|
||||
return ring_size
|
||||
|
||||
# Alternative: check for any large ring with lactone
|
||||
generic_pattern = Chem.MolFromSmarts('[r{12-20}][#8][#6](=[#8])')
|
||||
if generic_pattern:
|
||||
matches = mol.GetSubstructMatches(generic_pattern)
|
||||
if matches:
|
||||
# Try to determine ring size from the first match
|
||||
for match in matches:
|
||||
# Get the ring atoms
|
||||
for atom_idx in match:
|
||||
atom = mol.GetAtomWithIdx(atom_idx)
|
||||
if atom.IsInRing():
|
||||
# Find the ring size
|
||||
for ring in atom.GetOwningMol().GetRingInfo().AtomRings():
|
||||
if atom_idx in ring:
|
||||
ring_size = len(ring)
|
||||
if 12 <= ring_size <= 20:
|
||||
return ring_size
|
||||
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error detecting ring number for {smiles}: {e}")
|
||||
return 0
|
||||
|
||||
def generate_unique_id(self, index: int, existing_id: Optional[str] = None) -> str:
|
||||
"""Generate unique ID for molecule."""
|
||||
if existing_id and pd.notna(existing_id) and existing_id != '':
|
||||
return str(existing_id)
|
||||
else:
|
||||
return f"D{index:07d}"
|
||||
|
||||
def tanimoto_similarity(self, fp1: np.ndarray, fp2: np.ndarray) -> float:
|
||||
"""Calculate Tanimoto similarity between two fingerprints."""
|
||||
# Bit count
|
||||
bit_count1 = np.sum(fp1)
|
||||
bit_count2 = np.sum(fp2)
|
||||
common_bits = np.sum(fp1 & fp2)
|
||||
|
||||
if bit_count1 + bit_count2 - common_bits == 0:
|
||||
return 0.0
|
||||
|
||||
return common_bits / (bit_count1 + bit_count2 - common_bits)
|
||||
|
||||
def find_neighbors(self, X: np.ndarray, k: int = 15, batch_size: int = 1000) -> List[str]:
|
||||
"""Find k nearest neighbors for each molecule based on Tanimoto similarity."""
|
||||
n_samples = X.shape[0]
|
||||
neighbors = []
|
||||
|
||||
# Progress bar
|
||||
if HAS_TQDM:
|
||||
pbar = tqdm(total=n_samples, desc="Finding neighbors")
|
||||
|
||||
for i in range(n_samples):
|
||||
similarities = []
|
||||
|
||||
# Batch processing for memory efficiency
|
||||
for j in range(0, n_samples, batch_size):
|
||||
end_j = min(j + batch_size, n_samples)
|
||||
batch_X = X[j:end_j]
|
||||
|
||||
# Calculate similarities for this batch
|
||||
for batch_idx, fp in enumerate(batch_X):
|
||||
orig_idx = j + batch_idx
|
||||
if i != orig_idx:
|
||||
sim = self.tanimoto_similarity(X[i], fp)
|
||||
similarities.append((orig_idx, sim))
|
||||
|
||||
# Sort by similarity (descending)
|
||||
similarities.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Get top k neighbors
|
||||
top_neighbors = [str(idx) for idx, _ in similarities[:k]]
|
||||
neighbors.append(','.join(top_neighbors))
|
||||
|
||||
if HAS_TQDM:
|
||||
pbar.update(1)
|
||||
|
||||
if HAS_TQDM:
|
||||
pbar.close()
|
||||
|
||||
return neighbors
|
||||
|
||||
def perform_umap(self, X: np.ndarray, n_neighbors: int = 30,
|
||||
min_dist: float = 0.1, metric: str = 'jaccard') -> np.ndarray:
|
||||
"""Perform UMAP dimensionality reduction."""
|
||||
reducer = umap.UMAP(
|
||||
n_neighbors=n_neighbors,
|
||||
min_dist=min_dist,
|
||||
metric=metric,
|
||||
random_state=42
|
||||
)
|
||||
|
||||
return reducer.fit_transform(X)
|
||||
|
||||
def process_dataframe(self, df: pd.DataFrame, smiles_col: str = 'smiles',
|
||||
id_col: Optional[str] = None, max_molecules: Optional[int] = None) -> pd.DataFrame:
|
||||
"""Process dataframe with SMILES strings."""
|
||||
print(f"Processing {len(df)} molecules...")
|
||||
|
||||
# Limit molecules if requested
|
||||
if max_molecules:
|
||||
df = df.head(max_molecules)
|
||||
print(f"Limited to {max_molecules} molecules")
|
||||
|
||||
# Ensure we have a smiles column
|
||||
if smiles_col not in df.columns:
|
||||
raise ValueError(f"Column '{smiles_col}' not found in dataframe")
|
||||
|
||||
# Create a working copy
|
||||
result_df = df.copy()
|
||||
|
||||
# Generate unique IDs if needed
|
||||
if id_col and id_col in df.columns:
|
||||
result_df['molecule_id'] = [self.generate_unique_id(i, existing_id)
|
||||
for i, existing_id in enumerate(result_df[id_col])]
|
||||
else:
|
||||
result_df['molecule_id'] = [self.generate_unique_id(i)
|
||||
for i in range(len(result_df))]
|
||||
|
||||
# Process fingerprints
|
||||
print("Generating ECFP4 fingerprints...")
|
||||
fingerprints = []
|
||||
valid_indices = []
|
||||
|
||||
# Progress tracking
|
||||
iterator = enumerate(result_df[smiles_col])
|
||||
if HAS_TQDM:
|
||||
iterator = tqdm(iterator, total=len(result_df), desc="Processing fingerprints")
|
||||
|
||||
for idx, smiles in iterator:
|
||||
if pd.notna(smiles) and smiles != '':
|
||||
fp = self.ecfp4_fingerprint(smiles)
|
||||
if fp is not None:
|
||||
fingerprints.append(fp)
|
||||
valid_indices.append(idx)
|
||||
else:
|
||||
print(f"Failed to generate fingerprint for index {idx}: {smiles[:50]}...")
|
||||
else:
|
||||
print(f"Invalid SMILES at index {idx}")
|
||||
|
||||
# Filter dataframe to valid molecules only
|
||||
result_df = result_df.iloc[valid_indices].reset_index(drop=True)
|
||||
|
||||
if not fingerprints:
|
||||
raise ValueError("No valid fingerprints generated")
|
||||
|
||||
# Convert fingerprints to numpy array
|
||||
X = np.array(fingerprints)
|
||||
print(f"Generated fingerprints for {len(fingerprints)} molecules")
|
||||
|
||||
# Detect ring numbers
|
||||
print("Detecting ring numbers...")
|
||||
ring_numbers = []
|
||||
|
||||
iterator = result_df[smiles_col]
|
||||
if HAS_TQDM:
|
||||
iterator = tqdm(iterator, desc="Detecting rings")
|
||||
|
||||
for smiles in iterator:
|
||||
ring_num = self.detect_ring_number(smiles)
|
||||
ring_numbers.append(ring_num)
|
||||
|
||||
result_df['ring_num'] = ring_numbers
|
||||
|
||||
# Perform UMAP
|
||||
print("Performing UMAP dimensionality reduction...")
|
||||
embedding = self.perform_umap(X)
|
||||
result_df['projection_x'] = embedding[:, 0]
|
||||
result_df['projection_y'] = embedding[:, 1]
|
||||
|
||||
# Find neighbors for embedding-atlas
|
||||
print("Finding nearest neighbors...")
|
||||
neighbors = self.find_neighbors(X, k=15)
|
||||
result_df['neighbors'] = neighbors
|
||||
|
||||
# Add fingerprint information
|
||||
result_df['fingerprint_bits'] = [fp.tolist() for fp in fingerprints]
|
||||
|
||||
return result_df
|
||||
|
||||
def create_visualization(self, df: pd.DataFrame, output_path: str):
|
||||
"""Create visualization of the UMAP embedding."""
|
||||
plt.figure(figsize=(12, 8))
|
||||
|
||||
# Color by ring number
|
||||
scatter = plt.scatter(df['projection_x'], df['projection_y'],
|
||||
c=df['ring_num'], cmap='viridis', alpha=0.6, s=30)
|
||||
|
||||
plt.colorbar(scatter, label='Ring Number')
|
||||
plt.xlabel('UMAP 1')
|
||||
plt.ylabel('UMAP 2')
|
||||
plt.title('Macrolactone Molecules - ECFP4 + UMAP Visualization')
|
||||
|
||||
# Add some annotations for ring numbers
|
||||
for ring_num in sorted(df['ring_num'].unique()):
|
||||
if ring_num > 0:
|
||||
subset = df[df['ring_num'] == ring_num]
|
||||
if len(subset) > 0:
|
||||
center_x = subset['projection_x'].mean()
|
||||
center_y = subset['projection_y'].mean()
|
||||
plt.annotate(f'{ring_num} ring', (center_x, center_y),
|
||||
fontsize=10, fontweight='bold')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f"Visualization saved to {output_path}")
|
||||
|
||||
def main():
|
||||
"""Main function to run the processing pipeline."""
|
||||
|
||||
parser = argparse.ArgumentParser(description='ECFP4 + UMAP for Macrolactone Molecules')
|
||||
parser.add_argument('--input', '-i', required=True,
|
||||
help='Input CSV file path')
|
||||
parser.add_argument('--output', '-o', required=True,
|
||||
help='Output CSV file path')
|
||||
parser.add_argument('--smiles-col', default='smiles',
|
||||
help='Name of SMILES column (default: smiles)')
|
||||
parser.add_argument('--id-col', default=None,
|
||||
help='Name of ID column (optional)')
|
||||
parser.add_argument('--visualization', '-v', default='umap_visualization.png',
|
||||
help='Output visualization file path')
|
||||
parser.add_argument('--max-molecules', type=int, default=None,
|
||||
help='Maximum number of molecules to process (for testing)')
|
||||
parser.add_argument('--launch-atlas', action='store_true',
|
||||
help='Launch embedding-atlas process')
|
||||
parser.add_argument('--atlas-port', type=int, default=8080,
|
||||
help='Port for embedding-atlas server')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize processor
|
||||
processor = MacrolactoneProcessor(n_bits=2048, radius=2, chirality=True)
|
||||
|
||||
# Load data
|
||||
print(f"Loading data from {args.input}")
|
||||
try:
|
||||
df = pd.read_csv(args.input)
|
||||
print(f"Loaded {len(df)} molecules")
|
||||
print(f"Columns: {list(df.columns)}")
|
||||
except Exception as e:
|
||||
print(f"Error loading data: {e}")
|
||||
return 1
|
||||
|
||||
# Process dataframe
|
||||
try:
|
||||
processed_df = processor.process_dataframe(df,
|
||||
smiles_col=args.smiles_col,
|
||||
id_col=args.id_col,
|
||||
max_molecules=args.max_molecules)
|
||||
print(f"Successfully processed {len(processed_df)} molecules")
|
||||
except Exception as e:
|
||||
print(f"Error processing data: {e}")
|
||||
return 1
|
||||
|
||||
# Save results
|
||||
try:
|
||||
processed_df.to_csv(args.output, index=False)
|
||||
print(f"Results saved to {args.output}")
|
||||
except Exception as e:
|
||||
print(f"Error saving results: {e}")
|
||||
return 1
|
||||
|
||||
# Create visualization
|
||||
try:
|
||||
processor.create_visualization(processed_df, args.visualization)
|
||||
except Exception as e:
|
||||
print(f"Error creating visualization: {e}")
|
||||
|
||||
# Launch embedding-atlas if requested
|
||||
if args.launch_atlas:
|
||||
print("Launching embedding-atlas process...")
|
||||
try:
|
||||
# Prepare command for embedding-atlas
|
||||
cmd = [
|
||||
'embedding-atlas', 'data', args.output,
|
||||
'--text', args.smiles_col,
|
||||
'--port', str(args.atlas_port),
|
||||
'--neighbors', 'neighbors',
|
||||
'--x', 'projection_x',
|
||||
'--y', 'projection_y'
|
||||
]
|
||||
|
||||
print(f"Running command: {' '.join(cmd)}")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
print("Embedding-atlas process launched successfully")
|
||||
print(f"Access the visualization at: http://localhost:{args.atlas_port}")
|
||||
else:
|
||||
print(f"Error launching embedding-atlas: {result.stderr}")
|
||||
|
||||
except FileNotFoundError:
|
||||
print("embedding-atlas command not found. Please install it first.")
|
||||
print("You can install it with: pip install embedding-atlas")
|
||||
except Exception as e:
|
||||
print(f"Error launching embedding-atlas: {e}")
|
||||
|
||||
print("Processing complete!")
|
||||
return 0
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user