129 lines
4.2 KiB
Python
129 lines
4.2 KiB
Python
"""
|
|
MolE Representation Module
|
|
|
|
This module provides functions to generate MolE molecular representations.
|
|
"""
|
|
|
|
import os
|
|
import yaml
|
|
import torch
|
|
import pandas as pd
|
|
from rdkit import Chem
|
|
from rdkit import RDLogger
|
|
|
|
from .dataset_representation import batch_representation
|
|
from .ginet_concat import GINet
|
|
|
|
RDLogger.DisableLog('rdApp.*')
|
|
|
|
|
|
def read_smiles(data_path, smile_col="smiles", id_col="chem_id"):
|
|
"""
|
|
Read SMILES data from a file or DataFrame and remove invalid SMILES.
|
|
|
|
Parameters:
|
|
- data_path (str or pd.DataFrame): Path to the file or a DataFrame containing SMILES data.
|
|
- smile_col (str, optional): Name of the column containing SMILES strings.
|
|
- id_col (str, optional): Name of the column containing molecule IDs.
|
|
|
|
Returns:
|
|
- smile_df (pandas.DataFrame): DataFrame containing SMILES data with specified columns.
|
|
"""
|
|
|
|
# Read the data
|
|
if isinstance(data_path, pd.DataFrame):
|
|
smile_df = data_path.copy()
|
|
else:
|
|
# Try to read with different separators
|
|
try:
|
|
smile_df = pd.read_csv(data_path, sep='\t')
|
|
except:
|
|
smile_df = pd.read_csv(data_path)
|
|
|
|
# Check if columns exist, handle case-insensitive matching
|
|
columns_lower = {col.lower(): col for col in smile_df.columns}
|
|
|
|
smile_col_actual = columns_lower.get(smile_col.lower(), smile_col)
|
|
id_col_actual = columns_lower.get(id_col.lower(), id_col)
|
|
|
|
if smile_col_actual not in smile_df.columns:
|
|
raise ValueError(f"Column '{smile_col}' not found in data. Available columns: {list(smile_df.columns)}")
|
|
|
|
# Select columns
|
|
if id_col_actual in smile_df.columns:
|
|
smile_df = smile_df[[smile_col_actual, id_col_actual]]
|
|
smile_df.columns = [smile_col, id_col]
|
|
else:
|
|
# Create ID column if not exists
|
|
smile_df = smile_df[[smile_col_actual]]
|
|
smile_df.columns = [smile_col]
|
|
smile_df[id_col] = [f"mol{i+1}" for i in range(len(smile_df))]
|
|
|
|
# Make sure ID column is interpreted as str
|
|
smile_df[id_col] = smile_df[id_col].astype(str)
|
|
|
|
# Remove NaN
|
|
smile_df = smile_df.dropna()
|
|
|
|
# Remove invalid smiles
|
|
smile_df = smile_df[smile_df[smile_col].apply(lambda x: Chem.MolFromSmiles(x) is not None)]
|
|
|
|
return smile_df
|
|
|
|
|
|
def load_pretrained_model(pretrained_model_dir, device="cuda:0"):
|
|
"""
|
|
Load a pre-trained MolE model.
|
|
|
|
Parameters:
|
|
- pretrained_model_dir (str): Path to the pre-trained MolE model directory.
|
|
- device (str, optional): Device for computation (default is "cuda:0").
|
|
|
|
Returns:
|
|
- model: Loaded pre-trained model.
|
|
"""
|
|
|
|
# Read model configuration
|
|
config = yaml.load(open(os.path.join(pretrained_model_dir, "config.yaml"), "r"), Loader=yaml.FullLoader)
|
|
model_config = config["model"]
|
|
|
|
# Instantiate model
|
|
model = GINet(**model_config).to(device)
|
|
|
|
# Load pre-trained weights
|
|
model_pth_path = os.path.join(pretrained_model_dir, "model.pth")
|
|
print(f"Loading model from: {model_pth_path}")
|
|
|
|
state_dict = torch.load(model_pth_path, map_location=device)
|
|
model.load_my_state_dict(state_dict)
|
|
|
|
return model
|
|
|
|
|
|
def process_representation(dataset_path, smile_column_str, id_column_str, pretrained_dir, device):
|
|
"""
|
|
Process the dataset to generate molecular representations.
|
|
|
|
Parameters:
|
|
- dataset_path (str or pd.DataFrame): Path to the dataset file or DataFrame.
|
|
- pretrained_dir (str): Path to the pre-trained model directory.
|
|
- smile_column_str (str): Name of the column containing SMILES strings.
|
|
- id_column_str (str): Name of the column containing molecule IDs.
|
|
- device (str): Device to use for computation. Can be "cpu", "cuda:0", etc.
|
|
|
|
Returns:
|
|
- udl_representation (pandas.DataFrame): DataFrame containing molecular representations.
|
|
"""
|
|
|
|
# First we read the SMILES dataframe
|
|
smiles_df = read_smiles(dataset_path, smile_col=smile_column_str, id_col=id_column_str)
|
|
|
|
# Load the pre-trained model
|
|
pmodel = load_pretrained_model(pretrained_model_dir=pretrained_dir, device=device)
|
|
|
|
# Gather pre-trained representation
|
|
udl_representation = batch_representation(smiles_df, pmodel, smile_column_str, id_column_str, device=device)
|
|
|
|
return udl_representation
|
|
|