重构项目结构并更新README.md
1. 重构目录结构: - 创建src/visualization模块用于存放可视化相关功能 - 移动script/visualize_csv_comparison.py到src/visualization/comparison.py - 创建src/visualization/__init__.py导出主要函数 - 整理script目录,按功能分类存放脚本文件 2. 更新README.md: - 添加CSV文件比较可视化部分 - 提供Python API和命令行使用方法说明 - 描述功能特点和使用示例 3. 更新模块引用: - 修正comparison.py中的模块引用路径 - 更新命令行帮助信息中的使用示例
This commit is contained in:
439
script/visualization/ecfp4_umap_embedding_optimized.py
Normal file
439
script/visualization/ecfp4_umap_embedding_optimized.py
Normal file
@@ -0,0 +1,439 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Optimized ECFP4 Fingerprinting with UMAP Visualization for Macrolactone Molecules
|
||||
|
||||
This script processes SMILES data to:
|
||||
1. Generate ECFP4 fingerprints using RDKit
|
||||
2. Detect ring numbers in macrolactone molecules using SMARTS patterns
|
||||
3. Generate unique IDs for molecules without existing IDs
|
||||
4. Perform UMAP dimensionality reduction with Tanimoto distance
|
||||
5. Prepare data for embedding-atlas visualization
|
||||
|
||||
Optimized for large datasets with progress tracking and memory efficiency.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import subprocess
|
||||
from typing import Optional, List
|
||||
|
||||
# RDKit imports
|
||||
from rdkit import Chem
|
||||
from rdkit.Chem import rdMolDescriptors, DataStructs
|
||||
from rdkit.Chem.MolStandardize import rdMolStandardize
|
||||
|
||||
# Data processing
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
# UMAP and visualization
|
||||
import umap
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Suppress warnings
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
# Progress bar
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
HAS_TQDM = True
|
||||
except ImportError:
|
||||
HAS_TQDM = False
|
||||
|
||||
class MacrolactoneProcessor:
|
||||
"""Process macrolactone molecules for embedding visualization."""
|
||||
|
||||
def __init__(self, n_bits: int = 2048, radius: int = 2, chirality: bool = True):
|
||||
"""
|
||||
Initialize processor with ECFP4 parameters.
|
||||
|
||||
Args:
|
||||
n_bits: Number of fingerprint bits (default: 2048)
|
||||
radius: Morgan fingerprint radius (default: 2 for ECFP4)
|
||||
chirality: Include chirality information (default: True)
|
||||
"""
|
||||
self.n_bits = n_bits
|
||||
self.radius = radius
|
||||
self.chirality = chirality
|
||||
|
||||
# Standardizer for molecule preprocessing
|
||||
self.standardizer = rdMolStandardize.MetalDisconnector()
|
||||
|
||||
# SMARTS patterns for different ring sizes (12-20 membered rings)
|
||||
self.ring_smarts = {
|
||||
12: '[r12][#8][#6](=[#8])', # 12-membered ring with lactone
|
||||
13: '[r13][#8][#6](=[#8])', # 13-membered ring with lactone
|
||||
14: '[r14][#8][#6](=[#8])', # 14-membered ring with lactone
|
||||
15: '[r15][#8][#6](=[#8])', # 15-membered ring with lactone
|
||||
16: '[r16][#8][#6](=[#8])', # 16-membered ring with lactone
|
||||
17: '[r17][#8][#6](=[#8])', # 17-membered ring with lactone
|
||||
18: '[r18][#8][#6](=[#8])', # 18-membered ring with lactone
|
||||
19: '[r19][#8][#6](=[#8])', # 19-membered ring with lactone
|
||||
20: '[r20][#8][#6](=[#8])', # 20-membered ring with lactone
|
||||
}
|
||||
|
||||
def standardize_molecule(self, mol: Chem.Mol) -> Optional[Chem.Mol]:
|
||||
"""Standardize molecule using RDKit standardization."""
|
||||
try:
|
||||
# Remove metals
|
||||
mol = self.standardizer.Disconnect(mol)
|
||||
# Normalize
|
||||
mol = rdMolStandardize.Normalize(mol)
|
||||
# Remove fragments
|
||||
mol = rdMolStandardize.FragmentParent(mol)
|
||||
# Neutralize charges
|
||||
mol = rdMolStandardize.ChargeParent(mol)
|
||||
return mol
|
||||
except:
|
||||
return None
|
||||
|
||||
def ecfp4_fingerprint(self, smiles: str) -> Optional[np.ndarray]:
|
||||
"""Generate ECFP4 fingerprint from SMILES string using newer RDKit API."""
|
||||
try:
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
if mol is None:
|
||||
return None
|
||||
|
||||
# Standardize molecule
|
||||
mol = self.standardize_molecule(mol)
|
||||
if mol is None:
|
||||
return None
|
||||
|
||||
# Generate Morgan fingerprint using the newer API to avoid deprecation warnings
|
||||
from rdkit.Chem import rdFingerprintGenerator
|
||||
generator = rdFingerprintGenerator.GetMorganGenerator(
|
||||
radius=self.radius,
|
||||
fpSize=self.n_bits,
|
||||
includeChirality=self.chirality
|
||||
)
|
||||
bv = generator.GetFingerprint(mol)
|
||||
|
||||
# Convert to numpy array
|
||||
arr = np.zeros((self.n_bits,), dtype=np.uint8)
|
||||
DataStructs.ConvertToNumpyArray(bv, arr)
|
||||
return arr
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing SMILES {smiles[:50]}...: {e}")
|
||||
return None
|
||||
|
||||
def detect_ring_number(self, smiles: str) -> int:
|
||||
"""Detect the ring number in macrolactone molecule using SMARTS patterns."""
|
||||
try:
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
if mol is None:
|
||||
return 0
|
||||
|
||||
# Check each ring size pattern
|
||||
for ring_size, smarts in self.ring_smarts.items():
|
||||
query = Chem.MolFromSmarts(smarts)
|
||||
if query:
|
||||
matches = mol.GetSubstructMatches(query)
|
||||
if matches:
|
||||
return ring_size
|
||||
|
||||
# Alternative: check for any large ring with lactone
|
||||
generic_pattern = Chem.MolFromSmarts('[r{12-20}][#8][#6](=[#8])')
|
||||
if generic_pattern:
|
||||
matches = mol.GetSubstructMatches(generic_pattern)
|
||||
if matches:
|
||||
# Try to determine ring size from the first match
|
||||
for match in matches:
|
||||
# Get the ring atoms
|
||||
for atom_idx in match:
|
||||
atom = mol.GetAtomWithIdx(atom_idx)
|
||||
if atom.IsInRing():
|
||||
# Find the ring size
|
||||
for ring in atom.GetOwningMol().GetRingInfo().AtomRings():
|
||||
if atom_idx in ring:
|
||||
ring_size = len(ring)
|
||||
if 12 <= ring_size <= 20:
|
||||
return ring_size
|
||||
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error detecting ring number for {smiles}: {e}")
|
||||
return 0
|
||||
|
||||
def generate_unique_id(self, index: int, existing_id: Optional[str] = None) -> str:
|
||||
"""Generate unique ID for molecule."""
|
||||
if existing_id and pd.notna(existing_id) and existing_id != '':
|
||||
return str(existing_id)
|
||||
else:
|
||||
return f"D{index:07d}"
|
||||
|
||||
def tanimoto_similarity(self, fp1: np.ndarray, fp2: np.ndarray) -> float:
|
||||
"""Calculate Tanimoto similarity between two fingerprints."""
|
||||
# Bit count
|
||||
bit_count1 = np.sum(fp1)
|
||||
bit_count2 = np.sum(fp2)
|
||||
common_bits = np.sum(fp1 & fp2)
|
||||
|
||||
if bit_count1 + bit_count2 - common_bits == 0:
|
||||
return 0.0
|
||||
|
||||
return common_bits / (bit_count1 + bit_count2 - common_bits)
|
||||
|
||||
def find_neighbors(self, X: np.ndarray, k: int = 15, batch_size: int = 1000) -> List[str]:
|
||||
"""Find k nearest neighbors for each molecule based on Tanimoto similarity."""
|
||||
n_samples = X.shape[0]
|
||||
neighbors = []
|
||||
|
||||
# Progress bar
|
||||
if HAS_TQDM:
|
||||
pbar = tqdm(total=n_samples, desc="Finding neighbors")
|
||||
|
||||
for i in range(n_samples):
|
||||
similarities = []
|
||||
|
||||
# Batch processing for memory efficiency
|
||||
for j in range(0, n_samples, batch_size):
|
||||
end_j = min(j + batch_size, n_samples)
|
||||
batch_X = X[j:end_j]
|
||||
|
||||
# Calculate similarities for this batch
|
||||
for batch_idx, fp in enumerate(batch_X):
|
||||
orig_idx = j + batch_idx
|
||||
if i != orig_idx:
|
||||
sim = self.tanimoto_similarity(X[i], fp)
|
||||
similarities.append((orig_idx, sim))
|
||||
|
||||
# Sort by similarity (descending)
|
||||
similarities.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Get top k neighbors
|
||||
top_neighbors = [str(idx) for idx, _ in similarities[:k]]
|
||||
neighbors.append(','.join(top_neighbors))
|
||||
|
||||
if HAS_TQDM:
|
||||
pbar.update(1)
|
||||
|
||||
if HAS_TQDM:
|
||||
pbar.close()
|
||||
|
||||
return neighbors
|
||||
|
||||
def perform_umap(self, X: np.ndarray, n_neighbors: int = 30,
|
||||
min_dist: float = 0.1, metric: str = 'jaccard') -> np.ndarray:
|
||||
"""Perform UMAP dimensionality reduction."""
|
||||
reducer = umap.UMAP(
|
||||
n_neighbors=n_neighbors,
|
||||
min_dist=min_dist,
|
||||
metric=metric,
|
||||
random_state=42
|
||||
)
|
||||
|
||||
return reducer.fit_transform(X)
|
||||
|
||||
def process_dataframe(self, df: pd.DataFrame, smiles_col: str = 'smiles',
|
||||
id_col: Optional[str] = None, max_molecules: Optional[int] = None) -> pd.DataFrame:
|
||||
"""Process dataframe with SMILES strings."""
|
||||
print(f"Processing {len(df)} molecules...")
|
||||
|
||||
# Limit molecules if requested
|
||||
if max_molecules:
|
||||
df = df.head(max_molecules)
|
||||
print(f"Limited to {max_molecules} molecules")
|
||||
|
||||
# Ensure we have a smiles column
|
||||
if smiles_col not in df.columns:
|
||||
raise ValueError(f"Column '{smiles_col}' not found in dataframe")
|
||||
|
||||
# Create a working copy
|
||||
result_df = df.copy()
|
||||
|
||||
# Generate unique IDs if needed
|
||||
if id_col and id_col in df.columns:
|
||||
result_df['molecule_id'] = [self.generate_unique_id(i, existing_id)
|
||||
for i, existing_id in enumerate(result_df[id_col])]
|
||||
else:
|
||||
result_df['molecule_id'] = [self.generate_unique_id(i)
|
||||
for i in range(len(result_df))]
|
||||
|
||||
# Process fingerprints
|
||||
print("Generating ECFP4 fingerprints...")
|
||||
fingerprints = []
|
||||
valid_indices = []
|
||||
|
||||
# Progress tracking
|
||||
iterator = enumerate(result_df[smiles_col])
|
||||
if HAS_TQDM:
|
||||
iterator = tqdm(iterator, total=len(result_df), desc="Processing fingerprints")
|
||||
|
||||
for idx, smiles in iterator:
|
||||
if pd.notna(smiles) and smiles != '':
|
||||
fp = self.ecfp4_fingerprint(smiles)
|
||||
if fp is not None:
|
||||
fingerprints.append(fp)
|
||||
valid_indices.append(idx)
|
||||
else:
|
||||
print(f"Failed to generate fingerprint for index {idx}: {smiles[:50]}...")
|
||||
else:
|
||||
print(f"Invalid SMILES at index {idx}")
|
||||
|
||||
# Filter dataframe to valid molecules only
|
||||
result_df = result_df.iloc[valid_indices].reset_index(drop=True)
|
||||
|
||||
if not fingerprints:
|
||||
raise ValueError("No valid fingerprints generated")
|
||||
|
||||
# Convert fingerprints to numpy array
|
||||
X = np.array(fingerprints)
|
||||
print(f"Generated fingerprints for {len(fingerprints)} molecules")
|
||||
|
||||
# Detect ring numbers
|
||||
print("Detecting ring numbers...")
|
||||
ring_numbers = []
|
||||
|
||||
iterator = result_df[smiles_col]
|
||||
if HAS_TQDM:
|
||||
iterator = tqdm(iterator, desc="Detecting rings")
|
||||
|
||||
for smiles in iterator:
|
||||
ring_num = self.detect_ring_number(smiles)
|
||||
ring_numbers.append(ring_num)
|
||||
|
||||
result_df['ring_num'] = ring_numbers
|
||||
|
||||
# Perform UMAP
|
||||
print("Performing UMAP dimensionality reduction...")
|
||||
embedding = self.perform_umap(X)
|
||||
result_df['projection_x'] = embedding[:, 0]
|
||||
result_df['projection_y'] = embedding[:, 1]
|
||||
|
||||
# Find neighbors for embedding-atlas
|
||||
print("Finding nearest neighbors...")
|
||||
neighbors = self.find_neighbors(X, k=15)
|
||||
result_df['neighbors'] = neighbors
|
||||
|
||||
# Add fingerprint information
|
||||
result_df['fingerprint_bits'] = [fp.tolist() for fp in fingerprints]
|
||||
|
||||
return result_df
|
||||
|
||||
def create_visualization(self, df: pd.DataFrame, output_path: str):
|
||||
"""Create visualization of the UMAP embedding."""
|
||||
plt.figure(figsize=(12, 8))
|
||||
|
||||
# Color by ring number
|
||||
scatter = plt.scatter(df['projection_x'], df['projection_y'],
|
||||
c=df['ring_num'], cmap='viridis', alpha=0.6, s=30)
|
||||
|
||||
plt.colorbar(scatter, label='Ring Number')
|
||||
plt.xlabel('UMAP 1')
|
||||
plt.ylabel('UMAP 2')
|
||||
plt.title('Macrolactone Molecules - ECFP4 + UMAP Visualization')
|
||||
|
||||
# Add some annotations for ring numbers
|
||||
for ring_num in sorted(df['ring_num'].unique()):
|
||||
if ring_num > 0:
|
||||
subset = df[df['ring_num'] == ring_num]
|
||||
if len(subset) > 0:
|
||||
center_x = subset['projection_x'].mean()
|
||||
center_y = subset['projection_y'].mean()
|
||||
plt.annotate(f'{ring_num} ring', (center_x, center_y),
|
||||
fontsize=10, fontweight='bold')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f"Visualization saved to {output_path}")
|
||||
|
||||
def main():
|
||||
"""Main function to run the processing pipeline."""
|
||||
|
||||
parser = argparse.ArgumentParser(description='ECFP4 + UMAP for Macrolactone Molecules')
|
||||
parser.add_argument('--input', '-i', required=True,
|
||||
help='Input CSV file path')
|
||||
parser.add_argument('--output', '-o', required=True,
|
||||
help='Output CSV file path')
|
||||
parser.add_argument('--smiles-col', default='smiles',
|
||||
help='Name of SMILES column (default: smiles)')
|
||||
parser.add_argument('--id-col', default=None,
|
||||
help='Name of ID column (optional)')
|
||||
parser.add_argument('--visualization', '-v', default='umap_visualization.png',
|
||||
help='Output visualization file path')
|
||||
parser.add_argument('--max-molecules', type=int, default=None,
|
||||
help='Maximum number of molecules to process (for testing)')
|
||||
parser.add_argument('--launch-atlas', action='store_true',
|
||||
help='Launch embedding-atlas process')
|
||||
parser.add_argument('--atlas-port', type=int, default=8080,
|
||||
help='Port for embedding-atlas server')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize processor
|
||||
processor = MacrolactoneProcessor(n_bits=2048, radius=2, chirality=True)
|
||||
|
||||
# Load data
|
||||
print(f"Loading data from {args.input}")
|
||||
try:
|
||||
df = pd.read_csv(args.input)
|
||||
print(f"Loaded {len(df)} molecules")
|
||||
print(f"Columns: {list(df.columns)}")
|
||||
except Exception as e:
|
||||
print(f"Error loading data: {e}")
|
||||
return 1
|
||||
|
||||
# Process dataframe
|
||||
try:
|
||||
processed_df = processor.process_dataframe(df,
|
||||
smiles_col=args.smiles_col,
|
||||
id_col=args.id_col,
|
||||
max_molecules=args.max_molecules)
|
||||
print(f"Successfully processed {len(processed_df)} molecules")
|
||||
except Exception as e:
|
||||
print(f"Error processing data: {e}")
|
||||
return 1
|
||||
|
||||
# Save results
|
||||
try:
|
||||
processed_df.to_csv(args.output, index=False)
|
||||
print(f"Results saved to {args.output}")
|
||||
except Exception as e:
|
||||
print(f"Error saving results: {e}")
|
||||
return 1
|
||||
|
||||
# Create visualization
|
||||
try:
|
||||
processor.create_visualization(processed_df, args.visualization)
|
||||
except Exception as e:
|
||||
print(f"Error creating visualization: {e}")
|
||||
|
||||
# Launch embedding-atlas if requested
|
||||
if args.launch_atlas:
|
||||
print("Launching embedding-atlas process...")
|
||||
try:
|
||||
# Prepare command for embedding-atlas
|
||||
cmd = [
|
||||
'embedding-atlas', 'data', args.output,
|
||||
'--text', args.smiles_col,
|
||||
'--port', str(args.atlas_port),
|
||||
'--neighbors', 'neighbors',
|
||||
'--x', 'projection_x',
|
||||
'--y', 'projection_y'
|
||||
]
|
||||
|
||||
print(f"Running command: {' '.join(cmd)}")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
print("Embedding-atlas process launched successfully")
|
||||
print(f"Access the visualization at: http://localhost:{args.atlas_port}")
|
||||
else:
|
||||
print(f"Error launching embedding-atlas: {result.stderr}")
|
||||
|
||||
except FileNotFoundError:
|
||||
print("embedding-atlas command not found. Please install it first.")
|
||||
print("You can install it with: pip install embedding-atlas")
|
||||
except Exception as e:
|
||||
print(f"Error launching embedding-atlas: {e}")
|
||||
|
||||
print("Processing complete!")
|
||||
return 0
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user