add mole predcit module
This commit is contained in:
179
models/dataset_representation.py
Normal file
179
models/dataset_representation.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import os
|
||||
import yaml
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
import torch
|
||||
from torch_geometric.data import Data, Dataset, Batch
|
||||
|
||||
from rdkit import Chem
|
||||
from rdkit.Chem.rdchem import BondType as BT
|
||||
from rdkit import RDLogger
|
||||
|
||||
RDLogger.DisableLog('rdApp.*')
|
||||
|
||||
|
||||
ATOM_LIST = list(range(1,119))
|
||||
CHIRALITY_LIST = [
|
||||
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
|
||||
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
|
||||
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
|
||||
Chem.rdchem.ChiralType.CHI_OTHER
|
||||
]
|
||||
BOND_LIST = [
|
||||
BT.SINGLE,
|
||||
BT.DOUBLE,
|
||||
BT.TRIPLE,
|
||||
BT.AROMATIC
|
||||
]
|
||||
BONDDIR_LIST = [
|
||||
Chem.rdchem.BondDir.NONE,
|
||||
Chem.rdchem.BondDir.ENDUPRIGHT,
|
||||
Chem.rdchem.BondDir.ENDDOWNRIGHT
|
||||
]
|
||||
|
||||
|
||||
class MoleculeDataset(Dataset):
|
||||
"""
|
||||
Dataset class for creating molecular graphs.
|
||||
|
||||
Attributes:
|
||||
- smile_df (pandas.DataFrame): DataFrame containing SMILES data.
|
||||
- smile_column (str): Name of the column containing SMILES strings.
|
||||
- id_column (str): Name of the column containing molecule IDs.
|
||||
"""
|
||||
|
||||
def __init__(self, smile_df, smile_column, id_column):
|
||||
super(Dataset, self).__init__()
|
||||
|
||||
# Gather the SMILES and the corresponding IDs
|
||||
self.smiles_data = smile_df[smile_column].tolist()
|
||||
self.id_data = smile_df[id_column].tolist()
|
||||
|
||||
def __getitem__(self, index):
|
||||
# Get the molecule
|
||||
mol = Chem.MolFromSmiles(self.smiles_data[index])
|
||||
mol = Chem.AddHs(mol)
|
||||
|
||||
#########################
|
||||
# Get the molecule info #
|
||||
#########################
|
||||
type_idx = []
|
||||
chirality_idx = []
|
||||
atomic_number = []
|
||||
|
||||
# Roberto: Might want to add more features later on. Such as atomic spin
|
||||
for atom in mol.GetAtoms():
|
||||
if atom.GetAtomicNum() == 0:
|
||||
print(self.id_data[index])
|
||||
|
||||
type_idx.append(ATOM_LIST.index(atom.GetAtomicNum()))
|
||||
chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag()))
|
||||
atomic_number.append(atom.GetAtomicNum())
|
||||
|
||||
x1 = torch.tensor(type_idx, dtype=torch.long).view(-1,1)
|
||||
x2 = torch.tensor(chirality_idx, dtype=torch.long).view(-1,1)
|
||||
x = torch.cat([x1, x2], dim=-1)
|
||||
|
||||
row, col, edge_feat = [], [], []
|
||||
for bond in mol.GetBonds():
|
||||
start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
|
||||
row += [start, end]
|
||||
col += [end, start]
|
||||
edge_feat.append([
|
||||
BOND_LIST.index(bond.GetBondType()),
|
||||
BONDDIR_LIST.index(bond.GetBondDir())
|
||||
])
|
||||
edge_feat.append([
|
||||
BOND_LIST.index(bond.GetBondType()),
|
||||
BONDDIR_LIST.index(bond.GetBondDir())
|
||||
])
|
||||
|
||||
edge_index = torch.tensor([row, col], dtype=torch.long)
|
||||
edge_attr = torch.tensor(np.array(edge_feat), dtype=torch.long)
|
||||
|
||||
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr,
|
||||
chem_id=self.id_data[index])
|
||||
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.smiles_data)
|
||||
|
||||
def get(self, index):
|
||||
return self.__getitem__(index)
|
||||
|
||||
def len(self):
|
||||
return self.__len__()
|
||||
|
||||
|
||||
def batch_representation(smile_df, dl_model, column_str, id_str, batch_size=10_000, id_is_str=True, device="cuda:0"):
|
||||
"""
|
||||
Generate molecular representations using a Deep Learning model.
|
||||
|
||||
Parameters:
|
||||
- smile_df (pandas.DataFrame): DataFrame containing SMILES data.
|
||||
- dl_model: Deep Learning model for molecular representation.
|
||||
- column_str (str): Name of the column containing SMILES strings.
|
||||
- id_str (str): Name of the column containing molecule IDs.
|
||||
- batch_size (int, optional): Batch size for processing (default is 10,000).
|
||||
- id_is_str (bool, optional): Whether IDs are strings (default is True).
|
||||
- device (str, optional): Device for computation (default is "cuda:0").
|
||||
|
||||
Returns:
|
||||
- chem_representation (pandas.DataFrame): DataFrame containing molecular representations.
|
||||
"""
|
||||
|
||||
# First we create a list of graphs
|
||||
molecular_graph_dataset = MoleculeDataset(smile_df, column_str, id_str)
|
||||
graph_list = [g for g in molecular_graph_dataset]
|
||||
|
||||
# Determine number of loops to do given the batch size
|
||||
n_batches = len(graph_list) // batch_size
|
||||
|
||||
# Are all molecules accounted for?
|
||||
remaining_molecules = len(graph_list) % batch_size
|
||||
|
||||
# Starting indices
|
||||
start, end = 0, batch_size
|
||||
|
||||
# Determine number of iterations
|
||||
if remaining_molecules == 0:
|
||||
n_iter = n_batches
|
||||
|
||||
elif remaining_molecules > 0:
|
||||
n_iter = n_batches + 1
|
||||
|
||||
# A list to store the batch dataframes
|
||||
batch_dataframes = []
|
||||
|
||||
# Iterate over the batches
|
||||
for i in range(n_iter):
|
||||
# Start batch object
|
||||
batch_obj = Batch()
|
||||
graph_batch = batch_obj.from_data_list(graph_list[start:end])
|
||||
graph_batch = graph_batch.to(device)
|
||||
|
||||
# Gather the representation
|
||||
with torch.no_grad():
|
||||
dl_model.eval()
|
||||
h_representation, _ = dl_model(graph_batch)
|
||||
chem_ids = graph_batch.chem_id
|
||||
|
||||
batch_df = pd.DataFrame(h_representation.cpu().numpy(), index=chem_ids)
|
||||
batch_dataframes.append(batch_df)
|
||||
|
||||
# Get the next batch
|
||||
## In the final iteration we want to get all the remaining molecules
|
||||
if i == n_iter - 2:
|
||||
start = end
|
||||
end = len(graph_list)
|
||||
else:
|
||||
start = end
|
||||
end = end + batch_size
|
||||
|
||||
# Concatenate the dataframes
|
||||
chem_representation = pd.concat(batch_dataframes)
|
||||
|
||||
return chem_representation
|
||||
|
||||
Reference in New Issue
Block a user