551 lines
24 KiB
Python
551 lines
24 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
'''
|
|
@file :analyze_qed_mw_distribution.py
|
|
@Description :Analysis of QED and molecular weight distribution with KDE plots
|
|
@Date :2025/08/05
|
|
@Author :lyzeng
|
|
'''
|
|
|
|
import pandas as pd
|
|
import seaborn as sns
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from rdkit import Chem
|
|
import logging
|
|
import ast
|
|
import json
|
|
import click
|
|
|
|
# Setup logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def load_dataset(csv_file):
|
|
"""
|
|
Load dataset from CSV file
|
|
|
|
Args:
|
|
csv_file (str): Path to the CSV file
|
|
|
|
Returns:
|
|
pd.DataFrame: Loaded dataset
|
|
"""
|
|
df = pd.read_csv(csv_file)
|
|
logger.info(f"Loaded {len(df)} records from {csv_file}")
|
|
|
|
# Print basic statistics
|
|
logger.info(f"Statistics for {Path(csv_file).stem}:")
|
|
logger.info(f"QED - Min: {df['qed'].min():.3f}, Max: {df['qed'].max():.3f}, Mean: {df['qed'].mean():.3f}")
|
|
logger.info(f"Molecular Weight - Min: {df['molecular_weight'].min():.2f}, Max: {df['molecular_weight'].max():.2f}, Mean: {df['molecular_weight'].mean():.2f}")
|
|
|
|
return df
|
|
|
|
def load_reference_molecules(dataset_name):
|
|
"""
|
|
Load reference molecules from CSV file
|
|
|
|
Args:
|
|
dataset_name (str): Name of the dataset (fgbar or trpe)
|
|
|
|
Returns:
|
|
pd.DataFrame: Reference molecules with QED and molecular weight
|
|
"""
|
|
# Load reference molecules from the main CSV file
|
|
csv_files = list(Path(".").glob(f"qed_values_{dataset_name}.csv"))
|
|
if not csv_files:
|
|
logger.warning(f"No CSV file found for {dataset_name}")
|
|
return pd.DataFrame()
|
|
|
|
df = pd.read_csv(csv_files[0])
|
|
|
|
# Filter for reference molecules (those that have align_ and _out_converted.sdf in their filename)
|
|
reference_df = df[df['filename'].str.contains('align_.*_out_converted\.sdf', na=False, regex=True)]
|
|
|
|
logger.info(f"Loaded {len(reference_df)} reference molecules for {dataset_name}")
|
|
return reference_df
|
|
|
|
def extract_vina_scores_from_sdf(sdf_file_path):
|
|
"""
|
|
Extract Vina scores from all conformers in an SDF file
|
|
|
|
Args:
|
|
sdf_file_path (str): Path to the SDF file
|
|
|
|
Returns:
|
|
list: List of Vina scores (free_energy values) or empty list if failed
|
|
"""
|
|
scores = []
|
|
try:
|
|
supplier = Chem.SDMolSupplier(sdf_file_path, removeHs=False)
|
|
for mol in supplier:
|
|
if mol is None:
|
|
continue
|
|
|
|
# Get the meeko property which contains docking information
|
|
if mol.HasProp("meeko"):
|
|
meeko_raw = mol.GetProp("meeko")
|
|
try:
|
|
meeko_dict = json.loads(meeko_raw)
|
|
# Extract free energy (Vina score)
|
|
if 'free_energy' in meeko_dict:
|
|
scores.append(meeko_dict['free_energy'])
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"Failed to parse meeko JSON for {sdf_file_path}")
|
|
else:
|
|
logger.warning(f"No meeko property found in molecule from {sdf_file_path}")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to extract Vina scores from {sdf_file_path}: {e}")
|
|
|
|
return scores
|
|
|
|
def load_vina_scores_from_csv(df, max_files=1000):
|
|
"""
|
|
Load Vina scores from the CSV file
|
|
|
|
Args:
|
|
df (pd.DataFrame): DataFrame with vina_scores column
|
|
max_files (int): Maximum number of files to process
|
|
|
|
Returns:
|
|
list: List of all Vina scores from all molecules
|
|
"""
|
|
all_vina_scores = []
|
|
|
|
# Process only up to max_files to avoid memory issues
|
|
processed_files = 0
|
|
|
|
for idx, row in df.iterrows():
|
|
if processed_files >= max_files:
|
|
break
|
|
|
|
# Skip reference molecules (those with mol2 extension)
|
|
if '.mol2' in row['filename']:
|
|
continue
|
|
|
|
try:
|
|
# Parse the vina_scores string back to a list
|
|
vina_scores = ast.literal_eval(row['vina_scores'])
|
|
all_vina_scores.extend(vina_scores)
|
|
processed_files += 1
|
|
except (ValueError, SyntaxError) as e:
|
|
logger.warning(f"Failed to parse Vina scores for {row['filename']}: {e}")
|
|
|
|
logger.info(f"Loaded {len(all_vina_scores)} Vina scores from {processed_files} files")
|
|
return all_vina_scores
|
|
|
|
def get_min_vina_scores_length(df):
|
|
"""
|
|
Get the minimum length of vina_scores lists in the dataframe
|
|
|
|
Args:
|
|
df (pd.DataFrame): DataFrame with vina_scores column
|
|
|
|
Returns:
|
|
int: Minimum length of vina_scores lists
|
|
"""
|
|
min_length = float('inf')
|
|
|
|
for idx, row in df.iterrows():
|
|
# Skip reference molecules (those with mol2 extension)
|
|
if '.mol2' in row['filename']:
|
|
continue
|
|
|
|
try:
|
|
# Parse the vina_scores string back to a list
|
|
vina_scores = ast.literal_eval(row['vina_scores'])
|
|
min_length = min(min_length, len(vina_scores))
|
|
except (ValueError, SyntaxError) as e:
|
|
logger.warning(f"Failed to parse Vina scores for {row['filename']}: {e}")
|
|
|
|
return min_length if min_length != float('inf') else 0
|
|
|
|
def get_reference_vina_scores(dataset_name, rank=0):
|
|
"""
|
|
Get Vina scores for reference molecules
|
|
|
|
Args:
|
|
dataset_name (str): Name of the dataset (fgbar or trpe)
|
|
rank (int): Rank of the conformation to use (0 for best/first)
|
|
|
|
Returns:
|
|
dict: Dictionary with reference molecule identifiers and their Vina scores
|
|
"""
|
|
reference_scores = {}
|
|
|
|
# 使用原始目录名称 "refence"
|
|
reference_dir = Path("result") / "refence" / dataset_name
|
|
|
|
if not reference_dir.exists():
|
|
logger.warning(f"Reference directory {reference_dir} does not exist")
|
|
return reference_scores
|
|
|
|
# Find reference SDF files
|
|
reference_sdf_files = list(reference_dir.glob("*_converted.sdf"))
|
|
logger.info(f"Processing {len(reference_sdf_files)} reference SDF files in {reference_dir}")
|
|
|
|
for sdf_file in reference_sdf_files:
|
|
vina_scores = extract_vina_scores_from_sdf(str(sdf_file))
|
|
if vina_scores:
|
|
# Check if rank is valid
|
|
if rank >= len(vina_scores):
|
|
raise ValueError(f"Rank {rank} is out of range. The minimum number of conformers across all molecules is {len(vina_scores)}. Please choose a rank less than {len(vina_scores)}.")
|
|
|
|
# Get the score at the specified rank
|
|
reference_score = vina_scores[rank]
|
|
|
|
# Extract identifier from filename
|
|
filename_stem = sdf_file.stem
|
|
if '_out_converted' in filename_stem:
|
|
filename_stem = filename_stem.replace('_out_converted', '')
|
|
if '_addH' in filename_stem:
|
|
filename_stem = filename_stem.replace('_addH', '')
|
|
if 'align_' in filename_stem:
|
|
filename_stem = filename_stem.split('_')[-1] # Get the last part (e.g., 9NY or 0GA)
|
|
|
|
# Use filename_stem as key for reference_scores
|
|
reference_scores[filename_stem] = reference_score
|
|
logger.info(f"Reference Vina score for {filename_stem} (rank {rank}): {reference_score}")
|
|
|
|
return reference_scores
|
|
|
|
def plot_combined_kde_distribution_normalized(df, dataset_name, reference_df=None, reference_scores=None, vina_scores=None, reference_vina_scores=None):
|
|
"""
|
|
Plot combined KDE distribution for QED, molecular weight, and Vina scores (normalized)
|
|
|
|
Args:
|
|
df (pd.DataFrame): Main dataset
|
|
dataset_name (str): Name of the dataset (fgbar or trpe)
|
|
reference_df (pd.DataFrame): Reference molecules dataset (optional)
|
|
reference_scores (dict): Reference molecule scores (optional)
|
|
vina_scores (list): Vina scores for all molecules (optional)
|
|
reference_vina_scores (dict): Reference molecule Vina scores (optional)
|
|
"""
|
|
# Create figure
|
|
plt.figure(figsize=(15, 8))
|
|
|
|
# Normalize the data to make them comparable on the same scale
|
|
qed_normalized = (df['qed'] - df['qed'].min()) / (df['qed'].max() - df['qed'].min())
|
|
mw_normalized = (df['molecular_weight'] - df['molecular_weight'].min()) / (df['molecular_weight'].max() - df['molecular_weight'].min())
|
|
|
|
# Plot KDE for normalized QED
|
|
sns.kdeplot(qed_normalized, label='QED (normalized)', fill=True, alpha=0.5, color='blue')
|
|
|
|
# Plot KDE for normalized molecular weight
|
|
sns.kdeplot(mw_normalized, label='Molecular Weight (normalized)', fill=True, alpha=0.5, color='red')
|
|
|
|
# Plot KDE for normalized Vina scores if available
|
|
if vina_scores and len(vina_scores) > 0:
|
|
# Normalize Vina scores (note: lower scores are better, so we negate for visualization)
|
|
vina_series = pd.Series(vina_scores)
|
|
vina_normalized = (vina_series - vina_series.min()) / (vina_series.max() - vina_series.min())
|
|
sns.kdeplot(vina_normalized, label='Vina Score (normalized)', fill=True, alpha=0.5, color='green')
|
|
|
|
# Mark reference molecules if provided
|
|
if reference_df is not None and len(reference_df) > 0:
|
|
# Normalize reference data using the same scale as main dataset
|
|
ref_qed_normalized = (reference_df['qed'] - df['qed'].min()) / (df['qed'].max() - df['qed'].min())
|
|
ref_mw_normalized = (reference_df['molecular_weight'] - df['molecular_weight'].min()) / (df['molecular_weight'].max() - df['molecular_weight'].min())
|
|
|
|
# Dictionary to store reference positions for legend
|
|
legend_handles = []
|
|
|
|
# Mark reference molecules for QED, MW, and Vina scores
|
|
for i, (idx, row) in enumerate(reference_df.iterrows()):
|
|
filename_stem = Path(row['filename']).stem
|
|
# Extract actual identifier from filename
|
|
if '_addH' in filename_stem:
|
|
filename_stem = filename_stem.replace('_addH', '')
|
|
if 'align_' in filename_stem:
|
|
filename_stem = filename_stem.split('_')[-1] # Get the last part (e.g., 9NY or 0GA)
|
|
|
|
# Get values from the reference dataframe
|
|
qed_value = row['qed']
|
|
mw_value = row['molecular_weight']
|
|
|
|
# Build score text
|
|
score_text = f"{filename_stem}\n(QED: {qed_value:.2f}, MW: {mw_value:.2f}"
|
|
|
|
# Add Vina score if available
|
|
if reference_vina_scores and filename_stem in reference_vina_scores:
|
|
vina_score = reference_vina_scores[filename_stem]
|
|
score_text += f", Vina: {vina_score:.2f}"
|
|
score_text += ")"
|
|
|
|
# QED marker
|
|
x_pos = ref_qed_normalized.iloc[i]
|
|
plt.scatter(x_pos, 0, color='darkblue', s=100, marker='v', zorder=5)
|
|
|
|
# Molecular weight marker
|
|
x_pos = ref_mw_normalized.iloc[i]
|
|
plt.scatter(x_pos, 0, color='darkred', s=100, marker='^', zorder=5)
|
|
|
|
# Vina score marker if available
|
|
if reference_vina_scores and filename_stem in reference_vina_scores:
|
|
# Normalize reference Vina score
|
|
vina_min = min(vina_scores)
|
|
vina_max = max(vina_scores)
|
|
ref_vina_normalized = (reference_vina_scores[filename_stem] - vina_min) / (vina_max - vina_min)
|
|
plt.scatter(ref_vina_normalized, 0, color='darkgreen', s=100, marker='o', zorder=5)
|
|
|
|
# Annotate with combined information
|
|
plt.annotate(score_text,
|
|
(x_pos, 0),
|
|
xytext=(10, 30),
|
|
textcoords='offset points',
|
|
bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7),
|
|
arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'),
|
|
fontsize=8)
|
|
|
|
# Add to legend
|
|
legend_handles.append(plt.Line2D([0], [0], marker='v', color='darkblue', label=f"{filename_stem} - QED",
|
|
markerfacecolor='darkblue', markersize=8, linestyle=''))
|
|
legend_handles.append(plt.Line2D([0], [0], marker='^', color='darkred', label=f"{filename_stem} - MW",
|
|
markerfacecolor='darkred', markersize=8, linestyle=''))
|
|
if reference_vina_scores and filename_stem in reference_vina_scores:
|
|
legend_handles.append(plt.Line2D([0], [0], marker='o', color='darkgreen', label=f"{filename_stem} - Vina",
|
|
markerfacecolor='darkgreen', markersize=8, linestyle=''))
|
|
|
|
# Add combined legend
|
|
plt.legend(handles=legend_handles, loc='upper right', fontsize=10)
|
|
|
|
plt.title(f'Combined KDE Distribution (Normalized) - {dataset_name.upper()}', fontsize=16)
|
|
plt.xlabel('Normalized Values (0-1)')
|
|
plt.ylabel('Density')
|
|
plt.legend()
|
|
plt.grid(True, alpha=0.3)
|
|
|
|
# Adjust layout and save figure
|
|
plt.tight_layout()
|
|
plt.savefig(f'kde_distribution_{dataset_name}_normalized.png', dpi=300, bbox_inches='tight')
|
|
plt.close()
|
|
logger.info(f"Saved combined KDE distribution plot (normalized) for {dataset_name} as kde_distribution_{dataset_name}_normalized.png")
|
|
|
|
def plot_combined_kde_distribution_actual(df, dataset_name, reference_df=None, reference_scores=None, vina_scores=None, reference_vina_scores=None):
|
|
"""
|
|
Plot combined KDE distribution for QED, molecular weight, and Vina scores (actual values)
|
|
|
|
Args:
|
|
df (pd.DataFrame): Main dataset
|
|
dataset_name (str): Name of the dataset (fgbar or trpe)
|
|
reference_df (pd.DataFrame): Reference molecules dataset (optional)
|
|
reference_scores (dict): Reference molecule scores (optional)
|
|
vina_scores (list): Vina scores for all molecules (optional)
|
|
reference_vina_scores (dict): Reference molecule Vina scores (optional)
|
|
"""
|
|
# Create figure with subplots
|
|
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
|
|
fig.suptitle(f'KDE Distribution (Actual Values) - {dataset_name.upper()}', fontsize=16)
|
|
|
|
# Get reference molecule identifier (stem of the SDF filename)
|
|
reference_filename_stem = None
|
|
if reference_df is not None and len(reference_df) > 0:
|
|
reference_filename_stem = Path(reference_df.iloc[0]['filename']).stem
|
|
if '_out_converted' in reference_filename_stem:
|
|
reference_filename_stem = reference_filename_stem.replace('_out_converted', '')
|
|
if '_addH' in reference_filename_stem:
|
|
reference_filename_stem = reference_filename_stem.replace('_addH', '')
|
|
if 'align_' in reference_filename_stem:
|
|
reference_filename_stem = reference_filename_stem.split('_')[-1] # Get the last part (e.g., 9NY or 0GA)
|
|
|
|
# Plot 1: QED distribution
|
|
sns.kdeplot(df['qed'], ax=axes[0], fill=True, alpha=0.5, color='blue')
|
|
axes[0].set_title('QED Distribution')
|
|
axes[0].set_xlabel('QED Value')
|
|
axes[0].set_ylabel('Density')
|
|
axes[0].grid(True, alpha=0.3)
|
|
|
|
# Mark reference molecules for QED
|
|
if reference_df is not None and len(reference_df) > 0:
|
|
for i, (idx, row) in enumerate(reference_df.iterrows()):
|
|
filename_stem = Path(row['filename']).stem
|
|
# Extract actual identifier from filename
|
|
if '_addH' in filename_stem:
|
|
filename_stem = filename_stem.replace('_addH', '')
|
|
if 'align_' in filename_stem:
|
|
filename_stem = filename_stem.split('_')[-1] # Get the last part (e.g., 9NY or 0GA)
|
|
|
|
# Get QED value from the reference dataframe
|
|
qed_value = row['qed']
|
|
score_text = f"{reference_filename_stem}\n({qed_value:.2f})"
|
|
|
|
axes[0].scatter(row['qed'], 0, color='darkblue', s=100, marker='v', zorder=5)
|
|
axes[0].annotate(score_text,
|
|
(row['qed'], 0),
|
|
xytext=(10, 20),
|
|
textcoords='offset points',
|
|
bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7),
|
|
arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'),
|
|
fontsize=8)
|
|
|
|
# Plot 2: Molecular weight distribution
|
|
sns.kdeplot(df['molecular_weight'], ax=axes[1], fill=True, alpha=0.5, color='red')
|
|
axes[1].set_title('Molecular Weight Distribution')
|
|
axes[1].set_xlabel('Molecular Weight (Daltons)')
|
|
axes[1].set_ylabel('Density')
|
|
axes[1].grid(True, alpha=0.3)
|
|
|
|
# Mark reference molecules for molecular weight
|
|
if reference_df is not None and len(reference_df) > 0:
|
|
for i, (idx, row) in enumerate(reference_df.iterrows()):
|
|
filename_stem = Path(row['filename']).stem
|
|
# Extract actual identifier from filename
|
|
if '_addH' in filename_stem:
|
|
filename_stem = filename_stem.replace('_addH', '')
|
|
if 'align_' in filename_stem:
|
|
filename_stem = filename_stem.split('_')[-1] # Get the last part (e.g., 9NY or 0GA)
|
|
|
|
# Get MW value from the reference dataframe
|
|
mw_value = row['molecular_weight']
|
|
score_text = f"{reference_filename_stem}\n({mw_value:.2f})"
|
|
|
|
axes[1].scatter(row['molecular_weight'], 0, color='darkred', s=100, marker='^', zorder=5)
|
|
axes[1].annotate(score_text,
|
|
(row['molecular_weight'], 0),
|
|
xytext=(10, -30),
|
|
textcoords='offset points',
|
|
bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7),
|
|
arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'),
|
|
fontsize=8)
|
|
|
|
# Plot 3: Vina scores distribution
|
|
if vina_scores and len(vina_scores) > 0:
|
|
vina_series = pd.Series(vina_scores)
|
|
sns.kdeplot(vina_series, ax=axes[2], fill=True, alpha=0.5, color='green')
|
|
axes[2].set_title('Vina Score Distribution')
|
|
axes[2].set_xlabel('Vina Score (kcal/mol)')
|
|
axes[2].set_ylabel('Density')
|
|
axes[2].grid(True, alpha=0.3)
|
|
|
|
# Mark reference molecules for Vina scores
|
|
if reference_vina_scores:
|
|
for filename_stem, vina_score in reference_vina_scores.items():
|
|
score_text = f"{reference_filename_stem}\n({vina_score:.2f})"
|
|
|
|
axes[2].scatter(vina_score, 0, color='darkgreen', s=100, marker='o', zorder=5)
|
|
axes[2].annotate(score_text,
|
|
(vina_score, 0),
|
|
xytext=(10, -60),
|
|
textcoords='offset points',
|
|
bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7),
|
|
arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'),
|
|
fontsize=8)
|
|
|
|
# Adjust layout and save figure
|
|
plt.tight_layout()
|
|
plt.savefig(f'kde_distribution_{dataset_name}_actual.png', dpi=300, bbox_inches='tight')
|
|
plt.close()
|
|
logger.info(f"Saved combined KDE distribution plot (actual values) for {dataset_name} as kde_distribution_{dataset_name}_actual.png")
|
|
|
|
def analyze_dataset(csv_file, dataset_name, reference_scores=None, rank=0):
|
|
"""
|
|
Analyze a dataset and generate KDE plots
|
|
|
|
Args:
|
|
csv_file (str): Path to the CSV file
|
|
dataset_name (str): Name of the dataset (fgbar or trpe)
|
|
reference_scores (dict): Reference scores for each dataset
|
|
rank (int): Rank of the conformation to use for reference Vina scores (0 for best/first)
|
|
"""
|
|
# Load main dataset
|
|
df = load_dataset(csv_file)
|
|
|
|
# Check minimum vina scores length
|
|
min_vina_length = get_min_vina_scores_length(df)
|
|
if rank >= min_vina_length:
|
|
raise ValueError(f"Rank {rank} is out of range. The minimum number of conformers across all molecules is {min_vina_length}. Please choose a rank less than {min_vina_length}.")
|
|
|
|
# Load reference molecules
|
|
reference_df = load_reference_molecules(dataset_name)
|
|
|
|
# Load Vina scores from CSV
|
|
vina_scores = load_vina_scores_from_csv(df)
|
|
|
|
# Get reference Vina scores
|
|
reference_vina_scores = get_reference_vina_scores(dataset_name, rank)
|
|
|
|
# Plot combined KDE distributions (normalized)
|
|
plot_combined_kde_distribution_normalized(df, dataset_name, reference_df, reference_scores, vina_scores, reference_vina_scores)
|
|
|
|
# Plot combined KDE distributions (actual values)
|
|
plot_combined_kde_distribution_actual(df, dataset_name, reference_df, reference_scores, vina_scores, reference_vina_scores)
|
|
|
|
def get_default_reference_scores():
|
|
"""
|
|
Get default reference scores from README.md
|
|
|
|
Returns:
|
|
dict: Dictionary with default reference scores
|
|
"""
|
|
# Default reference scores from README.md
|
|
return {
|
|
'fgbar': {
|
|
'9NY': -5.268 # From README.md
|
|
},
|
|
'trpe': {
|
|
'0GA': -6.531 # From README.md
|
|
}
|
|
}
|
|
|
|
@click.command()
|
|
@click.argument('csv_files', type=click.Path(exists=True), nargs=-1)
|
|
@click.option('--dataset-names', '-d', multiple=True, help='Names of the datasets corresponding to CSV files')
|
|
@click.option('--reference-scores', '-r', type=str, help='Reference scores in JSON format')
|
|
@click.option('--rank', '-k', default=0, type=int, help='Rank of conformation to use for reference Vina scores (default: 0 for best/first)')
|
|
def main_cli(csv_files, dataset_names, reference_scores, rank):
|
|
"""
|
|
Analyze QED and molecular weight distributions and generate KDE plots
|
|
|
|
CSV_FILES: Paths to the CSV files with QED and molecular weight data
|
|
"""
|
|
if not csv_files:
|
|
logger.error("At least one CSV file must be provided")
|
|
return
|
|
|
|
# Convert dataset names to list
|
|
dataset_names_list = list(dataset_names) if dataset_names else None
|
|
|
|
# Parse reference scores if provided
|
|
if reference_scores:
|
|
try:
|
|
reference_scores_dict = json.loads(reference_scores)
|
|
except json.JSONDecodeError:
|
|
logger.error("Invalid JSON format for reference scores")
|
|
return
|
|
else:
|
|
reference_scores_dict = get_default_reference_scores()
|
|
|
|
# Run main analysis
|
|
try:
|
|
main_api(csv_files, dataset_names_list, reference_scores_dict, rank)
|
|
except Exception as e:
|
|
logger.error(f"Analysis failed: {e}")
|
|
raise
|
|
|
|
def main_api(csv_files, dataset_names=None, reference_scores=None, rank=0):
|
|
"""
|
|
Main function for API usage
|
|
|
|
Args:
|
|
csv_files (list): List of CSV files to analyze
|
|
dataset_names (list): List of dataset names corresponding to CSV files
|
|
reference_scores (dict): Reference scores for each dataset
|
|
rank (int): Rank of the conformation to use for reference Vina scores (0 for best/first)
|
|
"""
|
|
if dataset_names is None:
|
|
dataset_names = [Path(f).stem.replace('qed_values_', '') for f in csv_files]
|
|
|
|
if reference_scores is None:
|
|
reference_scores = get_default_reference_scores()
|
|
|
|
for csv_file, dataset_name in zip(csv_files, dataset_names):
|
|
try:
|
|
logger.info(f"Analyzing dataset: {dataset_name}")
|
|
analyze_dataset(csv_file, dataset_name, reference_scores.get(dataset_name, {}), rank)
|
|
except Exception as e:
|
|
logger.error(f"Error analyzing {dataset_name}: {e}")
|
|
raise
|
|
|
|
if __name__ == "__main__":
|
|
main_cli() |