import torch from torch import nn import torch.nn.functional as F from torch_geometric.nn import MessagePassing from torch_geometric.utils import add_self_loops from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool num_atom_type = 119 # including the extra mask tokens num_chirality_tag = 3 num_bond_type = 5 # including aromatic and self-loop edge num_bond_direction = 3 class GINEConv(MessagePassing): def __init__(self, emb_dim): super(GINEConv, self).__init__() self.mlp = nn.Sequential( nn.Linear(emb_dim, 2*emb_dim), nn.BatchNorm1d(2*emb_dim), nn.ReLU(), nn.Linear(2*emb_dim, emb_dim), nn.ReLU() ) self.edge_embedding1 = nn.Embedding(num_bond_type, emb_dim) self.edge_embedding2 = nn.Embedding(num_bond_direction, emb_dim) nn.init.xavier_uniform_(self.edge_embedding1.weight.data) nn.init.xavier_uniform_(self.edge_embedding2.weight.data) def forward(self, x, edge_index, edge_attr): # add self loops in the edge space edge_index = add_self_loops(edge_index, num_nodes=x.size(0))[0] # add features corresponding to self-loop edges. self_loop_attr = torch.zeros(x.size(0), 2) self_loop_attr[:,0] = 4 #bond type for self-loop edge self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype) edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0) edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1]) return self.propagate(edge_index, x=x, edge_attr=edge_embeddings) def message(self, x_j, edge_attr): return x_j + edge_attr def update(self, aggr_out): return self.mlp(aggr_out) class GINet(nn.Module): """ GIN encoder from MolE. Args: num_layer (int): Number of GNN layers. emb_dim (int): Dimensionality of embeddings for each graph layer. feat_dim (int): Dimensionality of embedding vector. drop_ratio (float): Dropout rate. pool (str): Pooling method for neighbor aggregation ('mean', 'max', or 'add'). Output: h_global_embedding: Graph-level representation out: Final embedding vector """ def __init__(self, num_layer=5, emb_dim=300, feat_dim=256, drop_ratio=0, pool='mean'): super(GINet, self).__init__() self.num_layer = num_layer self.emb_dim = emb_dim self.feat_dim = feat_dim self.drop_ratio = drop_ratio self.concat_dim = num_layer * emb_dim if self.concat_dim != self.feat_dim: print(f"Representation dimension ({self.concat_dim}) - Embedding dimension ({self.feat_dim})") self.x_embedding1 = nn.Embedding(num_atom_type, emb_dim) self.x_embedding2 = nn.Embedding(num_chirality_tag, emb_dim) nn.init.xavier_uniform_(self.x_embedding1.weight.data) nn.init.xavier_uniform_(self.x_embedding2.weight.data) # List of MLPs self.gnns = nn.ModuleList() for layer in range(num_layer): self.gnns.append(GINEConv(emb_dim)) # List of batchnorms self.batch_norms = nn.ModuleList() for layer in range(num_layer): self.batch_norms.append(nn.BatchNorm1d(emb_dim)) if pool == 'mean': self.pool = global_mean_pool elif pool == 'max': self.pool = global_max_pool elif pool == 'add': self.pool = global_add_pool self.feat_lin = nn.Linear(self.concat_dim, self.feat_dim) self.out_lin = nn.Sequential( nn.Linear(self.feat_dim, self.feat_dim), nn.BatchNorm1d(self.feat_dim), nn.ReLU(inplace=True), nn.Linear(self.feat_dim, self.feat_dim), # Is not reduced to half size! nn.BatchNorm1d(self.feat_dim), nn.ReLU(inplace=True), nn.Linear(self.feat_dim, self.feat_dim) ) def forward(self, data): x = data.x edge_index = data.edge_index edge_attr = data.edge_attr h_init = self.x_embedding1(x[:,0]) + self.x_embedding2(x[:,1]) # Perform the convolutions h_dict = {} for layer in range(self.num_layer): if layer == self.num_layer - 1: tmp_h = self.gnns[layer](h_dict[f"h_{layer - 1}"], edge_index, edge_attr) tmp_h = self.batch_norms[layer](tmp_h) h_dict[f"h_{layer}"] = F.dropout(tmp_h, self.drop_ratio, training=self.training) else: if layer == 0: tmp_h = self.gnns[layer](h_init, edge_index, edge_attr) tmp_h = self.batch_norms[layer](tmp_h) h_dict[f"h_{layer}"] = F.dropout(F.relu(tmp_h), self.drop_ratio, training=self.training) else: tmp_h = self.gnns[layer](h_dict[f"h_{layer - 1}"], edge_index, edge_attr) tmp_h = self.batch_norms[layer](tmp_h) h_dict[f"h_{layer}"] = F.dropout(F.relu(tmp_h), self.drop_ratio, training=self.training) # Graph representation h_list_pooled = [self.pool(h_dict[f"h_{layer}"], data.batch) for layer in range(self.num_layer)] h_global_embedding = torch.cat(h_list_pooled, dim=1) assert h_global_embedding.shape[1] == self.concat_dim # Projection h_expansion = self.feat_lin(h_global_embedding) out = self.out_lin(h_expansion) return h_global_embedding, out def load_my_state_dict(self, state_dict): own_state = self.state_dict() for name, param in state_dict.items(): if name not in own_state: continue if isinstance(param, nn.parameter.Parameter): # backwards compatibility for serialized parameters param = param.data print(name) own_state[name].copy_(param)