Files
labweb/models/model.py
2025-12-16 11:39:15 +08:00

74 lines
3.1 KiB
Python

import torch
import torch.nn as nn
from utils import Embeddings, PositionalEncoding, Generator
class Model(nn.Module):
def __init__(self, heads, d_model, tgt_vocab_size, num_encoder_layer, num_decoder_layer, dropout, compound_cab, src_input_dim=1024, first_linear_dim=256):
super().__init__()
# Source sequence dimension transformation layer: from 1024 to d_model(256)
self.src_linear1 = nn.Linear(src_input_dim, first_linear_dim)
self.src_linear2 = nn.Linear(first_linear_dim, d_model)
self.src_dropout = nn.Dropout(dropout)
self.src_activation = nn.ReLU()
self.tgt_embedding = Embeddings(d_model, tgt_vocab_size)
self.tgt_pos = PositionalEncoding(d_model, dropout)
self.transformer = nn.Transformer(
d_model=d_model,
batch_first=True,
nhead=heads,
num_encoder_layers=num_encoder_layer,
num_decoder_layers=num_decoder_layer,
dim_feedforward=4*d_model
)
self.generator = Generator(d_model, tgt_vocab_size)
self.compound_cab = compound_cab
def generate_masks(self, tgt):
tgt_pad_mask = (tgt == self.compound_cab['blank']).to(tgt.device)
tgt_len = tgt.size(1)
tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_len).to(tgt.device)
return tgt_pad_mask, tgt_mask
def forward(self, src, src_pad_mask, tgt):
# Transform source sequence dimensions: [batch_size, seq_len, 1024] -> [batch_size, seq_len, d_model]
src_transformed = self.src_linear1(src)
src_transformed = self.src_activation(src_transformed)
src_transformed = self.src_dropout(src_transformed)
src_transformed = self.src_linear2(src_transformed)
tgt_pad_mask, tgt_mask = self.generate_masks(tgt)
tgt_emb = self.tgt_pos(self.tgt_embedding(tgt))
y_pred = self.transformer(
src=src_transformed, tgt=tgt_emb, # Use transformed src
src_key_padding_mask=src_pad_mask,
tgt_key_padding_mask=tgt_pad_mask,
tgt_mask=tgt_mask
)
return self.generator(y_pred)
class LabelSmoothing(nn.Module):
def __init__(self, tgt_size, padding_idx, smoothing=0.0):
super().__init__()
self.criterion = nn.KLDivLoss(reduction="batchmean")
self.padding_idx = padding_idx
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.size = tgt_size
def forward(self, pred, target):
true_dist = pred.data.new_zeros(pred.size())
true_dist.scatter_(2, target.unsqueeze(-1), self.confidence)
true_dist[:, :, self.padding_idx] = 0.0
mask = target == self.padding_idx
true_dist[mask] = 0.0
return self.criterion(pred, true_dist.detach())
class SimpleLossCompute:
def __init__(self, criterion, device):
self.criterion = criterion
self.device = device
def __call__(self, x, y):
loss = self.criterion(x, y)
return loss.to(self.device)