79 lines
2.6 KiB
Python
79 lines
2.6 KiB
Python
import joblib
|
|
import numpy as np
|
|
from rdkit import Chem
|
|
from rdkit.Chem import AllChem
|
|
from pathlib import Path
|
|
from pprint import pprint
|
|
# Function to calculate 2D-QSAR descriptors using Morgan Fingerprints
|
|
def calculate_2dqsar_repr(smiles):
|
|
mol = Chem.MolFromSmiles(smiles)
|
|
if mol is None:
|
|
return None
|
|
# Calculate Morgan fingerprint with radius 3 and 1024 bits
|
|
fp = AllChem.GetMorganFingerprintAsBitVect(mol, 3, nBits=1024)
|
|
return np.array(fp)
|
|
|
|
# Load the SDF file and convert to SMILES
|
|
sdf_file_list = [i for i in Path('../predict_data').glob('*.sdf')]
|
|
# sdf_file = '/mnt/c/project/qsar/predict_data/chem1.sdf'
|
|
sdf_results = {}
|
|
for sdf_file in sdf_file_list:
|
|
supplier = Chem.SDMolSupplier(sdf_file)
|
|
new_mol = [mol for mol in supplier][0] # Assuming only one molecule in SDF
|
|
smiles = Chem.MolToSmiles(new_mol)
|
|
|
|
# Calculate the 2D-QSAR descriptors
|
|
descriptor = calculate_2dqsar_repr(smiles)
|
|
descriptor_array = np.array(descriptor).reshape(1, -1)
|
|
|
|
# Load the saved model (use the model that performed best in training)
|
|
model_file_list = [i for i in Path().cwd().glob('2d_qsar_*.pkl')]
|
|
|
|
results = {}
|
|
|
|
for model_file in model_file_list:
|
|
model = joblib.load(model_file)
|
|
# Predict the MIC value
|
|
predicted_mic = model.predict(descriptor_array)
|
|
# print(f"Predicted MIC value: {predicted_mic[0]}")
|
|
results[model_file.stem] = predicted_mic[0]
|
|
|
|
sdf_results[sdf_file.stem] = results
|
|
|
|
pprint(sdf_results)
|
|
|
|
import seaborn as sns
|
|
import matplotlib.pyplot as plt
|
|
import pandas as pd
|
|
|
|
# Filter out negative MIC values from sdf_results
|
|
filtered_sdf_results = {}
|
|
for sdf_name, model_results in sdf_results.items():
|
|
filtered_results = {model_name: mic_value for model_name, mic_value in model_results.items() if mic_value >= 0}
|
|
filtered_sdf_results[sdf_name] = filtered_results
|
|
|
|
# Convert the filtered results to a DataFrame for easier plotting
|
|
filtered_data = []
|
|
for sdf_name, model_results in filtered_sdf_results.items():
|
|
for model_name, mic_value in model_results.items():
|
|
filtered_data.append({'SDF': sdf_name, 'Model': model_name, 'MIC': mic_value})
|
|
|
|
df = pd.DataFrame(filtered_data)
|
|
|
|
# Set up the matplotlib figure
|
|
plt.figure(figsize=(12, 8))
|
|
|
|
# Create a seaborn barplot with the filtered data
|
|
sns.barplot(x='SDF', y='MIC', hue='Model', data=df, palette='tab20')
|
|
|
|
# Customize the plot
|
|
plt.title('Predicted MIC values by Model for Each SDF (Filtered)')
|
|
plt.xlabel('SDF Files')
|
|
plt.ylabel('Predicted MIC Values')
|
|
plt.xticks(rotation=45)
|
|
plt.legend(title='Model')
|
|
|
|
# Show the plot
|
|
plt.tight_layout()
|
|
plt.show()
|