{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Model Training \n", "### Author: Roberto Olayo-Alarcon\n", " \n", "Here, we perform a random search over XGBoost parameters to train models to predict antimicrobial activity using different chemical representations" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [], "source": [ "import os\n", "import random\n", "import numpy as np\n", "import pandas as pd\n", "\n", " \n", "# Classfier\n", "from xgboost import XGBClassifier\n", "\n", "# \n", "from sklearn.metrics import roc_auc_score, f1_score, precision_recall_curve, auc\n", "from sklearn.model_selection import ParameterGrid\n", "\n", "from sklearn.preprocessing import OneHotEncoder" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prepare directories" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [], "source": [ "INPUT_DIR = \"../data/01.prepare_training_data/\"\n", "\n", "OUTPUT_DIR = \"../data/02.model_training\"\n", "os.makedirs(OUTPUT_DIR, exist_ok=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prepare data" ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [], "source": [ "representation_dict = {\"MolE\": pd.read_csv(os.path.join(INPUT_DIR, \"maier_mole_representation.tsv.gz\"), index_col=0, sep='\\t'),\n", "\"ecfp4\": pd.read_csv(os.path.join(INPUT_DIR, \"maier_ecfp4_representation.tsv.gz\"), index_col=0, sep='\\t'),\n", "\"chemDesc\": pd.read_csv(os.path.join(INPUT_DIR, \"maier_chemdesc_representation.tsv.gz\"), index_col=0, sep='\\t')}\n", "\n", "split_df = pd.read_csv(os.path.join(INPUT_DIR, \"maier_scaffold_split.tsv.gz\"), index_col=\"prestwick_ID\", sep='\\t')\n", "screen_df = pd.read_csv(os.path.join(INPUT_DIR, \"maier_screening_results.tsv.gz\"), index_col=\"prestwick_ID\", sep='\\t')\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model parameters for random search" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [], "source": [ "XGB_PARAMS = {\"nthread\": 20,\n", " \"n_estimators\":[30, 100, 300, 500, 1000],\n", " \"max_depth\": [5, 10, 50, 100],\n", " \"eta\":[0.3, 0.1, 0.05, 1],\n", " \"subsample\": [0.3, 0.5, 0.8, 1.0],\n", " \"objective\": \"binary:logistic\"}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Helper Functions" ] }, { "cell_type": "code", "execution_count": 81, "metadata": {}, "outputs": [], "source": [ "# Select parameters randomly\n", "def select_params(original_config):\n", " \"\"\"\n", " Randomly select parameters from the provided configuration for modeling.\n", "\n", " This function traverses through the provided configuration dictionary and randomly selects\n", " values from lists or dictionaries.\n", "\n", " Parameters:\n", " - original_config (dict): The original configuration dictionary containing parameters.\n", "\n", " Returns:\n", " - model_config (dict): Configuration with randomly selected parameters.\n", " \"\"\"\n", "\n", " model_config = original_config.copy()\n", "\n", " for key, value in model_config.items():\n", "\n", " if type(value) == list:\n", " model_config[key] = random.choice(value)\n", " elif type(value) == dict:\n", " model_config[key] = select_params(value)\n", " \n", " return model_config\n", "\n", "# Prepare strain one-hot-encoding\n", "def prep_ohe(categories):\n", "\n", " \"\"\"\n", " Prepare one-hot encoding for strain variables.\n", "\n", " This function creates a one-hot encoding representation of the provided categorical variables.\n", " It fits a OneHotEncoder to the categories and transforms them into a pandas DataFrame.\n", "\n", " Parameters:\n", " - categories (array-like): Array-like object containing categorical variables.\n", "\n", " Returns:\n", " - cat_ohe (pandas.DataFrame): DataFrame representing the one-hot encoded categorical variables.\n", " \"\"\"\n", "\n", " ohe = OneHotEncoder(sparse=False)\n", " ohe.fit(pd.DataFrame(categories))\n", " cat_ohe = pd.DataFrame(ohe.transform(pd.DataFrame(categories)),\n", " index=categories, columns=ohe.categories_)\n", " \n", " return cat_ohe\n", "\n", "def get_split(data_df, y_df, splitter_df, split_strat = \"split\"):\n", "\n", " \"\"\"\n", " Prepare data splits for training, validation, and testing.\n", "\n", " This function prepares the data splits for training, validation, and testing based on the given split strategy.\n", " It joins molecular features with taxonomic One-Hot Encoded (OHE) labels and separates the data into respective splits.\n", "\n", " Parameters:\n", " - data_df (pandas.DataFrame): DataFrame containing molecular features.\n", " - y_df (pandas.DataFrame): DataFrame containing labels.\n", " - splitter_df (pandas.DataFrame): DataFrame containing chemical IDs and split information.\n", " - split_strat (str, optional): Split strategy to use (\"split\" by default).\n", "\n", " Returns:\n", " - X_train (pandas.DataFrame): DataFrame containing features for training.\n", " - X_valid (pandas.DataFrame): DataFrame containing features for validation.\n", " - X_test (pandas.DataFrame): DataFrame containing features for testing.\n", " - y_train (numpy.ndarray): Array containing labels for training.\n", " - y_valid (numpy.ndarray): Array containing labels for validation.\n", " - y_test (numpy.ndarray): Array containing labels for testing.\n", " \"\"\"\n", "\n", " # Get the chemicals in each split of data\n", " train_chems = splitter_df.loc[splitter_df[split_strat] == \"train\"].index\n", " validation_chems = splitter_df.loc[splitter_df[split_strat] == \"valid\"].index\n", " test_chems = splitter_df.loc[splitter_df[split_strat] == \"test\"].index\n", "\n", " # Prepare taxonomic OHE\n", " taxa_ohe = prep_ohe(y_df.columns) \n", "\n", " # Pivot longer screen results\n", " screen_melt = y_df.unstack().reset_index().rename(columns={0: \"label\",\n", " \"level_0\": \"taxa_name\"})\n", " \n", " # Join molecular features and then join taxa OHE\n", " data_df.columns = [str(c) for c in data_df.columns]\n", " data_df = data_df.fillna(0)\n", "\n", " screen_feat = screen_melt.join(data_df, on=\"prestwick_ID\")\n", " screen_feat = screen_feat.join(taxa_ohe, on=\"taxa_name\")\n", "\n", " assert screen_feat.shape[0] == screen_melt.shape[0]\n", "\n", "\n", " # Gather train\n", " X_train = screen_feat.loc[screen_feat[\"prestwick_ID\"].isin(train_chems)].drop(columns=[\"prestwick_ID\", \n", " \"label\", \n", " \"taxa_name\"])\n", " y_train = screen_feat.loc[screen_feat[\"prestwick_ID\"].isin(train_chems), [\"label\"]].values\n", "\n", " # Gather valid\n", " X_valid = screen_feat.loc[screen_feat[\"prestwick_ID\"].isin(validation_chems)].drop(columns=[\"prestwick_ID\", \n", " \"label\", \n", " \"taxa_name\"])\n", " y_valid = screen_feat.loc[screen_feat[\"prestwick_ID\"].isin(validation_chems), [\"label\"]].values\n", "\n", " # Gather test\n", " X_test = screen_feat.loc[screen_feat[\"prestwick_ID\"].isin(test_chems)].drop(columns=[\"prestwick_ID\", \n", " \"label\", \n", " \"taxa_name\"])\n", " y_test = screen_feat.loc[screen_feat[\"prestwick_ID\"].isin(test_chems), [\"label\"]].values\n", "\n", " \n", " return X_train, X_valid, X_test, y_train, y_valid, y_test\n", " \n", "def get_performance_metrics(y_true, y_pred, y_score, split_name):\n", "\n", " \"\"\"\n", " Compute performance metrics for a given data split.\n", "\n", " This function calculates various performance metrics including AUROC, AUPRC, and F1 score\n", " based on the true labels, predicted labels, and predicted scores.\n", "\n", " Parameters:\n", " - y_true (array-like): True labels.\n", " - y_pred (array-like): Predicted labels.\n", " - y_score (array-like): Predicted scores.\n", " - split_name (str): Name of the data split.\n", "\n", " Returns:\n", " - out_dict (dict): Dictionary containing computed performance metrics.\n", " Keys:\n", " - '{split_name}_auroc': Area Under the Receiver Operating Characteristic curve (AUROC) score.\n", " - '{split_name}_prauc': Area Under the Precision-Recall curve (AUPRC) score.\n", " - '{split_name}_f1': F1 score.\n", " \"\"\"\n", "\n", " pr, rec, _ = precision_recall_curve(y_true, y_score[:, 1])\n", "\n", " out_dict = {f\"{split_name}_auroc\": roc_auc_score(y_true=y_true, y_score=y_score[:, 1]),\n", " f\"{split_name}_prauc\": auc(rec, pr),\n", " f\"{split_name}_f1\": f1_score(y_true=y_true, y_pred=y_pred)}\n", "\n", " return out_dict" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Main training function" ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [], "source": [ "def eval_models(dataset_representation, \n", " n_train = 1, \n", " n_models = 1,\n", " feature_options = representation_dict, \n", " XGB_params_dict = XGB_PARAMS,\n", " split_df = split_df,\n", " screen_df = screen_df):\n", " \n", " \"\"\"\n", " Evaluate multiple XGBoost models with different configurations on a given dataset representation.\n", "\n", " This function trains and evaluates multiple XGBoost models with different configurations\n", " on a specified dataset representation. It computes performance metrics for each model\n", " on validation and test sets and returns the results as a DataFrame.\n", "\n", " Parameters:\n", " - dataset_representation (str): Name of the dataset representation to use.\n", " - n_train (int, optional): Number of training iterations for each model (default is 1).\n", " - n_models (int, optional): Number of models to evaluate (default is 1).\n", " - feature_options (dict, optional): Dictionary containing different dataset representations (default is representation_dict).\n", " - XGB_params_dict (dict, optional): Dictionary containing XGBoost parameters for model configuration (default is XGB_PARAMS).\n", " - split_df (pandas.DataFrame, optional): DataFrame containing split information (default is split_df).\n", " - screen_df (pandas.DataFrame, optional): DataFrame containing screening data (default is screen_df).\n", "\n", " Returns:\n", " - results_df (pandas.DataFrame): DataFrame containing performance metrics for all evaluated models.\n", " \"\"\"\n", "\n", " # This should be a dictionary containing all possible values for the classifier in question params\n", " classifier_params_copy = XGB_params_dict.copy()\n", "\n", " # Get the corresponding features and screen\n", " features_df = feature_options[dataset_representation].copy()\n", "\n", " # Since the splits are already made, we just have to separate the data\n", " X_train, X_valid, X_test, y_train, y_valid, y_test = get_split(features_df, screen_df, split_df)\n", "\n", "\n", " # Iterate over models\n", " results_list=[]\n", " for m in range(n_models):\n", " \n", " # Gather model configuration\n", " model_config = select_params(classifier_params_copy)\n", " model_config_str = str(model_config)\n", "\n", " # Iterate over training\n", " for t in range(n_train):\n", "\n", " # Create base estimator\n", " model_config[\"seed\"] = np.random.randint(1_000_000, size=1)[0]\n", " base_estimator = XGBClassifier(**model_config)\n", "\n", " # Train model\n", " base_estimator.fit(X=X_train, y=y_train)\n", "\n", " # Validation\n", " print(\"At Validation\")\n", " validation_proba = base_estimator.predict_proba(X=X_valid)\n", " validation_preds = base_estimator.predict(X=X_valid) \n", "\n", " # Testing\n", " print(\"At Testing\")\n", " test_proba = base_estimator.predict_proba(X=X_test)\n", " test_preds = base_estimator.predict(X=X_test)\n", "\n", " # Performance Metrics\n", " print(\"Gathering Results\")\n", " validation_performance = get_performance_metrics(y_true=y_valid, y_pred=validation_preds, y_score=validation_proba, split_name=\"validation\")\n", " test_performance = get_performance_metrics(y_true=y_test, y_pred=test_preds, y_score=test_proba, split_name=\"test\")\n", "\n", " performance_dict = {**validation_performance, **test_performance}\n", "\n", " # Add information to the metrics\n", " performance_dict[\"model\"] = f\"model_{m}\"\n", " performance_dict[\"train\"] = f\"train_{t}\"\n", " performance_dict[\"model_type\"] = \"XGB\"\n", " performance_dict[\"model_params\"] = model_config_str\n", " performance_dict[\"representation\"] = dataset_representation\n", "\n", " train_df = pd.DataFrame(performance_dict, index=[0])\n", " results_list.append(train_df)\n", " \n", " return pd.concat(results_list)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Random Search" ] }, { "cell_type": "code", "execution_count": 89, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Starting MolE representation\n", "At Validation\n", "[[9.9796784e-01 2.0321773e-03]\n", " [9.9822879e-01 1.7712050e-03]\n", " [9.9245560e-01 7.5444025e-03]\n", " ...\n", " [9.9983788e-01 1.6215106e-04]\n", " [9.9981576e-01 1.8425252e-04]\n", " [9.9978560e-01 2.1438736e-04]]\n", "At Testing\n", "Gathering Results\n", "Starting ECFP4 representation\n", "At Validation\n", "[[0.65721035 0.34278968]\n", " [0.85142165 0.14857836]\n", " [0.47510612 0.5248939 ]\n", " ...\n", " [0.86241525 0.13758476]\n", " [0.851101 0.148899 ]\n", " [0.86241525 0.13758476]]\n", "At Testing\n", "Gathering Results\n", "Starting ChemDesc representation\n", "At Validation\n", "[[9.99996781e-01 3.24534813e-06]\n", " [9.99079943e-01 9.20061138e-04]\n", " [9.99998808e-01 1.21489290e-06]\n", " ...\n", " [9.99993920e-01 6.05880268e-06]\n", " [9.99893427e-01 1.06556006e-04]\n", " [9.99999881e-01 1.19222065e-07]]\n", "At Testing\n", "Gathering Results\n" ] } ], "source": [ "i = 0 \n", "# Output file name\n", "filename = \"strain_performance.tsv.gz\"\n", "\n", "# Iterate over the representations\n", "for representation in representation_dict.keys():\n", " print(f\"Starting {representation} representation\")\n", "\n", " # Random search\n", " results = eval_models(dataset_representation=representation)\n", "\n", " # Append results\n", " if i == 0:\n", " results.to_csv(os.path.join(OUTPUT_DIR, filename), sep='\\t', index=False)\n", " else:\n", " results.to_csv(os.path.join(OUTPUT_DIR, filename), sep='\\t', index=False, header=False, mode=\"a\")\n", " \n", " i += 1\n" ] } ], "metadata": { "kernelspec": { "display_name": "mole_test", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.20" } }, "nbformat": 4, "nbformat_minor": 2 }