#!/usr/bin/env python # -*- encoding: utf-8 -*- ''' @file :analyze_qed_mw_distribution.py @Description :Analysis of QED and molecular weight distribution with KDE plots @Date :2025/08/05 @Author :lyzeng ''' import pandas as pd import seaborn as sns import matplotlib.pyplot as plt import numpy as np from pathlib import Path from rdkit import Chem import logging import ast import json import click # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def load_dataset(csv_file): """ Load dataset from CSV file Args: csv_file (str): Path to the CSV file Returns: pd.DataFrame: Loaded dataset """ df = pd.read_csv(csv_file) logger.info(f"Loaded {len(df)} records from {csv_file}") # Print basic statistics logger.info(f"Statistics for {Path(csv_file).stem}:") logger.info(f"QED - Min: {df['qed'].min():.3f}, Max: {df['qed'].max():.3f}, Mean: {df['qed'].mean():.3f}") logger.info(f"Molecular Weight - Min: {df['molecular_weight'].min():.2f}, Max: {df['molecular_weight'].max():.2f}, Mean: {df['molecular_weight'].mean():.2f}") return df def load_reference_molecules(dataset_name): """ Load reference molecules from CSV file Args: dataset_name (str): Name of the dataset (fgbar or trpe) Returns: pd.DataFrame: Reference molecules with QED and molecular weight """ # Load reference molecules from the main CSV file csv_files = list(Path(".").glob(f"qed_values_{dataset_name}.csv")) if not csv_files: logger.warning(f"No CSV file found for {dataset_name}") return pd.DataFrame() df = pd.read_csv(csv_files[0]) # Filter for reference molecules (those that have align_ and _out_converted.sdf in their filename) reference_df = df[df['filename'].str.contains('align_.*_out_converted\.sdf', na=False, regex=True)] logger.info(f"Loaded {len(reference_df)} reference molecules for {dataset_name}") return reference_df def extract_vina_scores_from_sdf(sdf_file_path): """ Extract Vina scores from all conformers in an SDF file Args: sdf_file_path (str): Path to the SDF file Returns: list: List of Vina scores (free_energy values) or empty list if failed """ scores = [] try: supplier = Chem.SDMolSupplier(sdf_file_path, removeHs=False) for mol in supplier: if mol is None: continue # Get the meeko property which contains docking information if mol.HasProp("meeko"): meeko_raw = mol.GetProp("meeko") try: meeko_dict = json.loads(meeko_raw) # Extract free energy (Vina score) if 'free_energy' in meeko_dict: scores.append(meeko_dict['free_energy']) except json.JSONDecodeError: logger.warning(f"Failed to parse meeko JSON for {sdf_file_path}") else: logger.warning(f"No meeko property found in molecule from {sdf_file_path}") except Exception as e: logger.warning(f"Failed to extract Vina scores from {sdf_file_path}: {e}") return scores def load_vina_scores_from_csv(df, max_files=1000): """ Load Vina scores from the CSV file Args: df (pd.DataFrame): DataFrame with vina_scores column max_files (int): Maximum number of files to process Returns: list: List of all Vina scores from all molecules """ all_vina_scores = [] # Process only up to max_files to avoid memory issues processed_files = 0 for idx, row in df.iterrows(): if processed_files >= max_files: break # Skip reference molecules (those with mol2 extension) if '.mol2' in row['filename']: continue try: # Parse the vina_scores string back to a list vina_scores = ast.literal_eval(row['vina_scores']) all_vina_scores.extend(vina_scores) processed_files += 1 except (ValueError, SyntaxError) as e: logger.warning(f"Failed to parse Vina scores for {row['filename']}: {e}") logger.info(f"Loaded {len(all_vina_scores)} Vina scores from {processed_files} files") return all_vina_scores def get_min_vina_scores_length(df): """ Get the minimum length of vina_scores lists in the dataframe Args: df (pd.DataFrame): DataFrame with vina_scores column Returns: int: Minimum length of vina_scores lists """ min_length = float('inf') for idx, row in df.iterrows(): # Skip reference molecules (those with mol2 extension) if '.mol2' in row['filename']: continue try: # Parse the vina_scores string back to a list vina_scores = ast.literal_eval(row['vina_scores']) min_length = min(min_length, len(vina_scores)) except (ValueError, SyntaxError) as e: logger.warning(f"Failed to parse Vina scores for {row['filename']}: {e}") return min_length if min_length != float('inf') else 0 def get_reference_vina_scores(dataset_name, rank=0): """ Get Vina scores for reference molecules Args: dataset_name (str): Name of the dataset (fgbar or trpe) rank (int): Rank of the conformation to use (0 for best/first) Returns: dict: Dictionary with reference molecule identifiers and their Vina scores """ reference_scores = {} # 使用原始目录名称 "refence" reference_dir = Path("result") / "refence" / dataset_name if not reference_dir.exists(): logger.warning(f"Reference directory {reference_dir} does not exist") return reference_scores # Find reference SDF files reference_sdf_files = list(reference_dir.glob("*_converted.sdf")) logger.info(f"Processing {len(reference_sdf_files)} reference SDF files in {reference_dir}") for sdf_file in reference_sdf_files: vina_scores = extract_vina_scores_from_sdf(str(sdf_file)) if vina_scores: # Check if rank is valid if rank >= len(vina_scores): raise ValueError(f"Rank {rank} is out of range. The minimum number of conformers across all molecules is {len(vina_scores)}. Please choose a rank less than {len(vina_scores)}.") # Get the score at the specified rank reference_score = vina_scores[rank] # Extract identifier from filename filename_stem = sdf_file.stem if '_out_converted' in filename_stem: filename_stem = filename_stem.replace('_out_converted', '') if '_addH' in filename_stem: filename_stem = filename_stem.replace('_addH', '') if 'align_' in filename_stem: filename_stem = filename_stem.split('_')[-1] # Get the last part (e.g., 9NY or 0GA) # Use filename_stem as key for reference_scores reference_scores[filename_stem] = reference_score logger.info(f"Reference Vina score for {filename_stem} (rank {rank}): {reference_score}") return reference_scores def plot_combined_kde_distribution_normalized(df, dataset_name, reference_df=None, reference_scores=None, vina_scores=None, reference_vina_scores=None): """ Plot combined KDE distribution for QED, molecular weight, and Vina scores (normalized) Args: df (pd.DataFrame): Main dataset dataset_name (str): Name of the dataset (fgbar or trpe) reference_df (pd.DataFrame): Reference molecules dataset (optional) reference_scores (dict): Reference molecule scores (optional) vina_scores (list): Vina scores for all molecules (optional) reference_vina_scores (dict): Reference molecule Vina scores (optional) """ # Create figure plt.figure(figsize=(15, 8)) # Normalize the data to make them comparable on the same scale qed_normalized = (df['qed'] - df['qed'].min()) / (df['qed'].max() - df['qed'].min()) mw_normalized = (df['molecular_weight'] - df['molecular_weight'].min()) / (df['molecular_weight'].max() - df['molecular_weight'].min()) # Plot KDE for normalized QED sns.kdeplot(qed_normalized, label='QED (normalized)', fill=True, alpha=0.5, color='blue') # Plot KDE for normalized molecular weight sns.kdeplot(mw_normalized, label='Molecular Weight (normalized)', fill=True, alpha=0.5, color='red') # Plot KDE for normalized Vina scores if available if vina_scores and len(vina_scores) > 0: # Normalize Vina scores (note: lower scores are better, so we negate for visualization) vina_series = pd.Series(vina_scores) vina_normalized = (vina_series - vina_series.min()) / (vina_series.max() - vina_series.min()) sns.kdeplot(vina_normalized, label='Vina Score (normalized)', fill=True, alpha=0.5, color='green') # Mark reference molecules if provided if reference_df is not None and len(reference_df) > 0: # Normalize reference data using the same scale as main dataset ref_qed_normalized = (reference_df['qed'] - df['qed'].min()) / (df['qed'].max() - df['qed'].min()) ref_mw_normalized = (reference_df['molecular_weight'] - df['molecular_weight'].min()) / (df['molecular_weight'].max() - df['molecular_weight'].min()) # Dictionary to store reference positions for legend legend_handles = [] # Mark reference molecules for QED, MW, and Vina scores for i, (idx, row) in enumerate(reference_df.iterrows()): filename_stem = Path(row['filename']).stem # Extract actual identifier from filename if '_addH' in filename_stem: filename_stem = filename_stem.replace('_addH', '') if 'align_' in filename_stem: filename_stem = filename_stem.split('_')[-1] # Get the last part (e.g., 9NY or 0GA) # Get values from the reference dataframe qed_value = row['qed'] mw_value = row['molecular_weight'] # Build score text score_text = f"{filename_stem}\n(QED: {qed_value:.2f}, MW: {mw_value:.2f}" # Add Vina score if available if reference_vina_scores and filename_stem in reference_vina_scores: vina_score = reference_vina_scores[filename_stem] score_text += f", Vina: {vina_score:.2f}" score_text += ")" # QED marker x_pos = ref_qed_normalized.iloc[i] plt.scatter(x_pos, 0, color='darkblue', s=100, marker='v', zorder=5) # Molecular weight marker x_pos = ref_mw_normalized.iloc[i] plt.scatter(x_pos, 0, color='darkred', s=100, marker='^', zorder=5) # Vina score marker if available if reference_vina_scores and filename_stem in reference_vina_scores: # Normalize reference Vina score vina_min = min(vina_scores) vina_max = max(vina_scores) ref_vina_normalized = (reference_vina_scores[filename_stem] - vina_min) / (vina_max - vina_min) plt.scatter(ref_vina_normalized, 0, color='darkgreen', s=100, marker='o', zorder=5) # Annotate with combined information plt.annotate(score_text, (x_pos, 0), xytext=(10, 30), textcoords='offset points', bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7), arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'), fontsize=8) # Add to legend legend_handles.append(plt.Line2D([0], [0], marker='v', color='darkblue', label=f"{filename_stem} - QED", markerfacecolor='darkblue', markersize=8, linestyle='')) legend_handles.append(plt.Line2D([0], [0], marker='^', color='darkred', label=f"{filename_stem} - MW", markerfacecolor='darkred', markersize=8, linestyle='')) if reference_vina_scores and filename_stem in reference_vina_scores: legend_handles.append(plt.Line2D([0], [0], marker='o', color='darkgreen', label=f"{filename_stem} - Vina", markerfacecolor='darkgreen', markersize=8, linestyle='')) # Add combined legend plt.legend(handles=legend_handles, loc='upper right', fontsize=10) plt.title(f'Combined KDE Distribution (Normalized) - {dataset_name.upper()}', fontsize=16) plt.xlabel('Normalized Values (0-1)') plt.ylabel('Density') plt.legend() plt.grid(True, alpha=0.3) # Adjust layout and save figure plt.tight_layout() plt.savefig(f'kde_distribution_{dataset_name}_normalized.png', dpi=300, bbox_inches='tight') plt.close() logger.info(f"Saved combined KDE distribution plot (normalized) for {dataset_name} as kde_distribution_{dataset_name}_normalized.png") def plot_combined_kde_distribution_actual(df, dataset_name, reference_df=None, reference_scores=None, vina_scores=None, reference_vina_scores=None): """ Plot combined KDE distribution for QED, molecular weight, and Vina scores (actual values) Args: df (pd.DataFrame): Main dataset dataset_name (str): Name of the dataset (fgbar or trpe) reference_df (pd.DataFrame): Reference molecules dataset (optional) reference_scores (dict): Reference molecule scores (optional) vina_scores (list): Vina scores for all molecules (optional) reference_vina_scores (dict): Reference molecule Vina scores (optional) """ # Create figure with subplots fig, axes = plt.subplots(1, 3, figsize=(18, 6)) fig.suptitle(f'KDE Distribution (Actual Values) - {dataset_name.upper()}', fontsize=16) # Get reference molecule identifier (stem of the SDF filename) reference_filename_stem = None if reference_df is not None and len(reference_df) > 0: reference_filename_stem = Path(reference_df.iloc[0]['filename']).stem if '_out_converted' in reference_filename_stem: reference_filename_stem = reference_filename_stem.replace('_out_converted', '') if '_addH' in reference_filename_stem: reference_filename_stem = reference_filename_stem.replace('_addH', '') if 'align_' in reference_filename_stem: reference_filename_stem = reference_filename_stem.split('_')[-1] # Get the last part (e.g., 9NY or 0GA) # Plot 1: QED distribution sns.kdeplot(df['qed'], ax=axes[0], fill=True, alpha=0.5, color='blue') axes[0].set_title('QED Distribution') axes[0].set_xlabel('QED Value') axes[0].set_ylabel('Density') axes[0].grid(True, alpha=0.3) # Mark reference molecules for QED if reference_df is not None and len(reference_df) > 0: for i, (idx, row) in enumerate(reference_df.iterrows()): filename_stem = Path(row['filename']).stem # Extract actual identifier from filename if '_addH' in filename_stem: filename_stem = filename_stem.replace('_addH', '') if 'align_' in filename_stem: filename_stem = filename_stem.split('_')[-1] # Get the last part (e.g., 9NY or 0GA) # Get QED value from the reference dataframe qed_value = row['qed'] score_text = f"{reference_filename_stem}\n({qed_value:.2f})" axes[0].scatter(row['qed'], 0, color='darkblue', s=100, marker='v', zorder=5) axes[0].annotate(score_text, (row['qed'], 0), xytext=(10, 20), textcoords='offset points', bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7), arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'), fontsize=8) # Plot 2: Molecular weight distribution sns.kdeplot(df['molecular_weight'], ax=axes[1], fill=True, alpha=0.5, color='red') axes[1].set_title('Molecular Weight Distribution') axes[1].set_xlabel('Molecular Weight (Daltons)') axes[1].set_ylabel('Density') axes[1].grid(True, alpha=0.3) # Mark reference molecules for molecular weight if reference_df is not None and len(reference_df) > 0: for i, (idx, row) in enumerate(reference_df.iterrows()): filename_stem = Path(row['filename']).stem # Extract actual identifier from filename if '_addH' in filename_stem: filename_stem = filename_stem.replace('_addH', '') if 'align_' in filename_stem: filename_stem = filename_stem.split('_')[-1] # Get the last part (e.g., 9NY or 0GA) # Get MW value from the reference dataframe mw_value = row['molecular_weight'] score_text = f"{reference_filename_stem}\n({mw_value:.2f})" axes[1].scatter(row['molecular_weight'], 0, color='darkred', s=100, marker='^', zorder=5) axes[1].annotate(score_text, (row['molecular_weight'], 0), xytext=(10, -30), textcoords='offset points', bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7), arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'), fontsize=8) # Plot 3: Vina scores distribution if vina_scores and len(vina_scores) > 0: vina_series = pd.Series(vina_scores) sns.kdeplot(vina_series, ax=axes[2], fill=True, alpha=0.5, color='green') axes[2].set_title('Vina Score Distribution') axes[2].set_xlabel('Vina Score (kcal/mol)') axes[2].set_ylabel('Density') axes[2].grid(True, alpha=0.3) # Mark reference molecules for Vina scores if reference_vina_scores: for filename_stem, vina_score in reference_vina_scores.items(): score_text = f"{reference_filename_stem}\n({vina_score:.2f})" axes[2].scatter(vina_score, 0, color='darkgreen', s=100, marker='o', zorder=5) axes[2].annotate(score_text, (vina_score, 0), xytext=(10, -60), textcoords='offset points', bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7), arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'), fontsize=8) # Adjust layout and save figure plt.tight_layout() plt.savefig(f'kde_distribution_{dataset_name}_actual.png', dpi=300, bbox_inches='tight') plt.close() logger.info(f"Saved combined KDE distribution plot (actual values) for {dataset_name} as kde_distribution_{dataset_name}_actual.png") def analyze_dataset(csv_file, dataset_name, reference_scores=None, rank=0): """ Analyze a dataset and generate KDE plots Args: csv_file (str): Path to the CSV file dataset_name (str): Name of the dataset (fgbar or trpe) reference_scores (dict): Reference scores for each dataset rank (int): Rank of the conformation to use for reference Vina scores (0 for best/first) """ # Load main dataset df = load_dataset(csv_file) # Check minimum vina scores length min_vina_length = get_min_vina_scores_length(df) if rank >= min_vina_length: raise ValueError(f"Rank {rank} is out of range. The minimum number of conformers across all molecules is {min_vina_length}. Please choose a rank less than {min_vina_length}.") # Load reference molecules reference_df = load_reference_molecules(dataset_name) # Load Vina scores from CSV vina_scores = load_vina_scores_from_csv(df) # Get reference Vina scores reference_vina_scores = get_reference_vina_scores(dataset_name, rank) # Plot combined KDE distributions (normalized) plot_combined_kde_distribution_normalized(df, dataset_name, reference_df, reference_scores, vina_scores, reference_vina_scores) # Plot combined KDE distributions (actual values) plot_combined_kde_distribution_actual(df, dataset_name, reference_df, reference_scores, vina_scores, reference_vina_scores) def get_default_reference_scores(): """ Get default reference scores from README.md Returns: dict: Dictionary with default reference scores """ # Default reference scores from README.md return { 'fgbar': { '9NY': -5.268 # From README.md }, 'trpe': { '0GA': -6.531 # From README.md } } @click.command() @click.argument('csv_files', type=click.Path(exists=True), nargs=-1) @click.option('--dataset-names', '-d', multiple=True, help='Names of the datasets corresponding to CSV files') @click.option('--reference-scores', '-r', type=str, help='Reference scores in JSON format') @click.option('--rank', '-k', default=0, type=int, help='Rank of conformation to use for reference Vina scores (default: 0 for best/first)') def main_cli(csv_files, dataset_names, reference_scores, rank): """ Analyze QED and molecular weight distributions and generate KDE plots CSV_FILES: Paths to the CSV files with QED and molecular weight data """ if not csv_files: logger.error("At least one CSV file must be provided") return # Convert dataset names to list dataset_names_list = list(dataset_names) if dataset_names else None # Parse reference scores if provided if reference_scores: try: reference_scores_dict = json.loads(reference_scores) except json.JSONDecodeError: logger.error("Invalid JSON format for reference scores") return else: reference_scores_dict = get_default_reference_scores() # Run main analysis try: main_api(csv_files, dataset_names_list, reference_scores_dict, rank) except Exception as e: logger.error(f"Analysis failed: {e}") raise def main_api(csv_files, dataset_names=None, reference_scores=None, rank=0): """ Main function for API usage Args: csv_files (list): List of CSV files to analyze dataset_names (list): List of dataset names corresponding to CSV files reference_scores (dict): Reference scores for each dataset rank (int): Rank of the conformation to use for reference Vina scores (0 for best/first) """ if dataset_names is None: dataset_names = [Path(f).stem.replace('qed_values_', '') for f in csv_files] if reference_scores is None: reference_scores = get_default_reference_scores() for csv_file, dataset_name in zip(csv_files, dataset_names): try: logger.info(f"Analyzing dataset: {dataset_name}") analyze_dataset(csv_file, dataset_name, reference_scores.get(dataset_name, {}), rank) except Exception as e: logger.error(f"Error analyzing {dataset_name}: {e}") raise if __name__ == "__main__": main_cli()