489 lines
20 KiB
Python
489 lines
20 KiB
Python
#!/usr/bin/env python3
|
||
# coding: utf-8
|
||
|
||
"""
|
||
ProtT5 Encoder-Decoder decoding
|
||
Support beam search and top-k/top-p sampling decoding methods
|
||
"""
|
||
|
||
import argparse
|
||
import pickle
|
||
import torch
|
||
import pandas as pd
|
||
import numpy as np
|
||
from tqdm import tqdm
|
||
from collections import Counter
|
||
import ast
|
||
from torch.cuda.amp import autocast
|
||
import torch.nn.functional as F
|
||
import torch.nn as nn
|
||
from model import Model
|
||
from datasets import DataProcessor, build_single_lazy_dataloader
|
||
import os
|
||
|
||
|
||
def get_key_padding_mask(tokens, pad):
|
||
"""generate key padding mask"""
|
||
key_padding_mask = torch.zeros(tokens.size())
|
||
key_padding_mask[tokens == pad] = -float('inf')
|
||
return key_padding_mask
|
||
|
||
|
||
def get_sequencing_mask(tgt, d_model):
|
||
"""generate sequence mask"""
|
||
sequencing_mask = nn.Transformer(d_model=d_model, batch_first=True).generate_square_subsequent_mask(tgt.size(-1))
|
||
return sequencing_mask
|
||
|
||
|
||
def top_k_top_p_filtering(probabilities, top_k=0, top_p=1.0, min_tokens_to_keep=1):
|
||
"""
|
||
Filters the probability distribution using top-k and top-p.
|
||
|
||
Args:
|
||
probabilities: Probability distribution tensor
|
||
top_k: Keeps the top k tokens with the highest probability
|
||
top_p: Keeps tokens with cumulative probability reaching top_p
|
||
min_tokens_to_keep: Ensures that at least one token is kept
|
||
|
||
Returns:
|
||
Filtered probability distribution
|
||
"""
|
||
# Top-K filter
|
||
if top_k > 0:
|
||
top_k = min(top_k, probabilities.size(-1))
|
||
values, indices = torch.topk(probabilities, top_k)
|
||
min_threshold = values[..., -1, None]
|
||
probabilities[probabilities < min_threshold] = 0
|
||
|
||
# Top-P (nucleus) filter
|
||
if top_p < 1.0:
|
||
sorted_probs, sorted_indices = torch.sort(probabilities, descending=True)
|
||
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
||
|
||
sorted_indices_to_remove = cumulative_probs > top_p
|
||
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
||
|
||
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
||
probabilities[indices_to_remove] = 0
|
||
|
||
return probabilities
|
||
|
||
|
||
class Generator_topk(nn.Module):
|
||
def __init__(self, d_model, vocab):
|
||
super(Generator_topk, self).__init__()
|
||
self.proj = nn.Linear(d_model, vocab)
|
||
|
||
def forward(self, x):
|
||
return F.softmax(self.proj(x), dim=-1)
|
||
|
||
|
||
class TopKtopPGenerator:
|
||
"""Top-K and Top-P sampling generator"""
|
||
def __init__(self, model, d_model, generator, device, tgt_mask_index,
|
||
repetition_penalty=1.3, start_symbol=None, max_length=98, top_k=10, top_p=0.9, eos_token=None):
|
||
self.model = model
|
||
self.d_model = d_model
|
||
self.generator = generator
|
||
self.device = device
|
||
self.tgt_mask_index = tgt_mask_index
|
||
self.repetition_penalty = repetition_penalty
|
||
self.start_symbol = start_symbol
|
||
self.max_length = max_length
|
||
self.top_k = top_k
|
||
self.top_p = top_p
|
||
self.eos_token = eos_token
|
||
|
||
def _model_decoder(self, ys, memory, d_model):
|
||
"""Internal decoding method"""
|
||
ys_pad_mask = ~(get_key_padding_mask(ys, self.tgt_mask_index) != float('-inf')).to(self.device)
|
||
ys_mask = ~(get_sequencing_mask(ys, d_model) != float('-inf')).to(self.device)
|
||
ys = self.model.tgt_pos(self.model.tgt_embedding(ys))
|
||
|
||
decoder_output = self.model.transformer.decoder(
|
||
tgt=ys, tgt_mask=ys_mask, tgt_key_padding_mask=ys_pad_mask, memory=memory
|
||
)
|
||
out = self.generator(decoder_output)
|
||
|
||
return out
|
||
|
||
def generate(self, src, src_pad_mask):
|
||
"""Generate text using Top-K and Top-P sampling"""
|
||
self.model.eval()
|
||
with torch.no_grad():
|
||
src = src.to(self.device)
|
||
src_pad_mask = src_pad_mask.to(self.device)
|
||
|
||
# Convert the dimension of the src sequence
|
||
src_transformed = self.model.src_linear1(src)
|
||
src_transformed = self.model.src_activation(src_transformed)
|
||
src_transformed = self.model.src_dropout(src_transformed)
|
||
src_transformed = self.model.src_linear2(src_transformed)
|
||
|
||
memory = self.model.transformer.encoder(src=src_transformed, src_key_padding_mask=src_pad_mask)
|
||
|
||
# Start from the start symbol
|
||
ys = torch.ones(1, 1).fill_(self.start_symbol).long().to(self.device)
|
||
generated_tokens = []
|
||
|
||
for _ in range(self.max_length):
|
||
prob = self._model_decoder(ys, memory, self.d_model)[:, -1, :]
|
||
|
||
# Apply duplicate penalties
|
||
for token in generated_tokens:
|
||
prob[0, token] /= self.repetition_penalty
|
||
|
||
# Apply Top-K and Top-P filtering
|
||
filtered_probs = top_k_top_p_filtering(prob.squeeze(0), top_k=self.top_k, top_p=self.top_p)
|
||
|
||
# Sample from the filtered distribution
|
||
next_token_id = torch.multinomial(filtered_probs, num_samples=1).item()
|
||
|
||
# Add the sampled token to the sequence
|
||
ys = torch.cat([ys, torch.ones(1, 1).fill_(next_token_id).long().to(self.device)], dim=1)
|
||
generated_tokens.append(next_token_id)
|
||
|
||
# Stop if the end symbol is generated
|
||
if self.eos_token is not None and next_token_id == self.eos_token:
|
||
break
|
||
|
||
return ys
|
||
|
||
|
||
class Beam_search(nn.Module):
|
||
"""Beam search Decoder"""
|
||
|
||
def __init__(self, model, trg_pad_idx, trg_bos_idx, trg_eos_idx, device,
|
||
d_model=256, beam_size=3, max_length=98, alpha=0.7):
|
||
super(Beam_search, self).__init__()
|
||
self.model = model.to(device)
|
||
self.trg_pad_idx = trg_pad_idx
|
||
self.trg_bos_idx = trg_bos_idx
|
||
self.trg_eos_idx = trg_eos_idx
|
||
self.device = device
|
||
self.d_model = d_model
|
||
self.beam_size = beam_size
|
||
self.max_length = max_length
|
||
self.alpha = alpha
|
||
|
||
# Initialize the sequence
|
||
self.register_buffer('init_seq', torch.LongTensor([[trg_bos_idx]]).to(device))
|
||
self.register_buffer(
|
||
'blank_seqs',
|
||
torch.full((beam_size, max_length), trg_pad_idx, dtype=torch.long).to(device)
|
||
)
|
||
self.blank_seqs[:, 0] = self.trg_bos_idx
|
||
|
||
self.register_buffer(
|
||
'len_map',
|
||
torch.arange(1, max_length + 1, dtype=torch.long).unsqueeze(0).to(device)
|
||
)
|
||
|
||
def frequency_penalty(self, frequency, alpha):
|
||
"""Frequency penalty"""
|
||
if frequency > 1:
|
||
penalty = alpha * frequency
|
||
else:
|
||
penalty = 0
|
||
return penalty
|
||
|
||
def _model_decoder(self, model, ys, memory, d_model):
|
||
"""Model decoding"""
|
||
ys_pad_mask = ~(get_key_padding_mask(ys, self.trg_pad_idx) != float('-inf')).to(self.device)
|
||
ys_mask = ~(get_sequencing_mask(ys, d_model) != float('-inf')).to(self.device)
|
||
ys = model.tgt_pos(model.tgt_embedding(ys))
|
||
|
||
decoder_output = model.transformer.decoder(
|
||
tgt=ys, tgt_mask=ys_mask, tgt_key_padding_mask=ys_pad_mask, memory=memory
|
||
)
|
||
out = model.generator(decoder_output)
|
||
return out
|
||
|
||
def get_init_state(self, beam_size, model, ys, memory):
|
||
"""Get the initial state"""
|
||
out = self._model_decoder(model=model, ys=ys, memory=memory, d_model=self.d_model)
|
||
best_k_probs, best_k_idx = out[:, -1, :].topk(beam_size)
|
||
scores = best_k_probs.view(beam_size)
|
||
gen_seq = self.blank_seqs.clone().detach()
|
||
gen_seq[:, 1] = best_k_idx[0]
|
||
|
||
return gen_seq, scores
|
||
|
||
def get_best_score_and_idx(self, beam_size, dec_output, scores, gen_seq, step):
|
||
"""Get the best score and index"""
|
||
dec_output = dec_output[:, -1, :]
|
||
for beam in range(dec_output.size(0)):
|
||
count = Counter(gen_seq[:, :step][beam].tolist())
|
||
for com in count:
|
||
penalty = self.frequency_penalty(count[com], 0.7)
|
||
dec_output[beam][com] -= penalty
|
||
|
||
best_k2_probs, best_k2_idx = dec_output.topk(beam_size)
|
||
scores = best_k2_probs.view(beam_size, -1) + scores.view(beam_size, 1)
|
||
scores, best_k_idx_in_k2 = scores.view(-1).topk(beam_size)
|
||
best_k_r_idxs = torch.div(best_k_idx_in_k2, beam_size, rounding_mode='trunc')
|
||
best_k_c_idxs = best_k_idx_in_k2 % beam_size
|
||
best_k_idx = best_k2_idx[best_k_r_idxs, best_k_c_idxs]
|
||
gen_seq[:, :step] = gen_seq[best_k_r_idxs, :step]
|
||
gen_seq[:, step] = best_k_idx
|
||
|
||
return gen_seq, scores
|
||
|
||
def beam_search(self, src, src_pad_mask):
|
||
"""Perform a beam search"""
|
||
self.model.eval()
|
||
|
||
trg_bos_idx, trg_eos_idx = self.trg_bos_idx, self.trg_eos_idx
|
||
max_seq_len, beam_size, alpha = self.max_length, self.beam_size, self.alpha
|
||
|
||
with torch.no_grad():
|
||
# calculate memory
|
||
src = src.to(self.device)
|
||
src_pad_mask = src_pad_mask.to(self.device)
|
||
|
||
# Convert the dimension of the src sequence
|
||
src_transformed = self.model.src_linear1(src)
|
||
src_transformed = self.model.src_activation(src_transformed)
|
||
src_transformed = self.model.src_dropout(src_transformed)
|
||
src_transformed = self.model.src_linear2(src_transformed)
|
||
|
||
memory = self.model.transformer.encoder(src=src_transformed, src_key_padding_mask=src_pad_mask)
|
||
|
||
# Starting character calculation
|
||
ys = torch.ones(1, 1).fill_(trg_bos_idx).type_as(src.data).long()
|
||
gen_seq, scores = self.get_init_state(beam_size=beam_size, model=self.model, ys=ys, memory=memory)
|
||
|
||
# Loop calculation
|
||
src_expanded = src.repeat(beam_size, 1, 1)
|
||
src_expanded_pad_mask = src_pad_mask.repeat(beam_size, 1)
|
||
src_expanded_pad_mask = src_expanded_pad_mask.to(self.device)
|
||
|
||
src_expanded_transformed = self.model.src_linear1(src_expanded)
|
||
src_expanded_transformed = self.model.src_activation(src_expanded_transformed)
|
||
src_expanded_transformed = self.model.src_dropout(src_expanded_transformed)
|
||
src_expanded_transformed = self.model.src_linear2(src_expanded_transformed)
|
||
|
||
memory_expanded = self.model.transformer.encoder(
|
||
src=src_expanded_transformed, src_key_padding_mask=src_expanded_pad_mask
|
||
)
|
||
|
||
for step in range(2, max_seq_len):
|
||
dec_output = self._model_decoder(
|
||
model=self.model, ys=gen_seq[:, :step], memory=memory_expanded, d_model=self.d_model
|
||
)
|
||
|
||
gen_seq, scores = self.get_best_score_and_idx(
|
||
beam_size=beam_size, dec_output=dec_output, scores=scores, step=step, gen_seq=gen_seq
|
||
)
|
||
|
||
# Determine whether the end character is encountered
|
||
eos_locs = gen_seq == trg_eos_idx
|
||
seq_lens, _ = self.len_map.masked_fill(~eos_locs, max_seq_len).min(1)
|
||
|
||
if (eos_locs.sum(1) > 0).sum(0).item() == beam_size:
|
||
_, ans_idx = scores.div(seq_lens.float() ** alpha).max(0)
|
||
ans_idx = ans_idx.item()
|
||
break
|
||
|
||
return gen_seq[0][:seq_lens[ans_idx]].tolist()
|
||
|
||
|
||
def keys_valus(tgt_dict, key):
|
||
"""Search for a key by value"""
|
||
value = [k for k, v in tgt_dict.items() if v == key]
|
||
return value
|
||
|
||
|
||
def medium(b, tgt_dict):
|
||
"""Convert a tensor to a compound name"""
|
||
l = []
|
||
for i in b:
|
||
tensor = i
|
||
l.extend(keys_valus(tgt_dict, tensor))
|
||
l = [item for item in l if item not in ['</s>', '<s>', 'blank']]
|
||
return l
|
||
|
||
|
||
def load_model(model_path, device, model_config, compound_cab):
|
||
"""Load model"""
|
||
print("=== Load model ===")
|
||
|
||
model = Model(
|
||
heads=model_config['heads'],
|
||
d_model=model_config['d_model'],
|
||
tgt_vocab_size=len(compound_cab),
|
||
num_encoder_layer=model_config['num_encoder_layer'],
|
||
num_decoder_layer=model_config['num_decoder_layer'],
|
||
dropout=model_config['dropout'],
|
||
compound_cab=compound_cab,
|
||
src_input_dim=model_config['src_input_dim'],
|
||
first_linear_dim=model_config['first_linear_dim']
|
||
).to(device)
|
||
|
||
model.load_state_dict(torch.load(model_path, map_location=device))
|
||
model.eval()
|
||
|
||
print(f"Model loaded: {model_path}")
|
||
print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")
|
||
|
||
return model
|
||
|
||
|
||
def evaluate_model(model, test_loader, compound_cab, device, model_config,
|
||
top_k_values, top_p_values, beam_size_values, decode_methods, output_path):
|
||
"""Evaluate the model and generate results with different decoding strategies, return combined final_df."""
|
||
print("=== Start decoding ===")
|
||
combined_df = pd.DataFrame()
|
||
|
||
# 处理 beam search
|
||
if "beam_search" in decode_methods:
|
||
for beam_size in beam_size_values:
|
||
results = {'genome': []}
|
||
column_name = f"beam_{beam_size}"
|
||
results[column_name] = []
|
||
|
||
beamsearch = Beam_search(
|
||
model=model,
|
||
trg_pad_idx=compound_cab['blank'],
|
||
trg_bos_idx=compound_cab['</s>'],
|
||
trg_eos_idx=compound_cab['<s>'],
|
||
device=device,
|
||
d_model=model_config['d_model'],
|
||
beam_size=beam_size,
|
||
max_length=model_config['compound_max_len'],
|
||
alpha=0.7
|
||
)
|
||
|
||
for b in tqdm(test_loader, desc=f"Beam {beam_size} decoding"):
|
||
results['genome'].append(b[2][0])
|
||
src_pad_mask = b[1]
|
||
generate = beamsearch.beam_search(b[0], src_pad_mask)
|
||
pred_medium = medium(generate, compound_cab)
|
||
results[column_name].append(pred_medium)
|
||
|
||
beam_df = pd.DataFrame.from_dict(results, orient="index").transpose()
|
||
if combined_df.empty:
|
||
combined_df = beam_df
|
||
else:
|
||
combined_df = pd.merge(combined_df, beam_df, on='genome', how='outer')
|
||
|
||
# output_file_name = os.path.join(output_path, f"beam_{beam_size}.csv")
|
||
# beam_df.to_csv(output_file_name, index=False)
|
||
torch.cuda.empty_cache()
|
||
|
||
# 处理 top-k/top-p sampling
|
||
if "top_k_top_p" in decode_methods:
|
||
for top_k in top_k_values:
|
||
for top_p in top_p_values:
|
||
results = {'genome': []}
|
||
column_name = f"top_{top_k}_top_{top_p}"
|
||
results[column_name] = []
|
||
|
||
generator = Generator_topk(model_config['d_model'], len(compound_cab)).to(device)
|
||
topKtopP = TopKtopPGenerator(
|
||
model=model,
|
||
d_model=model_config['d_model'],
|
||
generator=generator,
|
||
device=device,
|
||
tgt_mask_index=compound_cab['blank'],
|
||
repetition_penalty=1.3,
|
||
start_symbol=compound_cab['</s>'],
|
||
eos_token=compound_cab['<s>']
|
||
)
|
||
|
||
for b in tqdm(test_loader, desc=f"Top-k {top_k} Top-p {top_p} decoding"):
|
||
results['genome'].append(b[2][0]) # 保留genome信息(与beam search对齐)
|
||
src_pad_mask = b[1]
|
||
generate = topKtopP.generate(b[0], src_pad_mask)
|
||
pred_medium = medium(generate, compound_cab)
|
||
results[column_name].append(pred_medium)
|
||
|
||
topk_df = pd.DataFrame.from_dict(results, orient="index").transpose()
|
||
if combined_df.empty:
|
||
combined_df = topk_df
|
||
else:
|
||
combined_df = pd.merge(combined_df, topk_df, on='genome', how='outer')
|
||
|
||
# 保存当前top-k/top-p的结果
|
||
# output_file_name = os.path.join(output_path, f"top_{top_k}_top_{top_p}.csv")
|
||
# topk_df.to_csv(output_file_name, index=False)
|
||
torch.cuda.empty_cache()
|
||
|
||
return combined_df
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='ProtT5 Encoder-Decoder model decoder')
|
||
|
||
# Data path parameters
|
||
parser.add_argument('--data_pkl_path', type=str, required=True, help='data.pkl file path')
|
||
parser.add_argument('--model_path', type=str, required=True, help='Model weight file path')
|
||
parser.add_argument('--test_data_path', type=str, required=True, help='Test data CSV file path')
|
||
parser.add_argument('--output_path', type=str, required=True, help='Output result CSV file dir')
|
||
parser.add_argument('--test_genome_pkl_path', type=str, required=True, help='test_genome_dict.pkl file path')
|
||
|
||
# Model parameters
|
||
parser.add_argument('--heads', type=int, default=8, help='Number of attention heads')
|
||
parser.add_argument('--d_model', type=int, default=64, help='Model Dimensions')
|
||
parser.add_argument('--src_input_dim', type=int, default=1024, help='Input Dimension')
|
||
parser.add_argument('--first_linear_dim', type=int, default=256, help='The first linear dimension')
|
||
parser.add_argument('--num_encoder_layer', type=int, default=3, help='Number of encoder layers')
|
||
parser.add_argument('--num_decoder_layer', type=int, default=3, help='Number of decoder layers')
|
||
parser.add_argument('--dropout', type=float, default=0.1, help='Dropout')
|
||
parser.add_argument('--compound_max_len', type=int, default=98, help='Maximum length of nutrient sequence')
|
||
|
||
# Decoder parameters
|
||
parser.add_argument('--top_k', nargs='+', type=int, default=[8, 9, 10], help='Top-K value list')
|
||
parser.add_argument('--top_p', nargs='+', type=float, default=[0.9], help='Top-P value list')
|
||
parser.add_argument('--beam_size', nargs='+', type=int, default=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], help='Beam size value list')
|
||
parser.add_argument('--decode_methods', nargs='+', choices=['top_k_top_p', 'beam_search'],
|
||
default=['top_k_top_p', 'beam_search'], help='Decoding method')
|
||
|
||
# Other parameters
|
||
parser.add_argument('--device', type=str, default='cuda:0', help='device')
|
||
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# Setting the device
|
||
device = torch.device(args.device)
|
||
print(f"Using the device: {device}")
|
||
|
||
# Load data
|
||
print("=== Load data ===")
|
||
|
||
# Creating a Data Loader
|
||
|
||
print(f"Number of test set batches: {len(test_loader)}")
|
||
|
||
# genome dict
|
||
with open(args.test_genome_pkl_path, 'rb') as t:
|
||
genome_cab = pickle.load(t)
|
||
|
||
# Model Configuration
|
||
model_config = {
|
||
'heads': args.heads,
|
||
'd_model': args.d_model,
|
||
'src_input_dim': args.src_input_dim,
|
||
'first_linear_dim': args.first_linear_dim,
|
||
'num_encoder_layer': args.num_encoder_layer,
|
||
'num_decoder_layer': args.num_decoder_layer,
|
||
'dropout': args.dropout,
|
||
'compound_max_len': args.compound_max_len
|
||
}
|
||
|
||
# Load model
|
||
model = load_model(args.model_path, device, model_config, compound_cab)
|
||
|
||
# Generate
|
||
combined_df = evaluate_model(
|
||
model, test_loader, compound_cab, genome_cab, device, model_config,
|
||
args.top_k, args.top_p, args.beam_size, args.decode_methods, args.output_path
|
||
)
|
||
|
||
print("Decoding complete!")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|