112 lines
4.4 KiB
Python
112 lines
4.4 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Simplified training script runner
|
|
Contains preset configuration parameters for quick training startup
|
|
"""
|
|
|
|
import subprocess
|
|
import sys
|
|
import os
|
|
|
|
def run_training():
|
|
"""Run training script"""
|
|
|
|
# Preset configuration parameters
|
|
config = {
|
|
# Data paths (please modify according to your actual paths)
|
|
'ko_count_path': '/home/zzhang/gzy/Uncultured/generative_ML/data/gene_ko_protein/KO_count.csv',
|
|
'train_data_path': '/home/zzhang/gzy/Uncultured/generative_ML/encoder_decoder/ProtT5_encoder_fengzhuang/combine_medium/train_data.csv',
|
|
'valid_data_path': '/home/zzhang/gzy/Uncultured/generative_ML/encoder_decoder/ProtT5_encoder_fengzhuang/combine_medium/val_data.csv',
|
|
'test_data_path': '/home/zzhang/gzy/Uncultured/generative_ML/encoder_decoder/ProtT5_encoder_fengzhuang/combine_medium/test_data.csv',
|
|
'embedding_h5_path': '/home/zzhang/gzy/Uncultured/ko_pre_train/ProtT5_Encoder/all_genome_remove_nan_gene_ProtT5_embeddings.h5',
|
|
'vo_cab_pkl_path': '/home/zzhang/gzy/Uncultured/generative_ML/encoder_decoder/ProtT5_encoder_fengzhuang/combine_medium/data.pkl',
|
|
|
|
# Model parameters
|
|
'heads': 8,
|
|
'd_model': 256, # transformer internal dimension
|
|
'src_input_dim': 1024, # input dimension
|
|
'num_encoder_layer': 3,
|
|
'num_decoder_layer': 3,
|
|
'dropout': 0.1,
|
|
'label_smoothing': 0.0,
|
|
'first_linear_dim': 512,
|
|
|
|
# Training parameters
|
|
'epochs': 200,
|
|
'batch_size': 1,
|
|
'learning_rate': 1e-5, # smaller learning rate
|
|
'lr_patience': 3,
|
|
'lr_factor': 0.8,
|
|
'early_stopping_patience': 10,
|
|
'use_mixed_precision': False, # default off to avoid compatibility issues
|
|
|
|
# Data parameters
|
|
'count_min': 800,
|
|
'count_max': 4500,
|
|
'ko_max_len': 4500,
|
|
'compound_max_len': 113,
|
|
|
|
# Other parameters
|
|
'device': 'cuda:1',
|
|
'save_path': 'best_model.pt',
|
|
'loss_save_path': 'train_valid_info.csv',
|
|
'memory_fraction': 0.5
|
|
}
|
|
|
|
# Build command line parameters
|
|
cmd = ['python', 'train_model.py']
|
|
|
|
for key, value in config.items():
|
|
if key == 'use_mixed_precision':
|
|
if value:
|
|
cmd.append(f'--{key}')
|
|
else:
|
|
cmd.extend([f'--{key}', str(value)])
|
|
|
|
print("=== Training Configuration ===")
|
|
print("Model parameters:")
|
|
print(f" - heads: {config['heads']}")
|
|
print(f" - d_model: {config['d_model']} (transformer internal dimension)")
|
|
print(f" - src_input_dim: {config['src_input_dim']} (input dimension)")
|
|
print(f" - first_linear_dim: {config['first_linear_dim']} (first layer dimension)")
|
|
print(f" - encoder_layers: {config['num_encoder_layer']}")
|
|
print(f" - decoder_layers: {config['num_decoder_layer']}")
|
|
print(f" - dropout: {config['dropout']}")
|
|
print(f" - dimension transformation: {config['src_input_dim']} -> {config['first_linear_dim']} -> {config['d_model']}")
|
|
|
|
print("\nTraining parameters:")
|
|
print(f" - epochs: {config['epochs']}")
|
|
print(f" - batch_size: {config['batch_size']}")
|
|
print(f" - learning_rate: {config['learning_rate']}")
|
|
print(f" - use_mixed_precision: {config['use_mixed_precision']}")
|
|
|
|
print(f"\nSave path: {config['save_path']}")
|
|
print(f"Device: {config['device']}")
|
|
|
|
print("\n=== Starting Training ===")
|
|
print("Executing command:", ' '.join(cmd))
|
|
print()
|
|
|
|
# Execute training
|
|
try:
|
|
result = subprocess.run(cmd, check=True)
|
|
print("\n🎉 Training completed successfully!")
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"\n❌ Training failed: {e}")
|
|
sys.exit(1)
|
|
except KeyboardInterrupt:
|
|
print("\n⚠️ Training interrupted by user")
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Check if necessary files exist
|
|
required_files = ['train_model.py', 'datasets.py', 'model.py', 'trainer.py', 'utils.py']
|
|
missing_files = [f for f in required_files if not os.path.exists(f)]
|
|
|
|
if missing_files:
|
|
print(f"❌ Missing necessary files: {missing_files}")
|
|
sys.exit(1)
|
|
|
|
print("✅ All necessary files exist")
|
|
run_training() |