import os import yaml import argparse import torch import pandas as pd from sklearn.preprocessing import OneHotEncoder from xgboost import XGBClassifier from rdkit import Chem from rdkit import RDLogger from workflow.dataset.dataset_representation import batch_representation from workflow.models.ginet_concat import GINet RDLogger.DisableLog('rdApp.*') # Function to read command line arguments def parse_arguments(): """ This function returns parsed command line arguments. """ # Instantiate parser parser = argparse.ArgumentParser(prog="Represent molecular structures as using MolE.", description="This program recieves a file with SMILES and represents them using the MolE representation.", usage="python mole_representation.py smiles_filepath output_filepath [options]", formatter_class=argparse.ArgumentDefaultsHelpFormatter) # Input SMILES parser.add_argument("smiles_filepath", help="Complete path to the smiles filepath. Expects a TSV file with a column containing SMILES strings.") # Output filepath parser.add_argument("output_filepath", help="Complete path for the output.") # Column name for smiles parser.add_argument("-c", "--smiles_colname", help="Column name in smiles_filepath that contains the SMILES.", default="smiles") # Column name for id parser.add_argument("-i", "--chemid_colname", help="Column name in smiles_filepath that contains the ID string of each chemical.", default="chem_id") # MolE model parser.add_argument("-m", "--mole_model", help="Path to the directory containing the config.yaml and model.pth files of the pre-trained MolE chemical representation.", default="pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001") # Device parser.add_argument("-d", "--device", help="Device where the pre-trained model is loaded. Can be one of ['cpu', 'cuda', 'auto']. If 'auto' (default) then cuda:0 device is selected if a GPU is detected.", default="auto") # Parse arguments args = parser.parse_args() # Determine device for MolE model if args.device == "auto": args.device = "cuda:0" if torch.cuda.is_available() else "cpu" print(f"Using {args.device}") return args # A FUNCTION TO READ SMILES from file def read_smiles(data_path, smile_col="rdkit_no_salt", id_col="prestwick_ID"): """ Read SMILES data from a file or DataFrame and remove invalid SMILES. Parameters: - data_path (str or pd.DataFrame): Path to the file or a DataFrame containing SMILES data. - smile_col (str, optional): Name of the column containing SMILES strings. - id_col (str, optional): Name of the column containing molecule IDs. Returns: - smile_df (pandas.DataFrame): DataFrame containing SMILES data with specified columns. """ # Read the data if isinstance(data_path, pd.DataFrame): smile_df = data_path.copy() else: smile_df = pd.read_csv(data_path, sep='\t') smile_df = smile_df[[smile_col, id_col]] # Make sure ID column is interpreted as str smile_df[id_col] = smile_df[id_col].astype(str) # Remove NaN smile_df = smile_df.dropna() # Remove invalid smiles smile_df = smile_df[smile_df[smile_col].apply(lambda x: Chem.MolFromSmiles(x) is not None)] return smile_df # Function to load a pre-trained model def load_pretrained_model(pretrained_model_dir, device="cuda:0"): """ Load a pre-trained MolE model. Parameters: - pretrained_model_dir (str): Name of the pre-trained MolE model. - device (str, optional): Device for computation (default is "cuda:0"). Returns: - model: Loaded pre-trained model. """ # Read model configuration config = yaml.load(open(os.path.join(pretrained_model_dir, "config.yaml"), "r"), Loader=yaml.FullLoader) model_config = config["model"] # Instantiate model model = GINet(**model_config).to(device) # Load pre-trained weights model_pth_path = os.path.join(pretrained_model_dir, "model.pth") print(model_pth_path) state_dict = torch.load(model_pth_path, map_location=device) model.load_my_state_dict(state_dict) return model def process_representation(dataset_path, smile_column_str, id_column_str, pretrained_dir, device): """ Process the dataset to generate molecular representations. Parameters: - dataset_path (str): Path to the dataset file. - pretrained_dir (str): Name of the pre-trained model. - smile_column_str (str, optional): Name of the column containing SMILES strings. - id_column_str (str, optional): Name of the column containing molecule IDs. - device (str): Device to use for computation (default is "cuda:0"). Can also be "cpu". Returns: - udl_representation (pandas.DataFrame): DataFrame containing molecular representations if split_data=False. """ # First we read the SMILES dataframe smiles_df = read_smiles(dataset_path, smile_col=smile_column_str, id_col=id_column_str) # Load the pre-trained model pmodel = load_pretrained_model(pretrained_model_dir=pretrained_dir, device=device) # Gather pre-trained representation udl_representation = batch_representation(smiles_df, pmodel, smile_column_str, id_column_str, device=device) return udl_representation def main(): # Parse arguments args = parse_arguments() # Obtain MolE pre-trained representation mole_representation = process_representation(dataset_path = args.smiles_filepath, smile_column_str = args.smiles_colname, id_column_str = args.chemid_colname, pretrained_dir = args.mole_model, device=args.device) # Write MolE representation to output mole_representation.to_csv(args.output_filepath, sep='\t') if __name__ == "__main__": main()