import os import argparse import pandas as pd import numpy as np import torch import pickle import time from Bio import SeqIO from transformers import T5EncoderModel, T5Tokenizer from torch.utils.data import Dataset, DataLoader import logging from transformers import logging as hf_logging import warnings warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.modules.transformer") hf_logging.set_verbosity_error() # Assume these dependencies are properly implemented # Adjust imports according to actual file structure from model import Model from dmodel_256_decode_model import evaluate_model from datasets import ko_embedding_padding # Assumes this function exists class InferenceDataset(Dataset): """Dataset class for handling unseen data during inference""" def __init__(self, src, src_mask, genome): self.src = src # Input features (embeddings) self.src_mask = src_mask # Corresponding mask tensor self.genome = genome def __getitem__(self, index): """Retrieve sample by index""" return self.src[index], self.src_mask[index], self.genome[index] def __len__(self): """Return total number of samples""" return len(self.src) def build_inference_dataloader(ko_token, ko_mask, genome, batch_size=1, shuffle=False): """ Construct DataLoader for inference Args: ko_token: Input feature tensors ko_mask: Corresponding mask tensors batch_size: Batch size for inference shuffle: Whether to shuffle data (typically False for inference) Returns: DataLoader instance """ dataset = InferenceDataset(src=ko_token, src_mask=ko_mask, genome=genome) return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) def process_annotation_and_sequence(annotation_path, fasta_path, annotation_type, evalue_threshold=1e-5): """ Process annotations (eggNOG-mapper or kofamscan) and protein sequences, then merge them Args: annotation_path: Path to annotation file (eggNOG-mapper or kofamscan) fasta_path: Path to protein FASTA file annotation_type: Type of annotation file, must be 'emmapper' or 'kofamscan' evalue_threshold: E-value cutoff for filtering eggNOG-mapper annotations (only used if annotation_type is 'emmapper') Returns: DataFrame with merged gene-KO-sequence information """ # Validate annotation type if annotation_type not in ['emmapper', 'kofamscan']: raise ValueError("annotation_type must be either 'emmapper' or 'kofamscan'") # Process annotation data based on type if annotation_type == 'emmapper': annotation_df = ( pd.read_csv(annotation_path, sep='\t', comment='#', header=None, usecols=[0, 2, 11]) .rename(columns={0: 'gene', 2: 'evalue', 11: 'KO'}) .query(f'evalue < {evalue_threshold} & KO != "-"') # Filter low-quality annotations .assign(KO=lambda x: x['KO'].str.split(',')) # Split multiple KOs .explode('KO') # Expand KO list to rows .assign(KO=lambda x: x['KO'].str.split('ko:').str[1]) # Extract KO number .drop_duplicates(subset='gene', keep='first') # Keep first KO per gene .sort_values('gene') # Sort by gene ID .drop_duplicates(subset='KO', keep='first') # Keep first gene per KO .drop(columns='evalue') # Remove evalue column ) elif annotation_type == 'kofamscan': annotation_df = pd.read_csv(annotation_path, sep='\t', header=None, names=['gene', 'KO']) annotation_df = ( annotation_df[~annotation_df['KO'].isna()] # Remove rows with missing KO .drop_duplicates(subset='gene', keep='first') # Keep first KO per gene .sort_values('KO') # Sort by KO .drop_duplicates(subset='KO', keep='first') # Keep first gene per KO ) # Process FASTA sequences seq_df = pd.DataFrame( [(rec.id, str(rec.seq)) for rec in SeqIO.parse(fasta_path, "fasta")], columns=['gene', 'sequence'] ) # Merge annotations with sequences return annotation_df.merge(seq_df, on='gene', how='inner') def get_T5_model(local_model_dir, device): """ Load T5 encoder model and tokenizer Args: local_model_dir: Directory containing T5 model files device: Computation device (GPU/CPU) Returns: Tuple of (T5 model, tokenizer) """ tokenizer = T5Tokenizer.from_pretrained(local_model_dir, do_lower_case=False) model = T5EncoderModel.from_pretrained(local_model_dir) model = model.to(device).eval() # Move to device and set to evaluation mode return model, tokenizer def get_embeddings(model, tokenizer, seqs, device, per_residue=False, per_protein=True, sec_struct=False, max_residues=5000, max_seq_len=1000, max_batch=1): """ Generate protein embeddings using T5 encoder Args: model: T5 encoder model tokenizer: T5 tokenizer seqs: Dictionary of {gene_id: sequence} device: Computation device per_residue: Whether to generate per-residue embeddings per_protein: Whether to generate per-protein embeddings sec_struct: Whether to predict secondary structure max_residues: Maximum total residues per batch max_seq_len: Maximum length of single sequence max_batch: Maximum number of sequences per batch Returns: Tuple of (embedding results, invalid sequences) """ results = {"residue_embs": dict(), "protein_embs": dict(), "sec_structs": dict()} # Filter invalid sequences (e.g., NaN values) invalid_items = {k: v for k, v in seqs.items() if isinstance(v, float)} seq_dict = [(k, v) for k, v in seqs.items() if not isinstance(v, float)] start = time.time() batch = [] for seq_idx, (pdb_id, seq) in enumerate(seq_dict, 1): seq_len = len(seq) seq = ' '.join(list(seq)) # Format sequence for T5 tokenizer batch.append((pdb_id, seq, seq_len)) # Check if batch needs processing n_res_batch = sum([s_len for _, _, s_len in batch]) + seq_len if len(batch) >= max_batch or n_res_batch >= max_residues or \ seq_idx == len(seq_dict) or seq_len > max_seq_len: pdb_ids, seqs_batch, seq_lens = zip(*batch) batch = [] # Tokenize sequences token_encoding = tokenizer.batch_encode_plus( seqs_batch, add_special_tokens=True, padding="longest" ) input_ids = torch.tensor(token_encoding['input_ids']).to(device) attention_mask = torch.tensor(token_encoding['attention_mask']).to(device) try: with torch.no_grad(): # Disable gradient calculation for inference embedding_repr = model(input_ids, attention_mask=attention_mask) except RuntimeError: print(f"RuntimeError during embedding for {pdb_id} (length={seq_len})") continue # Process embeddings for each sequence in batch for batch_idx, identifier in enumerate(pdb_ids): s_len = seq_lens[batch_idx] emb = embedding_repr.last_hidden_state[batch_idx, :s_len] if per_protein: # Average pooling for per-protein embedding protein_emb = emb.mean(dim=0) results["protein_embs"][identifier] = protein_emb.detach().cpu().numpy().squeeze() del input_ids, attention_mask, embedding_repr torch.cuda.empty_cache() # Calculate timing statistics torch.cuda.empty_cache() passed_time = time.time() - start avg_time = passed_time / len(results["protein_embs"]) if results["protein_embs"] else 0 print(f'Total per-protein embeddings generated: {len(results["protein_embs"])}') print(f"Embedding generation time: {passed_time/60:.1f}m ({avg_time:.3f}s/protein)") return results, invalid_items def main(): # Parse command line arguments parser = argparse.ArgumentParser(description='Run model inference on annotation and sequence data') parser.add_argument('--annotation_dir', required=True, help='Directory containing annotation files') parser.add_argument('--startswith', required=True, default='', help='The file starts with [content]') parser.add_argument('--endswith', required=True, default='.emapper.annotations', help='The file end with [content]') parser.add_argument('--fasta_path', required=True, help='Path to the protein FASTA file') parser.add_argument('--data_pkl', required=True, default='./data_pkl', help='Path to the data.pkl file containing metadata') parser.add_argument('--model_path', required=True, help='Path to the trained model (.pt file)') parser.add_argument('--output_dir', required=True, help='Output directory for inference results') parser.add_argument('--t5_model_dir', default='./scripts/model_para/', help='Directory containing T5 model parameters (default: specified path)') parser.add_argument('--annotation_type', required=True, choices=['emmapper', 'kofamscan'], help='Type of annotation file: "emmapper" or "kofamscan"') args = parser.parse_args() # Set up computation device device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') print(f"Using computation device: {device}") # Load metadata from data.pkl with open(args.data_pkl, 'rb') as f: loaded_data = pickle.load(f) compound_cab = loaded_data['compound_cab'] compound_max = loaded_data['compound_max'] # Initialize and load the main model model = Model(heads=8, d_model=256, num_encoder_layer=3, num_decoder_layer=3, dropout=0.1, src_input_dim=1024, first_linear_dim=512, tgt_vocab_size=len(compound_cab), compound_cab=compound_cab ).to(device) model.load_state_dict(torch.load(args.model_path, map_location=device)) model.eval() # Set to evaluation mode print(f"Successfully loaded model from: {args.model_path}") # Load T5 encoder model and tokenizer t5_encoder, t5_tokenizer = get_T5_model(args.t5_model_dir, device) # Find all valid annotation files in the specified directory annotation_files = [] genomes = [] for f in os.listdir(args.annotation_dir): start_match = (not args.startswith) or f.startswith(args.startswith) end_match = (not args.endswith) or f.endswith(args.endswith) if start_match and end_match: annotation_files.append(os.path.join(args.annotation_dir, f)) if args.startswith: after_start = f.split(args.startswith)[1] else: after_start = f if args.endswith: genome_name = after_start.split(args.endswith)[0] else: genome_name = after_start genomes.append(genome_name) if not annotation_files: print(f"No valid annotation files found in {args.annotation_dir}. " f"Files must start with '{args.startswith}' and end with '{args.endswith}'.") return # Process each annotation file combined_results = pd.DataFrame() for idx, annotation_path in enumerate(annotation_files): genome_name = genomes[idx] protein_path = os.path.join(args.fasta_path, f"{genome_name}.faa") if not os.path.exists(protein_path): print(f"Error: Protein file not found - {protein_path}. Skipping this genome.") continue # Process annotation and sequence data try: geneKoSeq = process_annotation_and_sequence( annotation_path, protein_path, annotation_type=args.annotation_type, evalue_threshold=1e-5 ) except Exception as e: print(f"Error processing {genome_name}: {str(e)}. Skipping this genome.") continue if geneKoSeq.empty: print(f"No valid gene-KO-sequence entries for {genome_name}. Skipping.") continue # Prepare sequence dictionary seqs = {gene: seq for gene, seq in zip(geneKoSeq['gene'], geneKoSeq['sequence'])} # Generate protein embeddings results, invalid_items = get_embeddings( t5_encoder, t5_tokenizer, seqs, device, sec_struct=False, per_residue=False, per_protein=True ) if not results['protein_embs']: print(f"No valid protein embeddings for {genome_name}. Skipping.") continue # Process embeddings (padding to fixed length) combined_matrix = np.stack(list(results['protein_embs'].values())) embedding_pad, mask = ko_embedding_padding(combined_matrix, 4500) # Prepare data loader ko_token = [embedding_pad] ko_mask = [mask] ko_token = torch.stack([torch.from_numpy(x).float() for x in ko_token]) ko_mask = torch.stack([torch.from_numpy(x).float() for x in ko_mask]) current_genome = [genome_name] dataloader = build_inference_dataloader(ko_token=ko_token, ko_mask=ko_mask, genome=current_genome, batch_size=1, shuffle=False) # Model configuration model_config = { 'heads': 8, 'd_model': 256, 'src_input_dim': 1024, 'first_linear_dim': 512, 'num_encoder_layer': 3, 'num_decoder_layer': 3, 'dropout': 0.1, 'compound_max_len': compound_max } # Run inference medium_result = evaluate_model( model, dataloader, compound_cab, device, model_config, [8, 9, 10], [0.9], [4], ['beam_search'], args.output_dir ) combined_results = pd.concat([combined_results, medium_result], ignore_index=True) print(f"Processed {genome_name}, accumulated {len(combined_results)} entries") output_file_name = os.path.join(args.output_dir, "generateMedia.csv") combined_results.to_csv(output_file_name, sep='\t', index=False) if __name__ == "__main__": main()