74 lines
3.1 KiB
Python
74 lines
3.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from utils import Embeddings, PositionalEncoding, Generator
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, heads, d_model, tgt_vocab_size, num_encoder_layer, num_decoder_layer, dropout, compound_cab, src_input_dim=1024, first_linear_dim=256):
|
|
super().__init__()
|
|
# Source sequence dimension transformation layer: from 1024 to d_model(256)
|
|
self.src_linear1 = nn.Linear(src_input_dim, first_linear_dim)
|
|
self.src_linear2 = nn.Linear(first_linear_dim, d_model)
|
|
self.src_dropout = nn.Dropout(dropout)
|
|
self.src_activation = nn.ReLU()
|
|
|
|
self.tgt_embedding = Embeddings(d_model, tgt_vocab_size)
|
|
self.tgt_pos = PositionalEncoding(d_model, dropout)
|
|
self.transformer = nn.Transformer(
|
|
d_model=d_model,
|
|
batch_first=True,
|
|
nhead=heads,
|
|
num_encoder_layers=num_encoder_layer,
|
|
num_decoder_layers=num_decoder_layer,
|
|
dim_feedforward=4*d_model
|
|
)
|
|
self.generator = Generator(d_model, tgt_vocab_size)
|
|
self.compound_cab = compound_cab
|
|
|
|
def generate_masks(self, tgt):
|
|
tgt_pad_mask = (tgt == self.compound_cab['blank']).to(tgt.device)
|
|
tgt_len = tgt.size(1)
|
|
tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_len).to(tgt.device)
|
|
return tgt_pad_mask, tgt_mask
|
|
|
|
def forward(self, src, src_pad_mask, tgt):
|
|
# Transform source sequence dimensions: [batch_size, seq_len, 1024] -> [batch_size, seq_len, d_model]
|
|
src_transformed = self.src_linear1(src)
|
|
src_transformed = self.src_activation(src_transformed)
|
|
src_transformed = self.src_dropout(src_transformed)
|
|
src_transformed = self.src_linear2(src_transformed)
|
|
|
|
tgt_pad_mask, tgt_mask = self.generate_masks(tgt)
|
|
tgt_emb = self.tgt_pos(self.tgt_embedding(tgt))
|
|
y_pred = self.transformer(
|
|
src=src_transformed, tgt=tgt_emb, # Use transformed src
|
|
src_key_padding_mask=src_pad_mask,
|
|
tgt_key_padding_mask=tgt_pad_mask,
|
|
tgt_mask=tgt_mask
|
|
)
|
|
return self.generator(y_pred)
|
|
|
|
class LabelSmoothing(nn.Module):
|
|
def __init__(self, tgt_size, padding_idx, smoothing=0.0):
|
|
super().__init__()
|
|
self.criterion = nn.KLDivLoss(reduction="batchmean")
|
|
self.padding_idx = padding_idx
|
|
self.confidence = 1.0 - smoothing
|
|
self.smoothing = smoothing
|
|
self.size = tgt_size
|
|
|
|
def forward(self, pred, target):
|
|
true_dist = pred.data.new_zeros(pred.size())
|
|
true_dist.scatter_(2, target.unsqueeze(-1), self.confidence)
|
|
true_dist[:, :, self.padding_idx] = 0.0
|
|
mask = target == self.padding_idx
|
|
true_dist[mask] = 0.0
|
|
return self.criterion(pred, true_dist.detach())
|
|
|
|
class SimpleLossCompute:
|
|
def __init__(self, criterion, device):
|
|
self.criterion = criterion
|
|
self.device = device
|
|
def __call__(self, x, y):
|
|
loss = self.criterion(x, y)
|
|
return loss.to(self.device)
|