342 lines
13 KiB
Python
342 lines
13 KiB
Python
from torch.utils.data import Dataset, DataLoader, random_split
|
||
import pandas as pd
|
||
import torch
|
||
import h5py
|
||
import ast
|
||
import numpy as np
|
||
import pickle
|
||
import os
|
||
|
||
def ko_embedding_padding(embedding, max_len):
|
||
pad_rows = max_len - embedding.shape[0]
|
||
padded_embeddings = np.pad(
|
||
embedding,
|
||
pad_width=((0, pad_rows), (0, 0)),
|
||
mode="constant",
|
||
constant_values=0,
|
||
)
|
||
# mask = np.full(max_len, -np.inf, dtype=np.float32)
|
||
# mask[:embedding.shape[0]] = 0.0
|
||
mask = np.ones(max_len, dtype=bool)
|
||
mask[:embedding.shape[0]] = False
|
||
|
||
return padded_embeddings, mask
|
||
|
||
def get_input_token(inputs, vo_cab, max_len):
|
||
inputs = sorted(inputs)
|
||
input_token = [vo_cab['</s>']]
|
||
for j in inputs:
|
||
input_token.append(vo_cab[j])
|
||
input_token.append(vo_cab['<s>'])
|
||
padding_len = max_len - len(input_token)
|
||
if padding_len > 0:
|
||
input_token += [vo_cab['blank']] * padding_len
|
||
return input_token
|
||
|
||
def get_dict(data, colum):
|
||
all_list = list(data[colum])
|
||
all_list = (pd.DataFrame(all_list, columns=[colum])
|
||
.drop_duplicates([colum])
|
||
.sort_values(by=colum)
|
||
.reset_index(drop=True))
|
||
return dict(zip(all_list[colum], all_list.index))
|
||
|
||
class DataProcessor:
|
||
def __init__(self, ko_count_path, data_path, embedding_h5_path, vo_cab_pkl_path, model_type, ko_max_len=4500, compound_max_len=98):
|
||
self.ko_count_path = ko_count_path
|
||
self.data_path = data_path
|
||
self.embedding_h5_path = embedding_h5_path
|
||
self.vo_cab_pkl_path = vo_cab_pkl_path
|
||
self.model_type = model_type
|
||
self.ko_max_len = ko_max_len
|
||
self.compound_max_len = compound_max_len
|
||
|
||
self.KO_count = None
|
||
self.data = None
|
||
self.vo_cab = None
|
||
self.genome_dict = None
|
||
self.medium_dict = None
|
||
|
||
def load_data(self, count_min=800, count_max=4500):
|
||
self.KO_count = pd.read_csv(self.ko_count_path)
|
||
self.KO_count = self.KO_count[(self.KO_count['count'] >= count_min) & (self.KO_count['count'] <= count_max)]
|
||
|
||
self.data = pd.read_csv(self.data_path)
|
||
self.data = pd.merge(self.KO_count, self.data, how='inner', on='genome').drop_duplicates(['genome', 'compounds_list'])
|
||
self.data['compounds_list'] = self.data['compounds_list'].apply(lambda x: list(ast.literal_eval(x)))
|
||
|
||
|
||
# Load vo_cab: use pkl file
|
||
if self.vo_cab_pkl_path and os.path.exists(self.vo_cab_pkl_path):
|
||
print(f"Loading predefined compound_cab from {self.vo_cab_pkl_path}")
|
||
with open(self.vo_cab_pkl_path, 'rb') as f:
|
||
data_pkl = pickle.load(f)
|
||
if isinstance(data_pkl, dict) and 'compound_cab' in data_pkl:
|
||
self.vo_cab = data_pkl['compound_cab']
|
||
print(f"Successfully loaded compound_cab, vocabulary size: {len(self.vo_cab)}")
|
||
|
||
# genome and medium dictionaries
|
||
self.genome_dict = get_dict(self.data, 'genome')
|
||
self.medium_dict = get_dict(self.data, 'media_name')
|
||
|
||
# 保存 genome_dict 和 medium_dict 到文件
|
||
genome_filename = f"{self.model_type}_genome_dict.pkl"
|
||
medium_filename = f"{self.model_type}_medium_dict.pkl"
|
||
# 保存 genome_dict
|
||
with open(genome_filename, 'wb') as f:
|
||
pickle.dump(self.genome_dict, f)
|
||
# 保存 medium_dict
|
||
with open(medium_filename, 'wb') as f:
|
||
pickle.dump(self.medium_dict, f)
|
||
|
||
def get_sample_by_index(self, index):
|
||
"""根据索引获取单个样本的数据"""
|
||
if self.data is None:
|
||
raise ValueError("请先调用load_data()方法加载数据")
|
||
|
||
if index >= len(self.data):
|
||
raise IndexError("索引超出数据范围")
|
||
|
||
row = self.data.iloc[index]
|
||
genome = row['genome']
|
||
compounds = row['compounds_list']
|
||
medium = row['media_name']
|
||
|
||
# 从HDF5文件读取embedding
|
||
with h5py.File(self.embedding_h5_path, "r") as hf:
|
||
try:
|
||
embeddings = hf[genome][:]
|
||
except KeyError:
|
||
print(f"⚠️ Genome {genome} not found in HDF5")
|
||
return None
|
||
|
||
# 处理token和mask
|
||
genome_token = self.genome_dict[genome]
|
||
medium_token = self.medium_dict[medium]
|
||
compound_token = get_input_token(inputs=compounds, vo_cab=self.vo_cab, max_len=self.compound_max_len)
|
||
|
||
embedding_pad, mask = ko_embedding_padding(embeddings, self.ko_max_len)
|
||
|
||
return {
|
||
'ko_token': torch.from_numpy(embedding_pad).float(),
|
||
'compound_token': torch.tensor(compound_token),
|
||
'ko_mask': torch.tensor(mask),
|
||
'genome_token': torch.tensor(genome_token),
|
||
'medium_token': torch.tensor(medium_token)
|
||
}
|
||
|
||
def get_data_length(self):
|
||
"""获取数据集长度"""
|
||
if self.data is None:
|
||
raise ValueError("请先调用load_data()方法加载数据")
|
||
return len(self.data)
|
||
|
||
def process_test_data(self, genome_list=None):
|
||
# 如果 genome_list 为空,就使用整个数据集
|
||
if genome_list is None:
|
||
test_data = self.data
|
||
else:
|
||
test_data = self.data[self.data['genome'].isin(genome_list)]
|
||
|
||
ko_token = []
|
||
compound_token = []
|
||
medium_token = []
|
||
genome_token = []
|
||
ko_mask = []
|
||
|
||
with h5py.File(self.embedding_h5_path, "r") as hf:
|
||
for genome, compounds, medium in zip(test_data['genome'], test_data['compounds_list'], test_data['media_name']):
|
||
try:
|
||
embeddings = hf[genome][:]
|
||
except KeyError:
|
||
print(f"⚠️ Genome {genome} not found in HDF5, skipped")
|
||
continue # 跳过这个 genome
|
||
|
||
# 处理 token 和 mask
|
||
genome_token.append(self.genome_dict[genome])
|
||
medium_token.append(self.medium_dict[medium])
|
||
compound_token.append(get_input_token(inputs=compounds, vo_cab=self.vo_cab, max_len=self.compound_max_len))
|
||
|
||
embedding_pad, mask = ko_embedding_padding(embeddings, self.ko_max_len)
|
||
ko_token.append(embedding_pad)
|
||
ko_mask.append(mask)
|
||
|
||
# 转成 tensor
|
||
medium_token = torch.tensor(medium_token)
|
||
genome_token = torch.tensor(genome_token)
|
||
ko_mask = torch.tensor(ko_mask)
|
||
ko_token = torch.stack([torch.from_numpy(x).float() for x in ko_token])
|
||
compound_token = torch.tensor(compound_token)
|
||
|
||
return ko_token, compound_token, ko_mask, genome_token, medium_token
|
||
|
||
|
||
class MyDataset(Dataset):
|
||
def __init__(self, src, trg, src_mask, genome, medium):
|
||
self.src = src
|
||
self.trg = trg
|
||
self.src_mask = src_mask
|
||
self.genome = genome
|
||
self.medium = medium
|
||
|
||
def __getitem__(self, index):
|
||
return self.src[index], self.trg[index], self.src_mask[index], self.genome[index], self.medium[index]
|
||
|
||
def __len__(self):
|
||
return self.trg.size(0)
|
||
|
||
|
||
class LazyDataset(Dataset):
|
||
"""支持动态加载数据的Dataset类"""
|
||
def __init__(self, data_processor, indices=None):
|
||
self.data_processor = data_processor
|
||
if indices is None:
|
||
self.indices = list(range(data_processor.get_data_length()))
|
||
else:
|
||
self.indices = indices
|
||
|
||
def __getitem__(self, index):
|
||
# 根据索引动态从DataProcessor获取数据
|
||
actual_index = self.indices[index]
|
||
sample = self.data_processor.get_sample_by_index(actual_index)
|
||
|
||
if sample is None:
|
||
# 如果无法获取数据,返回下一个有效样本
|
||
for i in range(index + 1, len(self.indices)):
|
||
actual_index = self.indices[i]
|
||
sample = self.data_processor.get_sample_by_index(actual_index)
|
||
if sample is not None:
|
||
break
|
||
|
||
if sample is None:
|
||
raise RuntimeError(f"无法获取索引 {index} 的数据")
|
||
|
||
return (sample['ko_token'], sample['compound_token'], sample['ko_mask'],
|
||
sample['genome_token'], sample['medium_token'])
|
||
|
||
def __len__(self):
|
||
return len(self.indices)
|
||
|
||
# def build_dataloaders(ko_token, compound_token, ko_mask, genome_token, medium_token, batch_size):
|
||
# dataset = MyDataset(ko_token, compound_token, ko_mask, genome_token, medium_token)
|
||
# n = len(dataset)
|
||
# train_size = int(n * 10 / 10)
|
||
# # valid_size = int((n - train_size) / 2)
|
||
# # test_size = n - train_size - valid_size
|
||
|
||
# # train_ds, valid_ds, test_ds = random_split(dataset, [train_size, valid_size, test_size])
|
||
# train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
|
||
# # train_loader_no_shuffle = DataLoader(train_ds, batch_size=batch_size, shuffle=False)
|
||
# # valid_loader = DataLoader(valid_ds, batch_size=batch_size, shuffle=False)
|
||
# # test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
|
||
|
||
# # return train_loader, train_loader_no_shuffle, valid_loader, test_loader
|
||
# return train_loader
|
||
|
||
def build_dataloaders(ko_token, compound_token, ko_mask, genome_token, medium_token, batch_size, shuffle):
|
||
dataset = MyDataset(ko_token, compound_token, ko_mask, genome_token, medium_token)
|
||
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
|
||
return train_loader # 返回单个 DataLoader
|
||
|
||
|
||
def build_single_lazy_dataloader(data_processor, batch_size=1, shuffle=False):
|
||
"""为单个数据集创建懒加载数据加载器"""
|
||
dataset = LazyDataset(data_processor)
|
||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
|
||
return dataloader
|
||
|
||
|
||
def build_lazy_dataloaders(data_processor, batch_size=1, split_ratio=8, shuffle=True):
|
||
"""使用LazyDataset构建数据加载器,支持动态加载(保持向后兼容)"""
|
||
total_length = data_processor.get_data_length()
|
||
indices = list(range(total_length))
|
||
|
||
# 计算分割点
|
||
train_size = int(total_length * split_ratio / 10)
|
||
valid_size = int((total_length - train_size) / 2)
|
||
|
||
# 分割索引
|
||
if shuffle:
|
||
import random
|
||
random.shuffle(indices)
|
||
|
||
train_indices = indices[:train_size]
|
||
valid_indices = indices[train_size:train_size + valid_size]
|
||
test_indices = indices[train_size + valid_size:]
|
||
|
||
# 创建数据集和数据加载器
|
||
train_dataset = LazyDataset(data_processor, train_indices)
|
||
valid_dataset = LazyDataset(data_processor, valid_indices)
|
||
test_dataset = LazyDataset(data_processor, test_indices)
|
||
|
||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
|
||
train_loader_no_shuffle = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
|
||
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
|
||
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
||
|
||
return train_loader, train_loader_no_shuffle, valid_loader, test_loader
|
||
|
||
# Dataset class specifically for model inference (using unseen data)
|
||
class InferenceDataset(Dataset):
|
||
"""
|
||
Dataset class for loading data during inference phase.
|
||
Designed to handle unseen data after model training.
|
||
"""
|
||
def __init__(self, src, src_mask, genome):
|
||
"""
|
||
Initialize the inference dataset.
|
||
|
||
Args:
|
||
src: Input features (e.g., embeddings) with shape [num_samples, seq_len, d_model]
|
||
src_mask: Corresponding mask tensor with shape [num_samples, seq_len]
|
||
"""
|
||
self.src = src # Input features (e.g., ko_token embeddings)
|
||
self.src_mask = src_mask # Mask for input features
|
||
self.genome = genome
|
||
|
||
def __getitem__(self, index):
|
||
"""
|
||
Get a single sample from the dataset by index.
|
||
|
||
Args:
|
||
index: Index of the sample to retrieve
|
||
|
||
Returns:
|
||
Tuple containing (src_feature, src_mask) for the specified index
|
||
"""
|
||
return self.src[index], self.src_mask[index], self.genome[index]
|
||
|
||
def __len__(self):
|
||
"""
|
||
Return the total number of samples in the dataset.
|
||
|
||
Returns:
|
||
int: Number of samples (matches length of source features)
|
||
"""
|
||
return len(self.src)
|
||
|
||
|
||
def build_inference_dataloader(ko_token, ko_mask, genome, batch_size=1, shuffle=False):
|
||
"""
|
||
Build a DataLoader for model inference on unseen data.
|
||
|
||
Args:
|
||
ko_token: Input features (embeddings) for inference, shape [num_samples, seq_len, d_model]
|
||
ko_mask: Corresponding mask tensor, shape [num_samples, seq_len]
|
||
batch_size: Batch size for inference (default: 1)
|
||
shuffle: Whether to shuffle data (default: False for inference to preserve order)
|
||
|
||
Returns:
|
||
DataLoader: Ready-to-use DataLoader for inference
|
||
"""
|
||
# Initialize dataset with inference data
|
||
dataset = InferenceDataset(
|
||
src=ko_token,
|
||
src_mask=ko_mask,
|
||
genome=genome
|
||
)
|
||
|
||
# Create DataLoader with inference-optimized settings
|
||
return DataLoader(
|
||
dataset, batch_size=batch_size, shuffle=shuffle
|
||
) |