137 lines
4.7 KiB
Python
137 lines
4.7 KiB
Python
#!/usr/bin/env python3
|
|
# coding: utf-8
|
|
|
|
"""
|
|
ProtT5 Encoder-Decoder decoding
|
|
Support beam search and top-k/top-p sampling decoding methods
|
|
"""
|
|
|
|
import subprocess
|
|
import sys
|
|
import os
|
|
|
|
def run_decoding():
|
|
"""Run the decoding script"""
|
|
|
|
# Configuration parameters - Please modify these paths according to your actual situation
|
|
config = {
|
|
'data_pkl_path': 'data.pkl', # data.pkl file path
|
|
'model_path': 'best_model.pt', # The path to the trained model weight file
|
|
'test_data_path': 'test_data.csv', # Test data CSV file path
|
|
'output_path': './dmodel_256_decoding', # Output result file path
|
|
'test_genome_pkl_path':'test_genome_dict.pkl',
|
|
|
|
# Model parameters - need to be consistent with training
|
|
'heads': 8,
|
|
'd_model': 256,
|
|
'src_input_dim': 1024,
|
|
'first_linear_dim': 512,
|
|
'num_encoder_layer': 3,
|
|
'num_decoder_layer': 3,
|
|
'dropout': 0.1,
|
|
'compound_max_len': 113,
|
|
|
|
# Decoder parameters
|
|
'top_k': [8, 9, 10], # Top-K鍊?
|
|
'top_p': [0.9], # Top-P鍊?
|
|
'beam_size': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], # Beam size鍊?
|
|
'decode_methods': ['beam_search'], # 瑙g爜鏂规硶
|
|
|
|
# Other parameters
|
|
'device': 'cuda:1', # 璁惧
|
|
'batch_size': 1 # 鎵规澶у皬
|
|
}
|
|
|
|
# Build Commands
|
|
cmd = [
|
|
'python', 'dmodel_256_decode_model.py',
|
|
'--data_pkl_path', config['data_pkl_path'],
|
|
'--model_path', config['model_path'],
|
|
'--test_data_path', config['test_data_path'],
|
|
'--test_genome_pkl_path', config['test_genome_pkl_path'],
|
|
'--output_path', config['output_path'],
|
|
'--heads', str(config['heads']),
|
|
'--d_model', str(config['d_model']),
|
|
'--src_input_dim', str(config['src_input_dim']),
|
|
'--first_linear_dim', str(config['first_linear_dim']),
|
|
'--num_encoder_layer', str(config['num_encoder_layer']),
|
|
'--num_decoder_layer', str(config['num_decoder_layer']),
|
|
'--dropout', str(config['dropout']),
|
|
'--compound_max_len', str(config['compound_max_len']),
|
|
'--top_k'] + [str(k) for k in config['top_k']] + [
|
|
'--top_p'] + [str(p) for p in config['top_p']] + [
|
|
'--beam_size'] + [str(b) for b in config['beam_size']] + [
|
|
'--decode_methods'] + config['decode_methods'] + [
|
|
'--device', config['device'],
|
|
'--batch_size', str(config['batch_size'])
|
|
]
|
|
|
|
print("=== Start decoding ===")
|
|
print(f"Data files: {config['data_pkl_path']}")
|
|
print(f"Model file: {config['model_path']}")
|
|
print(f"Test data: {config['test_data_path']}")
|
|
print(f"Output File: {config['output_path']}")
|
|
print(f"device: {config['device']}")
|
|
print(f"Decoding method: {config['decode_methods']}")
|
|
print(f"Top-K value: {config['top_k']}")
|
|
print(f"Top-P value: {config['top_p']}")
|
|
print(f"Beam size value: {config['beam_size']}")
|
|
|
|
try:
|
|
# Run the decoding script
|
|
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
|
print("Decoding completed successfully!")
|
|
print(f"Results saved to: {config['output_path']}")
|
|
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"Decoding failed: {e}")
|
|
print(f"Error Output: {e.stderr}")
|
|
return False
|
|
except FileNotFoundError:
|
|
print("The decode_model.py file cannot be found. Please make sure the file is in the current directory.")
|
|
return False
|
|
|
|
return True
|
|
|
|
def check_files():
|
|
"""Check if necessary files exist"""
|
|
required_files = [
|
|
'dmodel_256_decode_model.py',
|
|
'data.pkl',
|
|
'test_data.csv'
|
|
]
|
|
|
|
missing_files = []
|
|
for file in required_files:
|
|
if not os.path.exists(file):
|
|
missing_files.append(file)
|
|
|
|
if missing_files:
|
|
print("Missing necessary documents:")
|
|
for file in missing_files:
|
|
print(f" - {file}")
|
|
return False
|
|
|
|
return True
|
|
|
|
def main():
|
|
print("ProtT5 Encoder-Decoder Model decoding tool")
|
|
print("=" * 50)
|
|
|
|
# 妫€鏌ユ枃浠?
|
|
if not check_files():
|
|
print("\nPlease make sure all necessary files are present before running the decoder.")
|
|
return
|
|
|
|
# 杩愯瑙g爜
|
|
success = run_decoding()
|
|
|
|
if success:
|
|
print("\nDecoding task completed!")
|
|
print("You can view the decoded results in the output file.")
|
|
else:
|
|
print("\nDecoding task failed, please check the error message.")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|