import argparse import torch import os from model import Model, LabelSmoothing, SimpleLossCompute from trainer import train_epoch, valid_epoch from datasets import build_dataloaders, build_lazy_dataloaders, DataProcessor from torch.cuda.amp import GradScaler class EarlyStopping: def __init__(self, patience=5, delta=0.0, save_path="best_model.pt"): self.patience = patience self.delta = delta self.save_path = save_path self.counter = 0 self.best_loss = None self.early_stop = False def __call__(self, val_loss, model): if self.best_loss is None or val_loss < self.best_loss - self.delta: self.best_loss = val_loss self.counter = 0 torch.save(model.state_dict(), self.save_path) print(f"? Validation loss improved to {val_loss:.4f}. Model saved to {self.save_path}") else: self.counter += 1 print(f"?? No improvement. EarlyStopping counter: {self.counter}/{self.patience}") if self.counter >= self.patience: self.early_stop = True def main(): parser = argparse.ArgumentParser() parser.add_argument("--device", type=str, default="cuda:0") parser.add_argument("--d_model", type=int, default=64) parser.add_argument("--heads", type=int, default=8) parser.add_argument("--enc_layers", type=int, default=3) parser.add_argument("--dec_layers", type=int, default=3) parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--lr", type=float, default=1e-5) parser.add_argument("--epochs", type=int, default=200) parser.add_argument("--es_patience", type=int, default=10) parser.add_argument("--lr_patience", type=int, default=3) parser.add_argument("--factor", type=float, default=0.8) parser.add_argument("--save_path", type=str, default="best_model.pt") args = parser.parse_args() device = torch.device(args.device) print(f"Using device: {device}") # Create DataProcessor instance dp = DataProcessor( ko_count_path="/home/zzhang/gzy/Uncultured/generative_ML/data/gene_ko_protein/KO_count.csv", data_path="/home/zzhang/gzy/Uncultured/generative_ML/encoder_decoder/protT5_embedding/new_ko_all.csv", embedding_h5_path="/home/zzhang/gzy/Uncultured/ko_pre_train/ProtT5_Encoder/test_genome_ProtT5_embeddings.h5" ) # Only load metadata, not all embedding data dp.load_data() # Use new lazy loading method to build data loaders train_loader, train_loader_no_shuffle, valid_loader, test_loader = build_lazy_dataloaders( dp, batch_size=args.batch_size, split_ratio=8, shuffle=True ) # Save data loaders torch.save(train_loader, 'train_loader.pth') torch.save(train_loader_no_shuffle, 'train_loader_no_shuffle.pth') torch.save(test_loader, 'test_loader.pth') torch.save(valid_loader, 'valid_loader.pth') model = Model(args.heads, args.d_model, len(dp.vo_cab), args.enc_layers, args.dec_layers, args.dropout, dp.vo_cab).to(device) criterion = LabelSmoothing(tgt_size=len(dp.vo_cab), padding_idx=dp.vo_cab["blank"], smoothing=0.0).to(device) loss_compute = SimpleLossCompute(criterion, device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=args.lr_patience, factor=args.factor) scaler = GradScaler() # Early stopping early_stopping = EarlyStopping(patience=args.es_patience, save_path=args.save_path) # Mixed precision training option (default off to avoid CUDA compatibility issues) use_mixed_precision = False print(f"Using mixed precision training: {use_mixed_precision}") for epoch in range(1, args.epochs+1): train_loss, t_time = train_epoch(model, train_loader, device, loss_compute, scaler, optimizer, use_amp=use_mixed_precision) valid_loss, v_time = valid_epoch(model, valid_loader, device, loss_compute, use_amp=use_mixed_precision) scheduler.step(valid_loss) current_lr = optimizer.param_groups[0]['lr'] print(f"[Epoch {epoch}] train_loss={train_loss:.4f} ({t_time:.1f}s), " f"valid_loss={valid_loss:.4f} ({v_time:.1f}s), lr={current_lr:.6f}") # early stopping check early_stopping(valid_loss, model) if early_stopping.early_stop: print("?? Early stopping triggered.") break print(f"Training finished. Best model is saved at {args.save_path}") if __name__ == "__main__": main()