Files
labweb/models/trainer.py
2025-12-16 11:39:15 +08:00

63 lines
2.2 KiB
Python

import time
import numpy as np
import torch
from torch.cuda.amp import autocast
def train_epoch(model, dataloader, device, simpleLossCompute, scaler, optimizer, use_amp=False):
"""
Train an epoch
Args:
use_amp: Whether to use mixed precision, default to False to avoid CUDA compatibility issues
"""
model.train()
losses = []
start = time.time()
for src, tgt, src_mask, _, _ in dataloader:
src, tgt, src_mask = src.to(device), tgt.to(device), src_mask.to(device)
if use_amp:
# use autocast
with autocast():
out = model(src=src, src_pad_mask=src_mask, tgt=tgt[:, :-1])
loss = simpleLossCompute(out, tgt[:, 1:])
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
# no use autocast
out = model(src=src, src_pad_mask=src_mask, tgt=tgt[:, :-1])
loss = simpleLossCompute(out, tgt[:, 1:])
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
return np.mean(losses), time.time() - start
def valid_epoch(model, dataloader, device, simpleLossCompute, use_amp=False):
"""
Train an epoch
Args:
use_amp: Whether to use mixed precision, default to False to avoid CUDA compatibility issues
"""
model.eval()
losses = []
start = time.time()
with torch.no_grad():
for src, tgt, src_mask, _, _ in dataloader:
src, tgt, src_mask = src.to(device), tgt.to(device), src_mask.to(device)
if use_amp:
# use autocast
with autocast():
out = model(src=src, src_pad_mask=src_mask, tgt=tgt[:, :-1])
loss = simpleLossCompute(out, tgt[:, 1:])
else:
# no use autocast
out = model(src=src, src_pad_mask=src_mask, tgt=tgt[:, :-1])
loss = simpleLossCompute(out, tgt[:, 1:])
losses.append(loss.item())
return np.mean(losses), time.time() - start