Files
vina_docking_batch/scripts/analyze_qed_mw_distribution.py
2025-08-05 20:37:33 +08:00

551 lines
24 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@file :analyze_qed_mw_distribution.py
@Description :Analysis of QED and molecular weight distribution with KDE plots
@Date :2025/08/05
@Author :lyzeng
'''
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from rdkit import Chem
import logging
import ast
import json
import click
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def load_dataset(csv_file):
"""
Load dataset from CSV file
Args:
csv_file (str): Path to the CSV file
Returns:
pd.DataFrame: Loaded dataset
"""
df = pd.read_csv(csv_file)
logger.info(f"Loaded {len(df)} records from {csv_file}")
# Print basic statistics
logger.info(f"Statistics for {Path(csv_file).stem}:")
logger.info(f"QED - Min: {df['qed'].min():.3f}, Max: {df['qed'].max():.3f}, Mean: {df['qed'].mean():.3f}")
logger.info(f"Molecular Weight - Min: {df['molecular_weight'].min():.2f}, Max: {df['molecular_weight'].max():.2f}, Mean: {df['molecular_weight'].mean():.2f}")
return df
def load_reference_molecules(dataset_name):
"""
Load reference molecules from CSV file
Args:
dataset_name (str): Name of the dataset (fgbar or trpe)
Returns:
pd.DataFrame: Reference molecules with QED and molecular weight
"""
# Load reference molecules from the main CSV file
csv_files = list(Path(".").glob(f"qed_values_{dataset_name}.csv"))
if not csv_files:
logger.warning(f"No CSV file found for {dataset_name}")
return pd.DataFrame()
df = pd.read_csv(csv_files[0])
# Filter for reference molecules (those that have align_ and _out_converted.sdf in their filename)
reference_df = df[df['filename'].str.contains('align_.*_out_converted\.sdf', na=False, regex=True)]
logger.info(f"Loaded {len(reference_df)} reference molecules for {dataset_name}")
return reference_df
def extract_vina_scores_from_sdf(sdf_file_path):
"""
Extract Vina scores from all conformers in an SDF file
Args:
sdf_file_path (str): Path to the SDF file
Returns:
list: List of Vina scores (free_energy values) or empty list if failed
"""
scores = []
try:
supplier = Chem.SDMolSupplier(sdf_file_path, removeHs=False)
for mol in supplier:
if mol is None:
continue
# Get the meeko property which contains docking information
if mol.HasProp("meeko"):
meeko_raw = mol.GetProp("meeko")
try:
meeko_dict = json.loads(meeko_raw)
# Extract free energy (Vina score)
if 'free_energy' in meeko_dict:
scores.append(meeko_dict['free_energy'])
except json.JSONDecodeError:
logger.warning(f"Failed to parse meeko JSON for {sdf_file_path}")
else:
logger.warning(f"No meeko property found in molecule from {sdf_file_path}")
except Exception as e:
logger.warning(f"Failed to extract Vina scores from {sdf_file_path}: {e}")
return scores
def load_vina_scores_from_csv(df, max_files=1000):
"""
Load Vina scores from the CSV file
Args:
df (pd.DataFrame): DataFrame with vina_scores column
max_files (int): Maximum number of files to process
Returns:
list: List of all Vina scores from all molecules
"""
all_vina_scores = []
# Process only up to max_files to avoid memory issues
processed_files = 0
for idx, row in df.iterrows():
if processed_files >= max_files:
break
# Skip reference molecules (those with mol2 extension)
if '.mol2' in row['filename']:
continue
try:
# Parse the vina_scores string back to a list
vina_scores = ast.literal_eval(row['vina_scores'])
all_vina_scores.extend(vina_scores)
processed_files += 1
except (ValueError, SyntaxError) as e:
logger.warning(f"Failed to parse Vina scores for {row['filename']}: {e}")
logger.info(f"Loaded {len(all_vina_scores)} Vina scores from {processed_files} files")
return all_vina_scores
def get_min_vina_scores_length(df):
"""
Get the minimum length of vina_scores lists in the dataframe
Args:
df (pd.DataFrame): DataFrame with vina_scores column
Returns:
int: Minimum length of vina_scores lists
"""
min_length = float('inf')
for idx, row in df.iterrows():
# Skip reference molecules (those with mol2 extension)
if '.mol2' in row['filename']:
continue
try:
# Parse the vina_scores string back to a list
vina_scores = ast.literal_eval(row['vina_scores'])
min_length = min(min_length, len(vina_scores))
except (ValueError, SyntaxError) as e:
logger.warning(f"Failed to parse Vina scores for {row['filename']}: {e}")
return min_length if min_length != float('inf') else 0
def get_reference_vina_scores(dataset_name, rank=0):
"""
Get Vina scores for reference molecules
Args:
dataset_name (str): Name of the dataset (fgbar or trpe)
rank (int): Rank of the conformation to use (0 for best/first)
Returns:
dict: Dictionary with reference molecule identifiers and their Vina scores
"""
reference_scores = {}
# 使用原始目录名称 "refence"
reference_dir = Path("result") / "refence" / dataset_name
if not reference_dir.exists():
logger.warning(f"Reference directory {reference_dir} does not exist")
return reference_scores
# Find reference SDF files
reference_sdf_files = list(reference_dir.glob("*_converted.sdf"))
logger.info(f"Processing {len(reference_sdf_files)} reference SDF files in {reference_dir}")
for sdf_file in reference_sdf_files:
vina_scores = extract_vina_scores_from_sdf(str(sdf_file))
if vina_scores:
# Check if rank is valid
if rank >= len(vina_scores):
raise ValueError(f"Rank {rank} is out of range. The minimum number of conformers across all molecules is {len(vina_scores)}. Please choose a rank less than {len(vina_scores)}.")
# Get the score at the specified rank
reference_score = vina_scores[rank]
# Extract identifier from filename
filename_stem = sdf_file.stem
if '_out_converted' in filename_stem:
filename_stem = filename_stem.replace('_out_converted', '')
if '_addH' in filename_stem:
filename_stem = filename_stem.replace('_addH', '')
if 'align_' in filename_stem:
filename_stem = filename_stem.split('_')[-1] # Get the last part (e.g., 9NY or 0GA)
# Use filename_stem as key for reference_scores
reference_scores[filename_stem] = reference_score
logger.info(f"Reference Vina score for {filename_stem} (rank {rank}): {reference_score}")
return reference_scores
def plot_combined_kde_distribution_normalized(df, dataset_name, reference_df=None, reference_scores=None, vina_scores=None, reference_vina_scores=None):
"""
Plot combined KDE distribution for QED, molecular weight, and Vina scores (normalized)
Args:
df (pd.DataFrame): Main dataset
dataset_name (str): Name of the dataset (fgbar or trpe)
reference_df (pd.DataFrame): Reference molecules dataset (optional)
reference_scores (dict): Reference molecule scores (optional)
vina_scores (list): Vina scores for all molecules (optional)
reference_vina_scores (dict): Reference molecule Vina scores (optional)
"""
# Create figure
plt.figure(figsize=(15, 8))
# Normalize the data to make them comparable on the same scale
qed_normalized = (df['qed'] - df['qed'].min()) / (df['qed'].max() - df['qed'].min())
mw_normalized = (df['molecular_weight'] - df['molecular_weight'].min()) / (df['molecular_weight'].max() - df['molecular_weight'].min())
# Plot KDE for normalized QED
sns.kdeplot(qed_normalized, label='QED (normalized)', fill=True, alpha=0.5, color='blue')
# Plot KDE for normalized molecular weight
sns.kdeplot(mw_normalized, label='Molecular Weight (normalized)', fill=True, alpha=0.5, color='red')
# Plot KDE for normalized Vina scores if available
if vina_scores and len(vina_scores) > 0:
# Normalize Vina scores (note: lower scores are better, so we negate for visualization)
vina_series = pd.Series(vina_scores)
vina_normalized = (vina_series - vina_series.min()) / (vina_series.max() - vina_series.min())
sns.kdeplot(vina_normalized, label='Vina Score (normalized)', fill=True, alpha=0.5, color='green')
# Mark reference molecules if provided
if reference_df is not None and len(reference_df) > 0:
# Normalize reference data using the same scale as main dataset
ref_qed_normalized = (reference_df['qed'] - df['qed'].min()) / (df['qed'].max() - df['qed'].min())
ref_mw_normalized = (reference_df['molecular_weight'] - df['molecular_weight'].min()) / (df['molecular_weight'].max() - df['molecular_weight'].min())
# Dictionary to store reference positions for legend
legend_handles = []
# Mark reference molecules for QED, MW, and Vina scores
for i, (idx, row) in enumerate(reference_df.iterrows()):
filename_stem = Path(row['filename']).stem
# Extract actual identifier from filename
if '_addH' in filename_stem:
filename_stem = filename_stem.replace('_addH', '')
if 'align_' in filename_stem:
filename_stem = filename_stem.split('_')[-1] # Get the last part (e.g., 9NY or 0GA)
# Get values from the reference dataframe
qed_value = row['qed']
mw_value = row['molecular_weight']
# Build score text
score_text = f"{filename_stem}\n(QED: {qed_value:.2f}, MW: {mw_value:.2f}"
# Add Vina score if available
if reference_vina_scores and filename_stem in reference_vina_scores:
vina_score = reference_vina_scores[filename_stem]
score_text += f", Vina: {vina_score:.2f}"
score_text += ")"
# QED marker
x_pos = ref_qed_normalized.iloc[i]
plt.scatter(x_pos, 0, color='darkblue', s=100, marker='v', zorder=5)
# Molecular weight marker
x_pos = ref_mw_normalized.iloc[i]
plt.scatter(x_pos, 0, color='darkred', s=100, marker='^', zorder=5)
# Vina score marker if available
if reference_vina_scores and filename_stem in reference_vina_scores:
# Normalize reference Vina score
vina_min = min(vina_scores)
vina_max = max(vina_scores)
ref_vina_normalized = (reference_vina_scores[filename_stem] - vina_min) / (vina_max - vina_min)
plt.scatter(ref_vina_normalized, 0, color='darkgreen', s=100, marker='o', zorder=5)
# Annotate with combined information
plt.annotate(score_text,
(x_pos, 0),
xytext=(10, 30),
textcoords='offset points',
bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7),
arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'),
fontsize=8)
# Add to legend
legend_handles.append(plt.Line2D([0], [0], marker='v', color='darkblue', label=f"{filename_stem} - QED",
markerfacecolor='darkblue', markersize=8, linestyle=''))
legend_handles.append(plt.Line2D([0], [0], marker='^', color='darkred', label=f"{filename_stem} - MW",
markerfacecolor='darkred', markersize=8, linestyle=''))
if reference_vina_scores and filename_stem in reference_vina_scores:
legend_handles.append(plt.Line2D([0], [0], marker='o', color='darkgreen', label=f"{filename_stem} - Vina",
markerfacecolor='darkgreen', markersize=8, linestyle=''))
# Add combined legend
plt.legend(handles=legend_handles, loc='upper right', fontsize=10)
plt.title(f'Combined KDE Distribution (Normalized) - {dataset_name.upper()}', fontsize=16)
plt.xlabel('Normalized Values (0-1)')
plt.ylabel('Density')
plt.legend()
plt.grid(True, alpha=0.3)
# Adjust layout and save figure
plt.tight_layout()
plt.savefig(f'kde_distribution_{dataset_name}_normalized.png', dpi=300, bbox_inches='tight')
plt.close()
logger.info(f"Saved combined KDE distribution plot (normalized) for {dataset_name} as kde_distribution_{dataset_name}_normalized.png")
def plot_combined_kde_distribution_actual(df, dataset_name, reference_df=None, reference_scores=None, vina_scores=None, reference_vina_scores=None):
"""
Plot combined KDE distribution for QED, molecular weight, and Vina scores (actual values)
Args:
df (pd.DataFrame): Main dataset
dataset_name (str): Name of the dataset (fgbar or trpe)
reference_df (pd.DataFrame): Reference molecules dataset (optional)
reference_scores (dict): Reference molecule scores (optional)
vina_scores (list): Vina scores for all molecules (optional)
reference_vina_scores (dict): Reference molecule Vina scores (optional)
"""
# Create figure with subplots
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
fig.suptitle(f'KDE Distribution (Actual Values) - {dataset_name.upper()}', fontsize=16)
# Get reference molecule identifier (stem of the SDF filename)
reference_filename_stem = None
if reference_df is not None and len(reference_df) > 0:
reference_filename_stem = Path(reference_df.iloc[0]['filename']).stem
if '_out_converted' in reference_filename_stem:
reference_filename_stem = reference_filename_stem.replace('_out_converted', '')
if '_addH' in reference_filename_stem:
reference_filename_stem = reference_filename_stem.replace('_addH', '')
if 'align_' in reference_filename_stem:
reference_filename_stem = reference_filename_stem.split('_')[-1] # Get the last part (e.g., 9NY or 0GA)
# Plot 1: QED distribution
sns.kdeplot(df['qed'], ax=axes[0], fill=True, alpha=0.5, color='blue')
axes[0].set_title('QED Distribution')
axes[0].set_xlabel('QED Value')
axes[0].set_ylabel('Density')
axes[0].grid(True, alpha=0.3)
# Mark reference molecules for QED
if reference_df is not None and len(reference_df) > 0:
for i, (idx, row) in enumerate(reference_df.iterrows()):
filename_stem = Path(row['filename']).stem
# Extract actual identifier from filename
if '_addH' in filename_stem:
filename_stem = filename_stem.replace('_addH', '')
if 'align_' in filename_stem:
filename_stem = filename_stem.split('_')[-1] # Get the last part (e.g., 9NY or 0GA)
# Get QED value from the reference dataframe
qed_value = row['qed']
score_text = f"{reference_filename_stem}\n({qed_value:.2f})"
axes[0].scatter(row['qed'], 0, color='darkblue', s=100, marker='v', zorder=5)
axes[0].annotate(score_text,
(row['qed'], 0),
xytext=(10, 20),
textcoords='offset points',
bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7),
arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'),
fontsize=8)
# Plot 2: Molecular weight distribution
sns.kdeplot(df['molecular_weight'], ax=axes[1], fill=True, alpha=0.5, color='red')
axes[1].set_title('Molecular Weight Distribution')
axes[1].set_xlabel('Molecular Weight (Daltons)')
axes[1].set_ylabel('Density')
axes[1].grid(True, alpha=0.3)
# Mark reference molecules for molecular weight
if reference_df is not None and len(reference_df) > 0:
for i, (idx, row) in enumerate(reference_df.iterrows()):
filename_stem = Path(row['filename']).stem
# Extract actual identifier from filename
if '_addH' in filename_stem:
filename_stem = filename_stem.replace('_addH', '')
if 'align_' in filename_stem:
filename_stem = filename_stem.split('_')[-1] # Get the last part (e.g., 9NY or 0GA)
# Get MW value from the reference dataframe
mw_value = row['molecular_weight']
score_text = f"{reference_filename_stem}\n({mw_value:.2f})"
axes[1].scatter(row['molecular_weight'], 0, color='darkred', s=100, marker='^', zorder=5)
axes[1].annotate(score_text,
(row['molecular_weight'], 0),
xytext=(10, -30),
textcoords='offset points',
bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7),
arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'),
fontsize=8)
# Plot 3: Vina scores distribution
if vina_scores and len(vina_scores) > 0:
vina_series = pd.Series(vina_scores)
sns.kdeplot(vina_series, ax=axes[2], fill=True, alpha=0.5, color='green')
axes[2].set_title('Vina Score Distribution')
axes[2].set_xlabel('Vina Score (kcal/mol)')
axes[2].set_ylabel('Density')
axes[2].grid(True, alpha=0.3)
# Mark reference molecules for Vina scores
if reference_vina_scores:
for filename_stem, vina_score in reference_vina_scores.items():
score_text = f"{reference_filename_stem}\n({vina_score:.2f})"
axes[2].scatter(vina_score, 0, color='darkgreen', s=100, marker='o', zorder=5)
axes[2].annotate(score_text,
(vina_score, 0),
xytext=(10, -60),
textcoords='offset points',
bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7),
arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'),
fontsize=8)
# Adjust layout and save figure
plt.tight_layout()
plt.savefig(f'kde_distribution_{dataset_name}_actual.png', dpi=300, bbox_inches='tight')
plt.close()
logger.info(f"Saved combined KDE distribution plot (actual values) for {dataset_name} as kde_distribution_{dataset_name}_actual.png")
def analyze_dataset(csv_file, dataset_name, reference_scores=None, rank=0):
"""
Analyze a dataset and generate KDE plots
Args:
csv_file (str): Path to the CSV file
dataset_name (str): Name of the dataset (fgbar or trpe)
reference_scores (dict): Reference scores for each dataset
rank (int): Rank of the conformation to use for reference Vina scores (0 for best/first)
"""
# Load main dataset
df = load_dataset(csv_file)
# Check minimum vina scores length
min_vina_length = get_min_vina_scores_length(df)
if rank >= min_vina_length:
raise ValueError(f"Rank {rank} is out of range. The minimum number of conformers across all molecules is {min_vina_length}. Please choose a rank less than {min_vina_length}.")
# Load reference molecules
reference_df = load_reference_molecules(dataset_name)
# Load Vina scores from CSV
vina_scores = load_vina_scores_from_csv(df)
# Get reference Vina scores
reference_vina_scores = get_reference_vina_scores(dataset_name, rank)
# Plot combined KDE distributions (normalized)
plot_combined_kde_distribution_normalized(df, dataset_name, reference_df, reference_scores, vina_scores, reference_vina_scores)
# Plot combined KDE distributions (actual values)
plot_combined_kde_distribution_actual(df, dataset_name, reference_df, reference_scores, vina_scores, reference_vina_scores)
def get_default_reference_scores():
"""
Get default reference scores from README.md
Returns:
dict: Dictionary with default reference scores
"""
# Default reference scores from README.md
return {
'fgbar': {
'9NY': -5.268 # From README.md
},
'trpe': {
'0GA': -6.531 # From README.md
}
}
@click.command()
@click.argument('csv_files', type=click.Path(exists=True), nargs=-1)
@click.option('--dataset-names', '-d', multiple=True, help='Names of the datasets corresponding to CSV files')
@click.option('--reference-scores', '-r', type=str, help='Reference scores in JSON format')
@click.option('--rank', '-k', default=0, type=int, help='Rank of conformation to use for reference Vina scores (default: 0 for best/first)')
def main_cli(csv_files, dataset_names, reference_scores, rank):
"""
Analyze QED and molecular weight distributions and generate KDE plots
CSV_FILES: Paths to the CSV files with QED and molecular weight data
"""
if not csv_files:
logger.error("At least one CSV file must be provided")
return
# Convert dataset names to list
dataset_names_list = list(dataset_names) if dataset_names else None
# Parse reference scores if provided
if reference_scores:
try:
reference_scores_dict = json.loads(reference_scores)
except json.JSONDecodeError:
logger.error("Invalid JSON format for reference scores")
return
else:
reference_scores_dict = get_default_reference_scores()
# Run main analysis
try:
main_api(csv_files, dataset_names_list, reference_scores_dict, rank)
except Exception as e:
logger.error(f"Analysis failed: {e}")
raise
def main_api(csv_files, dataset_names=None, reference_scores=None, rank=0):
"""
Main function for API usage
Args:
csv_files (list): List of CSV files to analyze
dataset_names (list): List of dataset names corresponding to CSV files
reference_scores (dict): Reference scores for each dataset
rank (int): Rank of the conformation to use for reference Vina scores (0 for best/first)
"""
if dataset_names is None:
dataset_names = [Path(f).stem.replace('qed_values_', '') for f in csv_files]
if reference_scores is None:
reference_scores = get_default_reference_scores()
for csv_file, dataset_name in zip(csv_files, dataset_names):
try:
logger.info(f"Analyzing dataset: {dataset_name}")
analyze_dataset(csv_file, dataset_name, reference_scores.get(dataset_name, {}), rank)
except Exception as e:
logger.error(f"Error analyzing {dataset_name}: {e}")
raise
if __name__ == "__main__":
main_cli()