first commit
This commit is contained in:
62
models/trainer.py
Normal file
62
models/trainer.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
def train_epoch(model, dataloader, device, simpleLossCompute, scaler, optimizer, use_amp=False):
|
||||
"""
|
||||
Train an epoch
|
||||
Args:
|
||||
use_amp: Whether to use mixed precision, default to False to avoid CUDA compatibility issues
|
||||
"""
|
||||
model.train()
|
||||
losses = []
|
||||
start = time.time()
|
||||
for src, tgt, src_mask, _, _ in dataloader:
|
||||
src, tgt, src_mask = src.to(device), tgt.to(device), src_mask.to(device)
|
||||
|
||||
if use_amp:
|
||||
# use autocast
|
||||
with autocast():
|
||||
out = model(src=src, src_pad_mask=src_mask, tgt=tgt[:, :-1])
|
||||
loss = simpleLossCompute(out, tgt[:, 1:])
|
||||
optimizer.zero_grad()
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
# no use autocast
|
||||
out = model(src=src, src_pad_mask=src_mask, tgt=tgt[:, :-1])
|
||||
loss = simpleLossCompute(out, tgt[:, 1:])
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
losses.append(loss.item())
|
||||
return np.mean(losses), time.time() - start
|
||||
|
||||
def valid_epoch(model, dataloader, device, simpleLossCompute, use_amp=False):
|
||||
"""
|
||||
Train an epoch
|
||||
Args:
|
||||
use_amp: Whether to use mixed precision, default to False to avoid CUDA compatibility issues
|
||||
"""
|
||||
model.eval()
|
||||
losses = []
|
||||
start = time.time()
|
||||
with torch.no_grad():
|
||||
for src, tgt, src_mask, _, _ in dataloader:
|
||||
src, tgt, src_mask = src.to(device), tgt.to(device), src_mask.to(device)
|
||||
|
||||
if use_amp:
|
||||
# use autocast
|
||||
with autocast():
|
||||
out = model(src=src, src_pad_mask=src_mask, tgt=tgt[:, :-1])
|
||||
loss = simpleLossCompute(out, tgt[:, 1:])
|
||||
else:
|
||||
# no use autocast
|
||||
out = model(src=src, src_pad_mask=src_mask, tgt=tgt[:, :-1])
|
||||
loss = simpleLossCompute(out, tgt[:, 1:])
|
||||
|
||||
losses.append(loss.item())
|
||||
return np.mean(losses), time.time() - start
|
||||
Reference in New Issue
Block a user