Files
labweb/models/dmodel_256_decode_model.py
2025-12-16 11:39:15 +08:00

489 lines
20 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# coding: utf-8
"""
ProtT5 Encoder-Decoder decoding
Support beam search and top-k/top-p sampling decoding methods
"""
import argparse
import pickle
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
from collections import Counter
import ast
from torch.cuda.amp import autocast
import torch.nn.functional as F
import torch.nn as nn
from model import Model
from datasets import DataProcessor, build_single_lazy_dataloader
import os
def get_key_padding_mask(tokens, pad):
"""generate key padding mask"""
key_padding_mask = torch.zeros(tokens.size())
key_padding_mask[tokens == pad] = -float('inf')
return key_padding_mask
def get_sequencing_mask(tgt, d_model):
"""generate sequence mask"""
sequencing_mask = nn.Transformer(d_model=d_model, batch_first=True).generate_square_subsequent_mask(tgt.size(-1))
return sequencing_mask
def top_k_top_p_filtering(probabilities, top_k=0, top_p=1.0, min_tokens_to_keep=1):
"""
Filters the probability distribution using top-k and top-p.
Args:
probabilities: Probability distribution tensor
top_k: Keeps the top k tokens with the highest probability
top_p: Keeps tokens with cumulative probability reaching top_p
min_tokens_to_keep: Ensures that at least one token is kept
Returns:
Filtered probability distribution
"""
# Top-K filter
if top_k > 0:
top_k = min(top_k, probabilities.size(-1))
values, indices = torch.topk(probabilities, top_k)
min_threshold = values[..., -1, None]
probabilities[probabilities < min_threshold] = 0
# Top-P (nucleus) filter
if top_p < 1.0:
sorted_probs, sorted_indices = torch.sort(probabilities, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
probabilities[indices_to_remove] = 0
return probabilities
class Generator_topk(nn.Module):
def __init__(self, d_model, vocab):
super(Generator_topk, self).__init__()
self.proj = nn.Linear(d_model, vocab)
def forward(self, x):
return F.softmax(self.proj(x), dim=-1)
class TopKtopPGenerator:
"""Top-K and Top-P sampling generator"""
def __init__(self, model, d_model, generator, device, tgt_mask_index,
repetition_penalty=1.3, start_symbol=None, max_length=98, top_k=10, top_p=0.9, eos_token=None):
self.model = model
self.d_model = d_model
self.generator = generator
self.device = device
self.tgt_mask_index = tgt_mask_index
self.repetition_penalty = repetition_penalty
self.start_symbol = start_symbol
self.max_length = max_length
self.top_k = top_k
self.top_p = top_p
self.eos_token = eos_token
def _model_decoder(self, ys, memory, d_model):
"""Internal decoding method"""
ys_pad_mask = ~(get_key_padding_mask(ys, self.tgt_mask_index) != float('-inf')).to(self.device)
ys_mask = ~(get_sequencing_mask(ys, d_model) != float('-inf')).to(self.device)
ys = self.model.tgt_pos(self.model.tgt_embedding(ys))
decoder_output = self.model.transformer.decoder(
tgt=ys, tgt_mask=ys_mask, tgt_key_padding_mask=ys_pad_mask, memory=memory
)
out = self.generator(decoder_output)
return out
def generate(self, src, src_pad_mask):
"""Generate text using Top-K and Top-P sampling"""
self.model.eval()
with torch.no_grad():
src = src.to(self.device)
src_pad_mask = src_pad_mask.to(self.device)
# Convert the dimension of the src sequence
src_transformed = self.model.src_linear1(src)
src_transformed = self.model.src_activation(src_transformed)
src_transformed = self.model.src_dropout(src_transformed)
src_transformed = self.model.src_linear2(src_transformed)
memory = self.model.transformer.encoder(src=src_transformed, src_key_padding_mask=src_pad_mask)
# Start from the start symbol
ys = torch.ones(1, 1).fill_(self.start_symbol).long().to(self.device)
generated_tokens = []
for _ in range(self.max_length):
prob = self._model_decoder(ys, memory, self.d_model)[:, -1, :]
# Apply duplicate penalties
for token in generated_tokens:
prob[0, token] /= self.repetition_penalty
# Apply Top-K and Top-P filtering
filtered_probs = top_k_top_p_filtering(prob.squeeze(0), top_k=self.top_k, top_p=self.top_p)
# Sample from the filtered distribution
next_token_id = torch.multinomial(filtered_probs, num_samples=1).item()
# Add the sampled token to the sequence
ys = torch.cat([ys, torch.ones(1, 1).fill_(next_token_id).long().to(self.device)], dim=1)
generated_tokens.append(next_token_id)
# Stop if the end symbol is generated
if self.eos_token is not None and next_token_id == self.eos_token:
break
return ys
class Beam_search(nn.Module):
"""Beam search Decoder"""
def __init__(self, model, trg_pad_idx, trg_bos_idx, trg_eos_idx, device,
d_model=256, beam_size=3, max_length=98, alpha=0.7):
super(Beam_search, self).__init__()
self.model = model.to(device)
self.trg_pad_idx = trg_pad_idx
self.trg_bos_idx = trg_bos_idx
self.trg_eos_idx = trg_eos_idx
self.device = device
self.d_model = d_model
self.beam_size = beam_size
self.max_length = max_length
self.alpha = alpha
# Initialize the sequence
self.register_buffer('init_seq', torch.LongTensor([[trg_bos_idx]]).to(device))
self.register_buffer(
'blank_seqs',
torch.full((beam_size, max_length), trg_pad_idx, dtype=torch.long).to(device)
)
self.blank_seqs[:, 0] = self.trg_bos_idx
self.register_buffer(
'len_map',
torch.arange(1, max_length + 1, dtype=torch.long).unsqueeze(0).to(device)
)
def frequency_penalty(self, frequency, alpha):
"""Frequency penalty"""
if frequency > 1:
penalty = alpha * frequency
else:
penalty = 0
return penalty
def _model_decoder(self, model, ys, memory, d_model):
"""Model decoding"""
ys_pad_mask = ~(get_key_padding_mask(ys, self.trg_pad_idx) != float('-inf')).to(self.device)
ys_mask = ~(get_sequencing_mask(ys, d_model) != float('-inf')).to(self.device)
ys = model.tgt_pos(model.tgt_embedding(ys))
decoder_output = model.transformer.decoder(
tgt=ys, tgt_mask=ys_mask, tgt_key_padding_mask=ys_pad_mask, memory=memory
)
out = model.generator(decoder_output)
return out
def get_init_state(self, beam_size, model, ys, memory):
"""Get the initial state"""
out = self._model_decoder(model=model, ys=ys, memory=memory, d_model=self.d_model)
best_k_probs, best_k_idx = out[:, -1, :].topk(beam_size)
scores = best_k_probs.view(beam_size)
gen_seq = self.blank_seqs.clone().detach()
gen_seq[:, 1] = best_k_idx[0]
return gen_seq, scores
def get_best_score_and_idx(self, beam_size, dec_output, scores, gen_seq, step):
"""Get the best score and index"""
dec_output = dec_output[:, -1, :]
for beam in range(dec_output.size(0)):
count = Counter(gen_seq[:, :step][beam].tolist())
for com in count:
penalty = self.frequency_penalty(count[com], 0.7)
dec_output[beam][com] -= penalty
best_k2_probs, best_k2_idx = dec_output.topk(beam_size)
scores = best_k2_probs.view(beam_size, -1) + scores.view(beam_size, 1)
scores, best_k_idx_in_k2 = scores.view(-1).topk(beam_size)
best_k_r_idxs = torch.div(best_k_idx_in_k2, beam_size, rounding_mode='trunc')
best_k_c_idxs = best_k_idx_in_k2 % beam_size
best_k_idx = best_k2_idx[best_k_r_idxs, best_k_c_idxs]
gen_seq[:, :step] = gen_seq[best_k_r_idxs, :step]
gen_seq[:, step] = best_k_idx
return gen_seq, scores
def beam_search(self, src, src_pad_mask):
"""Perform a beam search"""
self.model.eval()
trg_bos_idx, trg_eos_idx = self.trg_bos_idx, self.trg_eos_idx
max_seq_len, beam_size, alpha = self.max_length, self.beam_size, self.alpha
with torch.no_grad():
# calculate memory
src = src.to(self.device)
src_pad_mask = src_pad_mask.to(self.device)
# Convert the dimension of the src sequence
src_transformed = self.model.src_linear1(src)
src_transformed = self.model.src_activation(src_transformed)
src_transformed = self.model.src_dropout(src_transformed)
src_transformed = self.model.src_linear2(src_transformed)
memory = self.model.transformer.encoder(src=src_transformed, src_key_padding_mask=src_pad_mask)
# Starting character calculation
ys = torch.ones(1, 1).fill_(trg_bos_idx).type_as(src.data).long()
gen_seq, scores = self.get_init_state(beam_size=beam_size, model=self.model, ys=ys, memory=memory)
# Loop calculation
src_expanded = src.repeat(beam_size, 1, 1)
src_expanded_pad_mask = src_pad_mask.repeat(beam_size, 1)
src_expanded_pad_mask = src_expanded_pad_mask.to(self.device)
src_expanded_transformed = self.model.src_linear1(src_expanded)
src_expanded_transformed = self.model.src_activation(src_expanded_transformed)
src_expanded_transformed = self.model.src_dropout(src_expanded_transformed)
src_expanded_transformed = self.model.src_linear2(src_expanded_transformed)
memory_expanded = self.model.transformer.encoder(
src=src_expanded_transformed, src_key_padding_mask=src_expanded_pad_mask
)
for step in range(2, max_seq_len):
dec_output = self._model_decoder(
model=self.model, ys=gen_seq[:, :step], memory=memory_expanded, d_model=self.d_model
)
gen_seq, scores = self.get_best_score_and_idx(
beam_size=beam_size, dec_output=dec_output, scores=scores, step=step, gen_seq=gen_seq
)
# Determine whether the end character is encountered
eos_locs = gen_seq == trg_eos_idx
seq_lens, _ = self.len_map.masked_fill(~eos_locs, max_seq_len).min(1)
if (eos_locs.sum(1) > 0).sum(0).item() == beam_size:
_, ans_idx = scores.div(seq_lens.float() ** alpha).max(0)
ans_idx = ans_idx.item()
break
return gen_seq[0][:seq_lens[ans_idx]].tolist()
def keys_valus(tgt_dict, key):
"""Search for a key by value"""
value = [k for k, v in tgt_dict.items() if v == key]
return value
def medium(b, tgt_dict):
"""Convert a tensor to a compound name"""
l = []
for i in b:
tensor = i
l.extend(keys_valus(tgt_dict, tensor))
l = [item for item in l if item not in ['</s>', '<s>', 'blank']]
return l
def load_model(model_path, device, model_config, compound_cab):
"""Load model"""
print("=== Load model ===")
model = Model(
heads=model_config['heads'],
d_model=model_config['d_model'],
tgt_vocab_size=len(compound_cab),
num_encoder_layer=model_config['num_encoder_layer'],
num_decoder_layer=model_config['num_decoder_layer'],
dropout=model_config['dropout'],
compound_cab=compound_cab,
src_input_dim=model_config['src_input_dim'],
first_linear_dim=model_config['first_linear_dim']
).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
print(f"Model loaded: {model_path}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")
return model
def evaluate_model(model, test_loader, compound_cab, device, model_config,
top_k_values, top_p_values, beam_size_values, decode_methods, output_path):
"""Evaluate the model and generate results with different decoding strategies, return combined final_df."""
print("=== Start decoding ===")
combined_df = pd.DataFrame()
# 处理 beam search
if "beam_search" in decode_methods:
for beam_size in beam_size_values:
results = {'genome': []}
column_name = f"beam_{beam_size}"
results[column_name] = []
beamsearch = Beam_search(
model=model,
trg_pad_idx=compound_cab['blank'],
trg_bos_idx=compound_cab['</s>'],
trg_eos_idx=compound_cab['<s>'],
device=device,
d_model=model_config['d_model'],
beam_size=beam_size,
max_length=model_config['compound_max_len'],
alpha=0.7
)
for b in tqdm(test_loader, desc=f"Beam {beam_size} decoding"):
results['genome'].append(b[2][0])
src_pad_mask = b[1]
generate = beamsearch.beam_search(b[0], src_pad_mask)
pred_medium = medium(generate, compound_cab)
results[column_name].append(pred_medium)
beam_df = pd.DataFrame.from_dict(results, orient="index").transpose()
if combined_df.empty:
combined_df = beam_df
else:
combined_df = pd.merge(combined_df, beam_df, on='genome', how='outer')
# output_file_name = os.path.join(output_path, f"beam_{beam_size}.csv")
# beam_df.to_csv(output_file_name, index=False)
torch.cuda.empty_cache()
# 处理 top-k/top-p sampling
if "top_k_top_p" in decode_methods:
for top_k in top_k_values:
for top_p in top_p_values:
results = {'genome': []}
column_name = f"top_{top_k}_top_{top_p}"
results[column_name] = []
generator = Generator_topk(model_config['d_model'], len(compound_cab)).to(device)
topKtopP = TopKtopPGenerator(
model=model,
d_model=model_config['d_model'],
generator=generator,
device=device,
tgt_mask_index=compound_cab['blank'],
repetition_penalty=1.3,
start_symbol=compound_cab['</s>'],
eos_token=compound_cab['<s>']
)
for b in tqdm(test_loader, desc=f"Top-k {top_k} Top-p {top_p} decoding"):
results['genome'].append(b[2][0]) # 保留genome信息与beam search对齐
src_pad_mask = b[1]
generate = topKtopP.generate(b[0], src_pad_mask)
pred_medium = medium(generate, compound_cab)
results[column_name].append(pred_medium)
topk_df = pd.DataFrame.from_dict(results, orient="index").transpose()
if combined_df.empty:
combined_df = topk_df
else:
combined_df = pd.merge(combined_df, topk_df, on='genome', how='outer')
# 保存当前top-k/top-p的结果
# output_file_name = os.path.join(output_path, f"top_{top_k}_top_{top_p}.csv")
# topk_df.to_csv(output_file_name, index=False)
torch.cuda.empty_cache()
return combined_df
def main():
parser = argparse.ArgumentParser(description='ProtT5 Encoder-Decoder model decoder')
# Data path parameters
parser.add_argument('--data_pkl_path', type=str, required=True, help='data.pkl file path')
parser.add_argument('--model_path', type=str, required=True, help='Model weight file path')
parser.add_argument('--test_data_path', type=str, required=True, help='Test data CSV file path')
parser.add_argument('--output_path', type=str, required=True, help='Output result CSV file dir')
parser.add_argument('--test_genome_pkl_path', type=str, required=True, help='test_genome_dict.pkl file path')
# Model parameters
parser.add_argument('--heads', type=int, default=8, help='Number of attention heads')
parser.add_argument('--d_model', type=int, default=64, help='Model Dimensions')
parser.add_argument('--src_input_dim', type=int, default=1024, help='Input Dimension')
parser.add_argument('--first_linear_dim', type=int, default=256, help='The first linear dimension')
parser.add_argument('--num_encoder_layer', type=int, default=3, help='Number of encoder layers')
parser.add_argument('--num_decoder_layer', type=int, default=3, help='Number of decoder layers')
parser.add_argument('--dropout', type=float, default=0.1, help='Dropout')
parser.add_argument('--compound_max_len', type=int, default=98, help='Maximum length of nutrient sequence')
# Decoder parameters
parser.add_argument('--top_k', nargs='+', type=int, default=[8, 9, 10], help='Top-K value list')
parser.add_argument('--top_p', nargs='+', type=float, default=[0.9], help='Top-P value list')
parser.add_argument('--beam_size', nargs='+', type=int, default=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], help='Beam size value list')
parser.add_argument('--decode_methods', nargs='+', choices=['top_k_top_p', 'beam_search'],
default=['top_k_top_p', 'beam_search'], help='Decoding method')
# Other parameters
parser.add_argument('--device', type=str, default='cuda:0', help='device')
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
args = parser.parse_args()
# Setting the device
device = torch.device(args.device)
print(f"Using the device: {device}")
# Load data
print("=== Load data ===")
# Creating a Data Loader
print(f"Number of test set batches: {len(test_loader)}")
# genome dict
with open(args.test_genome_pkl_path, 'rb') as t:
genome_cab = pickle.load(t)
# Model Configuration
model_config = {
'heads': args.heads,
'd_model': args.d_model,
'src_input_dim': args.src_input_dim,
'first_linear_dim': args.first_linear_dim,
'num_encoder_layer': args.num_encoder_layer,
'num_decoder_layer': args.num_decoder_layer,
'dropout': args.dropout,
'compound_max_len': args.compound_max_len
}
# Load model
model = load_model(args.model_path, device, model_config, compound_cab)
# Generate
combined_df = evaluate_model(
model, test_loader, compound_cab, genome_cab, device, model_config,
args.top_k, args.top_p, args.beam_size, args.decode_methods, args.output_path
)
print("Decoding complete!")
if __name__ == "__main__":
main()