109 lines
4.6 KiB
Python
109 lines
4.6 KiB
Python
import argparse
|
|
import torch
|
|
import os
|
|
from model import Model, LabelSmoothing, SimpleLossCompute
|
|
from trainer import train_epoch, valid_epoch
|
|
from datasets import build_dataloaders, build_lazy_dataloaders, DataProcessor
|
|
from torch.cuda.amp import GradScaler
|
|
|
|
class EarlyStopping:
|
|
def __init__(self, patience=5, delta=0.0, save_path="best_model.pt"):
|
|
|
|
self.patience = patience
|
|
self.delta = delta
|
|
self.save_path = save_path
|
|
self.counter = 0
|
|
self.best_loss = None
|
|
self.early_stop = False
|
|
|
|
def __call__(self, val_loss, model):
|
|
if self.best_loss is None or val_loss < self.best_loss - self.delta:
|
|
self.best_loss = val_loss
|
|
self.counter = 0
|
|
torch.save(model.state_dict(), self.save_path)
|
|
print(f"? Validation loss improved to {val_loss:.4f}. Model saved to {self.save_path}")
|
|
else:
|
|
self.counter += 1
|
|
print(f"?? No improvement. EarlyStopping counter: {self.counter}/{self.patience}")
|
|
if self.counter >= self.patience:
|
|
self.early_stop = True
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--device", type=str, default="cuda:0")
|
|
parser.add_argument("--d_model", type=int, default=64)
|
|
parser.add_argument("--heads", type=int, default=8)
|
|
parser.add_argument("--enc_layers", type=int, default=3)
|
|
parser.add_argument("--dec_layers", type=int, default=3)
|
|
parser.add_argument("--dropout", type=float, default=0.1)
|
|
parser.add_argument("--batch_size", type=int, default=1)
|
|
parser.add_argument("--lr", type=float, default=1e-5)
|
|
parser.add_argument("--epochs", type=int, default=200)
|
|
parser.add_argument("--es_patience", type=int, default=10)
|
|
parser.add_argument("--lr_patience", type=int, default=3)
|
|
parser.add_argument("--factor", type=float, default=0.8)
|
|
parser.add_argument("--save_path", type=str, default="best_model.pt")
|
|
args = parser.parse_args()
|
|
|
|
device = torch.device(args.device)
|
|
print(f"Using device: {device}")
|
|
|
|
# Create DataProcessor instance
|
|
dp = DataProcessor(
|
|
ko_count_path="/home/zzhang/gzy/Uncultured/generative_ML/data/gene_ko_protein/KO_count.csv",
|
|
data_path="/home/zzhang/gzy/Uncultured/generative_ML/encoder_decoder/protT5_embedding/new_ko_all.csv",
|
|
embedding_h5_path="/home/zzhang/gzy/Uncultured/ko_pre_train/ProtT5_Encoder/test_genome_ProtT5_embeddings.h5"
|
|
)
|
|
|
|
# Only load metadata, not all embedding data
|
|
dp.load_data()
|
|
|
|
# Use new lazy loading method to build data loaders
|
|
train_loader, train_loader_no_shuffle, valid_loader, test_loader = build_lazy_dataloaders(
|
|
dp, batch_size=args.batch_size, split_ratio=8, shuffle=True
|
|
)
|
|
|
|
# Save data loaders
|
|
torch.save(train_loader, 'train_loader.pth')
|
|
torch.save(train_loader_no_shuffle, 'train_loader_no_shuffle.pth')
|
|
torch.save(test_loader, 'test_loader.pth')
|
|
torch.save(valid_loader, 'valid_loader.pth')
|
|
|
|
|
|
model = Model(args.heads, args.d_model, len(dp.vo_cab), args.enc_layers, args.dec_layers, args.dropout, dp.vo_cab).to(device)
|
|
criterion = LabelSmoothing(tgt_size=len(dp.vo_cab), padding_idx=dp.vo_cab["blank"], smoothing=0.0).to(device)
|
|
loss_compute = SimpleLossCompute(criterion, device)
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=args.lr_patience, factor=args.factor)
|
|
|
|
scaler = GradScaler()
|
|
|
|
# Early stopping
|
|
early_stopping = EarlyStopping(patience=args.es_patience, save_path=args.save_path)
|
|
|
|
# Mixed precision training option (default off to avoid CUDA compatibility issues)
|
|
use_mixed_precision = False
|
|
print(f"Using mixed precision training: {use_mixed_precision}")
|
|
|
|
for epoch in range(1, args.epochs+1):
|
|
train_loss, t_time = train_epoch(model, train_loader, device, loss_compute, scaler, optimizer, use_amp=use_mixed_precision)
|
|
valid_loss, v_time = valid_epoch(model, valid_loader, device, loss_compute, use_amp=use_mixed_precision)
|
|
|
|
scheduler.step(valid_loss)
|
|
current_lr = optimizer.param_groups[0]['lr']
|
|
|
|
print(f"[Epoch {epoch}] train_loss={train_loss:.4f} ({t_time:.1f}s), "
|
|
f"valid_loss={valid_loss:.4f} ({v_time:.1f}s), lr={current_lr:.6f}")
|
|
|
|
# early stopping check
|
|
early_stopping(valid_loss, model)
|
|
if early_stopping.early_stop:
|
|
print("?? Early stopping triggered.")
|
|
break
|
|
|
|
print(f"Training finished. Best model is saved at {args.save_path}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|