328 lines
14 KiB
Python
328 lines
14 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
ProtT5 Encoder-Decoder Model Training Script
|
|
Supports lazy loading of data, uses predefined vocab to avoid memory issues
|
|
"""
|
|
|
|
import argparse
|
|
import torch
|
|
import os
|
|
from datasets import DataProcessor, build_single_lazy_dataloader
|
|
from model import Model, LabelSmoothing, SimpleLossCompute
|
|
from trainer import train_epoch, valid_epoch
|
|
from torch.cuda.amp import GradScaler
|
|
import pandas as pd
|
|
|
|
|
|
class EarlyStopping:
|
|
def __init__(self, patience=5, delta=0.0, save_path="best_model.pt"):
|
|
self.patience = patience
|
|
self.delta = delta
|
|
self.save_path = save_path
|
|
self.counter = 0
|
|
self.best_loss = None
|
|
self.early_stop = False
|
|
|
|
def __call__(self, val_loss, model):
|
|
if self.best_loss is None or val_loss < self.best_loss - self.delta:
|
|
self.best_loss = val_loss
|
|
self.counter = 0
|
|
torch.save(model.state_dict(), self.save_path)
|
|
print(f"✅ Validation loss improved to {val_loss:.4f}. Model saved to {self.save_path}")
|
|
else:
|
|
self.counter += 1
|
|
print(f"❌ No improvement. EarlyStopping counter: {self.counter}/{self.patience}")
|
|
if self.counter >= self.patience:
|
|
self.early_stop = True
|
|
|
|
|
|
def check_gpu_compatibility():
|
|
"""Check if GPU supports mixed precision training"""
|
|
if torch.cuda.is_available():
|
|
device_cap = torch.cuda.get_device_capability()
|
|
print(f"GPU Compute Capability: {device_cap}")
|
|
if device_cap[0] >= 7: # V100, RTX series, etc.
|
|
print("✅ GPU supports mixed precision training")
|
|
return True
|
|
else:
|
|
print("❌ GPU does not support mixed precision training, recommend using normal precision")
|
|
return False
|
|
else:
|
|
print("⚠️ No CUDA device detected")
|
|
return False
|
|
|
|
|
|
def create_data_processors(config):
|
|
"""Create training, validation, and test data processors"""
|
|
print("=== Creating Data Processors ===")
|
|
|
|
# Training set
|
|
dp_train = DataProcessor(
|
|
ko_count_path=config['ko_count_path'],
|
|
data_path=config['train_data_path'],
|
|
embedding_h5_path=config['embedding_h5_path'],
|
|
vo_cab_pkl_path=config['vo_cab_pkl_path'],
|
|
model_type='train',
|
|
ko_max_len=config['ko_max_len'],
|
|
compound_max_len=config['compound_max_len']
|
|
)
|
|
dp_train.load_data(count_min=config['count_min'], count_max=config['count_max'])
|
|
|
|
# Validation set
|
|
dp_valid = DataProcessor(
|
|
ko_count_path=config['ko_count_path'],
|
|
data_path=config['valid_data_path'],
|
|
embedding_h5_path=config['embedding_h5_path'],
|
|
vo_cab_pkl_path=config['vo_cab_pkl_path'],
|
|
model_type='valid',
|
|
ko_max_len=config['ko_max_len'],
|
|
compound_max_len=config['compound_max_len']
|
|
)
|
|
dp_valid.load_data(count_min=config['count_min'], count_max=config['count_max'])
|
|
|
|
# Test set
|
|
dp_test = DataProcessor(
|
|
ko_count_path=config['ko_count_path'],
|
|
data_path=config['test_data_path'],
|
|
embedding_h5_path=config['embedding_h5_path'],
|
|
vo_cab_pkl_path=config['vo_cab_pkl_path'],
|
|
model_type='test',
|
|
ko_max_len=config['ko_max_len'],
|
|
compound_max_len=config['compound_max_len']
|
|
)
|
|
dp_test.load_data(count_min=config['count_min'], count_max=config['count_max'])
|
|
|
|
return dp_train, dp_valid, dp_test
|
|
|
|
|
|
def create_data_loaders(dp_train, dp_valid, dp_test, batch_size):
|
|
"""Create data loaders"""
|
|
print("=== Creating Data Loaders ===")
|
|
|
|
train_loader = build_single_lazy_dataloader(dp_train, batch_size=batch_size, shuffle=True)
|
|
valid_loader = build_single_lazy_dataloader(dp_valid, batch_size=batch_size, shuffle=False)
|
|
test_loader = build_single_lazy_dataloader(dp_test, batch_size=batch_size, shuffle=False)
|
|
|
|
print(f"Training set batch count: {len(train_loader)}")
|
|
print(f"Validation set batch count: {len(valid_loader)}")
|
|
print(f"Test set batch count: {len(test_loader)}")
|
|
|
|
return train_loader, valid_loader, test_loader
|
|
|
|
|
|
def create_model(model_config, vo_cab, device):
|
|
"""Create model"""
|
|
print("=== Creating Model ===")
|
|
|
|
model = Model(
|
|
heads=model_config['heads'],
|
|
d_model=model_config['d_model'],
|
|
tgt_vocab_size=len(vo_cab),
|
|
num_encoder_layer=model_config['num_encoder_layer'],
|
|
num_decoder_layer=model_config['num_decoder_layer'],
|
|
dropout=model_config['dropout'],
|
|
compound_cab=vo_cab,
|
|
src_input_dim=model_config['src_input_dim'],
|
|
first_linear_dim=model_config['first_linear_dim']
|
|
).to(device)
|
|
|
|
criterion = LabelSmoothing(
|
|
tgt_size=len(vo_cab),
|
|
padding_idx=vo_cab["blank"],
|
|
smoothing=model_config['label_smoothing']
|
|
).to(device)
|
|
|
|
loss_compute = SimpleLossCompute(criterion, device)
|
|
|
|
print(f"Model parameters:")
|
|
print(f" - heads: {model_config['heads']}")
|
|
print(f" - d_model: {model_config['d_model']} (transformer internal dimension)")
|
|
print(f" - src_input_dim: {model_config['src_input_dim']} (input dimension)")
|
|
print(f" - first_linear_dim: {model_config['first_linear_dim']} (first layer dimension)")
|
|
print(f" - vocab_size: {len(vo_cab)}")
|
|
print(f" - encoder_layers: {model_config['num_encoder_layer']}")
|
|
print(f" - decoder_layers: {model_config['num_decoder_layer']}")
|
|
print(f" - dropout: {model_config['dropout']}")
|
|
print(f" - label_smoothing: {model_config['label_smoothing']}")
|
|
print(f" - dimension transformation: {model_config['src_input_dim']} -> {model_config['first_linear_dim']} -> {model_config['d_model']}")
|
|
print(f" - model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
|
|
|
return model, loss_compute
|
|
|
|
|
|
def train_model(model, train_loader, valid_loader, loss_compute, train_config, device, vo_cab_size):
|
|
"""Train model"""
|
|
print("=== Starting Training ===")
|
|
|
|
# Optimizer and scheduler
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=train_config['learning_rate'])
|
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
optimizer, 'min',
|
|
patience=train_config['lr_patience'],
|
|
factor=train_config['lr_factor']
|
|
)
|
|
|
|
# Mixed precision training
|
|
scaler = GradScaler()
|
|
use_mixed_precision = train_config['use_mixed_precision'] and check_gpu_compatibility()
|
|
|
|
# Early stopping
|
|
early_stopping = EarlyStopping(
|
|
patience=train_config['early_stopping_patience'],
|
|
save_path=train_config['parameters_save_path']
|
|
)
|
|
|
|
print(f"Training configuration:")
|
|
print(f" - epochs: {train_config['epochs']}")
|
|
print(f" - learning_rate: {train_config['learning_rate']}")
|
|
print(f" - use_mixed_precision: {use_mixed_precision}")
|
|
print(f" - early_stopping_patience: {train_config['early_stopping_patience']}")
|
|
print(f" - parameters_save_path: {train_config['parameters_save_path']}")
|
|
print(f" - loss_save_path: {train_config['loss_save_path']}")
|
|
|
|
# Training loop
|
|
loss_log_path = train_config['loss_save_path']
|
|
all_train_valid_info = []
|
|
all_train_valid_info = pd.DataFrame(all_train_valid_info, columns = ['epoch', 'train_loss', 'train_time', 'valid_loss', 'valid_time', 'lr'])
|
|
for epoch in range(1, train_config['epochs'] + 1):
|
|
train_loss, t_time = train_epoch(
|
|
model, train_loader, device, loss_compute, scaler, optimizer,
|
|
use_amp=use_mixed_precision
|
|
)
|
|
valid_loss, v_time = valid_epoch(
|
|
model, valid_loader, device, loss_compute,
|
|
use_amp=use_mixed_precision
|
|
)
|
|
|
|
scheduler.step(valid_loss)
|
|
current_lr = optimizer.param_groups[0]['lr']
|
|
|
|
print(f"[Epoch {epoch:3d}] train_loss={train_loss:.4f} ({t_time:5.1f}s), "
|
|
f"valid_loss={valid_loss:.4f} ({v_time:5.1f}s), lr={current_lr}")
|
|
|
|
# Save file
|
|
new_row = pd.DataFrame([[epoch, train_loss, t_time, valid_loss, v_time, current_lr]],
|
|
columns=['epoch', 'train_loss', 'train_time', 'valid_loss', 'valid_time', 'lr'])
|
|
all_train_valid_info = pd.concat([all_train_valid_info, new_row])
|
|
all_train_valid_info.to_csv(loss_log_path, index=False)
|
|
|
|
# Early stopping check
|
|
early_stopping(valid_loss, model)
|
|
if early_stopping.early_stop:
|
|
print("🛑 Early stopping triggered.")
|
|
break
|
|
|
|
print(f"🎯 Training finished. Best model saved at {train_config['parameters_save_path']}")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='ProtT5 Encoder-Decoder Model Training')
|
|
|
|
# Data path parameters
|
|
parser.add_argument('--ko_count_path', type=str, required=True, help='KO count CSV file path')
|
|
parser.add_argument('--train_data_path', type=str, required=True, help='Training data CSV file path')
|
|
parser.add_argument('--valid_data_path', type=str, required=True, help='Validation data CSV file path')
|
|
parser.add_argument('--test_data_path', type=str, required=True, help='Test data CSV file path')
|
|
parser.add_argument('--embedding_h5_path', type=str, required=True, help='HDF5 embedding file path')
|
|
parser.add_argument('--vo_cab_pkl_path', type=str, required=True, help='Vocab pickle file path')
|
|
|
|
# Model parameters
|
|
parser.add_argument('--heads', type=int, default=8, help='Number of attention heads')
|
|
parser.add_argument('--d_model', type=int, default=256, help='Model dimension (after src transformation)')
|
|
parser.add_argument('--src_input_dim', type=int, default=1024, help='Input dimension of source sequences')
|
|
parser.add_argument('--first_linear_dim', type=int, default=512, help='First linear dimension')
|
|
parser.add_argument('--num_encoder_layer', type=int, default=3, help='Number of encoder layers')
|
|
parser.add_argument('--num_decoder_layer', type=int, default=3, help='Number of decoder layers')
|
|
parser.add_argument('--dropout', type=float, default=0.1, help='Dropout rate')
|
|
parser.add_argument('--label_smoothing', type=float, default=0.0, help='Label smoothing factor')
|
|
|
|
# Training parameters
|
|
parser.add_argument('--epochs', type=int, default=200, help='Number of training epochs')
|
|
parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
|
|
parser.add_argument('--learning_rate', type=float, default=1e-5, help='Learning rate')
|
|
parser.add_argument('--lr_patience', type=int, default=3, help='LR scheduler patience')
|
|
parser.add_argument('--lr_factor', type=float, default=0.8, help='LR reduction factor')
|
|
parser.add_argument('--early_stopping_patience', type=int, default=10, help='Early stopping patience')
|
|
parser.add_argument('--use_mixed_precision', action='store_true', help='Use mixed precision training')
|
|
|
|
# Data parameters
|
|
parser.add_argument('--count_min', type=int, default=800, help='Minimum count filter')
|
|
parser.add_argument('--count_max', type=int, default=4500, help='Maximum count filter')
|
|
parser.add_argument('--ko_max_len', type=int, default=4500, help='Maximum KO sequence length')
|
|
parser.add_argument('--compound_max_len', type=int, default=98, help='Maximum compound sequence length')
|
|
|
|
# Other parameters
|
|
parser.add_argument('--device', type=str, default='cuda:0', help='Device to use')
|
|
parser.add_argument('--save_path', type=str, default='best_model.pt', help='Model save path')
|
|
parser.add_argument('--loss_save_path', type=str, default='train_valid_info.csv', help='train valid loss save path')
|
|
parser.add_argument('--memory_fraction', type=float, default=0.5, help='CUDA memory fraction')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Set device
|
|
device = torch.device(args.device)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.set_per_process_memory_fraction(args.memory_fraction, device=device)
|
|
print(f"Using device: {device}")
|
|
|
|
# Configuration dictionary
|
|
data_config = {
|
|
'ko_count_path': args.ko_count_path,
|
|
'train_data_path': args.train_data_path,
|
|
'valid_data_path': args.valid_data_path,
|
|
'test_data_path': args.test_data_path,
|
|
'embedding_h5_path': args.embedding_h5_path,
|
|
'vo_cab_pkl_path': args.vo_cab_pkl_path,
|
|
'count_min': args.count_min,
|
|
'count_max': args.count_max,
|
|
'ko_max_len': args.ko_max_len,
|
|
'compound_max_len': args.compound_max_len
|
|
}
|
|
|
|
model_config = {
|
|
'heads': args.heads,
|
|
'd_model': args.d_model,
|
|
'src_input_dim': args.src_input_dim,
|
|
'first_linear_dim': args.first_linear_dim,
|
|
'num_encoder_layer': args.num_encoder_layer,
|
|
'num_decoder_layer': args.num_decoder_layer,
|
|
'dropout': args.dropout,
|
|
'label_smoothing': args.label_smoothing
|
|
}
|
|
|
|
train_config = {
|
|
'epochs': args.epochs,
|
|
'learning_rate': args.learning_rate,
|
|
'lr_patience': args.lr_patience,
|
|
'lr_factor': args.lr_factor,
|
|
'early_stopping_patience': args.early_stopping_patience,
|
|
'use_mixed_precision': args.use_mixed_precision,
|
|
'parameters_save_path': args.save_path,
|
|
'loss_save_path': args.loss_save_path
|
|
}
|
|
|
|
# Create data processors
|
|
dp_train, dp_valid, dp_test = create_data_processors(data_config)
|
|
|
|
# Create data loaders
|
|
train_loader, valid_loader, test_loader = create_data_loaders(
|
|
dp_train, dp_valid, dp_test, args.batch_size
|
|
)
|
|
|
|
# Load data.pkl
|
|
import pickle
|
|
with open('data.pkl', 'rb') as f: # 'rb' means read in binary mode
|
|
loaded_data = pickle.load(f)
|
|
compound_cab = loaded_data['compound_cab']
|
|
|
|
# Create model
|
|
model, loss_compute = create_model(model_config, compound_cab, device)
|
|
|
|
# Train model
|
|
train_model(model, train_loader, valid_loader, loss_compute, train_config, device, len(compound_cab))
|
|
|
|
print("🎉 Training completed successfully!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |