import time import numpy as np import torch from torch.cuda.amp import autocast def train_epoch(model, dataloader, device, simpleLossCompute, scaler, optimizer, use_amp=False): """ Train an epoch Args: use_amp: Whether to use mixed precision, default to False to avoid CUDA compatibility issues """ model.train() losses = [] start = time.time() for src, tgt, src_mask, _, _ in dataloader: src, tgt, src_mask = src.to(device), tgt.to(device), src_mask.to(device) if use_amp: # use autocast with autocast(): out = model(src=src, src_pad_mask=src_mask, tgt=tgt[:, :-1]) loss = simpleLossCompute(out, tgt[:, 1:]) optimizer.zero_grad() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: # no use autocast out = model(src=src, src_pad_mask=src_mask, tgt=tgt[:, :-1]) loss = simpleLossCompute(out, tgt[:, 1:]) optimizer.zero_grad() loss.backward() optimizer.step() losses.append(loss.item()) return np.mean(losses), time.time() - start def valid_epoch(model, dataloader, device, simpleLossCompute, use_amp=False): """ Train an epoch Args: use_amp: Whether to use mixed precision, default to False to avoid CUDA compatibility issues """ model.eval() losses = [] start = time.time() with torch.no_grad(): for src, tgt, src_mask, _, _ in dataloader: src, tgt, src_mask = src.to(device), tgt.to(device), src_mask.to(device) if use_amp: # use autocast with autocast(): out = model(src=src, src_pad_mask=src_mask, tgt=tgt[:, :-1]) loss = simpleLossCompute(out, tgt[:, 1:]) else: # no use autocast out = model(src=src, src_pad_mask=src_mask, tgt=tgt[:, :-1]) loss = simpleLossCompute(out, tgt[:, 1:]) losses.append(loss.item()) return np.mean(losses), time.time() - start