Files
embedding_atlas/script/visualization/ecfp4_umap_embedding_optimized.py
lingyuzeng bbf1746046 重构项目结构并更新README.md
1. 重构目录结构:
   - 创建src/visualization模块用于存放可视化相关功能
   - 移动script/visualize_csv_comparison.py到src/visualization/comparison.py
   - 创建src/visualization/__init__.py导出主要函数
   - 整理script目录,按功能分类存放脚本文件

2. 更新README.md:
   - 添加CSV文件比较可视化部分
   - 提供Python API和命令行使用方法说明
   - 描述功能特点和使用示例

3. 更新模块引用:
   - 修正comparison.py中的模块引用路径
   - 更新命令行帮助信息中的使用示例
2025-10-23 17:55:36 +08:00

439 lines
16 KiB
Python

#!/usr/bin/env python3
"""
Optimized ECFP4 Fingerprinting with UMAP Visualization for Macrolactone Molecules
This script processes SMILES data to:
1. Generate ECFP4 fingerprints using RDKit
2. Detect ring numbers in macrolactone molecules using SMARTS patterns
3. Generate unique IDs for molecules without existing IDs
4. Perform UMAP dimensionality reduction with Tanimoto distance
5. Prepare data for embedding-atlas visualization
Optimized for large datasets with progress tracking and memory efficiency.
"""
import os
import sys
import argparse
import subprocess
from typing import Optional, List
# RDKit imports
from rdkit import Chem
from rdkit.Chem import rdMolDescriptors, DataStructs
from rdkit.Chem.MolStandardize import rdMolStandardize
# Data processing
import pandas as pd
import numpy as np
# UMAP and visualization
import umap
import matplotlib.pyplot as plt
# Suppress warnings
import warnings
warnings.filterwarnings('ignore')
# Progress bar
try:
from tqdm import tqdm
HAS_TQDM = True
except ImportError:
HAS_TQDM = False
class MacrolactoneProcessor:
"""Process macrolactone molecules for embedding visualization."""
def __init__(self, n_bits: int = 2048, radius: int = 2, chirality: bool = True):
"""
Initialize processor with ECFP4 parameters.
Args:
n_bits: Number of fingerprint bits (default: 2048)
radius: Morgan fingerprint radius (default: 2 for ECFP4)
chirality: Include chirality information (default: True)
"""
self.n_bits = n_bits
self.radius = radius
self.chirality = chirality
# Standardizer for molecule preprocessing
self.standardizer = rdMolStandardize.MetalDisconnector()
# SMARTS patterns for different ring sizes (12-20 membered rings)
self.ring_smarts = {
12: '[r12][#8][#6](=[#8])', # 12-membered ring with lactone
13: '[r13][#8][#6](=[#8])', # 13-membered ring with lactone
14: '[r14][#8][#6](=[#8])', # 14-membered ring with lactone
15: '[r15][#8][#6](=[#8])', # 15-membered ring with lactone
16: '[r16][#8][#6](=[#8])', # 16-membered ring with lactone
17: '[r17][#8][#6](=[#8])', # 17-membered ring with lactone
18: '[r18][#8][#6](=[#8])', # 18-membered ring with lactone
19: '[r19][#8][#6](=[#8])', # 19-membered ring with lactone
20: '[r20][#8][#6](=[#8])', # 20-membered ring with lactone
}
def standardize_molecule(self, mol: Chem.Mol) -> Optional[Chem.Mol]:
"""Standardize molecule using RDKit standardization."""
try:
# Remove metals
mol = self.standardizer.Disconnect(mol)
# Normalize
mol = rdMolStandardize.Normalize(mol)
# Remove fragments
mol = rdMolStandardize.FragmentParent(mol)
# Neutralize charges
mol = rdMolStandardize.ChargeParent(mol)
return mol
except:
return None
def ecfp4_fingerprint(self, smiles: str) -> Optional[np.ndarray]:
"""Generate ECFP4 fingerprint from SMILES string using newer RDKit API."""
try:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
# Standardize molecule
mol = self.standardize_molecule(mol)
if mol is None:
return None
# Generate Morgan fingerprint using the newer API to avoid deprecation warnings
from rdkit.Chem import rdFingerprintGenerator
generator = rdFingerprintGenerator.GetMorganGenerator(
radius=self.radius,
fpSize=self.n_bits,
includeChirality=self.chirality
)
bv = generator.GetFingerprint(mol)
# Convert to numpy array
arr = np.zeros((self.n_bits,), dtype=np.uint8)
DataStructs.ConvertToNumpyArray(bv, arr)
return arr
except Exception as e:
print(f"Error processing SMILES {smiles[:50]}...: {e}")
return None
def detect_ring_number(self, smiles: str) -> int:
"""Detect the ring number in macrolactone molecule using SMARTS patterns."""
try:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return 0
# Check each ring size pattern
for ring_size, smarts in self.ring_smarts.items():
query = Chem.MolFromSmarts(smarts)
if query:
matches = mol.GetSubstructMatches(query)
if matches:
return ring_size
# Alternative: check for any large ring with lactone
generic_pattern = Chem.MolFromSmarts('[r{12-20}][#8][#6](=[#8])')
if generic_pattern:
matches = mol.GetSubstructMatches(generic_pattern)
if matches:
# Try to determine ring size from the first match
for match in matches:
# Get the ring atoms
for atom_idx in match:
atom = mol.GetAtomWithIdx(atom_idx)
if atom.IsInRing():
# Find the ring size
for ring in atom.GetOwningMol().GetRingInfo().AtomRings():
if atom_idx in ring:
ring_size = len(ring)
if 12 <= ring_size <= 20:
return ring_size
return 0
except Exception as e:
print(f"Error detecting ring number for {smiles}: {e}")
return 0
def generate_unique_id(self, index: int, existing_id: Optional[str] = None) -> str:
"""Generate unique ID for molecule."""
if existing_id and pd.notna(existing_id) and existing_id != '':
return str(existing_id)
else:
return f"D{index:07d}"
def tanimoto_similarity(self, fp1: np.ndarray, fp2: np.ndarray) -> float:
"""Calculate Tanimoto similarity between two fingerprints."""
# Bit count
bit_count1 = np.sum(fp1)
bit_count2 = np.sum(fp2)
common_bits = np.sum(fp1 & fp2)
if bit_count1 + bit_count2 - common_bits == 0:
return 0.0
return common_bits / (bit_count1 + bit_count2 - common_bits)
def find_neighbors(self, X: np.ndarray, k: int = 15, batch_size: int = 1000) -> List[str]:
"""Find k nearest neighbors for each molecule based on Tanimoto similarity."""
n_samples = X.shape[0]
neighbors = []
# Progress bar
if HAS_TQDM:
pbar = tqdm(total=n_samples, desc="Finding neighbors")
for i in range(n_samples):
similarities = []
# Batch processing for memory efficiency
for j in range(0, n_samples, batch_size):
end_j = min(j + batch_size, n_samples)
batch_X = X[j:end_j]
# Calculate similarities for this batch
for batch_idx, fp in enumerate(batch_X):
orig_idx = j + batch_idx
if i != orig_idx:
sim = self.tanimoto_similarity(X[i], fp)
similarities.append((orig_idx, sim))
# Sort by similarity (descending)
similarities.sort(key=lambda x: x[1], reverse=True)
# Get top k neighbors
top_neighbors = [str(idx) for idx, _ in similarities[:k]]
neighbors.append(','.join(top_neighbors))
if HAS_TQDM:
pbar.update(1)
if HAS_TQDM:
pbar.close()
return neighbors
def perform_umap(self, X: np.ndarray, n_neighbors: int = 30,
min_dist: float = 0.1, metric: str = 'jaccard') -> np.ndarray:
"""Perform UMAP dimensionality reduction."""
reducer = umap.UMAP(
n_neighbors=n_neighbors,
min_dist=min_dist,
metric=metric,
random_state=42
)
return reducer.fit_transform(X)
def process_dataframe(self, df: pd.DataFrame, smiles_col: str = 'smiles',
id_col: Optional[str] = None, max_molecules: Optional[int] = None) -> pd.DataFrame:
"""Process dataframe with SMILES strings."""
print(f"Processing {len(df)} molecules...")
# Limit molecules if requested
if max_molecules:
df = df.head(max_molecules)
print(f"Limited to {max_molecules} molecules")
# Ensure we have a smiles column
if smiles_col not in df.columns:
raise ValueError(f"Column '{smiles_col}' not found in dataframe")
# Create a working copy
result_df = df.copy()
# Generate unique IDs if needed
if id_col and id_col in df.columns:
result_df['molecule_id'] = [self.generate_unique_id(i, existing_id)
for i, existing_id in enumerate(result_df[id_col])]
else:
result_df['molecule_id'] = [self.generate_unique_id(i)
for i in range(len(result_df))]
# Process fingerprints
print("Generating ECFP4 fingerprints...")
fingerprints = []
valid_indices = []
# Progress tracking
iterator = enumerate(result_df[smiles_col])
if HAS_TQDM:
iterator = tqdm(iterator, total=len(result_df), desc="Processing fingerprints")
for idx, smiles in iterator:
if pd.notna(smiles) and smiles != '':
fp = self.ecfp4_fingerprint(smiles)
if fp is not None:
fingerprints.append(fp)
valid_indices.append(idx)
else:
print(f"Failed to generate fingerprint for index {idx}: {smiles[:50]}...")
else:
print(f"Invalid SMILES at index {idx}")
# Filter dataframe to valid molecules only
result_df = result_df.iloc[valid_indices].reset_index(drop=True)
if not fingerprints:
raise ValueError("No valid fingerprints generated")
# Convert fingerprints to numpy array
X = np.array(fingerprints)
print(f"Generated fingerprints for {len(fingerprints)} molecules")
# Detect ring numbers
print("Detecting ring numbers...")
ring_numbers = []
iterator = result_df[smiles_col]
if HAS_TQDM:
iterator = tqdm(iterator, desc="Detecting rings")
for smiles in iterator:
ring_num = self.detect_ring_number(smiles)
ring_numbers.append(ring_num)
result_df['ring_num'] = ring_numbers
# Perform UMAP
print("Performing UMAP dimensionality reduction...")
embedding = self.perform_umap(X)
result_df['projection_x'] = embedding[:, 0]
result_df['projection_y'] = embedding[:, 1]
# Find neighbors for embedding-atlas
print("Finding nearest neighbors...")
neighbors = self.find_neighbors(X, k=15)
result_df['neighbors'] = neighbors
# Add fingerprint information
result_df['fingerprint_bits'] = [fp.tolist() for fp in fingerprints]
return result_df
def create_visualization(self, df: pd.DataFrame, output_path: str):
"""Create visualization of the UMAP embedding."""
plt.figure(figsize=(12, 8))
# Color by ring number
scatter = plt.scatter(df['projection_x'], df['projection_y'],
c=df['ring_num'], cmap='viridis', alpha=0.6, s=30)
plt.colorbar(scatter, label='Ring Number')
plt.xlabel('UMAP 1')
plt.ylabel('UMAP 2')
plt.title('Macrolactone Molecules - ECFP4 + UMAP Visualization')
# Add some annotations for ring numbers
for ring_num in sorted(df['ring_num'].unique()):
if ring_num > 0:
subset = df[df['ring_num'] == ring_num]
if len(subset) > 0:
center_x = subset['projection_x'].mean()
center_y = subset['projection_y'].mean()
plt.annotate(f'{ring_num} ring', (center_x, center_y),
fontsize=10, fontweight='bold')
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"Visualization saved to {output_path}")
def main():
"""Main function to run the processing pipeline."""
parser = argparse.ArgumentParser(description='ECFP4 + UMAP for Macrolactone Molecules')
parser.add_argument('--input', '-i', required=True,
help='Input CSV file path')
parser.add_argument('--output', '-o', required=True,
help='Output CSV file path')
parser.add_argument('--smiles-col', default='smiles',
help='Name of SMILES column (default: smiles)')
parser.add_argument('--id-col', default=None,
help='Name of ID column (optional)')
parser.add_argument('--visualization', '-v', default='umap_visualization.png',
help='Output visualization file path')
parser.add_argument('--max-molecules', type=int, default=None,
help='Maximum number of molecules to process (for testing)')
parser.add_argument('--launch-atlas', action='store_true',
help='Launch embedding-atlas process')
parser.add_argument('--atlas-port', type=int, default=8080,
help='Port for embedding-atlas server')
args = parser.parse_args()
# Initialize processor
processor = MacrolactoneProcessor(n_bits=2048, radius=2, chirality=True)
# Load data
print(f"Loading data from {args.input}")
try:
df = pd.read_csv(args.input)
print(f"Loaded {len(df)} molecules")
print(f"Columns: {list(df.columns)}")
except Exception as e:
print(f"Error loading data: {e}")
return 1
# Process dataframe
try:
processed_df = processor.process_dataframe(df,
smiles_col=args.smiles_col,
id_col=args.id_col,
max_molecules=args.max_molecules)
print(f"Successfully processed {len(processed_df)} molecules")
except Exception as e:
print(f"Error processing data: {e}")
return 1
# Save results
try:
processed_df.to_csv(args.output, index=False)
print(f"Results saved to {args.output}")
except Exception as e:
print(f"Error saving results: {e}")
return 1
# Create visualization
try:
processor.create_visualization(processed_df, args.visualization)
except Exception as e:
print(f"Error creating visualization: {e}")
# Launch embedding-atlas if requested
if args.launch_atlas:
print("Launching embedding-atlas process...")
try:
# Prepare command for embedding-atlas
cmd = [
'embedding-atlas', 'data', args.output,
'--text', args.smiles_col,
'--port', str(args.atlas_port),
'--neighbors', 'neighbors',
'--x', 'projection_x',
'--y', 'projection_y'
]
print(f"Running command: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0:
print("Embedding-atlas process launched successfully")
print(f"Access the visualization at: http://localhost:{args.atlas_port}")
else:
print(f"Error launching embedding-atlas: {result.stderr}")
except FileNotFoundError:
print("embedding-atlas command not found. Please install it first.")
print("You can install it with: pip install embedding-atlas")
except Exception as e:
print(f"Error launching embedding-atlas: {e}")
print("Processing complete!")
return 0
if __name__ == '__main__':
sys.exit(main())