Files
labweb/models/train_model.py
2025-12-16 11:39:15 +08:00

328 lines
14 KiB
Python

#!/usr/bin/env python3
"""
ProtT5 Encoder-Decoder Model Training Script
Supports lazy loading of data, uses predefined vocab to avoid memory issues
"""
import argparse
import torch
import os
from datasets import DataProcessor, build_single_lazy_dataloader
from model import Model, LabelSmoothing, SimpleLossCompute
from trainer import train_epoch, valid_epoch
from torch.cuda.amp import GradScaler
import pandas as pd
class EarlyStopping:
def __init__(self, patience=5, delta=0.0, save_path="best_model.pt"):
self.patience = patience
self.delta = delta
self.save_path = save_path
self.counter = 0
self.best_loss = None
self.early_stop = False
def __call__(self, val_loss, model):
if self.best_loss is None or val_loss < self.best_loss - self.delta:
self.best_loss = val_loss
self.counter = 0
torch.save(model.state_dict(), self.save_path)
print(f"✅ Validation loss improved to {val_loss:.4f}. Model saved to {self.save_path}")
else:
self.counter += 1
print(f"❌ No improvement. EarlyStopping counter: {self.counter}/{self.patience}")
if self.counter >= self.patience:
self.early_stop = True
def check_gpu_compatibility():
"""Check if GPU supports mixed precision training"""
if torch.cuda.is_available():
device_cap = torch.cuda.get_device_capability()
print(f"GPU Compute Capability: {device_cap}")
if device_cap[0] >= 7: # V100, RTX series, etc.
print("✅ GPU supports mixed precision training")
return True
else:
print("❌ GPU does not support mixed precision training, recommend using normal precision")
return False
else:
print("⚠️ No CUDA device detected")
return False
def create_data_processors(config):
"""Create training, validation, and test data processors"""
print("=== Creating Data Processors ===")
# Training set
dp_train = DataProcessor(
ko_count_path=config['ko_count_path'],
data_path=config['train_data_path'],
embedding_h5_path=config['embedding_h5_path'],
vo_cab_pkl_path=config['vo_cab_pkl_path'],
model_type='train',
ko_max_len=config['ko_max_len'],
compound_max_len=config['compound_max_len']
)
dp_train.load_data(count_min=config['count_min'], count_max=config['count_max'])
# Validation set
dp_valid = DataProcessor(
ko_count_path=config['ko_count_path'],
data_path=config['valid_data_path'],
embedding_h5_path=config['embedding_h5_path'],
vo_cab_pkl_path=config['vo_cab_pkl_path'],
model_type='valid',
ko_max_len=config['ko_max_len'],
compound_max_len=config['compound_max_len']
)
dp_valid.load_data(count_min=config['count_min'], count_max=config['count_max'])
# Test set
dp_test = DataProcessor(
ko_count_path=config['ko_count_path'],
data_path=config['test_data_path'],
embedding_h5_path=config['embedding_h5_path'],
vo_cab_pkl_path=config['vo_cab_pkl_path'],
model_type='test',
ko_max_len=config['ko_max_len'],
compound_max_len=config['compound_max_len']
)
dp_test.load_data(count_min=config['count_min'], count_max=config['count_max'])
return dp_train, dp_valid, dp_test
def create_data_loaders(dp_train, dp_valid, dp_test, batch_size):
"""Create data loaders"""
print("=== Creating Data Loaders ===")
train_loader = build_single_lazy_dataloader(dp_train, batch_size=batch_size, shuffle=True)
valid_loader = build_single_lazy_dataloader(dp_valid, batch_size=batch_size, shuffle=False)
test_loader = build_single_lazy_dataloader(dp_test, batch_size=batch_size, shuffle=False)
print(f"Training set batch count: {len(train_loader)}")
print(f"Validation set batch count: {len(valid_loader)}")
print(f"Test set batch count: {len(test_loader)}")
return train_loader, valid_loader, test_loader
def create_model(model_config, vo_cab, device):
"""Create model"""
print("=== Creating Model ===")
model = Model(
heads=model_config['heads'],
d_model=model_config['d_model'],
tgt_vocab_size=len(vo_cab),
num_encoder_layer=model_config['num_encoder_layer'],
num_decoder_layer=model_config['num_decoder_layer'],
dropout=model_config['dropout'],
compound_cab=vo_cab,
src_input_dim=model_config['src_input_dim'],
first_linear_dim=model_config['first_linear_dim']
).to(device)
criterion = LabelSmoothing(
tgt_size=len(vo_cab),
padding_idx=vo_cab["blank"],
smoothing=model_config['label_smoothing']
).to(device)
loss_compute = SimpleLossCompute(criterion, device)
print(f"Model parameters:")
print(f" - heads: {model_config['heads']}")
print(f" - d_model: {model_config['d_model']} (transformer internal dimension)")
print(f" - src_input_dim: {model_config['src_input_dim']} (input dimension)")
print(f" - first_linear_dim: {model_config['first_linear_dim']} (first layer dimension)")
print(f" - vocab_size: {len(vo_cab)}")
print(f" - encoder_layers: {model_config['num_encoder_layer']}")
print(f" - decoder_layers: {model_config['num_decoder_layer']}")
print(f" - dropout: {model_config['dropout']}")
print(f" - label_smoothing: {model_config['label_smoothing']}")
print(f" - dimension transformation: {model_config['src_input_dim']} -> {model_config['first_linear_dim']} -> {model_config['d_model']}")
print(f" - model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
return model, loss_compute
def train_model(model, train_loader, valid_loader, loss_compute, train_config, device, vo_cab_size):
"""Train model"""
print("=== Starting Training ===")
# Optimizer and scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=train_config['learning_rate'])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, 'min',
patience=train_config['lr_patience'],
factor=train_config['lr_factor']
)
# Mixed precision training
scaler = GradScaler()
use_mixed_precision = train_config['use_mixed_precision'] and check_gpu_compatibility()
# Early stopping
early_stopping = EarlyStopping(
patience=train_config['early_stopping_patience'],
save_path=train_config['parameters_save_path']
)
print(f"Training configuration:")
print(f" - epochs: {train_config['epochs']}")
print(f" - learning_rate: {train_config['learning_rate']}")
print(f" - use_mixed_precision: {use_mixed_precision}")
print(f" - early_stopping_patience: {train_config['early_stopping_patience']}")
print(f" - parameters_save_path: {train_config['parameters_save_path']}")
print(f" - loss_save_path: {train_config['loss_save_path']}")
# Training loop
loss_log_path = train_config['loss_save_path']
all_train_valid_info = []
all_train_valid_info = pd.DataFrame(all_train_valid_info, columns = ['epoch', 'train_loss', 'train_time', 'valid_loss', 'valid_time', 'lr'])
for epoch in range(1, train_config['epochs'] + 1):
train_loss, t_time = train_epoch(
model, train_loader, device, loss_compute, scaler, optimizer,
use_amp=use_mixed_precision
)
valid_loss, v_time = valid_epoch(
model, valid_loader, device, loss_compute,
use_amp=use_mixed_precision
)
scheduler.step(valid_loss)
current_lr = optimizer.param_groups[0]['lr']
print(f"[Epoch {epoch:3d}] train_loss={train_loss:.4f} ({t_time:5.1f}s), "
f"valid_loss={valid_loss:.4f} ({v_time:5.1f}s), lr={current_lr}")
# Save file
new_row = pd.DataFrame([[epoch, train_loss, t_time, valid_loss, v_time, current_lr]],
columns=['epoch', 'train_loss', 'train_time', 'valid_loss', 'valid_time', 'lr'])
all_train_valid_info = pd.concat([all_train_valid_info, new_row])
all_train_valid_info.to_csv(loss_log_path, index=False)
# Early stopping check
early_stopping(valid_loss, model)
if early_stopping.early_stop:
print("🛑 Early stopping triggered.")
break
print(f"🎯 Training finished. Best model saved at {train_config['parameters_save_path']}")
def main():
parser = argparse.ArgumentParser(description='ProtT5 Encoder-Decoder Model Training')
# Data path parameters
parser.add_argument('--ko_count_path', type=str, required=True, help='KO count CSV file path')
parser.add_argument('--train_data_path', type=str, required=True, help='Training data CSV file path')
parser.add_argument('--valid_data_path', type=str, required=True, help='Validation data CSV file path')
parser.add_argument('--test_data_path', type=str, required=True, help='Test data CSV file path')
parser.add_argument('--embedding_h5_path', type=str, required=True, help='HDF5 embedding file path')
parser.add_argument('--vo_cab_pkl_path', type=str, required=True, help='Vocab pickle file path')
# Model parameters
parser.add_argument('--heads', type=int, default=8, help='Number of attention heads')
parser.add_argument('--d_model', type=int, default=256, help='Model dimension (after src transformation)')
parser.add_argument('--src_input_dim', type=int, default=1024, help='Input dimension of source sequences')
parser.add_argument('--first_linear_dim', type=int, default=512, help='First linear dimension')
parser.add_argument('--num_encoder_layer', type=int, default=3, help='Number of encoder layers')
parser.add_argument('--num_decoder_layer', type=int, default=3, help='Number of decoder layers')
parser.add_argument('--dropout', type=float, default=0.1, help='Dropout rate')
parser.add_argument('--label_smoothing', type=float, default=0.0, help='Label smoothing factor')
# Training parameters
parser.add_argument('--epochs', type=int, default=200, help='Number of training epochs')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
parser.add_argument('--learning_rate', type=float, default=1e-5, help='Learning rate')
parser.add_argument('--lr_patience', type=int, default=3, help='LR scheduler patience')
parser.add_argument('--lr_factor', type=float, default=0.8, help='LR reduction factor')
parser.add_argument('--early_stopping_patience', type=int, default=10, help='Early stopping patience')
parser.add_argument('--use_mixed_precision', action='store_true', help='Use mixed precision training')
# Data parameters
parser.add_argument('--count_min', type=int, default=800, help='Minimum count filter')
parser.add_argument('--count_max', type=int, default=4500, help='Maximum count filter')
parser.add_argument('--ko_max_len', type=int, default=4500, help='Maximum KO sequence length')
parser.add_argument('--compound_max_len', type=int, default=98, help='Maximum compound sequence length')
# Other parameters
parser.add_argument('--device', type=str, default='cuda:0', help='Device to use')
parser.add_argument('--save_path', type=str, default='best_model.pt', help='Model save path')
parser.add_argument('--loss_save_path', type=str, default='train_valid_info.csv', help='train valid loss save path')
parser.add_argument('--memory_fraction', type=float, default=0.5, help='CUDA memory fraction')
args = parser.parse_args()
# Set device
device = torch.device(args.device)
if torch.cuda.is_available():
torch.cuda.set_per_process_memory_fraction(args.memory_fraction, device=device)
print(f"Using device: {device}")
# Configuration dictionary
data_config = {
'ko_count_path': args.ko_count_path,
'train_data_path': args.train_data_path,
'valid_data_path': args.valid_data_path,
'test_data_path': args.test_data_path,
'embedding_h5_path': args.embedding_h5_path,
'vo_cab_pkl_path': args.vo_cab_pkl_path,
'count_min': args.count_min,
'count_max': args.count_max,
'ko_max_len': args.ko_max_len,
'compound_max_len': args.compound_max_len
}
model_config = {
'heads': args.heads,
'd_model': args.d_model,
'src_input_dim': args.src_input_dim,
'first_linear_dim': args.first_linear_dim,
'num_encoder_layer': args.num_encoder_layer,
'num_decoder_layer': args.num_decoder_layer,
'dropout': args.dropout,
'label_smoothing': args.label_smoothing
}
train_config = {
'epochs': args.epochs,
'learning_rate': args.learning_rate,
'lr_patience': args.lr_patience,
'lr_factor': args.lr_factor,
'early_stopping_patience': args.early_stopping_patience,
'use_mixed_precision': args.use_mixed_precision,
'parameters_save_path': args.save_path,
'loss_save_path': args.loss_save_path
}
# Create data processors
dp_train, dp_valid, dp_test = create_data_processors(data_config)
# Create data loaders
train_loader, valid_loader, test_loader = create_data_loaders(
dp_train, dp_valid, dp_test, args.batch_size
)
# Load data.pkl
import pickle
with open('data.pkl', 'rb') as f: # 'rb' means read in binary mode
loaded_data = pickle.load(f)
compound_cab = loaded_data['compound_cab']
# Create model
model, loss_compute = create_model(model_config, compound_cab, device)
# Train model
train_model(model, train_loader, valid_loader, loss_compute, train_config, device, len(compound_cab))
print("🎉 Training completed successfully!")
if __name__ == "__main__":
main()