Files
labweb/models/dmodel_256_run_decoding.py
2025-12-16 11:39:15 +08:00

137 lines
4.7 KiB
Python

#!/usr/bin/env python3
# coding: utf-8
"""
ProtT5 Encoder-Decoder decoding
Support beam search and top-k/top-p sampling decoding methods
"""
import subprocess
import sys
import os
def run_decoding():
"""Run the decoding script"""
# Configuration parameters - Please modify these paths according to your actual situation
config = {
'data_pkl_path': 'data.pkl', # data.pkl file path
'model_path': 'best_model.pt', # The path to the trained model weight file
'test_data_path': 'test_data.csv', # Test data CSV file path
'output_path': './dmodel_256_decoding', # Output result file path
'test_genome_pkl_path':'test_genome_dict.pkl',
# Model parameters - need to be consistent with training
'heads': 8,
'd_model': 256,
'src_input_dim': 1024,
'first_linear_dim': 512,
'num_encoder_layer': 3,
'num_decoder_layer': 3,
'dropout': 0.1,
'compound_max_len': 113,
# Decoder parameters
'top_k': [8, 9, 10], # Top-K鍊?
'top_p': [0.9], # Top-P鍊?
'beam_size': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], # Beam size鍊?
'decode_methods': ['beam_search'], # 瑙g爜鏂规硶
# Other parameters
'device': 'cuda:1', # 璁惧
'batch_size': 1 # 鎵规澶у皬
}
# Build Commands
cmd = [
'python', 'dmodel_256_decode_model.py',
'--data_pkl_path', config['data_pkl_path'],
'--model_path', config['model_path'],
'--test_data_path', config['test_data_path'],
'--test_genome_pkl_path', config['test_genome_pkl_path'],
'--output_path', config['output_path'],
'--heads', str(config['heads']),
'--d_model', str(config['d_model']),
'--src_input_dim', str(config['src_input_dim']),
'--first_linear_dim', str(config['first_linear_dim']),
'--num_encoder_layer', str(config['num_encoder_layer']),
'--num_decoder_layer', str(config['num_decoder_layer']),
'--dropout', str(config['dropout']),
'--compound_max_len', str(config['compound_max_len']),
'--top_k'] + [str(k) for k in config['top_k']] + [
'--top_p'] + [str(p) for p in config['top_p']] + [
'--beam_size'] + [str(b) for b in config['beam_size']] + [
'--decode_methods'] + config['decode_methods'] + [
'--device', config['device'],
'--batch_size', str(config['batch_size'])
]
print("=== Start decoding ===")
print(f"Data files: {config['data_pkl_path']}")
print(f"Model file: {config['model_path']}")
print(f"Test data: {config['test_data_path']}")
print(f"Output File: {config['output_path']}")
print(f"device: {config['device']}")
print(f"Decoding method: {config['decode_methods']}")
print(f"Top-K value: {config['top_k']}")
print(f"Top-P value: {config['top_p']}")
print(f"Beam size value: {config['beam_size']}")
try:
# Run the decoding script
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
print("Decoding completed successfully!")
print(f"Results saved to: {config['output_path']}")
except subprocess.CalledProcessError as e:
print(f"Decoding failed: {e}")
print(f"Error Output: {e.stderr}")
return False
except FileNotFoundError:
print("The decode_model.py file cannot be found. Please make sure the file is in the current directory.")
return False
return True
def check_files():
"""Check if necessary files exist"""
required_files = [
'dmodel_256_decode_model.py',
'data.pkl',
'test_data.csv'
]
missing_files = []
for file in required_files:
if not os.path.exists(file):
missing_files.append(file)
if missing_files:
print("Missing necessary documents:")
for file in missing_files:
print(f" - {file}")
return False
return True
def main():
print("ProtT5 Encoder-Decoder Model decoding tool")
print("=" * 50)
# 妫€鏌ユ枃浠?
if not check_files():
print("\nPlease make sure all necessary files are present before running the decoder.")
return
# 杩愯瑙g爜
success = run_decoding()
if success:
print("\nDecoding task completed!")
print("You can view the decoded results in the output file.")
else:
print("\nDecoding task failed, please check the error message.")
if __name__ == "__main__":
main()