import torch import torch.nn as nn from utils import Embeddings, PositionalEncoding, Generator class Model(nn.Module): def __init__(self, heads, d_model, tgt_vocab_size, num_encoder_layer, num_decoder_layer, dropout, compound_cab, src_input_dim=1024, first_linear_dim=256): super().__init__() # Source sequence dimension transformation layer: from 1024 to d_model(256) self.src_linear1 = nn.Linear(src_input_dim, first_linear_dim) self.src_linear2 = nn.Linear(first_linear_dim, d_model) self.src_dropout = nn.Dropout(dropout) self.src_activation = nn.ReLU() self.tgt_embedding = Embeddings(d_model, tgt_vocab_size) self.tgt_pos = PositionalEncoding(d_model, dropout) self.transformer = nn.Transformer( d_model=d_model, batch_first=True, nhead=heads, num_encoder_layers=num_encoder_layer, num_decoder_layers=num_decoder_layer, dim_feedforward=4*d_model ) self.generator = Generator(d_model, tgt_vocab_size) self.compound_cab = compound_cab def generate_masks(self, tgt): tgt_pad_mask = (tgt == self.compound_cab['blank']).to(tgt.device) tgt_len = tgt.size(1) tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_len).to(tgt.device) return tgt_pad_mask, tgt_mask def forward(self, src, src_pad_mask, tgt): # Transform source sequence dimensions: [batch_size, seq_len, 1024] -> [batch_size, seq_len, d_model] src_transformed = self.src_linear1(src) src_transformed = self.src_activation(src_transformed) src_transformed = self.src_dropout(src_transformed) src_transformed = self.src_linear2(src_transformed) tgt_pad_mask, tgt_mask = self.generate_masks(tgt) tgt_emb = self.tgt_pos(self.tgt_embedding(tgt)) y_pred = self.transformer( src=src_transformed, tgt=tgt_emb, # Use transformed src src_key_padding_mask=src_pad_mask, tgt_key_padding_mask=tgt_pad_mask, tgt_mask=tgt_mask ) return self.generator(y_pred) class LabelSmoothing(nn.Module): def __init__(self, tgt_size, padding_idx, smoothing=0.0): super().__init__() self.criterion = nn.KLDivLoss(reduction="batchmean") self.padding_idx = padding_idx self.confidence = 1.0 - smoothing self.smoothing = smoothing self.size = tgt_size def forward(self, pred, target): true_dist = pred.data.new_zeros(pred.size()) true_dist.scatter_(2, target.unsqueeze(-1), self.confidence) true_dist[:, :, self.padding_idx] = 0.0 mask = target == self.padding_idx true_dist[mask] = 0.0 return self.criterion(pred, true_dist.detach()) class SimpleLossCompute: def __init__(self, criterion, device): self.criterion = criterion self.device = device def __call__(self, x, y): loss = self.criterion(x, y) return loss.to(self.device)