from torch.utils.data import Dataset, DataLoader, random_split import pandas as pd import torch import h5py import ast import numpy as np import pickle import os def ko_embedding_padding(embedding, max_len): pad_rows = max_len - embedding.shape[0] padded_embeddings = np.pad( embedding, pad_width=((0, pad_rows), (0, 0)), mode="constant", constant_values=0, ) # mask = np.full(max_len, -np.inf, dtype=np.float32) # mask[:embedding.shape[0]] = 0.0 mask = np.ones(max_len, dtype=bool) mask[:embedding.shape[0]] = False return padded_embeddings, mask def get_input_token(inputs, vo_cab, max_len): inputs = sorted(inputs) input_token = [vo_cab['']] for j in inputs: input_token.append(vo_cab[j]) input_token.append(vo_cab['']) padding_len = max_len - len(input_token) if padding_len > 0: input_token += [vo_cab['blank']] * padding_len return input_token def get_dict(data, colum): all_list = list(data[colum]) all_list = (pd.DataFrame(all_list, columns=[colum]) .drop_duplicates([colum]) .sort_values(by=colum) .reset_index(drop=True)) return dict(zip(all_list[colum], all_list.index)) class DataProcessor: def __init__(self, ko_count_path, data_path, embedding_h5_path, vo_cab_pkl_path, model_type, ko_max_len=4500, compound_max_len=98): self.ko_count_path = ko_count_path self.data_path = data_path self.embedding_h5_path = embedding_h5_path self.vo_cab_pkl_path = vo_cab_pkl_path self.model_type = model_type self.ko_max_len = ko_max_len self.compound_max_len = compound_max_len self.KO_count = None self.data = None self.vo_cab = None self.genome_dict = None self.medium_dict = None def load_data(self, count_min=800, count_max=4500): self.KO_count = pd.read_csv(self.ko_count_path) self.KO_count = self.KO_count[(self.KO_count['count'] >= count_min) & (self.KO_count['count'] <= count_max)] self.data = pd.read_csv(self.data_path) self.data = pd.merge(self.KO_count, self.data, how='inner', on='genome').drop_duplicates(['genome', 'compounds_list']) self.data['compounds_list'] = self.data['compounds_list'].apply(lambda x: list(ast.literal_eval(x))) # Load vo_cab: use pkl file if self.vo_cab_pkl_path and os.path.exists(self.vo_cab_pkl_path): print(f"Loading predefined compound_cab from {self.vo_cab_pkl_path}") with open(self.vo_cab_pkl_path, 'rb') as f: data_pkl = pickle.load(f) if isinstance(data_pkl, dict) and 'compound_cab' in data_pkl: self.vo_cab = data_pkl['compound_cab'] print(f"Successfully loaded compound_cab, vocabulary size: {len(self.vo_cab)}") # genome and medium dictionaries self.genome_dict = get_dict(self.data, 'genome') self.medium_dict = get_dict(self.data, 'media_name') # 保存 genome_dict 和 medium_dict 到文件 genome_filename = f"{self.model_type}_genome_dict.pkl" medium_filename = f"{self.model_type}_medium_dict.pkl" # 保存 genome_dict with open(genome_filename, 'wb') as f: pickle.dump(self.genome_dict, f) # 保存 medium_dict with open(medium_filename, 'wb') as f: pickle.dump(self.medium_dict, f) def get_sample_by_index(self, index): """根据索引获取单个样本的数据""" if self.data is None: raise ValueError("请先调用load_data()方法加载数据") if index >= len(self.data): raise IndexError("索引超出数据范围") row = self.data.iloc[index] genome = row['genome'] compounds = row['compounds_list'] medium = row['media_name'] # 从HDF5文件读取embedding with h5py.File(self.embedding_h5_path, "r") as hf: try: embeddings = hf[genome][:] except KeyError: print(f"⚠️ Genome {genome} not found in HDF5") return None # 处理token和mask genome_token = self.genome_dict[genome] medium_token = self.medium_dict[medium] compound_token = get_input_token(inputs=compounds, vo_cab=self.vo_cab, max_len=self.compound_max_len) embedding_pad, mask = ko_embedding_padding(embeddings, self.ko_max_len) return { 'ko_token': torch.from_numpy(embedding_pad).float(), 'compound_token': torch.tensor(compound_token), 'ko_mask': torch.tensor(mask), 'genome_token': torch.tensor(genome_token), 'medium_token': torch.tensor(medium_token) } def get_data_length(self): """获取数据集长度""" if self.data is None: raise ValueError("请先调用load_data()方法加载数据") return len(self.data) def process_test_data(self, genome_list=None): # 如果 genome_list 为空,就使用整个数据集 if genome_list is None: test_data = self.data else: test_data = self.data[self.data['genome'].isin(genome_list)] ko_token = [] compound_token = [] medium_token = [] genome_token = [] ko_mask = [] with h5py.File(self.embedding_h5_path, "r") as hf: for genome, compounds, medium in zip(test_data['genome'], test_data['compounds_list'], test_data['media_name']): try: embeddings = hf[genome][:] except KeyError: print(f"⚠️ Genome {genome} not found in HDF5, skipped") continue # 跳过这个 genome # 处理 token 和 mask genome_token.append(self.genome_dict[genome]) medium_token.append(self.medium_dict[medium]) compound_token.append(get_input_token(inputs=compounds, vo_cab=self.vo_cab, max_len=self.compound_max_len)) embedding_pad, mask = ko_embedding_padding(embeddings, self.ko_max_len) ko_token.append(embedding_pad) ko_mask.append(mask) # 转成 tensor medium_token = torch.tensor(medium_token) genome_token = torch.tensor(genome_token) ko_mask = torch.tensor(ko_mask) ko_token = torch.stack([torch.from_numpy(x).float() for x in ko_token]) compound_token = torch.tensor(compound_token) return ko_token, compound_token, ko_mask, genome_token, medium_token class MyDataset(Dataset): def __init__(self, src, trg, src_mask, genome, medium): self.src = src self.trg = trg self.src_mask = src_mask self.genome = genome self.medium = medium def __getitem__(self, index): return self.src[index], self.trg[index], self.src_mask[index], self.genome[index], self.medium[index] def __len__(self): return self.trg.size(0) class LazyDataset(Dataset): """支持动态加载数据的Dataset类""" def __init__(self, data_processor, indices=None): self.data_processor = data_processor if indices is None: self.indices = list(range(data_processor.get_data_length())) else: self.indices = indices def __getitem__(self, index): # 根据索引动态从DataProcessor获取数据 actual_index = self.indices[index] sample = self.data_processor.get_sample_by_index(actual_index) if sample is None: # 如果无法获取数据,返回下一个有效样本 for i in range(index + 1, len(self.indices)): actual_index = self.indices[i] sample = self.data_processor.get_sample_by_index(actual_index) if sample is not None: break if sample is None: raise RuntimeError(f"无法获取索引 {index} 的数据") return (sample['ko_token'], sample['compound_token'], sample['ko_mask'], sample['genome_token'], sample['medium_token']) def __len__(self): return len(self.indices) # def build_dataloaders(ko_token, compound_token, ko_mask, genome_token, medium_token, batch_size): # dataset = MyDataset(ko_token, compound_token, ko_mask, genome_token, medium_token) # n = len(dataset) # train_size = int(n * 10 / 10) # # valid_size = int((n - train_size) / 2) # # test_size = n - train_size - valid_size # # train_ds, valid_ds, test_ds = random_split(dataset, [train_size, valid_size, test_size]) # train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True) # # train_loader_no_shuffle = DataLoader(train_ds, batch_size=batch_size, shuffle=False) # # valid_loader = DataLoader(valid_ds, batch_size=batch_size, shuffle=False) # # test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False) # # return train_loader, train_loader_no_shuffle, valid_loader, test_loader # return train_loader def build_dataloaders(ko_token, compound_token, ko_mask, genome_token, medium_token, batch_size, shuffle): dataset = MyDataset(ko_token, compound_token, ko_mask, genome_token, medium_token) train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) return train_loader # 返回单个 DataLoader def build_single_lazy_dataloader(data_processor, batch_size=1, shuffle=False): """为单个数据集创建懒加载数据加载器""" dataset = LazyDataset(data_processor) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) return dataloader def build_lazy_dataloaders(data_processor, batch_size=1, split_ratio=8, shuffle=True): """使用LazyDataset构建数据加载器,支持动态加载(保持向后兼容)""" total_length = data_processor.get_data_length() indices = list(range(total_length)) # 计算分割点 train_size = int(total_length * split_ratio / 10) valid_size = int((total_length - train_size) / 2) # 分割索引 if shuffle: import random random.shuffle(indices) train_indices = indices[:train_size] valid_indices = indices[train_size:train_size + valid_size] test_indices = indices[train_size + valid_size:] # 创建数据集和数据加载器 train_dataset = LazyDataset(data_processor, train_indices) valid_dataset = LazyDataset(data_processor, valid_indices) test_dataset = LazyDataset(data_processor, test_indices) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle) train_loader_no_shuffle = DataLoader(train_dataset, batch_size=batch_size, shuffle=False) valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) return train_loader, train_loader_no_shuffle, valid_loader, test_loader # Dataset class specifically for model inference (using unseen data) class InferenceDataset(Dataset): """ Dataset class for loading data during inference phase. Designed to handle unseen data after model training. """ def __init__(self, src, src_mask, genome): """ Initialize the inference dataset. Args: src: Input features (e.g., embeddings) with shape [num_samples, seq_len, d_model] src_mask: Corresponding mask tensor with shape [num_samples, seq_len] """ self.src = src # Input features (e.g., ko_token embeddings) self.src_mask = src_mask # Mask for input features self.genome = genome def __getitem__(self, index): """ Get a single sample from the dataset by index. Args: index: Index of the sample to retrieve Returns: Tuple containing (src_feature, src_mask) for the specified index """ return self.src[index], self.src_mask[index], self.genome[index] def __len__(self): """ Return the total number of samples in the dataset. Returns: int: Number of samples (matches length of source features) """ return len(self.src) def build_inference_dataloader(ko_token, ko_mask, genome, batch_size=1, shuffle=False): """ Build a DataLoader for model inference on unseen data. Args: ko_token: Input features (embeddings) for inference, shape [num_samples, seq_len, d_model] ko_mask: Corresponding mask tensor, shape [num_samples, seq_len] batch_size: Batch size for inference (default: 1) shuffle: Whether to shuffle data (default: False for inference to preserve order) Returns: DataLoader: Ready-to-use DataLoader for inference """ # Initialize dataset with inference data dataset = InferenceDataset( src=ko_token, src_mask=ko_mask, genome=genome ) # Create DataLoader with inference-optimized settings return DataLoader( dataset, batch_size=batch_size, shuffle=shuffle )