#!/usr/bin/env python3 """ ProtT5 Encoder-Decoder Model Training Script Supports lazy loading of data, uses predefined vocab to avoid memory issues """ import argparse import torch import os from datasets import DataProcessor, build_single_lazy_dataloader from model import Model, LabelSmoothing, SimpleLossCompute from trainer import train_epoch, valid_epoch from torch.cuda.amp import GradScaler import pandas as pd class EarlyStopping: def __init__(self, patience=5, delta=0.0, save_path="best_model.pt"): self.patience = patience self.delta = delta self.save_path = save_path self.counter = 0 self.best_loss = None self.early_stop = False def __call__(self, val_loss, model): if self.best_loss is None or val_loss < self.best_loss - self.delta: self.best_loss = val_loss self.counter = 0 torch.save(model.state_dict(), self.save_path) print(f"✅ Validation loss improved to {val_loss:.4f}. Model saved to {self.save_path}") else: self.counter += 1 print(f"❌ No improvement. EarlyStopping counter: {self.counter}/{self.patience}") if self.counter >= self.patience: self.early_stop = True def check_gpu_compatibility(): """Check if GPU supports mixed precision training""" if torch.cuda.is_available(): device_cap = torch.cuda.get_device_capability() print(f"GPU Compute Capability: {device_cap}") if device_cap[0] >= 7: # V100, RTX series, etc. print("✅ GPU supports mixed precision training") return True else: print("❌ GPU does not support mixed precision training, recommend using normal precision") return False else: print("⚠️ No CUDA device detected") return False def create_data_processors(config): """Create training, validation, and test data processors""" print("=== Creating Data Processors ===") # Training set dp_train = DataProcessor( ko_count_path=config['ko_count_path'], data_path=config['train_data_path'], embedding_h5_path=config['embedding_h5_path'], vo_cab_pkl_path=config['vo_cab_pkl_path'], model_type='train', ko_max_len=config['ko_max_len'], compound_max_len=config['compound_max_len'] ) dp_train.load_data(count_min=config['count_min'], count_max=config['count_max']) # Validation set dp_valid = DataProcessor( ko_count_path=config['ko_count_path'], data_path=config['valid_data_path'], embedding_h5_path=config['embedding_h5_path'], vo_cab_pkl_path=config['vo_cab_pkl_path'], model_type='valid', ko_max_len=config['ko_max_len'], compound_max_len=config['compound_max_len'] ) dp_valid.load_data(count_min=config['count_min'], count_max=config['count_max']) # Test set dp_test = DataProcessor( ko_count_path=config['ko_count_path'], data_path=config['test_data_path'], embedding_h5_path=config['embedding_h5_path'], vo_cab_pkl_path=config['vo_cab_pkl_path'], model_type='test', ko_max_len=config['ko_max_len'], compound_max_len=config['compound_max_len'] ) dp_test.load_data(count_min=config['count_min'], count_max=config['count_max']) return dp_train, dp_valid, dp_test def create_data_loaders(dp_train, dp_valid, dp_test, batch_size): """Create data loaders""" print("=== Creating Data Loaders ===") train_loader = build_single_lazy_dataloader(dp_train, batch_size=batch_size, shuffle=True) valid_loader = build_single_lazy_dataloader(dp_valid, batch_size=batch_size, shuffle=False) test_loader = build_single_lazy_dataloader(dp_test, batch_size=batch_size, shuffle=False) print(f"Training set batch count: {len(train_loader)}") print(f"Validation set batch count: {len(valid_loader)}") print(f"Test set batch count: {len(test_loader)}") return train_loader, valid_loader, test_loader def create_model(model_config, vo_cab, device): """Create model""" print("=== Creating Model ===") model = Model( heads=model_config['heads'], d_model=model_config['d_model'], tgt_vocab_size=len(vo_cab), num_encoder_layer=model_config['num_encoder_layer'], num_decoder_layer=model_config['num_decoder_layer'], dropout=model_config['dropout'], compound_cab=vo_cab, src_input_dim=model_config['src_input_dim'], first_linear_dim=model_config['first_linear_dim'] ).to(device) criterion = LabelSmoothing( tgt_size=len(vo_cab), padding_idx=vo_cab["blank"], smoothing=model_config['label_smoothing'] ).to(device) loss_compute = SimpleLossCompute(criterion, device) print(f"Model parameters:") print(f" - heads: {model_config['heads']}") print(f" - d_model: {model_config['d_model']} (transformer internal dimension)") print(f" - src_input_dim: {model_config['src_input_dim']} (input dimension)") print(f" - first_linear_dim: {model_config['first_linear_dim']} (first layer dimension)") print(f" - vocab_size: {len(vo_cab)}") print(f" - encoder_layers: {model_config['num_encoder_layer']}") print(f" - decoder_layers: {model_config['num_decoder_layer']}") print(f" - dropout: {model_config['dropout']}") print(f" - label_smoothing: {model_config['label_smoothing']}") print(f" - dimension transformation: {model_config['src_input_dim']} -> {model_config['first_linear_dim']} -> {model_config['d_model']}") print(f" - model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") return model, loss_compute def train_model(model, train_loader, valid_loader, loss_compute, train_config, device, vo_cab_size): """Train model""" print("=== Starting Training ===") # Optimizer and scheduler optimizer = torch.optim.Adam(model.parameters(), lr=train_config['learning_rate']) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, 'min', patience=train_config['lr_patience'], factor=train_config['lr_factor'] ) # Mixed precision training scaler = GradScaler() use_mixed_precision = train_config['use_mixed_precision'] and check_gpu_compatibility() # Early stopping early_stopping = EarlyStopping( patience=train_config['early_stopping_patience'], save_path=train_config['parameters_save_path'] ) print(f"Training configuration:") print(f" - epochs: {train_config['epochs']}") print(f" - learning_rate: {train_config['learning_rate']}") print(f" - use_mixed_precision: {use_mixed_precision}") print(f" - early_stopping_patience: {train_config['early_stopping_patience']}") print(f" - parameters_save_path: {train_config['parameters_save_path']}") print(f" - loss_save_path: {train_config['loss_save_path']}") # Training loop loss_log_path = train_config['loss_save_path'] all_train_valid_info = [] all_train_valid_info = pd.DataFrame(all_train_valid_info, columns = ['epoch', 'train_loss', 'train_time', 'valid_loss', 'valid_time', 'lr']) for epoch in range(1, train_config['epochs'] + 1): train_loss, t_time = train_epoch( model, train_loader, device, loss_compute, scaler, optimizer, use_amp=use_mixed_precision ) valid_loss, v_time = valid_epoch( model, valid_loader, device, loss_compute, use_amp=use_mixed_precision ) scheduler.step(valid_loss) current_lr = optimizer.param_groups[0]['lr'] print(f"[Epoch {epoch:3d}] train_loss={train_loss:.4f} ({t_time:5.1f}s), " f"valid_loss={valid_loss:.4f} ({v_time:5.1f}s), lr={current_lr}") # Save file new_row = pd.DataFrame([[epoch, train_loss, t_time, valid_loss, v_time, current_lr]], columns=['epoch', 'train_loss', 'train_time', 'valid_loss', 'valid_time', 'lr']) all_train_valid_info = pd.concat([all_train_valid_info, new_row]) all_train_valid_info.to_csv(loss_log_path, index=False) # Early stopping check early_stopping(valid_loss, model) if early_stopping.early_stop: print("🛑 Early stopping triggered.") break print(f"🎯 Training finished. Best model saved at {train_config['parameters_save_path']}") def main(): parser = argparse.ArgumentParser(description='ProtT5 Encoder-Decoder Model Training') # Data path parameters parser.add_argument('--ko_count_path', type=str, required=True, help='KO count CSV file path') parser.add_argument('--train_data_path', type=str, required=True, help='Training data CSV file path') parser.add_argument('--valid_data_path', type=str, required=True, help='Validation data CSV file path') parser.add_argument('--test_data_path', type=str, required=True, help='Test data CSV file path') parser.add_argument('--embedding_h5_path', type=str, required=True, help='HDF5 embedding file path') parser.add_argument('--vo_cab_pkl_path', type=str, required=True, help='Vocab pickle file path') # Model parameters parser.add_argument('--heads', type=int, default=8, help='Number of attention heads') parser.add_argument('--d_model', type=int, default=256, help='Model dimension (after src transformation)') parser.add_argument('--src_input_dim', type=int, default=1024, help='Input dimension of source sequences') parser.add_argument('--first_linear_dim', type=int, default=512, help='First linear dimension') parser.add_argument('--num_encoder_layer', type=int, default=3, help='Number of encoder layers') parser.add_argument('--num_decoder_layer', type=int, default=3, help='Number of decoder layers') parser.add_argument('--dropout', type=float, default=0.1, help='Dropout rate') parser.add_argument('--label_smoothing', type=float, default=0.0, help='Label smoothing factor') # Training parameters parser.add_argument('--epochs', type=int, default=200, help='Number of training epochs') parser.add_argument('--batch_size', type=int, default=1, help='Batch size') parser.add_argument('--learning_rate', type=float, default=1e-5, help='Learning rate') parser.add_argument('--lr_patience', type=int, default=3, help='LR scheduler patience') parser.add_argument('--lr_factor', type=float, default=0.8, help='LR reduction factor') parser.add_argument('--early_stopping_patience', type=int, default=10, help='Early stopping patience') parser.add_argument('--use_mixed_precision', action='store_true', help='Use mixed precision training') # Data parameters parser.add_argument('--count_min', type=int, default=800, help='Minimum count filter') parser.add_argument('--count_max', type=int, default=4500, help='Maximum count filter') parser.add_argument('--ko_max_len', type=int, default=4500, help='Maximum KO sequence length') parser.add_argument('--compound_max_len', type=int, default=98, help='Maximum compound sequence length') # Other parameters parser.add_argument('--device', type=str, default='cuda:0', help='Device to use') parser.add_argument('--save_path', type=str, default='best_model.pt', help='Model save path') parser.add_argument('--loss_save_path', type=str, default='train_valid_info.csv', help='train valid loss save path') parser.add_argument('--memory_fraction', type=float, default=0.5, help='CUDA memory fraction') args = parser.parse_args() # Set device device = torch.device(args.device) if torch.cuda.is_available(): torch.cuda.set_per_process_memory_fraction(args.memory_fraction, device=device) print(f"Using device: {device}") # Configuration dictionary data_config = { 'ko_count_path': args.ko_count_path, 'train_data_path': args.train_data_path, 'valid_data_path': args.valid_data_path, 'test_data_path': args.test_data_path, 'embedding_h5_path': args.embedding_h5_path, 'vo_cab_pkl_path': args.vo_cab_pkl_path, 'count_min': args.count_min, 'count_max': args.count_max, 'ko_max_len': args.ko_max_len, 'compound_max_len': args.compound_max_len } model_config = { 'heads': args.heads, 'd_model': args.d_model, 'src_input_dim': args.src_input_dim, 'first_linear_dim': args.first_linear_dim, 'num_encoder_layer': args.num_encoder_layer, 'num_decoder_layer': args.num_decoder_layer, 'dropout': args.dropout, 'label_smoothing': args.label_smoothing } train_config = { 'epochs': args.epochs, 'learning_rate': args.learning_rate, 'lr_patience': args.lr_patience, 'lr_factor': args.lr_factor, 'early_stopping_patience': args.early_stopping_patience, 'use_mixed_precision': args.use_mixed_precision, 'parameters_save_path': args.save_path, 'loss_save_path': args.loss_save_path } # Create data processors dp_train, dp_valid, dp_test = create_data_processors(data_config) # Create data loaders train_loader, valid_loader, test_loader = create_data_loaders( dp_train, dp_valid, dp_test, args.batch_size ) # Load data.pkl import pickle with open('data.pkl', 'rb') as f: # 'rb' means read in binary mode loaded_data = pickle.load(f) compound_cab = loaded_data['compound_cab'] # Create model model, loss_compute = create_model(model_config, compound_cab, device) # Train model train_model(model, train_loader, valid_loader, loss_compute, train_config, device, len(compound_cab)) print("🎉 Training completed successfully!") if __name__ == "__main__": main()