""" MolE Representation Module This module provides functions to generate MolE molecular representations. """ import os import yaml import torch import pandas as pd from rdkit import Chem from rdkit import RDLogger from .dataset_representation import batch_representation from .ginet_concat import GINet RDLogger.DisableLog('rdApp.*') def read_smiles(data_path, smile_col="smiles", id_col="chem_id"): """ Read SMILES data from a file or DataFrame and remove invalid SMILES. Parameters: - data_path (str or pd.DataFrame): Path to the file or a DataFrame containing SMILES data. - smile_col (str, optional): Name of the column containing SMILES strings. - id_col (str, optional): Name of the column containing molecule IDs. Returns: - smile_df (pandas.DataFrame): DataFrame containing SMILES data with specified columns. """ # Read the data if isinstance(data_path, pd.DataFrame): smile_df = data_path.copy() else: # Try to read with different separators try: smile_df = pd.read_csv(data_path, sep='\t') except: smile_df = pd.read_csv(data_path) # Check if columns exist, handle case-insensitive matching columns_lower = {col.lower(): col for col in smile_df.columns} smile_col_actual = columns_lower.get(smile_col.lower(), smile_col) id_col_actual = columns_lower.get(id_col.lower(), id_col) if smile_col_actual not in smile_df.columns: raise ValueError(f"Column '{smile_col}' not found in data. Available columns: {list(smile_df.columns)}") # Select columns if id_col_actual in smile_df.columns: smile_df = smile_df[[smile_col_actual, id_col_actual]] smile_df.columns = [smile_col, id_col] else: # Create ID column if not exists smile_df = smile_df[[smile_col_actual]] smile_df.columns = [smile_col] smile_df[id_col] = [f"mol{i+1}" for i in range(len(smile_df))] # Make sure ID column is interpreted as str smile_df[id_col] = smile_df[id_col].astype(str) # Remove NaN smile_df = smile_df.dropna() # Remove invalid smiles smile_df = smile_df[smile_df[smile_col].apply(lambda x: Chem.MolFromSmiles(x) is not None)] return smile_df def load_pretrained_model(pretrained_model_dir, device="cuda:0"): """ Load a pre-trained MolE model. Parameters: - pretrained_model_dir (str): Path to the pre-trained MolE model directory. - device (str, optional): Device for computation (default is "cuda:0"). Returns: - model: Loaded pre-trained model. """ # Read model configuration config = yaml.load(open(os.path.join(pretrained_model_dir, "config.yaml"), "r"), Loader=yaml.FullLoader) model_config = config["model"] # Instantiate model model = GINet(**model_config).to(device) # Load pre-trained weights model_pth_path = os.path.join(pretrained_model_dir, "model.pth") print(f"Loading model from: {model_pth_path}") state_dict = torch.load(model_pth_path, map_location=device) model.load_my_state_dict(state_dict) return model def process_representation(dataset_path, smile_column_str, id_column_str, pretrained_dir, device): """ Process the dataset to generate molecular representations. Parameters: - dataset_path (str or pd.DataFrame): Path to the dataset file or DataFrame. - pretrained_dir (str): Path to the pre-trained model directory. - smile_column_str (str): Name of the column containing SMILES strings. - id_column_str (str): Name of the column containing molecule IDs. - device (str): Device to use for computation. Can be "cpu", "cuda:0", etc. Returns: - udl_representation (pandas.DataFrame): DataFrame containing molecular representations. """ # First we read the SMILES dataframe smiles_df = read_smiles(dataset_path, smile_col=smile_column_str, id_col=id_column_str) # Load the pre-trained model pmodel = load_pretrained_model(pretrained_model_dir=pretrained_dir, device=device) # Gather pre-trained representation udl_representation = batch_representation(smiles_df, pmodel, smile_column_str, id_column_str, device=device) return udl_representation