Files
SIME/models/mole_representation.py
2025-10-17 15:54:00 +08:00

129 lines
4.2 KiB
Python

"""
MolE Representation Module
This module provides functions to generate MolE molecular representations.
"""
import os
import yaml
import torch
import pandas as pd
from rdkit import Chem
from rdkit import RDLogger
from .dataset_representation import batch_representation
from .ginet_concat import GINet
RDLogger.DisableLog('rdApp.*')
def read_smiles(data_path, smile_col="smiles", id_col="chem_id"):
"""
Read SMILES data from a file or DataFrame and remove invalid SMILES.
Parameters:
- data_path (str or pd.DataFrame): Path to the file or a DataFrame containing SMILES data.
- smile_col (str, optional): Name of the column containing SMILES strings.
- id_col (str, optional): Name of the column containing molecule IDs.
Returns:
- smile_df (pandas.DataFrame): DataFrame containing SMILES data with specified columns.
"""
# Read the data
if isinstance(data_path, pd.DataFrame):
smile_df = data_path.copy()
else:
# Try to read with different separators
try:
smile_df = pd.read_csv(data_path, sep='\t')
except:
smile_df = pd.read_csv(data_path)
# Check if columns exist, handle case-insensitive matching
columns_lower = {col.lower(): col for col in smile_df.columns}
smile_col_actual = columns_lower.get(smile_col.lower(), smile_col)
id_col_actual = columns_lower.get(id_col.lower(), id_col)
if smile_col_actual not in smile_df.columns:
raise ValueError(f"Column '{smile_col}' not found in data. Available columns: {list(smile_df.columns)}")
# Select columns
if id_col_actual in smile_df.columns:
smile_df = smile_df[[smile_col_actual, id_col_actual]]
smile_df.columns = [smile_col, id_col]
else:
# Create ID column if not exists
smile_df = smile_df[[smile_col_actual]]
smile_df.columns = [smile_col]
smile_df[id_col] = [f"mol{i+1}" for i in range(len(smile_df))]
# Make sure ID column is interpreted as str
smile_df[id_col] = smile_df[id_col].astype(str)
# Remove NaN
smile_df = smile_df.dropna()
# Remove invalid smiles
smile_df = smile_df[smile_df[smile_col].apply(lambda x: Chem.MolFromSmiles(x) is not None)]
return smile_df
def load_pretrained_model(pretrained_model_dir, device="cuda:0"):
"""
Load a pre-trained MolE model.
Parameters:
- pretrained_model_dir (str): Path to the pre-trained MolE model directory.
- device (str, optional): Device for computation (default is "cuda:0").
Returns:
- model: Loaded pre-trained model.
"""
# Read model configuration
config = yaml.load(open(os.path.join(pretrained_model_dir, "config.yaml"), "r"), Loader=yaml.FullLoader)
model_config = config["model"]
# Instantiate model
model = GINet(**model_config).to(device)
# Load pre-trained weights
model_pth_path = os.path.join(pretrained_model_dir, "model.pth")
print(f"Loading model from: {model_pth_path}")
state_dict = torch.load(model_pth_path, map_location=device)
model.load_my_state_dict(state_dict)
return model
def process_representation(dataset_path, smile_column_str, id_column_str, pretrained_dir, device):
"""
Process the dataset to generate molecular representations.
Parameters:
- dataset_path (str or pd.DataFrame): Path to the dataset file or DataFrame.
- pretrained_dir (str): Path to the pre-trained model directory.
- smile_column_str (str): Name of the column containing SMILES strings.
- id_column_str (str): Name of the column containing molecule IDs.
- device (str): Device to use for computation. Can be "cpu", "cuda:0", etc.
Returns:
- udl_representation (pandas.DataFrame): DataFrame containing molecular representations.
"""
# First we read the SMILES dataframe
smiles_df = read_smiles(dataset_path, smile_col=smile_column_str, id_col=id_column_str)
# Load the pre-trained model
pmodel = load_pretrained_model(pretrained_model_dir=pretrained_dir, device=device)
# Gather pre-trained representation
udl_representation = batch_representation(smiles_df, pmodel, smile_column_str, id_column_str, device=device)
return udl_representation