first commit
This commit is contained in:
136
models/dmodel_256_run_decoding.py
Normal file
136
models/dmodel_256_run_decoding.py
Normal file
@@ -0,0 +1,136 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding: utf-8
|
||||
|
||||
"""
|
||||
ProtT5 Encoder-Decoder decoding
|
||||
Support beam search and top-k/top-p sampling decoding methods
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
|
||||
def run_decoding():
|
||||
"""Run the decoding script"""
|
||||
|
||||
# Configuration parameters - Please modify these paths according to your actual situation
|
||||
config = {
|
||||
'data_pkl_path': 'data.pkl', # data.pkl file path
|
||||
'model_path': 'best_model.pt', # The path to the trained model weight file
|
||||
'test_data_path': 'test_data.csv', # Test data CSV file path
|
||||
'output_path': './dmodel_256_decoding', # Output result file path
|
||||
'test_genome_pkl_path':'test_genome_dict.pkl',
|
||||
|
||||
# Model parameters - need to be consistent with training
|
||||
'heads': 8,
|
||||
'd_model': 256,
|
||||
'src_input_dim': 1024,
|
||||
'first_linear_dim': 512,
|
||||
'num_encoder_layer': 3,
|
||||
'num_decoder_layer': 3,
|
||||
'dropout': 0.1,
|
||||
'compound_max_len': 113,
|
||||
|
||||
# Decoder parameters
|
||||
'top_k': [8, 9, 10], # Top-K鍊?
|
||||
'top_p': [0.9], # Top-P鍊?
|
||||
'beam_size': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], # Beam size鍊?
|
||||
'decode_methods': ['beam_search'], # 瑙g爜鏂规硶
|
||||
|
||||
# Other parameters
|
||||
'device': 'cuda:1', # 璁惧
|
||||
'batch_size': 1 # 鎵规澶у皬
|
||||
}
|
||||
|
||||
# Build Commands
|
||||
cmd = [
|
||||
'python', 'dmodel_256_decode_model.py',
|
||||
'--data_pkl_path', config['data_pkl_path'],
|
||||
'--model_path', config['model_path'],
|
||||
'--test_data_path', config['test_data_path'],
|
||||
'--test_genome_pkl_path', config['test_genome_pkl_path'],
|
||||
'--output_path', config['output_path'],
|
||||
'--heads', str(config['heads']),
|
||||
'--d_model', str(config['d_model']),
|
||||
'--src_input_dim', str(config['src_input_dim']),
|
||||
'--first_linear_dim', str(config['first_linear_dim']),
|
||||
'--num_encoder_layer', str(config['num_encoder_layer']),
|
||||
'--num_decoder_layer', str(config['num_decoder_layer']),
|
||||
'--dropout', str(config['dropout']),
|
||||
'--compound_max_len', str(config['compound_max_len']),
|
||||
'--top_k'] + [str(k) for k in config['top_k']] + [
|
||||
'--top_p'] + [str(p) for p in config['top_p']] + [
|
||||
'--beam_size'] + [str(b) for b in config['beam_size']] + [
|
||||
'--decode_methods'] + config['decode_methods'] + [
|
||||
'--device', config['device'],
|
||||
'--batch_size', str(config['batch_size'])
|
||||
]
|
||||
|
||||
print("=== Start decoding ===")
|
||||
print(f"Data files: {config['data_pkl_path']}")
|
||||
print(f"Model file: {config['model_path']}")
|
||||
print(f"Test data: {config['test_data_path']}")
|
||||
print(f"Output File: {config['output_path']}")
|
||||
print(f"device: {config['device']}")
|
||||
print(f"Decoding method: {config['decode_methods']}")
|
||||
print(f"Top-K value: {config['top_k']}")
|
||||
print(f"Top-P value: {config['top_p']}")
|
||||
print(f"Beam size value: {config['beam_size']}")
|
||||
|
||||
try:
|
||||
# Run the decoding script
|
||||
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||||
print("Decoding completed successfully!")
|
||||
print(f"Results saved to: {config['output_path']}")
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Decoding failed: {e}")
|
||||
print(f"Error Output: {e.stderr}")
|
||||
return False
|
||||
except FileNotFoundError:
|
||||
print("The decode_model.py file cannot be found. Please make sure the file is in the current directory.")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def check_files():
|
||||
"""Check if necessary files exist"""
|
||||
required_files = [
|
||||
'dmodel_256_decode_model.py',
|
||||
'data.pkl',
|
||||
'test_data.csv'
|
||||
]
|
||||
|
||||
missing_files = []
|
||||
for file in required_files:
|
||||
if not os.path.exists(file):
|
||||
missing_files.append(file)
|
||||
|
||||
if missing_files:
|
||||
print("Missing necessary documents:")
|
||||
for file in missing_files:
|
||||
print(f" - {file}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def main():
|
||||
print("ProtT5 Encoder-Decoder Model decoding tool")
|
||||
print("=" * 50)
|
||||
|
||||
# 妫€鏌ユ枃浠?
|
||||
if not check_files():
|
||||
print("\nPlease make sure all necessary files are present before running the decoder.")
|
||||
return
|
||||
|
||||
# 杩愯瑙g爜
|
||||
success = run_decoding()
|
||||
|
||||
if success:
|
||||
print("\nDecoding task completed!")
|
||||
print("You can view the decoded results in the output file.")
|
||||
else:
|
||||
print("\nDecoding task failed, please check the error message.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user