180 lines
5.5 KiB
Python
180 lines
5.5 KiB
Python
import os
|
|
import yaml
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
import torch
|
|
from torch_geometric.data import Data, Dataset, Batch
|
|
|
|
from rdkit import Chem
|
|
from rdkit.Chem.rdchem import BondType as BT
|
|
from rdkit import RDLogger
|
|
|
|
RDLogger.DisableLog('rdApp.*')
|
|
|
|
|
|
ATOM_LIST = list(range(1,119))
|
|
CHIRALITY_LIST = [
|
|
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
|
|
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
|
|
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
|
|
Chem.rdchem.ChiralType.CHI_OTHER
|
|
]
|
|
BOND_LIST = [
|
|
BT.SINGLE,
|
|
BT.DOUBLE,
|
|
BT.TRIPLE,
|
|
BT.AROMATIC
|
|
]
|
|
BONDDIR_LIST = [
|
|
Chem.rdchem.BondDir.NONE,
|
|
Chem.rdchem.BondDir.ENDUPRIGHT,
|
|
Chem.rdchem.BondDir.ENDDOWNRIGHT
|
|
]
|
|
|
|
|
|
class MoleculeDataset(Dataset):
|
|
"""
|
|
Dataset class for creating molecular graphs.
|
|
|
|
Attributes:
|
|
- smile_df (pandas.DataFrame): DataFrame containing SMILES data.
|
|
- smile_column (str): Name of the column containing SMILES strings.
|
|
- id_column (str): Name of the column containing molecule IDs.
|
|
"""
|
|
|
|
def __init__(self, smile_df, smile_column, id_column):
|
|
super(Dataset, self).__init__()
|
|
|
|
# Gather the SMILES and the corresponding IDs
|
|
self.smiles_data = smile_df[smile_column].tolist()
|
|
self.id_data = smile_df[id_column].tolist()
|
|
|
|
def __getitem__(self, index):
|
|
# Get the molecule
|
|
mol = Chem.MolFromSmiles(self.smiles_data[index])
|
|
mol = Chem.AddHs(mol)
|
|
|
|
#########################
|
|
# Get the molecule info #
|
|
#########################
|
|
type_idx = []
|
|
chirality_idx = []
|
|
atomic_number = []
|
|
|
|
# Roberto: Might want to add more features later on. Such as atomic spin
|
|
for atom in mol.GetAtoms():
|
|
if atom.GetAtomicNum() == 0:
|
|
print(self.id_data[index])
|
|
|
|
type_idx.append(ATOM_LIST.index(atom.GetAtomicNum()))
|
|
chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag()))
|
|
atomic_number.append(atom.GetAtomicNum())
|
|
|
|
x1 = torch.tensor(type_idx, dtype=torch.long).view(-1,1)
|
|
x2 = torch.tensor(chirality_idx, dtype=torch.long).view(-1,1)
|
|
x = torch.cat([x1, x2], dim=-1)
|
|
|
|
row, col, edge_feat = [], [], []
|
|
for bond in mol.GetBonds():
|
|
start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
|
|
row += [start, end]
|
|
col += [end, start]
|
|
edge_feat.append([
|
|
BOND_LIST.index(bond.GetBondType()),
|
|
BONDDIR_LIST.index(bond.GetBondDir())
|
|
])
|
|
edge_feat.append([
|
|
BOND_LIST.index(bond.GetBondType()),
|
|
BONDDIR_LIST.index(bond.GetBondDir())
|
|
])
|
|
|
|
edge_index = torch.tensor([row, col], dtype=torch.long)
|
|
edge_attr = torch.tensor(np.array(edge_feat), dtype=torch.long)
|
|
|
|
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr,
|
|
chem_id=self.id_data[index])
|
|
|
|
return data
|
|
|
|
def __len__(self):
|
|
return len(self.smiles_data)
|
|
|
|
def get(self, index):
|
|
return self.__getitem__(index)
|
|
|
|
def len(self):
|
|
return self.__len__()
|
|
|
|
|
|
def batch_representation(smile_df, dl_model, column_str, id_str, batch_size=10_000, id_is_str=True, device="cuda:0"):
|
|
"""
|
|
Generate molecular representations using a Deep Learning model.
|
|
|
|
Parameters:
|
|
- smile_df (pandas.DataFrame): DataFrame containing SMILES data.
|
|
- dl_model: Deep Learning model for molecular representation.
|
|
- column_str (str): Name of the column containing SMILES strings.
|
|
- id_str (str): Name of the column containing molecule IDs.
|
|
- batch_size (int, optional): Batch size for processing (default is 10,000).
|
|
- id_is_str (bool, optional): Whether IDs are strings (default is True).
|
|
- device (str, optional): Device for computation (default is "cuda:0").
|
|
|
|
Returns:
|
|
- chem_representation (pandas.DataFrame): DataFrame containing molecular representations.
|
|
"""
|
|
|
|
# First we create a list of graphs
|
|
molecular_graph_dataset = MoleculeDataset(smile_df, column_str, id_str)
|
|
graph_list = [g for g in molecular_graph_dataset]
|
|
|
|
# Determine number of loops to do given the batch size
|
|
n_batches = len(graph_list) // batch_size
|
|
|
|
# Are all molecules accounted for?
|
|
remaining_molecules = len(graph_list) % batch_size
|
|
|
|
# Starting indices
|
|
start, end = 0, batch_size
|
|
|
|
# Determine number of iterations
|
|
if remaining_molecules == 0:
|
|
n_iter = n_batches
|
|
|
|
elif remaining_molecules > 0:
|
|
n_iter = n_batches + 1
|
|
|
|
# A list to store the batch dataframes
|
|
batch_dataframes = []
|
|
|
|
# Iterate over the batches
|
|
for i in range(n_iter):
|
|
# Start batch object
|
|
batch_obj = Batch()
|
|
graph_batch = batch_obj.from_data_list(graph_list[start:end])
|
|
graph_batch = graph_batch.to(device)
|
|
|
|
# Gather the representation
|
|
with torch.no_grad():
|
|
dl_model.eval()
|
|
h_representation, _ = dl_model(graph_batch)
|
|
chem_ids = graph_batch.chem_id
|
|
|
|
batch_df = pd.DataFrame(h_representation.cpu().numpy(), index=chem_ids)
|
|
batch_dataframes.append(batch_df)
|
|
|
|
# Get the next batch
|
|
## In the final iteration we want to get all the remaining molecules
|
|
if i == n_iter - 2:
|
|
start = end
|
|
end = len(graph_list)
|
|
else:
|
|
start = end
|
|
end = end + batch_size
|
|
|
|
# Concatenate the dataframes
|
|
chem_representation = pd.concat(batch_dataframes)
|
|
|
|
return chem_representation
|
|
|