Files
mole_broad_spectrum_parallel/workflow/02.model_training.ipynb
mm644706215 a56e60e9a3 first add
2025-10-16 17:21:48 +08:00

450 lines
18 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Training \n",
"### Author: Roberto Olayo-Alarcon\n",
" \n",
"Here, we perform a random search over XGBoost parameters to train models to predict antimicrobial activity using different chemical representations"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import random\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
" \n",
"# Classfier\n",
"from xgboost import XGBClassifier\n",
"\n",
"# \n",
"from sklearn.metrics import roc_auc_score, f1_score, precision_recall_curve, auc\n",
"from sklearn.model_selection import ParameterGrid\n",
"\n",
"from sklearn.preprocessing import OneHotEncoder"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prepare directories"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
"INPUT_DIR = \"../data/01.prepare_training_data/\"\n",
"\n",
"OUTPUT_DIR = \"../data/02.model_training\"\n",
"os.makedirs(OUTPUT_DIR, exist_ok=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prepare data"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"representation_dict = {\"MolE\": pd.read_csv(os.path.join(INPUT_DIR, \"maier_mole_representation.tsv.gz\"), index_col=0, sep='\\t'),\n",
"\"ecfp4\": pd.read_csv(os.path.join(INPUT_DIR, \"maier_ecfp4_representation.tsv.gz\"), index_col=0, sep='\\t'),\n",
"\"chemDesc\": pd.read_csv(os.path.join(INPUT_DIR, \"maier_chemdesc_representation.tsv.gz\"), index_col=0, sep='\\t')}\n",
"\n",
"split_df = pd.read_csv(os.path.join(INPUT_DIR, \"maier_scaffold_split.tsv.gz\"), index_col=\"prestwick_ID\", sep='\\t')\n",
"screen_df = pd.read_csv(os.path.join(INPUT_DIR, \"maier_screening_results.tsv.gz\"), index_col=\"prestwick_ID\", sep='\\t')\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model parameters for random search"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
"XGB_PARAMS = {\"nthread\": 20,\n",
" \"n_estimators\":[30, 100, 300, 500, 1000],\n",
" \"max_depth\": [5, 10, 50, 100],\n",
" \"eta\":[0.3, 0.1, 0.05, 1],\n",
" \"subsample\": [0.3, 0.5, 0.8, 1.0],\n",
" \"objective\": \"binary:logistic\"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Helper Functions"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [],
"source": [
"# Select parameters randomly\n",
"def select_params(original_config):\n",
" \"\"\"\n",
" Randomly select parameters from the provided configuration for modeling.\n",
"\n",
" This function traverses through the provided configuration dictionary and randomly selects\n",
" values from lists or dictionaries.\n",
"\n",
" Parameters:\n",
" - original_config (dict): The original configuration dictionary containing parameters.\n",
"\n",
" Returns:\n",
" - model_config (dict): Configuration with randomly selected parameters.\n",
" \"\"\"\n",
"\n",
" model_config = original_config.copy()\n",
"\n",
" for key, value in model_config.items():\n",
"\n",
" if type(value) == list:\n",
" model_config[key] = random.choice(value)\n",
" elif type(value) == dict:\n",
" model_config[key] = select_params(value)\n",
" \n",
" return model_config\n",
"\n",
"# Prepare strain one-hot-encoding\n",
"def prep_ohe(categories):\n",
"\n",
" \"\"\"\n",
" Prepare one-hot encoding for strain variables.\n",
"\n",
" This function creates a one-hot encoding representation of the provided categorical variables.\n",
" It fits a OneHotEncoder to the categories and transforms them into a pandas DataFrame.\n",
"\n",
" Parameters:\n",
" - categories (array-like): Array-like object containing categorical variables.\n",
"\n",
" Returns:\n",
" - cat_ohe (pandas.DataFrame): DataFrame representing the one-hot encoded categorical variables.\n",
" \"\"\"\n",
"\n",
" ohe = OneHotEncoder(sparse=False)\n",
" ohe.fit(pd.DataFrame(categories))\n",
" cat_ohe = pd.DataFrame(ohe.transform(pd.DataFrame(categories)),\n",
" index=categories, columns=ohe.categories_)\n",
" \n",
" return cat_ohe\n",
"\n",
"def get_split(data_df, y_df, splitter_df, split_strat = \"split\"):\n",
"\n",
" \"\"\"\n",
" Prepare data splits for training, validation, and testing.\n",
"\n",
" This function prepares the data splits for training, validation, and testing based on the given split strategy.\n",
" It joins molecular features with taxonomic One-Hot Encoded (OHE) labels and separates the data into respective splits.\n",
"\n",
" Parameters:\n",
" - data_df (pandas.DataFrame): DataFrame containing molecular features.\n",
" - y_df (pandas.DataFrame): DataFrame containing labels.\n",
" - splitter_df (pandas.DataFrame): DataFrame containing chemical IDs and split information.\n",
" - split_strat (str, optional): Split strategy to use (\"split\" by default).\n",
"\n",
" Returns:\n",
" - X_train (pandas.DataFrame): DataFrame containing features for training.\n",
" - X_valid (pandas.DataFrame): DataFrame containing features for validation.\n",
" - X_test (pandas.DataFrame): DataFrame containing features for testing.\n",
" - y_train (numpy.ndarray): Array containing labels for training.\n",
" - y_valid (numpy.ndarray): Array containing labels for validation.\n",
" - y_test (numpy.ndarray): Array containing labels for testing.\n",
" \"\"\"\n",
"\n",
" # Get the chemicals in each split of data\n",
" train_chems = splitter_df.loc[splitter_df[split_strat] == \"train\"].index\n",
" validation_chems = splitter_df.loc[splitter_df[split_strat] == \"valid\"].index\n",
" test_chems = splitter_df.loc[splitter_df[split_strat] == \"test\"].index\n",
"\n",
" # Prepare taxonomic OHE\n",
" taxa_ohe = prep_ohe(y_df.columns) \n",
"\n",
" # Pivot longer screen results\n",
" screen_melt = y_df.unstack().reset_index().rename(columns={0: \"label\",\n",
" \"level_0\": \"taxa_name\"})\n",
" \n",
" # Join molecular features and then join taxa OHE\n",
" data_df.columns = [str(c) for c in data_df.columns]\n",
" data_df = data_df.fillna(0)\n",
"\n",
" screen_feat = screen_melt.join(data_df, on=\"prestwick_ID\")\n",
" screen_feat = screen_feat.join(taxa_ohe, on=\"taxa_name\")\n",
"\n",
" assert screen_feat.shape[0] == screen_melt.shape[0]\n",
"\n",
"\n",
" # Gather train\n",
" X_train = screen_feat.loc[screen_feat[\"prestwick_ID\"].isin(train_chems)].drop(columns=[\"prestwick_ID\", \n",
" \"label\", \n",
" \"taxa_name\"])\n",
" y_train = screen_feat.loc[screen_feat[\"prestwick_ID\"].isin(train_chems), [\"label\"]].values\n",
"\n",
" # Gather valid\n",
" X_valid = screen_feat.loc[screen_feat[\"prestwick_ID\"].isin(validation_chems)].drop(columns=[\"prestwick_ID\", \n",
" \"label\", \n",
" \"taxa_name\"])\n",
" y_valid = screen_feat.loc[screen_feat[\"prestwick_ID\"].isin(validation_chems), [\"label\"]].values\n",
"\n",
" # Gather test\n",
" X_test = screen_feat.loc[screen_feat[\"prestwick_ID\"].isin(test_chems)].drop(columns=[\"prestwick_ID\", \n",
" \"label\", \n",
" \"taxa_name\"])\n",
" y_test = screen_feat.loc[screen_feat[\"prestwick_ID\"].isin(test_chems), [\"label\"]].values\n",
"\n",
" \n",
" return X_train, X_valid, X_test, y_train, y_valid, y_test\n",
" \n",
"def get_performance_metrics(y_true, y_pred, y_score, split_name):\n",
"\n",
" \"\"\"\n",
" Compute performance metrics for a given data split.\n",
"\n",
" This function calculates various performance metrics including AUROC, AUPRC, and F1 score\n",
" based on the true labels, predicted labels, and predicted scores.\n",
"\n",
" Parameters:\n",
" - y_true (array-like): True labels.\n",
" - y_pred (array-like): Predicted labels.\n",
" - y_score (array-like): Predicted scores.\n",
" - split_name (str): Name of the data split.\n",
"\n",
" Returns:\n",
" - out_dict (dict): Dictionary containing computed performance metrics.\n",
" Keys:\n",
" - '{split_name}_auroc': Area Under the Receiver Operating Characteristic curve (AUROC) score.\n",
" - '{split_name}_prauc': Area Under the Precision-Recall curve (AUPRC) score.\n",
" - '{split_name}_f1': F1 score.\n",
" \"\"\"\n",
"\n",
" pr, rec, _ = precision_recall_curve(y_true, y_score[:, 1])\n",
"\n",
" out_dict = {f\"{split_name}_auroc\": roc_auc_score(y_true=y_true, y_score=y_score[:, 1]),\n",
" f\"{split_name}_prauc\": auc(rec, pr),\n",
" f\"{split_name}_f1\": f1_score(y_true=y_true, y_pred=y_pred)}\n",
"\n",
" return out_dict"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Main training function"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
"outputs": [],
"source": [
"def eval_models(dataset_representation, \n",
" n_train = 1, \n",
" n_models = 1,\n",
" feature_options = representation_dict, \n",
" XGB_params_dict = XGB_PARAMS,\n",
" split_df = split_df,\n",
" screen_df = screen_df):\n",
" \n",
" \"\"\"\n",
" Evaluate multiple XGBoost models with different configurations on a given dataset representation.\n",
"\n",
" This function trains and evaluates multiple XGBoost models with different configurations\n",
" on a specified dataset representation. It computes performance metrics for each model\n",
" on validation and test sets and returns the results as a DataFrame.\n",
"\n",
" Parameters:\n",
" - dataset_representation (str): Name of the dataset representation to use.\n",
" - n_train (int, optional): Number of training iterations for each model (default is 1).\n",
" - n_models (int, optional): Number of models to evaluate (default is 1).\n",
" - feature_options (dict, optional): Dictionary containing different dataset representations (default is representation_dict).\n",
" - XGB_params_dict (dict, optional): Dictionary containing XGBoost parameters for model configuration (default is XGB_PARAMS).\n",
" - split_df (pandas.DataFrame, optional): DataFrame containing split information (default is split_df).\n",
" - screen_df (pandas.DataFrame, optional): DataFrame containing screening data (default is screen_df).\n",
"\n",
" Returns:\n",
" - results_df (pandas.DataFrame): DataFrame containing performance metrics for all evaluated models.\n",
" \"\"\"\n",
"\n",
" # This should be a dictionary containing all possible values for the classifier in question params\n",
" classifier_params_copy = XGB_params_dict.copy()\n",
"\n",
" # Get the corresponding features and screen\n",
" features_df = feature_options[dataset_representation].copy()\n",
"\n",
" # Since the splits are already made, we just have to separate the data\n",
" X_train, X_valid, X_test, y_train, y_valid, y_test = get_split(features_df, screen_df, split_df)\n",
"\n",
"\n",
" # Iterate over models\n",
" results_list=[]\n",
" for m in range(n_models):\n",
" \n",
" # Gather model configuration\n",
" model_config = select_params(classifier_params_copy)\n",
" model_config_str = str(model_config)\n",
"\n",
" # Iterate over training\n",
" for t in range(n_train):\n",
"\n",
" # Create base estimator\n",
" model_config[\"seed\"] = np.random.randint(1_000_000, size=1)[0]\n",
" base_estimator = XGBClassifier(**model_config)\n",
"\n",
" # Train model\n",
" base_estimator.fit(X=X_train, y=y_train)\n",
"\n",
" # Validation\n",
" print(\"At Validation\")\n",
" validation_proba = base_estimator.predict_proba(X=X_valid)\n",
" validation_preds = base_estimator.predict(X=X_valid) \n",
"\n",
" # Testing\n",
" print(\"At Testing\")\n",
" test_proba = base_estimator.predict_proba(X=X_test)\n",
" test_preds = base_estimator.predict(X=X_test)\n",
"\n",
" # Performance Metrics\n",
" print(\"Gathering Results\")\n",
" validation_performance = get_performance_metrics(y_true=y_valid, y_pred=validation_preds, y_score=validation_proba, split_name=\"validation\")\n",
" test_performance = get_performance_metrics(y_true=y_test, y_pred=test_preds, y_score=test_proba, split_name=\"test\")\n",
"\n",
" performance_dict = {**validation_performance, **test_performance}\n",
"\n",
" # Add information to the metrics\n",
" performance_dict[\"model\"] = f\"model_{m}\"\n",
" performance_dict[\"train\"] = f\"train_{t}\"\n",
" performance_dict[\"model_type\"] = \"XGB\"\n",
" performance_dict[\"model_params\"] = model_config_str\n",
" performance_dict[\"representation\"] = dataset_representation\n",
"\n",
" train_df = pd.DataFrame(performance_dict, index=[0])\n",
" results_list.append(train_df)\n",
" \n",
" return pd.concat(results_list)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Random Search"
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Starting MolE representation\n",
"At Validation\n",
"[[9.9796784e-01 2.0321773e-03]\n",
" [9.9822879e-01 1.7712050e-03]\n",
" [9.9245560e-01 7.5444025e-03]\n",
" ...\n",
" [9.9983788e-01 1.6215106e-04]\n",
" [9.9981576e-01 1.8425252e-04]\n",
" [9.9978560e-01 2.1438736e-04]]\n",
"At Testing\n",
"Gathering Results\n",
"Starting ECFP4 representation\n",
"At Validation\n",
"[[0.65721035 0.34278968]\n",
" [0.85142165 0.14857836]\n",
" [0.47510612 0.5248939 ]\n",
" ...\n",
" [0.86241525 0.13758476]\n",
" [0.851101 0.148899 ]\n",
" [0.86241525 0.13758476]]\n",
"At Testing\n",
"Gathering Results\n",
"Starting ChemDesc representation\n",
"At Validation\n",
"[[9.99996781e-01 3.24534813e-06]\n",
" [9.99079943e-01 9.20061138e-04]\n",
" [9.99998808e-01 1.21489290e-06]\n",
" ...\n",
" [9.99993920e-01 6.05880268e-06]\n",
" [9.99893427e-01 1.06556006e-04]\n",
" [9.99999881e-01 1.19222065e-07]]\n",
"At Testing\n",
"Gathering Results\n"
]
}
],
"source": [
"i = 0 \n",
"# Output file name\n",
"filename = \"strain_performance.tsv.gz\"\n",
"\n",
"# Iterate over the representations\n",
"for representation in representation_dict.keys():\n",
" print(f\"Starting {representation} representation\")\n",
"\n",
" # Random search\n",
" results = eval_models(dataset_representation=representation)\n",
"\n",
" # Append results\n",
" if i == 0:\n",
" results.to_csv(os.path.join(OUTPUT_DIR, filename), sep='\\t', index=False)\n",
" else:\n",
" results.to_csv(os.path.join(OUTPUT_DIR, filename), sep='\\t', index=False, header=False, mode=\"a\")\n",
" \n",
" i += 1\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "mole_test",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.20"
}
},
"nbformat": 4,
"nbformat_minor": 2
}