Files
labweb/models/Inference.py
2025-12-16 11:39:15 +08:00

338 lines
15 KiB
Python

import os
import argparse
import pandas as pd
import numpy as np
import torch
import pickle
import time
from Bio import SeqIO
from transformers import T5EncoderModel, T5Tokenizer
from torch.utils.data import Dataset, DataLoader
import logging
from transformers import logging as hf_logging
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.modules.transformer")
hf_logging.set_verbosity_error()
# Assume these dependencies are properly implemented
# Adjust imports according to actual file structure
from model import Model
from dmodel_256_decode_model import evaluate_model
from datasets import ko_embedding_padding # Assumes this function exists
class InferenceDataset(Dataset):
"""Dataset class for handling unseen data during inference"""
def __init__(self, src, src_mask, genome):
self.src = src # Input features (embeddings)
self.src_mask = src_mask # Corresponding mask tensor
self.genome = genome
def __getitem__(self, index):
"""Retrieve sample by index"""
return self.src[index], self.src_mask[index], self.genome[index]
def __len__(self):
"""Return total number of samples"""
return len(self.src)
def build_inference_dataloader(ko_token, ko_mask, genome, batch_size=1, shuffle=False):
"""
Construct DataLoader for inference
Args:
ko_token: Input feature tensors
ko_mask: Corresponding mask tensors
batch_size: Batch size for inference
shuffle: Whether to shuffle data (typically False for inference)
Returns:
DataLoader instance
"""
dataset = InferenceDataset(src=ko_token, src_mask=ko_mask, genome=genome)
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
def process_annotation_and_sequence(annotation_path, fasta_path, annotation_type, evalue_threshold=1e-5):
"""
Process annotations (eggNOG-mapper or kofamscan) and protein sequences, then merge them
Args:
annotation_path: Path to annotation file (eggNOG-mapper or kofamscan)
fasta_path: Path to protein FASTA file
annotation_type: Type of annotation file, must be 'emmapper' or 'kofamscan'
evalue_threshold: E-value cutoff for filtering eggNOG-mapper annotations (only used if annotation_type is 'emmapper')
Returns:
DataFrame with merged gene-KO-sequence information
"""
# Validate annotation type
if annotation_type not in ['emmapper', 'kofamscan']:
raise ValueError("annotation_type must be either 'emmapper' or 'kofamscan'")
# Process annotation data based on type
if annotation_type == 'emmapper':
annotation_df = (
pd.read_csv(annotation_path, sep='\t', comment='#', header=None, usecols=[0, 2, 11])
.rename(columns={0: 'gene', 2: 'evalue', 11: 'KO'})
.query(f'evalue < {evalue_threshold} & KO != "-"') # Filter low-quality annotations
.assign(KO=lambda x: x['KO'].str.split(',')) # Split multiple KOs
.explode('KO') # Expand KO list to rows
.assign(KO=lambda x: x['KO'].str.split('ko:').str[1]) # Extract KO number
.drop_duplicates(subset='gene', keep='first') # Keep first KO per gene
.sort_values('gene') # Sort by gene ID
.drop_duplicates(subset='KO', keep='first') # Keep first gene per KO
.drop(columns='evalue') # Remove evalue column
)
elif annotation_type == 'kofamscan':
annotation_df = pd.read_csv(annotation_path, sep='\t', header=None, names=['gene', 'KO'])
annotation_df = (
annotation_df[~annotation_df['KO'].isna()] # Remove rows with missing KO
.drop_duplicates(subset='gene', keep='first') # Keep first KO per gene
.sort_values('KO') # Sort by KO
.drop_duplicates(subset='KO', keep='first') # Keep first gene per KO
)
# Process FASTA sequences
seq_df = pd.DataFrame(
[(rec.id, str(rec.seq)) for rec in SeqIO.parse(fasta_path, "fasta")],
columns=['gene', 'sequence']
)
# Merge annotations with sequences
return annotation_df.merge(seq_df, on='gene', how='inner')
def get_T5_model(local_model_dir, device):
"""
Load T5 encoder model and tokenizer
Args:
local_model_dir: Directory containing T5 model files
device: Computation device (GPU/CPU)
Returns:
Tuple of (T5 model, tokenizer)
"""
tokenizer = T5Tokenizer.from_pretrained(local_model_dir, do_lower_case=False)
model = T5EncoderModel.from_pretrained(local_model_dir)
model = model.to(device).eval() # Move to device and set to evaluation mode
return model, tokenizer
def get_embeddings(model, tokenizer, seqs, device, per_residue=False, per_protein=True, sec_struct=False,
max_residues=5000, max_seq_len=1000, max_batch=1):
"""
Generate protein embeddings using T5 encoder
Args:
model: T5 encoder model
tokenizer: T5 tokenizer
seqs: Dictionary of {gene_id: sequence}
device: Computation device
per_residue: Whether to generate per-residue embeddings
per_protein: Whether to generate per-protein embeddings
sec_struct: Whether to predict secondary structure
max_residues: Maximum total residues per batch
max_seq_len: Maximum length of single sequence
max_batch: Maximum number of sequences per batch
Returns:
Tuple of (embedding results, invalid sequences)
"""
results = {"residue_embs": dict(), "protein_embs": dict(), "sec_structs": dict()}
# Filter invalid sequences (e.g., NaN values)
invalid_items = {k: v for k, v in seqs.items() if isinstance(v, float)}
seq_dict = [(k, v) for k, v in seqs.items() if not isinstance(v, float)]
start = time.time()
batch = []
for seq_idx, (pdb_id, seq) in enumerate(seq_dict, 1):
seq_len = len(seq)
seq = ' '.join(list(seq)) # Format sequence for T5 tokenizer
batch.append((pdb_id, seq, seq_len))
# Check if batch needs processing
n_res_batch = sum([s_len for _, _, s_len in batch]) + seq_len
if len(batch) >= max_batch or n_res_batch >= max_residues or \
seq_idx == len(seq_dict) or seq_len > max_seq_len:
pdb_ids, seqs_batch, seq_lens = zip(*batch)
batch = []
# Tokenize sequences
token_encoding = tokenizer.batch_encode_plus(
seqs_batch,
add_special_tokens=True,
padding="longest"
)
input_ids = torch.tensor(token_encoding['input_ids']).to(device)
attention_mask = torch.tensor(token_encoding['attention_mask']).to(device)
try:
with torch.no_grad(): # Disable gradient calculation for inference
embedding_repr = model(input_ids, attention_mask=attention_mask)
except RuntimeError:
print(f"RuntimeError during embedding for {pdb_id} (length={seq_len})")
continue
# Process embeddings for each sequence in batch
for batch_idx, identifier in enumerate(pdb_ids):
s_len = seq_lens[batch_idx]
emb = embedding_repr.last_hidden_state[batch_idx, :s_len]
if per_protein:
# Average pooling for per-protein embedding
protein_emb = emb.mean(dim=0)
results["protein_embs"][identifier] = protein_emb.detach().cpu().numpy().squeeze()
del input_ids, attention_mask, embedding_repr
torch.cuda.empty_cache()
# Calculate timing statistics
torch.cuda.empty_cache()
passed_time = time.time() - start
avg_time = passed_time / len(results["protein_embs"]) if results["protein_embs"] else 0
print(f'Total per-protein embeddings generated: {len(results["protein_embs"])}')
print(f"Embedding generation time: {passed_time/60:.1f}m ({avg_time:.3f}s/protein)")
return results, invalid_items
def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='Run model inference on annotation and sequence data')
parser.add_argument('--annotation_dir', required=True,
help='Directory containing annotation files')
parser.add_argument('--startswith', required=True,
default='',
help='The file starts with [content]')
parser.add_argument('--endswith', required=True,
default='.emapper.annotations',
help='The file end with [content]')
parser.add_argument('--fasta_path', required=True,
help='Path to the protein FASTA file')
parser.add_argument('--data_pkl', required=True,
default='./data_pkl',
help='Path to the data.pkl file containing metadata')
parser.add_argument('--model_path', required=True,
help='Path to the trained model (.pt file)')
parser.add_argument('--output_dir', required=True,
help='Output directory for inference results')
parser.add_argument('--t5_model_dir',
default='./scripts/model_para/',
help='Directory containing T5 model parameters (default: specified path)')
parser.add_argument('--annotation_type',
required=True,
choices=['emmapper', 'kofamscan'],
help='Type of annotation file: "emmapper" or "kofamscan"')
args = parser.parse_args()
# Set up computation device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Using computation device: {device}")
# Load metadata from data.pkl
with open(args.data_pkl, 'rb') as f:
loaded_data = pickle.load(f)
compound_cab = loaded_data['compound_cab']
compound_max = loaded_data['compound_max']
# Initialize and load the main model
model = Model(heads=8, d_model=256, num_encoder_layer=3, num_decoder_layer=3, dropout=0.1,
src_input_dim=1024, first_linear_dim=512, tgt_vocab_size=len(compound_cab), compound_cab=compound_cab
).to(device)
model.load_state_dict(torch.load(args.model_path, map_location=device))
model.eval() # Set to evaluation mode
print(f"Successfully loaded model from: {args.model_path}")
# Load T5 encoder model and tokenizer
t5_encoder, t5_tokenizer = get_T5_model(args.t5_model_dir, device)
# Find all valid annotation files in the specified directory
annotation_files = []
genomes = []
for f in os.listdir(args.annotation_dir):
start_match = (not args.startswith) or f.startswith(args.startswith)
end_match = (not args.endswith) or f.endswith(args.endswith)
if start_match and end_match:
annotation_files.append(os.path.join(args.annotation_dir, f))
if args.startswith:
after_start = f.split(args.startswith)[1]
else:
after_start = f
if args.endswith:
genome_name = after_start.split(args.endswith)[0]
else:
genome_name = after_start
genomes.append(genome_name)
if not annotation_files:
print(f"No valid annotation files found in {args.annotation_dir}. "
f"Files must start with '{args.startswith}' and end with '{args.endswith}'.")
return
# Process each annotation file
combined_results = pd.DataFrame()
for idx, annotation_path in enumerate(annotation_files):
genome_name = genomes[idx]
protein_path = os.path.join(args.fasta_path, f"{genome_name}.faa")
if not os.path.exists(protein_path):
print(f"Error: Protein file not found - {protein_path}. Skipping this genome.")
continue
# Process annotation and sequence data
try:
geneKoSeq = process_annotation_and_sequence(
annotation_path,
protein_path,
annotation_type=args.annotation_type,
evalue_threshold=1e-5
)
except Exception as e:
print(f"Error processing {genome_name}: {str(e)}. Skipping this genome.")
continue
if geneKoSeq.empty:
print(f"No valid gene-KO-sequence entries for {genome_name}. Skipping.")
continue
# Prepare sequence dictionary
seqs = {gene: seq for gene, seq in zip(geneKoSeq['gene'], geneKoSeq['sequence'])}
# Generate protein embeddings
results, invalid_items = get_embeddings(
t5_encoder, t5_tokenizer, seqs, device,
sec_struct=False, per_residue=False, per_protein=True
)
if not results['protein_embs']:
print(f"No valid protein embeddings for {genome_name}. Skipping.")
continue
# Process embeddings (padding to fixed length)
combined_matrix = np.stack(list(results['protein_embs'].values()))
embedding_pad, mask = ko_embedding_padding(combined_matrix, 4500)
# Prepare data loader
ko_token = [embedding_pad]
ko_mask = [mask]
ko_token = torch.stack([torch.from_numpy(x).float() for x in ko_token])
ko_mask = torch.stack([torch.from_numpy(x).float() for x in ko_mask])
current_genome = [genome_name]
dataloader = build_inference_dataloader(ko_token=ko_token, ko_mask=ko_mask, genome=current_genome, batch_size=1, shuffle=False)
# Model configuration
model_config = {
'heads': 8,
'd_model': 256,
'src_input_dim': 1024,
'first_linear_dim': 512,
'num_encoder_layer': 3,
'num_decoder_layer': 3,
'dropout': 0.1,
'compound_max_len': compound_max
}
# Run inference
medium_result = evaluate_model(
model, dataloader, compound_cab, device, model_config,
[8, 9, 10], [0.9], [4], ['beam_search'], args.output_dir
)
combined_results = pd.concat([combined_results, medium_result], ignore_index=True)
print(f"Processed {genome_name}, accumulated {len(combined_results)} entries")
output_file_name = os.path.join(args.output_dir, "generateMedia.csv")
combined_results.to_csv(output_file_name, sep='\t', index=False)
if __name__ == "__main__":
main()