338 lines
15 KiB
Python
338 lines
15 KiB
Python
import os
|
|
import argparse
|
|
import pandas as pd
|
|
import numpy as np
|
|
import torch
|
|
import pickle
|
|
import time
|
|
from Bio import SeqIO
|
|
from transformers import T5EncoderModel, T5Tokenizer
|
|
from torch.utils.data import Dataset, DataLoader
|
|
import logging
|
|
from transformers import logging as hf_logging
|
|
import warnings
|
|
warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.modules.transformer")
|
|
|
|
hf_logging.set_verbosity_error()
|
|
# Assume these dependencies are properly implemented
|
|
# Adjust imports according to actual file structure
|
|
from model import Model
|
|
from dmodel_256_decode_model import evaluate_model
|
|
from datasets import ko_embedding_padding # Assumes this function exists
|
|
|
|
class InferenceDataset(Dataset):
|
|
"""Dataset class for handling unseen data during inference"""
|
|
def __init__(self, src, src_mask, genome):
|
|
self.src = src # Input features (embeddings)
|
|
self.src_mask = src_mask # Corresponding mask tensor
|
|
self.genome = genome
|
|
|
|
def __getitem__(self, index):
|
|
"""Retrieve sample by index"""
|
|
return self.src[index], self.src_mask[index], self.genome[index]
|
|
|
|
def __len__(self):
|
|
"""Return total number of samples"""
|
|
return len(self.src)
|
|
|
|
def build_inference_dataloader(ko_token, ko_mask, genome, batch_size=1, shuffle=False):
|
|
"""
|
|
Construct DataLoader for inference
|
|
|
|
Args:
|
|
ko_token: Input feature tensors
|
|
ko_mask: Corresponding mask tensors
|
|
batch_size: Batch size for inference
|
|
shuffle: Whether to shuffle data (typically False for inference)
|
|
|
|
Returns:
|
|
DataLoader instance
|
|
"""
|
|
dataset = InferenceDataset(src=ko_token, src_mask=ko_mask, genome=genome)
|
|
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
|
|
|
|
def process_annotation_and_sequence(annotation_path, fasta_path, annotation_type, evalue_threshold=1e-5):
|
|
"""
|
|
Process annotations (eggNOG-mapper or kofamscan) and protein sequences, then merge them
|
|
Args:
|
|
annotation_path: Path to annotation file (eggNOG-mapper or kofamscan)
|
|
fasta_path: Path to protein FASTA file
|
|
annotation_type: Type of annotation file, must be 'emmapper' or 'kofamscan'
|
|
evalue_threshold: E-value cutoff for filtering eggNOG-mapper annotations (only used if annotation_type is 'emmapper')
|
|
Returns:
|
|
DataFrame with merged gene-KO-sequence information
|
|
"""
|
|
# Validate annotation type
|
|
if annotation_type not in ['emmapper', 'kofamscan']:
|
|
raise ValueError("annotation_type must be either 'emmapper' or 'kofamscan'")
|
|
|
|
# Process annotation data based on type
|
|
if annotation_type == 'emmapper':
|
|
annotation_df = (
|
|
pd.read_csv(annotation_path, sep='\t', comment='#', header=None, usecols=[0, 2, 11])
|
|
.rename(columns={0: 'gene', 2: 'evalue', 11: 'KO'})
|
|
.query(f'evalue < {evalue_threshold} & KO != "-"') # Filter low-quality annotations
|
|
.assign(KO=lambda x: x['KO'].str.split(',')) # Split multiple KOs
|
|
.explode('KO') # Expand KO list to rows
|
|
.assign(KO=lambda x: x['KO'].str.split('ko:').str[1]) # Extract KO number
|
|
.drop_duplicates(subset='gene', keep='first') # Keep first KO per gene
|
|
.sort_values('gene') # Sort by gene ID
|
|
.drop_duplicates(subset='KO', keep='first') # Keep first gene per KO
|
|
.drop(columns='evalue') # Remove evalue column
|
|
)
|
|
elif annotation_type == 'kofamscan':
|
|
annotation_df = pd.read_csv(annotation_path, sep='\t', header=None, names=['gene', 'KO'])
|
|
annotation_df = (
|
|
annotation_df[~annotation_df['KO'].isna()] # Remove rows with missing KO
|
|
.drop_duplicates(subset='gene', keep='first') # Keep first KO per gene
|
|
.sort_values('KO') # Sort by KO
|
|
.drop_duplicates(subset='KO', keep='first') # Keep first gene per KO
|
|
)
|
|
|
|
# Process FASTA sequences
|
|
seq_df = pd.DataFrame(
|
|
[(rec.id, str(rec.seq)) for rec in SeqIO.parse(fasta_path, "fasta")],
|
|
columns=['gene', 'sequence']
|
|
)
|
|
|
|
# Merge annotations with sequences
|
|
return annotation_df.merge(seq_df, on='gene', how='inner')
|
|
|
|
def get_T5_model(local_model_dir, device):
|
|
"""
|
|
Load T5 encoder model and tokenizer
|
|
|
|
Args:
|
|
local_model_dir: Directory containing T5 model files
|
|
device: Computation device (GPU/CPU)
|
|
|
|
Returns:
|
|
Tuple of (T5 model, tokenizer)
|
|
"""
|
|
tokenizer = T5Tokenizer.from_pretrained(local_model_dir, do_lower_case=False)
|
|
model = T5EncoderModel.from_pretrained(local_model_dir)
|
|
model = model.to(device).eval() # Move to device and set to evaluation mode
|
|
return model, tokenizer
|
|
|
|
def get_embeddings(model, tokenizer, seqs, device, per_residue=False, per_protein=True, sec_struct=False,
|
|
max_residues=5000, max_seq_len=1000, max_batch=1):
|
|
"""
|
|
Generate protein embeddings using T5 encoder
|
|
|
|
Args:
|
|
model: T5 encoder model
|
|
tokenizer: T5 tokenizer
|
|
seqs: Dictionary of {gene_id: sequence}
|
|
device: Computation device
|
|
per_residue: Whether to generate per-residue embeddings
|
|
per_protein: Whether to generate per-protein embeddings
|
|
sec_struct: Whether to predict secondary structure
|
|
max_residues: Maximum total residues per batch
|
|
max_seq_len: Maximum length of single sequence
|
|
max_batch: Maximum number of sequences per batch
|
|
|
|
Returns:
|
|
Tuple of (embedding results, invalid sequences)
|
|
"""
|
|
results = {"residue_embs": dict(), "protein_embs": dict(), "sec_structs": dict()}
|
|
# Filter invalid sequences (e.g., NaN values)
|
|
invalid_items = {k: v for k, v in seqs.items() if isinstance(v, float)}
|
|
seq_dict = [(k, v) for k, v in seqs.items() if not isinstance(v, float)]
|
|
|
|
start = time.time()
|
|
batch = []
|
|
for seq_idx, (pdb_id, seq) in enumerate(seq_dict, 1):
|
|
seq_len = len(seq)
|
|
seq = ' '.join(list(seq)) # Format sequence for T5 tokenizer
|
|
batch.append((pdb_id, seq, seq_len))
|
|
|
|
# Check if batch needs processing
|
|
n_res_batch = sum([s_len for _, _, s_len in batch]) + seq_len
|
|
if len(batch) >= max_batch or n_res_batch >= max_residues or \
|
|
seq_idx == len(seq_dict) or seq_len > max_seq_len:
|
|
pdb_ids, seqs_batch, seq_lens = zip(*batch)
|
|
batch = []
|
|
|
|
# Tokenize sequences
|
|
token_encoding = tokenizer.batch_encode_plus(
|
|
seqs_batch,
|
|
add_special_tokens=True,
|
|
padding="longest"
|
|
)
|
|
input_ids = torch.tensor(token_encoding['input_ids']).to(device)
|
|
attention_mask = torch.tensor(token_encoding['attention_mask']).to(device)
|
|
|
|
try:
|
|
with torch.no_grad(): # Disable gradient calculation for inference
|
|
embedding_repr = model(input_ids, attention_mask=attention_mask)
|
|
except RuntimeError:
|
|
print(f"RuntimeError during embedding for {pdb_id} (length={seq_len})")
|
|
continue
|
|
|
|
# Process embeddings for each sequence in batch
|
|
for batch_idx, identifier in enumerate(pdb_ids):
|
|
s_len = seq_lens[batch_idx]
|
|
emb = embedding_repr.last_hidden_state[batch_idx, :s_len]
|
|
if per_protein:
|
|
# Average pooling for per-protein embedding
|
|
protein_emb = emb.mean(dim=0)
|
|
results["protein_embs"][identifier] = protein_emb.detach().cpu().numpy().squeeze()
|
|
del input_ids, attention_mask, embedding_repr
|
|
torch.cuda.empty_cache()
|
|
|
|
# Calculate timing statistics
|
|
torch.cuda.empty_cache()
|
|
passed_time = time.time() - start
|
|
avg_time = passed_time / len(results["protein_embs"]) if results["protein_embs"] else 0
|
|
print(f'Total per-protein embeddings generated: {len(results["protein_embs"])}')
|
|
print(f"Embedding generation time: {passed_time/60:.1f}m ({avg_time:.3f}s/protein)")
|
|
return results, invalid_items
|
|
|
|
def main():
|
|
# Parse command line arguments
|
|
parser = argparse.ArgumentParser(description='Run model inference on annotation and sequence data')
|
|
parser.add_argument('--annotation_dir', required=True,
|
|
help='Directory containing annotation files')
|
|
parser.add_argument('--startswith', required=True,
|
|
default='',
|
|
help='The file starts with [content]')
|
|
parser.add_argument('--endswith', required=True,
|
|
default='.emapper.annotations',
|
|
help='The file end with [content]')
|
|
parser.add_argument('--fasta_path', required=True,
|
|
help='Path to the protein FASTA file')
|
|
parser.add_argument('--data_pkl', required=True,
|
|
default='./data_pkl',
|
|
help='Path to the data.pkl file containing metadata')
|
|
parser.add_argument('--model_path', required=True,
|
|
help='Path to the trained model (.pt file)')
|
|
parser.add_argument('--output_dir', required=True,
|
|
help='Output directory for inference results')
|
|
parser.add_argument('--t5_model_dir',
|
|
default='./scripts/model_para/',
|
|
help='Directory containing T5 model parameters (default: specified path)')
|
|
parser.add_argument('--annotation_type',
|
|
required=True,
|
|
choices=['emmapper', 'kofamscan'],
|
|
help='Type of annotation file: "emmapper" or "kofamscan"')
|
|
args = parser.parse_args()
|
|
|
|
# Set up computation device
|
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
|
print(f"Using computation device: {device}")
|
|
|
|
# Load metadata from data.pkl
|
|
with open(args.data_pkl, 'rb') as f:
|
|
loaded_data = pickle.load(f)
|
|
compound_cab = loaded_data['compound_cab']
|
|
compound_max = loaded_data['compound_max']
|
|
|
|
# Initialize and load the main model
|
|
model = Model(heads=8, d_model=256, num_encoder_layer=3, num_decoder_layer=3, dropout=0.1,
|
|
src_input_dim=1024, first_linear_dim=512, tgt_vocab_size=len(compound_cab), compound_cab=compound_cab
|
|
).to(device)
|
|
model.load_state_dict(torch.load(args.model_path, map_location=device))
|
|
model.eval() # Set to evaluation mode
|
|
print(f"Successfully loaded model from: {args.model_path}")
|
|
|
|
# Load T5 encoder model and tokenizer
|
|
t5_encoder, t5_tokenizer = get_T5_model(args.t5_model_dir, device)
|
|
|
|
# Find all valid annotation files in the specified directory
|
|
annotation_files = []
|
|
genomes = []
|
|
for f in os.listdir(args.annotation_dir):
|
|
start_match = (not args.startswith) or f.startswith(args.startswith)
|
|
end_match = (not args.endswith) or f.endswith(args.endswith)
|
|
if start_match and end_match:
|
|
annotation_files.append(os.path.join(args.annotation_dir, f))
|
|
if args.startswith:
|
|
after_start = f.split(args.startswith)[1]
|
|
else:
|
|
after_start = f
|
|
if args.endswith:
|
|
genome_name = after_start.split(args.endswith)[0]
|
|
else:
|
|
genome_name = after_start
|
|
genomes.append(genome_name)
|
|
|
|
if not annotation_files:
|
|
print(f"No valid annotation files found in {args.annotation_dir}. "
|
|
f"Files must start with '{args.startswith}' and end with '{args.endswith}'.")
|
|
return
|
|
|
|
# Process each annotation file
|
|
combined_results = pd.DataFrame()
|
|
for idx, annotation_path in enumerate(annotation_files):
|
|
genome_name = genomes[idx]
|
|
protein_path = os.path.join(args.fasta_path, f"{genome_name}.faa")
|
|
|
|
if not os.path.exists(protein_path):
|
|
print(f"Error: Protein file not found - {protein_path}. Skipping this genome.")
|
|
continue
|
|
|
|
# Process annotation and sequence data
|
|
try:
|
|
geneKoSeq = process_annotation_and_sequence(
|
|
annotation_path,
|
|
protein_path,
|
|
annotation_type=args.annotation_type,
|
|
evalue_threshold=1e-5
|
|
)
|
|
except Exception as e:
|
|
print(f"Error processing {genome_name}: {str(e)}. Skipping this genome.")
|
|
continue
|
|
|
|
if geneKoSeq.empty:
|
|
print(f"No valid gene-KO-sequence entries for {genome_name}. Skipping.")
|
|
continue
|
|
|
|
# Prepare sequence dictionary
|
|
seqs = {gene: seq for gene, seq in zip(geneKoSeq['gene'], geneKoSeq['sequence'])}
|
|
|
|
# Generate protein embeddings
|
|
results, invalid_items = get_embeddings(
|
|
t5_encoder, t5_tokenizer, seqs, device,
|
|
sec_struct=False, per_residue=False, per_protein=True
|
|
)
|
|
if not results['protein_embs']:
|
|
print(f"No valid protein embeddings for {genome_name}. Skipping.")
|
|
continue
|
|
|
|
# Process embeddings (padding to fixed length)
|
|
combined_matrix = np.stack(list(results['protein_embs'].values()))
|
|
embedding_pad, mask = ko_embedding_padding(combined_matrix, 4500)
|
|
|
|
# Prepare data loader
|
|
ko_token = [embedding_pad]
|
|
ko_mask = [mask]
|
|
ko_token = torch.stack([torch.from_numpy(x).float() for x in ko_token])
|
|
ko_mask = torch.stack([torch.from_numpy(x).float() for x in ko_mask])
|
|
current_genome = [genome_name]
|
|
dataloader = build_inference_dataloader(ko_token=ko_token, ko_mask=ko_mask, genome=current_genome, batch_size=1, shuffle=False)
|
|
|
|
# Model configuration
|
|
model_config = {
|
|
'heads': 8,
|
|
'd_model': 256,
|
|
'src_input_dim': 1024,
|
|
'first_linear_dim': 512,
|
|
'num_encoder_layer': 3,
|
|
'num_decoder_layer': 3,
|
|
'dropout': 0.1,
|
|
'compound_max_len': compound_max
|
|
}
|
|
|
|
# Run inference
|
|
medium_result = evaluate_model(
|
|
model, dataloader, compound_cab, device, model_config,
|
|
[8, 9, 10], [0.9], [4], ['beam_search'], args.output_dir
|
|
)
|
|
combined_results = pd.concat([combined_results, medium_result], ignore_index=True)
|
|
print(f"Processed {genome_name}, accumulated {len(combined_results)} entries")
|
|
|
|
output_file_name = os.path.join(args.output_dir, "generateMedia.csv")
|
|
combined_results.to_csv(output_file_name, sep='\t', index=False)
|
|
|
|
if __name__ == "__main__":
|
|
main() |