first commit
This commit is contained in:
328
models/train_model.py
Normal file
328
models/train_model.py
Normal file
@@ -0,0 +1,328 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
ProtT5 Encoder-Decoder Model Training Script
|
||||
Supports lazy loading of data, uses predefined vocab to avoid memory issues
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
import os
|
||||
from datasets import DataProcessor, build_single_lazy_dataloader
|
||||
from model import Model, LabelSmoothing, SimpleLossCompute
|
||||
from trainer import train_epoch, valid_epoch
|
||||
from torch.cuda.amp import GradScaler
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class EarlyStopping:
|
||||
def __init__(self, patience=5, delta=0.0, save_path="best_model.pt"):
|
||||
self.patience = patience
|
||||
self.delta = delta
|
||||
self.save_path = save_path
|
||||
self.counter = 0
|
||||
self.best_loss = None
|
||||
self.early_stop = False
|
||||
|
||||
def __call__(self, val_loss, model):
|
||||
if self.best_loss is None or val_loss < self.best_loss - self.delta:
|
||||
self.best_loss = val_loss
|
||||
self.counter = 0
|
||||
torch.save(model.state_dict(), self.save_path)
|
||||
print(f"✅ Validation loss improved to {val_loss:.4f}. Model saved to {self.save_path}")
|
||||
else:
|
||||
self.counter += 1
|
||||
print(f"❌ No improvement. EarlyStopping counter: {self.counter}/{self.patience}")
|
||||
if self.counter >= self.patience:
|
||||
self.early_stop = True
|
||||
|
||||
|
||||
def check_gpu_compatibility():
|
||||
"""Check if GPU supports mixed precision training"""
|
||||
if torch.cuda.is_available():
|
||||
device_cap = torch.cuda.get_device_capability()
|
||||
print(f"GPU Compute Capability: {device_cap}")
|
||||
if device_cap[0] >= 7: # V100, RTX series, etc.
|
||||
print("✅ GPU supports mixed precision training")
|
||||
return True
|
||||
else:
|
||||
print("❌ GPU does not support mixed precision training, recommend using normal precision")
|
||||
return False
|
||||
else:
|
||||
print("⚠️ No CUDA device detected")
|
||||
return False
|
||||
|
||||
|
||||
def create_data_processors(config):
|
||||
"""Create training, validation, and test data processors"""
|
||||
print("=== Creating Data Processors ===")
|
||||
|
||||
# Training set
|
||||
dp_train = DataProcessor(
|
||||
ko_count_path=config['ko_count_path'],
|
||||
data_path=config['train_data_path'],
|
||||
embedding_h5_path=config['embedding_h5_path'],
|
||||
vo_cab_pkl_path=config['vo_cab_pkl_path'],
|
||||
model_type='train',
|
||||
ko_max_len=config['ko_max_len'],
|
||||
compound_max_len=config['compound_max_len']
|
||||
)
|
||||
dp_train.load_data(count_min=config['count_min'], count_max=config['count_max'])
|
||||
|
||||
# Validation set
|
||||
dp_valid = DataProcessor(
|
||||
ko_count_path=config['ko_count_path'],
|
||||
data_path=config['valid_data_path'],
|
||||
embedding_h5_path=config['embedding_h5_path'],
|
||||
vo_cab_pkl_path=config['vo_cab_pkl_path'],
|
||||
model_type='valid',
|
||||
ko_max_len=config['ko_max_len'],
|
||||
compound_max_len=config['compound_max_len']
|
||||
)
|
||||
dp_valid.load_data(count_min=config['count_min'], count_max=config['count_max'])
|
||||
|
||||
# Test set
|
||||
dp_test = DataProcessor(
|
||||
ko_count_path=config['ko_count_path'],
|
||||
data_path=config['test_data_path'],
|
||||
embedding_h5_path=config['embedding_h5_path'],
|
||||
vo_cab_pkl_path=config['vo_cab_pkl_path'],
|
||||
model_type='test',
|
||||
ko_max_len=config['ko_max_len'],
|
||||
compound_max_len=config['compound_max_len']
|
||||
)
|
||||
dp_test.load_data(count_min=config['count_min'], count_max=config['count_max'])
|
||||
|
||||
return dp_train, dp_valid, dp_test
|
||||
|
||||
|
||||
def create_data_loaders(dp_train, dp_valid, dp_test, batch_size):
|
||||
"""Create data loaders"""
|
||||
print("=== Creating Data Loaders ===")
|
||||
|
||||
train_loader = build_single_lazy_dataloader(dp_train, batch_size=batch_size, shuffle=True)
|
||||
valid_loader = build_single_lazy_dataloader(dp_valid, batch_size=batch_size, shuffle=False)
|
||||
test_loader = build_single_lazy_dataloader(dp_test, batch_size=batch_size, shuffle=False)
|
||||
|
||||
print(f"Training set batch count: {len(train_loader)}")
|
||||
print(f"Validation set batch count: {len(valid_loader)}")
|
||||
print(f"Test set batch count: {len(test_loader)}")
|
||||
|
||||
return train_loader, valid_loader, test_loader
|
||||
|
||||
|
||||
def create_model(model_config, vo_cab, device):
|
||||
"""Create model"""
|
||||
print("=== Creating Model ===")
|
||||
|
||||
model = Model(
|
||||
heads=model_config['heads'],
|
||||
d_model=model_config['d_model'],
|
||||
tgt_vocab_size=len(vo_cab),
|
||||
num_encoder_layer=model_config['num_encoder_layer'],
|
||||
num_decoder_layer=model_config['num_decoder_layer'],
|
||||
dropout=model_config['dropout'],
|
||||
compound_cab=vo_cab,
|
||||
src_input_dim=model_config['src_input_dim'],
|
||||
first_linear_dim=model_config['first_linear_dim']
|
||||
).to(device)
|
||||
|
||||
criterion = LabelSmoothing(
|
||||
tgt_size=len(vo_cab),
|
||||
padding_idx=vo_cab["blank"],
|
||||
smoothing=model_config['label_smoothing']
|
||||
).to(device)
|
||||
|
||||
loss_compute = SimpleLossCompute(criterion, device)
|
||||
|
||||
print(f"Model parameters:")
|
||||
print(f" - heads: {model_config['heads']}")
|
||||
print(f" - d_model: {model_config['d_model']} (transformer internal dimension)")
|
||||
print(f" - src_input_dim: {model_config['src_input_dim']} (input dimension)")
|
||||
print(f" - first_linear_dim: {model_config['first_linear_dim']} (first layer dimension)")
|
||||
print(f" - vocab_size: {len(vo_cab)}")
|
||||
print(f" - encoder_layers: {model_config['num_encoder_layer']}")
|
||||
print(f" - decoder_layers: {model_config['num_decoder_layer']}")
|
||||
print(f" - dropout: {model_config['dropout']}")
|
||||
print(f" - label_smoothing: {model_config['label_smoothing']}")
|
||||
print(f" - dimension transformation: {model_config['src_input_dim']} -> {model_config['first_linear_dim']} -> {model_config['d_model']}")
|
||||
print(f" - model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
|
||||
return model, loss_compute
|
||||
|
||||
|
||||
def train_model(model, train_loader, valid_loader, loss_compute, train_config, device, vo_cab_size):
|
||||
"""Train model"""
|
||||
print("=== Starting Training ===")
|
||||
|
||||
# Optimizer and scheduler
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=train_config['learning_rate'])
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer, 'min',
|
||||
patience=train_config['lr_patience'],
|
||||
factor=train_config['lr_factor']
|
||||
)
|
||||
|
||||
# Mixed precision training
|
||||
scaler = GradScaler()
|
||||
use_mixed_precision = train_config['use_mixed_precision'] and check_gpu_compatibility()
|
||||
|
||||
# Early stopping
|
||||
early_stopping = EarlyStopping(
|
||||
patience=train_config['early_stopping_patience'],
|
||||
save_path=train_config['parameters_save_path']
|
||||
)
|
||||
|
||||
print(f"Training configuration:")
|
||||
print(f" - epochs: {train_config['epochs']}")
|
||||
print(f" - learning_rate: {train_config['learning_rate']}")
|
||||
print(f" - use_mixed_precision: {use_mixed_precision}")
|
||||
print(f" - early_stopping_patience: {train_config['early_stopping_patience']}")
|
||||
print(f" - parameters_save_path: {train_config['parameters_save_path']}")
|
||||
print(f" - loss_save_path: {train_config['loss_save_path']}")
|
||||
|
||||
# Training loop
|
||||
loss_log_path = train_config['loss_save_path']
|
||||
all_train_valid_info = []
|
||||
all_train_valid_info = pd.DataFrame(all_train_valid_info, columns = ['epoch', 'train_loss', 'train_time', 'valid_loss', 'valid_time', 'lr'])
|
||||
for epoch in range(1, train_config['epochs'] + 1):
|
||||
train_loss, t_time = train_epoch(
|
||||
model, train_loader, device, loss_compute, scaler, optimizer,
|
||||
use_amp=use_mixed_precision
|
||||
)
|
||||
valid_loss, v_time = valid_epoch(
|
||||
model, valid_loader, device, loss_compute,
|
||||
use_amp=use_mixed_precision
|
||||
)
|
||||
|
||||
scheduler.step(valid_loss)
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
|
||||
print(f"[Epoch {epoch:3d}] train_loss={train_loss:.4f} ({t_time:5.1f}s), "
|
||||
f"valid_loss={valid_loss:.4f} ({v_time:5.1f}s), lr={current_lr}")
|
||||
|
||||
# Save file
|
||||
new_row = pd.DataFrame([[epoch, train_loss, t_time, valid_loss, v_time, current_lr]],
|
||||
columns=['epoch', 'train_loss', 'train_time', 'valid_loss', 'valid_time', 'lr'])
|
||||
all_train_valid_info = pd.concat([all_train_valid_info, new_row])
|
||||
all_train_valid_info.to_csv(loss_log_path, index=False)
|
||||
|
||||
# Early stopping check
|
||||
early_stopping(valid_loss, model)
|
||||
if early_stopping.early_stop:
|
||||
print("🛑 Early stopping triggered.")
|
||||
break
|
||||
|
||||
print(f"🎯 Training finished. Best model saved at {train_config['parameters_save_path']}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='ProtT5 Encoder-Decoder Model Training')
|
||||
|
||||
# Data path parameters
|
||||
parser.add_argument('--ko_count_path', type=str, required=True, help='KO count CSV file path')
|
||||
parser.add_argument('--train_data_path', type=str, required=True, help='Training data CSV file path')
|
||||
parser.add_argument('--valid_data_path', type=str, required=True, help='Validation data CSV file path')
|
||||
parser.add_argument('--test_data_path', type=str, required=True, help='Test data CSV file path')
|
||||
parser.add_argument('--embedding_h5_path', type=str, required=True, help='HDF5 embedding file path')
|
||||
parser.add_argument('--vo_cab_pkl_path', type=str, required=True, help='Vocab pickle file path')
|
||||
|
||||
# Model parameters
|
||||
parser.add_argument('--heads', type=int, default=8, help='Number of attention heads')
|
||||
parser.add_argument('--d_model', type=int, default=256, help='Model dimension (after src transformation)')
|
||||
parser.add_argument('--src_input_dim', type=int, default=1024, help='Input dimension of source sequences')
|
||||
parser.add_argument('--first_linear_dim', type=int, default=512, help='First linear dimension')
|
||||
parser.add_argument('--num_encoder_layer', type=int, default=3, help='Number of encoder layers')
|
||||
parser.add_argument('--num_decoder_layer', type=int, default=3, help='Number of decoder layers')
|
||||
parser.add_argument('--dropout', type=float, default=0.1, help='Dropout rate')
|
||||
parser.add_argument('--label_smoothing', type=float, default=0.0, help='Label smoothing factor')
|
||||
|
||||
# Training parameters
|
||||
parser.add_argument('--epochs', type=int, default=200, help='Number of training epochs')
|
||||
parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
|
||||
parser.add_argument('--learning_rate', type=float, default=1e-5, help='Learning rate')
|
||||
parser.add_argument('--lr_patience', type=int, default=3, help='LR scheduler patience')
|
||||
parser.add_argument('--lr_factor', type=float, default=0.8, help='LR reduction factor')
|
||||
parser.add_argument('--early_stopping_patience', type=int, default=10, help='Early stopping patience')
|
||||
parser.add_argument('--use_mixed_precision', action='store_true', help='Use mixed precision training')
|
||||
|
||||
# Data parameters
|
||||
parser.add_argument('--count_min', type=int, default=800, help='Minimum count filter')
|
||||
parser.add_argument('--count_max', type=int, default=4500, help='Maximum count filter')
|
||||
parser.add_argument('--ko_max_len', type=int, default=4500, help='Maximum KO sequence length')
|
||||
parser.add_argument('--compound_max_len', type=int, default=98, help='Maximum compound sequence length')
|
||||
|
||||
# Other parameters
|
||||
parser.add_argument('--device', type=str, default='cuda:0', help='Device to use')
|
||||
parser.add_argument('--save_path', type=str, default='best_model.pt', help='Model save path')
|
||||
parser.add_argument('--loss_save_path', type=str, default='train_valid_info.csv', help='train valid loss save path')
|
||||
parser.add_argument('--memory_fraction', type=float, default=0.5, help='CUDA memory fraction')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set device
|
||||
device = torch.device(args.device)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_per_process_memory_fraction(args.memory_fraction, device=device)
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Configuration dictionary
|
||||
data_config = {
|
||||
'ko_count_path': args.ko_count_path,
|
||||
'train_data_path': args.train_data_path,
|
||||
'valid_data_path': args.valid_data_path,
|
||||
'test_data_path': args.test_data_path,
|
||||
'embedding_h5_path': args.embedding_h5_path,
|
||||
'vo_cab_pkl_path': args.vo_cab_pkl_path,
|
||||
'count_min': args.count_min,
|
||||
'count_max': args.count_max,
|
||||
'ko_max_len': args.ko_max_len,
|
||||
'compound_max_len': args.compound_max_len
|
||||
}
|
||||
|
||||
model_config = {
|
||||
'heads': args.heads,
|
||||
'd_model': args.d_model,
|
||||
'src_input_dim': args.src_input_dim,
|
||||
'first_linear_dim': args.first_linear_dim,
|
||||
'num_encoder_layer': args.num_encoder_layer,
|
||||
'num_decoder_layer': args.num_decoder_layer,
|
||||
'dropout': args.dropout,
|
||||
'label_smoothing': args.label_smoothing
|
||||
}
|
||||
|
||||
train_config = {
|
||||
'epochs': args.epochs,
|
||||
'learning_rate': args.learning_rate,
|
||||
'lr_patience': args.lr_patience,
|
||||
'lr_factor': args.lr_factor,
|
||||
'early_stopping_patience': args.early_stopping_patience,
|
||||
'use_mixed_precision': args.use_mixed_precision,
|
||||
'parameters_save_path': args.save_path,
|
||||
'loss_save_path': args.loss_save_path
|
||||
}
|
||||
|
||||
# Create data processors
|
||||
dp_train, dp_valid, dp_test = create_data_processors(data_config)
|
||||
|
||||
# Create data loaders
|
||||
train_loader, valid_loader, test_loader = create_data_loaders(
|
||||
dp_train, dp_valid, dp_test, args.batch_size
|
||||
)
|
||||
|
||||
# Load data.pkl
|
||||
import pickle
|
||||
with open('data.pkl', 'rb') as f: # 'rb' means read in binary mode
|
||||
loaded_data = pickle.load(f)
|
||||
compound_cab = loaded_data['compound_cab']
|
||||
|
||||
# Create model
|
||||
model, loss_compute = create_model(model_config, compound_cab, device)
|
||||
|
||||
# Train model
|
||||
train_model(model, train_loader, valid_loader, loss_compute, train_config, device, len(compound_cab))
|
||||
|
||||
print("🎉 Training completed successfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user