#!/usr/bin/env python3 # coding: utf-8 """ ProtT5 Encoder-Decoder decoding Support beam search and top-k/top-p sampling decoding methods """ import argparse import pickle import torch import pandas as pd import numpy as np from tqdm import tqdm from collections import Counter import ast from torch.cuda.amp import autocast import torch.nn.functional as F import torch.nn as nn from model import Model from datasets import DataProcessor, build_single_lazy_dataloader import os def get_key_padding_mask(tokens, pad): """generate key padding mask""" key_padding_mask = torch.zeros(tokens.size()) key_padding_mask[tokens == pad] = -float('inf') return key_padding_mask def get_sequencing_mask(tgt, d_model): """generate sequence mask""" sequencing_mask = nn.Transformer(d_model=d_model, batch_first=True).generate_square_subsequent_mask(tgt.size(-1)) return sequencing_mask def top_k_top_p_filtering(probabilities, top_k=0, top_p=1.0, min_tokens_to_keep=1): """ Filters the probability distribution using top-k and top-p. Args: probabilities: Probability distribution tensor top_k: Keeps the top k tokens with the highest probability top_p: Keeps tokens with cumulative probability reaching top_p min_tokens_to_keep: Ensures that at least one token is kept Returns: Filtered probability distribution """ # Top-K filter if top_k > 0: top_k = min(top_k, probabilities.size(-1)) values, indices = torch.topk(probabilities, top_k) min_threshold = values[..., -1, None] probabilities[probabilities < min_threshold] = 0 # Top-P (nucleus) filter if top_p < 1.0: sorted_probs, sorted_indices = torch.sort(probabilities, descending=True) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] probabilities[indices_to_remove] = 0 return probabilities class Generator_topk(nn.Module): def __init__(self, d_model, vocab): super(Generator_topk, self).__init__() self.proj = nn.Linear(d_model, vocab) def forward(self, x): return F.softmax(self.proj(x), dim=-1) class TopKtopPGenerator: """Top-K and Top-P sampling generator""" def __init__(self, model, d_model, generator, device, tgt_mask_index, repetition_penalty=1.3, start_symbol=None, max_length=98, top_k=10, top_p=0.9, eos_token=None): self.model = model self.d_model = d_model self.generator = generator self.device = device self.tgt_mask_index = tgt_mask_index self.repetition_penalty = repetition_penalty self.start_symbol = start_symbol self.max_length = max_length self.top_k = top_k self.top_p = top_p self.eos_token = eos_token def _model_decoder(self, ys, memory, d_model): """Internal decoding method""" ys_pad_mask = ~(get_key_padding_mask(ys, self.tgt_mask_index) != float('-inf')).to(self.device) ys_mask = ~(get_sequencing_mask(ys, d_model) != float('-inf')).to(self.device) ys = self.model.tgt_pos(self.model.tgt_embedding(ys)) decoder_output = self.model.transformer.decoder( tgt=ys, tgt_mask=ys_mask, tgt_key_padding_mask=ys_pad_mask, memory=memory ) out = self.generator(decoder_output) return out def generate(self, src, src_pad_mask): """Generate text using Top-K and Top-P sampling""" self.model.eval() with torch.no_grad(): src = src.to(self.device) src_pad_mask = src_pad_mask.to(self.device) # Convert the dimension of the src sequence src_transformed = self.model.src_linear1(src) src_transformed = self.model.src_activation(src_transformed) src_transformed = self.model.src_dropout(src_transformed) src_transformed = self.model.src_linear2(src_transformed) memory = self.model.transformer.encoder(src=src_transformed, src_key_padding_mask=src_pad_mask) # Start from the start symbol ys = torch.ones(1, 1).fill_(self.start_symbol).long().to(self.device) generated_tokens = [] for _ in range(self.max_length): prob = self._model_decoder(ys, memory, self.d_model)[:, -1, :] # Apply duplicate penalties for token in generated_tokens: prob[0, token] /= self.repetition_penalty # Apply Top-K and Top-P filtering filtered_probs = top_k_top_p_filtering(prob.squeeze(0), top_k=self.top_k, top_p=self.top_p) # Sample from the filtered distribution next_token_id = torch.multinomial(filtered_probs, num_samples=1).item() # Add the sampled token to the sequence ys = torch.cat([ys, torch.ones(1, 1).fill_(next_token_id).long().to(self.device)], dim=1) generated_tokens.append(next_token_id) # Stop if the end symbol is generated if self.eos_token is not None and next_token_id == self.eos_token: break return ys class Beam_search(nn.Module): """Beam search Decoder""" def __init__(self, model, trg_pad_idx, trg_bos_idx, trg_eos_idx, device, d_model=256, beam_size=3, max_length=98, alpha=0.7): super(Beam_search, self).__init__() self.model = model.to(device) self.trg_pad_idx = trg_pad_idx self.trg_bos_idx = trg_bos_idx self.trg_eos_idx = trg_eos_idx self.device = device self.d_model = d_model self.beam_size = beam_size self.max_length = max_length self.alpha = alpha # Initialize the sequence self.register_buffer('init_seq', torch.LongTensor([[trg_bos_idx]]).to(device)) self.register_buffer( 'blank_seqs', torch.full((beam_size, max_length), trg_pad_idx, dtype=torch.long).to(device) ) self.blank_seqs[:, 0] = self.trg_bos_idx self.register_buffer( 'len_map', torch.arange(1, max_length + 1, dtype=torch.long).unsqueeze(0).to(device) ) def frequency_penalty(self, frequency, alpha): """Frequency penalty""" if frequency > 1: penalty = alpha * frequency else: penalty = 0 return penalty def _model_decoder(self, model, ys, memory, d_model): """Model decoding""" ys_pad_mask = ~(get_key_padding_mask(ys, self.trg_pad_idx) != float('-inf')).to(self.device) ys_mask = ~(get_sequencing_mask(ys, d_model) != float('-inf')).to(self.device) ys = model.tgt_pos(model.tgt_embedding(ys)) decoder_output = model.transformer.decoder( tgt=ys, tgt_mask=ys_mask, tgt_key_padding_mask=ys_pad_mask, memory=memory ) out = model.generator(decoder_output) return out def get_init_state(self, beam_size, model, ys, memory): """Get the initial state""" out = self._model_decoder(model=model, ys=ys, memory=memory, d_model=self.d_model) best_k_probs, best_k_idx = out[:, -1, :].topk(beam_size) scores = best_k_probs.view(beam_size) gen_seq = self.blank_seqs.clone().detach() gen_seq[:, 1] = best_k_idx[0] return gen_seq, scores def get_best_score_and_idx(self, beam_size, dec_output, scores, gen_seq, step): """Get the best score and index""" dec_output = dec_output[:, -1, :] for beam in range(dec_output.size(0)): count = Counter(gen_seq[:, :step][beam].tolist()) for com in count: penalty = self.frequency_penalty(count[com], 0.7) dec_output[beam][com] -= penalty best_k2_probs, best_k2_idx = dec_output.topk(beam_size) scores = best_k2_probs.view(beam_size, -1) + scores.view(beam_size, 1) scores, best_k_idx_in_k2 = scores.view(-1).topk(beam_size) best_k_r_idxs = torch.div(best_k_idx_in_k2, beam_size, rounding_mode='trunc') best_k_c_idxs = best_k_idx_in_k2 % beam_size best_k_idx = best_k2_idx[best_k_r_idxs, best_k_c_idxs] gen_seq[:, :step] = gen_seq[best_k_r_idxs, :step] gen_seq[:, step] = best_k_idx return gen_seq, scores def beam_search(self, src, src_pad_mask): """Perform a beam search""" self.model.eval() trg_bos_idx, trg_eos_idx = self.trg_bos_idx, self.trg_eos_idx max_seq_len, beam_size, alpha = self.max_length, self.beam_size, self.alpha with torch.no_grad(): # calculate memory src = src.to(self.device) src_pad_mask = src_pad_mask.to(self.device) # Convert the dimension of the src sequence src_transformed = self.model.src_linear1(src) src_transformed = self.model.src_activation(src_transformed) src_transformed = self.model.src_dropout(src_transformed) src_transformed = self.model.src_linear2(src_transformed) memory = self.model.transformer.encoder(src=src_transformed, src_key_padding_mask=src_pad_mask) # Starting character calculation ys = torch.ones(1, 1).fill_(trg_bos_idx).type_as(src.data).long() gen_seq, scores = self.get_init_state(beam_size=beam_size, model=self.model, ys=ys, memory=memory) # Loop calculation src_expanded = src.repeat(beam_size, 1, 1) src_expanded_pad_mask = src_pad_mask.repeat(beam_size, 1) src_expanded_pad_mask = src_expanded_pad_mask.to(self.device) src_expanded_transformed = self.model.src_linear1(src_expanded) src_expanded_transformed = self.model.src_activation(src_expanded_transformed) src_expanded_transformed = self.model.src_dropout(src_expanded_transformed) src_expanded_transformed = self.model.src_linear2(src_expanded_transformed) memory_expanded = self.model.transformer.encoder( src=src_expanded_transformed, src_key_padding_mask=src_expanded_pad_mask ) for step in range(2, max_seq_len): dec_output = self._model_decoder( model=self.model, ys=gen_seq[:, :step], memory=memory_expanded, d_model=self.d_model ) gen_seq, scores = self.get_best_score_and_idx( beam_size=beam_size, dec_output=dec_output, scores=scores, step=step, gen_seq=gen_seq ) # Determine whether the end character is encountered eos_locs = gen_seq == trg_eos_idx seq_lens, _ = self.len_map.masked_fill(~eos_locs, max_seq_len).min(1) if (eos_locs.sum(1) > 0).sum(0).item() == beam_size: _, ans_idx = scores.div(seq_lens.float() ** alpha).max(0) ans_idx = ans_idx.item() break return gen_seq[0][:seq_lens[ans_idx]].tolist() def keys_valus(tgt_dict, key): """Search for a key by value""" value = [k for k, v in tgt_dict.items() if v == key] return value def medium(b, tgt_dict): """Convert a tensor to a compound name""" l = [] for i in b: tensor = i l.extend(keys_valus(tgt_dict, tensor)) l = [item for item in l if item not in ['', '', 'blank']] return l def load_model(model_path, device, model_config, compound_cab): """Load model""" print("=== Load model ===") model = Model( heads=model_config['heads'], d_model=model_config['d_model'], tgt_vocab_size=len(compound_cab), num_encoder_layer=model_config['num_encoder_layer'], num_decoder_layer=model_config['num_decoder_layer'], dropout=model_config['dropout'], compound_cab=compound_cab, src_input_dim=model_config['src_input_dim'], first_linear_dim=model_config['first_linear_dim'] ).to(device) model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() print(f"Model loaded: {model_path}") print(f"Model parameters: {sum(p.numel() for p in model.parameters())}") return model def evaluate_model(model, test_loader, compound_cab, genome_cab, device, model_config, top_k_values, top_p_values, beam_size_values, decode_methods, output_path): """Evaluate the model and generate results with different decoding strategies.""" print("=== Start decoding ===") # First loop: iterate over different beam sizes if "beam_search" in decode_methods: for beam_size in beam_size_values: results = {} results['real'] = [] results['genome'] = [] column_name = f"beam_{beam_size}" results[column_name] = [] # Initialize beam search generator beamsearch = Beam_search( model=model, trg_pad_idx=compound_cab['blank'], trg_bos_idx=compound_cab[''], trg_eos_idx=compound_cab[''], device=device, d_model=model_config['d_model'], beam_size=beam_size, max_length=model_config['compound_max_len'], alpha=0.7 # Length normalization coefficient ) # Iterate over the entire dataloader for this beam size for b in tqdm(test_loader, desc=f"Beam {beam_size} decoding"): # genome info genome = medium(b[3], genome_cab) results['genome'].append(genome) # Real culture medium zhenshi_medium = medium(b[1][0], compound_cab) results['real'].append(zhenshi_medium) # Get source padding mask (assuming b[2] contains it) src_pad_mask = b[2] # Generate predictions using beam search generate = beamsearch.beam_search(b[0], src_pad_mask) # Convert predictions to readable format (e.g., token IDs to strings) pred_medium = medium(generate, compound_cab) results[column_name].append(pred_medium) final_df = pd.DataFrame.from_dict(results, orient="index").transpose() output_file_name = os.path.join(output_path, f"beam_{beam_size}.csv") final_df.to_csv(output_file_name, index=False) torch.cuda.empty_cache() # Second loop: handle top-k and top-p sampling (if enabled) if "top_k_top_p" in decode_methods: # Note: Current implementation processes all top_k/top_p combinations # in a single pass. If you want separate dataloader passes for each # combination (like beam search), restructure this part similarly. for top_k in top_k_values: for top_p in top_p_values: results = {} results['real'] = [] results['genome'] = [] column_name = f"top_{top_k}_top_{top_p}" results[column_name] = [] # Initialize top-k top-p generator generator = Generator_topk(model_config['d_model'], len(compound_cab)).to(device) topKtopP = TopKtopPGenerator( model=model, d_model=model_config['d_model'], generator=generator, device=device, tgt_mask_index=compound_cab['blank'], # Target padding index repetition_penalty=1.3, # Penalty for repeated tokens start_symbol=compound_cab[''], # Start token eos_token=compound_cab[''] # End token ) # Iterate over dataloader for this top-k/top-p combination for b in tqdm(test_loader, desc=f"Top-k {top_k} Top-p {top_p} decoding"): # genome info genome = medium(b[3], genome_cab) results['genome'].append(genome) # Real culture medium zhenshi_medium = medium(b[1][0], compound_cab) results['real'].append(zhenshi_medium) src_pad_mask = b[2] generate = topKtopP.generate(b[0], src_pad_mask) pred_medium = medium(generate, compound_cab) results[column_name].append(pred_medium) # save final_df = pd.DataFrame.from_dict(results, orient="index").transpose() output_file_name = os.path.join(output_path, f"top_{top_k}_top_{top_p}.csv") final_df.to_csv(output_file_name, index=False) torch.cuda.empty_cache() def main(): parser = argparse.ArgumentParser(description='ProtT5 Encoder-Decoder model decoder') # Data path parameters parser.add_argument('--data_pkl_path', type=str, required=True, help='data.pkl file path') parser.add_argument('--model_path', type=str, required=True, help='Model weight file path') parser.add_argument('--test_data_path', type=str, required=True, help='Test data CSV file path') parser.add_argument('--output_path', type=str, required=True, help='Output result CSV file dir') parser.add_argument('--test_genome_pkl_path', type=str, required=True, help='test_genome_dict.pkl file path') # Model parameters parser.add_argument('--heads', type=int, default=8, help='Number of attention heads') parser.add_argument('--d_model', type=int, default=64, help='Model Dimensions') parser.add_argument('--src_input_dim', type=int, default=1024, help='Input Dimension') parser.add_argument('--first_linear_dim', type=int, default=256, help='The first linear dimension') parser.add_argument('--num_encoder_layer', type=int, default=3, help='Number of encoder layers') parser.add_argument('--num_decoder_layer', type=int, default=3, help='Number of decoder layers') parser.add_argument('--dropout', type=float, default=0.1, help='Dropout') parser.add_argument('--compound_max_len', type=int, default=98, help='Maximum length of nutrient sequence') # Decoder parameters parser.add_argument('--top_k', nargs='+', type=int, default=[8, 9, 10], help='Top-K value list') parser.add_argument('--top_p', nargs='+', type=float, default=[0.9], help='Top-P value list') parser.add_argument('--beam_size', nargs='+', type=int, default=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], help='Beam size value list') parser.add_argument('--decode_methods', nargs='+', choices=['top_k_top_p', 'beam_search'], default=['top_k_top_p', 'beam_search'], help='Decoding method') # Other parameters parser.add_argument('--device', type=str, default='cuda:0', help='device') parser.add_argument('--batch_size', type=int, default=1, help='batch size') args = parser.parse_args() # Setting the device device = torch.device(args.device) print(f"Using the device: {device}") # Load data print("=== Load data ===") with open(args.data_pkl_path, 'rb') as f: loaded_data = pickle.load(f) compound_cab = loaded_data['compound_cab'] dp_test = DataProcessor( ko_count_path='/home/zzhang/gzy/Uncultured/generative_ML/data/gene_ko_protein/KO_count.csv', data_path=args.test_data_path, embedding_h5_path='/home/zzhang/gzy/Uncultured/ko_pre_train/ProtT5_Encoder/all_genome_remove_nan_gene_ProtT5_embeddings.h5', vo_cab_pkl_path=args.data_pkl_path, model_type='test', ko_max_len=4500, compound_max_len=args.compound_max_len ) dp_test.load_data(count_min=800, count_max=4500) # Creating a Data Loader test_loader = build_single_lazy_dataloader(dp_test, batch_size=args.batch_size, shuffle=False) print(f"Number of test set batches: {len(test_loader)}") # genome dict with open(args.test_genome_pkl_path, 'rb') as t: genome_cab = pickle.load(t) # Model Configuration model_config = { 'heads': args.heads, 'd_model': args.d_model, 'src_input_dim': args.src_input_dim, 'first_linear_dim': args.first_linear_dim, 'num_encoder_layer': args.num_encoder_layer, 'num_decoder_layer': args.num_decoder_layer, 'dropout': args.dropout, 'compound_max_len': args.compound_max_len } # Load model model = load_model(args.model_path, device, model_config, compound_cab) # Generate evaluate_model( model, test_loader, compound_cab, genome_cab, device, model_config, args.top_k, args.top_p, args.beam_size, args.decode_methods, args.output_path ) print("Decoding complete!") if __name__ == "__main__": main()