import torch import torch.nn as nn import torch.nn.functional as F import math class Embeddings(nn.Module): def __init__(self, d_model, vocab): super().__init__() self.lut = nn.Embedding(vocab, d_model) self.d_model = d_model def forward(self, x): return self.lut(x) * math.sqrt(self.d_model) class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout, max_len=4000): super().__init__() self.dropout = nn.Dropout(dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0., max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer("pe", pe.unsqueeze(0)) def forward(self, x): x = x + self.pe[:, :x.size(1)].detach() return self.dropout(x) class Generator(nn.Module): def __init__(self, d_model, vocab): super().__init__() self.proj = nn.Linear(d_model, vocab) def forward(self, x): return F.log_softmax(self.proj(x), dim=-1)