63 lines
2.2 KiB
Python
63 lines
2.2 KiB
Python
import time
|
|
import numpy as np
|
|
import torch
|
|
from torch.cuda.amp import autocast
|
|
|
|
def train_epoch(model, dataloader, device, simpleLossCompute, scaler, optimizer, use_amp=False):
|
|
"""
|
|
Train an epoch
|
|
Args:
|
|
use_amp: Whether to use mixed precision, default to False to avoid CUDA compatibility issues
|
|
"""
|
|
model.train()
|
|
losses = []
|
|
start = time.time()
|
|
for src, tgt, src_mask, _, _ in dataloader:
|
|
src, tgt, src_mask = src.to(device), tgt.to(device), src_mask.to(device)
|
|
|
|
if use_amp:
|
|
# use autocast
|
|
with autocast():
|
|
out = model(src=src, src_pad_mask=src_mask, tgt=tgt[:, :-1])
|
|
loss = simpleLossCompute(out, tgt[:, 1:])
|
|
optimizer.zero_grad()
|
|
scaler.scale(loss).backward()
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
else:
|
|
# no use autocast
|
|
out = model(src=src, src_pad_mask=src_mask, tgt=tgt[:, :-1])
|
|
loss = simpleLossCompute(out, tgt[:, 1:])
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
losses.append(loss.item())
|
|
return np.mean(losses), time.time() - start
|
|
|
|
def valid_epoch(model, dataloader, device, simpleLossCompute, use_amp=False):
|
|
"""
|
|
Train an epoch
|
|
Args:
|
|
use_amp: Whether to use mixed precision, default to False to avoid CUDA compatibility issues
|
|
"""
|
|
model.eval()
|
|
losses = []
|
|
start = time.time()
|
|
with torch.no_grad():
|
|
for src, tgt, src_mask, _, _ in dataloader:
|
|
src, tgt, src_mask = src.to(device), tgt.to(device), src_mask.to(device)
|
|
|
|
if use_amp:
|
|
# use autocast
|
|
with autocast():
|
|
out = model(src=src, src_pad_mask=src_mask, tgt=tgt[:, :-1])
|
|
loss = simpleLossCompute(out, tgt[:, 1:])
|
|
else:
|
|
# no use autocast
|
|
out = model(src=src, src_pad_mask=src_mask, tgt=tgt[:, :-1])
|
|
loss = simpleLossCompute(out, tgt[:, 1:])
|
|
|
|
losses.append(loss.item())
|
|
return np.mean(losses), time.time() - start
|