Files
SIME/models/dataset_representation.py
2025-10-17 15:54:00 +08:00

180 lines
5.5 KiB
Python

import os
import yaml
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data, Dataset, Batch
from rdkit import Chem
from rdkit.Chem.rdchem import BondType as BT
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
ATOM_LIST = list(range(1,119))
CHIRALITY_LIST = [
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
Chem.rdchem.ChiralType.CHI_OTHER
]
BOND_LIST = [
BT.SINGLE,
BT.DOUBLE,
BT.TRIPLE,
BT.AROMATIC
]
BONDDIR_LIST = [
Chem.rdchem.BondDir.NONE,
Chem.rdchem.BondDir.ENDUPRIGHT,
Chem.rdchem.BondDir.ENDDOWNRIGHT
]
class MoleculeDataset(Dataset):
"""
Dataset class for creating molecular graphs.
Attributes:
- smile_df (pandas.DataFrame): DataFrame containing SMILES data.
- smile_column (str): Name of the column containing SMILES strings.
- id_column (str): Name of the column containing molecule IDs.
"""
def __init__(self, smile_df, smile_column, id_column):
super(Dataset, self).__init__()
# Gather the SMILES and the corresponding IDs
self.smiles_data = smile_df[smile_column].tolist()
self.id_data = smile_df[id_column].tolist()
def __getitem__(self, index):
# Get the molecule
mol = Chem.MolFromSmiles(self.smiles_data[index])
mol = Chem.AddHs(mol)
#########################
# Get the molecule info #
#########################
type_idx = []
chirality_idx = []
atomic_number = []
# Roberto: Might want to add more features later on. Such as atomic spin
for atom in mol.GetAtoms():
if atom.GetAtomicNum() == 0:
print(self.id_data[index])
type_idx.append(ATOM_LIST.index(atom.GetAtomicNum()))
chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag()))
atomic_number.append(atom.GetAtomicNum())
x1 = torch.tensor(type_idx, dtype=torch.long).view(-1,1)
x2 = torch.tensor(chirality_idx, dtype=torch.long).view(-1,1)
x = torch.cat([x1, x2], dim=-1)
row, col, edge_feat = [], [], []
for bond in mol.GetBonds():
start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
row += [start, end]
col += [end, start]
edge_feat.append([
BOND_LIST.index(bond.GetBondType()),
BONDDIR_LIST.index(bond.GetBondDir())
])
edge_feat.append([
BOND_LIST.index(bond.GetBondType()),
BONDDIR_LIST.index(bond.GetBondDir())
])
edge_index = torch.tensor([row, col], dtype=torch.long)
edge_attr = torch.tensor(np.array(edge_feat), dtype=torch.long)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr,
chem_id=self.id_data[index])
return data
def __len__(self):
return len(self.smiles_data)
def get(self, index):
return self.__getitem__(index)
def len(self):
return self.__len__()
def batch_representation(smile_df, dl_model, column_str, id_str, batch_size=10_000, id_is_str=True, device="cuda:0"):
"""
Generate molecular representations using a Deep Learning model.
Parameters:
- smile_df (pandas.DataFrame): DataFrame containing SMILES data.
- dl_model: Deep Learning model for molecular representation.
- column_str (str): Name of the column containing SMILES strings.
- id_str (str): Name of the column containing molecule IDs.
- batch_size (int, optional): Batch size for processing (default is 10,000).
- id_is_str (bool, optional): Whether IDs are strings (default is True).
- device (str, optional): Device for computation (default is "cuda:0").
Returns:
- chem_representation (pandas.DataFrame): DataFrame containing molecular representations.
"""
# First we create a list of graphs
molecular_graph_dataset = MoleculeDataset(smile_df, column_str, id_str)
graph_list = [g for g in molecular_graph_dataset]
# Determine number of loops to do given the batch size
n_batches = len(graph_list) // batch_size
# Are all molecules accounted for?
remaining_molecules = len(graph_list) % batch_size
# Starting indices
start, end = 0, batch_size
# Determine number of iterations
if remaining_molecules == 0:
n_iter = n_batches
elif remaining_molecules > 0:
n_iter = n_batches + 1
# A list to store the batch dataframes
batch_dataframes = []
# Iterate over the batches
for i in range(n_iter):
# Start batch object
batch_obj = Batch()
graph_batch = batch_obj.from_data_list(graph_list[start:end])
graph_batch = graph_batch.to(device)
# Gather the representation
with torch.no_grad():
dl_model.eval()
h_representation, _ = dl_model(graph_batch)
chem_ids = graph_batch.chem_id
batch_df = pd.DataFrame(h_representation.cpu().numpy(), index=chem_ids)
batch_dataframes.append(batch_df)
# Get the next batch
## In the final iteration we want to get all the remaining molecules
if i == n_iter - 2:
start = end
end = len(graph_list)
else:
start = end
end = end + batch_size
# Concatenate the dataframes
chem_representation = pd.concat(batch_dataframes)
return chem_representation