first commit
This commit is contained in:
342
models/datasets.py
Normal file
342
models/datasets.py
Normal file
@@ -0,0 +1,342 @@
|
||||
from torch.utils.data import Dataset, DataLoader, random_split
|
||||
import pandas as pd
|
||||
import torch
|
||||
import h5py
|
||||
import ast
|
||||
import numpy as np
|
||||
import pickle
|
||||
import os
|
||||
|
||||
def ko_embedding_padding(embedding, max_len):
|
||||
pad_rows = max_len - embedding.shape[0]
|
||||
padded_embeddings = np.pad(
|
||||
embedding,
|
||||
pad_width=((0, pad_rows), (0, 0)),
|
||||
mode="constant",
|
||||
constant_values=0,
|
||||
)
|
||||
# mask = np.full(max_len, -np.inf, dtype=np.float32)
|
||||
# mask[:embedding.shape[0]] = 0.0
|
||||
mask = np.ones(max_len, dtype=bool)
|
||||
mask[:embedding.shape[0]] = False
|
||||
|
||||
return padded_embeddings, mask
|
||||
|
||||
def get_input_token(inputs, vo_cab, max_len):
|
||||
inputs = sorted(inputs)
|
||||
input_token = [vo_cab['</s>']]
|
||||
for j in inputs:
|
||||
input_token.append(vo_cab[j])
|
||||
input_token.append(vo_cab['<s>'])
|
||||
padding_len = max_len - len(input_token)
|
||||
if padding_len > 0:
|
||||
input_token += [vo_cab['blank']] * padding_len
|
||||
return input_token
|
||||
|
||||
def get_dict(data, colum):
|
||||
all_list = list(data[colum])
|
||||
all_list = (pd.DataFrame(all_list, columns=[colum])
|
||||
.drop_duplicates([colum])
|
||||
.sort_values(by=colum)
|
||||
.reset_index(drop=True))
|
||||
return dict(zip(all_list[colum], all_list.index))
|
||||
|
||||
class DataProcessor:
|
||||
def __init__(self, ko_count_path, data_path, embedding_h5_path, vo_cab_pkl_path, model_type, ko_max_len=4500, compound_max_len=98):
|
||||
self.ko_count_path = ko_count_path
|
||||
self.data_path = data_path
|
||||
self.embedding_h5_path = embedding_h5_path
|
||||
self.vo_cab_pkl_path = vo_cab_pkl_path
|
||||
self.model_type = model_type
|
||||
self.ko_max_len = ko_max_len
|
||||
self.compound_max_len = compound_max_len
|
||||
|
||||
self.KO_count = None
|
||||
self.data = None
|
||||
self.vo_cab = None
|
||||
self.genome_dict = None
|
||||
self.medium_dict = None
|
||||
|
||||
def load_data(self, count_min=800, count_max=4500):
|
||||
self.KO_count = pd.read_csv(self.ko_count_path)
|
||||
self.KO_count = self.KO_count[(self.KO_count['count'] >= count_min) & (self.KO_count['count'] <= count_max)]
|
||||
|
||||
self.data = pd.read_csv(self.data_path)
|
||||
self.data = pd.merge(self.KO_count, self.data, how='inner', on='genome').drop_duplicates(['genome', 'compounds_list'])
|
||||
self.data['compounds_list'] = self.data['compounds_list'].apply(lambda x: list(ast.literal_eval(x)))
|
||||
|
||||
|
||||
# Load vo_cab: use pkl file
|
||||
if self.vo_cab_pkl_path and os.path.exists(self.vo_cab_pkl_path):
|
||||
print(f"Loading predefined compound_cab from {self.vo_cab_pkl_path}")
|
||||
with open(self.vo_cab_pkl_path, 'rb') as f:
|
||||
data_pkl = pickle.load(f)
|
||||
if isinstance(data_pkl, dict) and 'compound_cab' in data_pkl:
|
||||
self.vo_cab = data_pkl['compound_cab']
|
||||
print(f"Successfully loaded compound_cab, vocabulary size: {len(self.vo_cab)}")
|
||||
|
||||
# genome and medium dictionaries
|
||||
self.genome_dict = get_dict(self.data, 'genome')
|
||||
self.medium_dict = get_dict(self.data, 'media_name')
|
||||
|
||||
# 保存 genome_dict 和 medium_dict 到文件
|
||||
genome_filename = f"{self.model_type}_genome_dict.pkl"
|
||||
medium_filename = f"{self.model_type}_medium_dict.pkl"
|
||||
# 保存 genome_dict
|
||||
with open(genome_filename, 'wb') as f:
|
||||
pickle.dump(self.genome_dict, f)
|
||||
# 保存 medium_dict
|
||||
with open(medium_filename, 'wb') as f:
|
||||
pickle.dump(self.medium_dict, f)
|
||||
|
||||
def get_sample_by_index(self, index):
|
||||
"""根据索引获取单个样本的数据"""
|
||||
if self.data is None:
|
||||
raise ValueError("请先调用load_data()方法加载数据")
|
||||
|
||||
if index >= len(self.data):
|
||||
raise IndexError("索引超出数据范围")
|
||||
|
||||
row = self.data.iloc[index]
|
||||
genome = row['genome']
|
||||
compounds = row['compounds_list']
|
||||
medium = row['media_name']
|
||||
|
||||
# 从HDF5文件读取embedding
|
||||
with h5py.File(self.embedding_h5_path, "r") as hf:
|
||||
try:
|
||||
embeddings = hf[genome][:]
|
||||
except KeyError:
|
||||
print(f"⚠️ Genome {genome} not found in HDF5")
|
||||
return None
|
||||
|
||||
# 处理token和mask
|
||||
genome_token = self.genome_dict[genome]
|
||||
medium_token = self.medium_dict[medium]
|
||||
compound_token = get_input_token(inputs=compounds, vo_cab=self.vo_cab, max_len=self.compound_max_len)
|
||||
|
||||
embedding_pad, mask = ko_embedding_padding(embeddings, self.ko_max_len)
|
||||
|
||||
return {
|
||||
'ko_token': torch.from_numpy(embedding_pad).float(),
|
||||
'compound_token': torch.tensor(compound_token),
|
||||
'ko_mask': torch.tensor(mask),
|
||||
'genome_token': torch.tensor(genome_token),
|
||||
'medium_token': torch.tensor(medium_token)
|
||||
}
|
||||
|
||||
def get_data_length(self):
|
||||
"""获取数据集长度"""
|
||||
if self.data is None:
|
||||
raise ValueError("请先调用load_data()方法加载数据")
|
||||
return len(self.data)
|
||||
|
||||
def process_test_data(self, genome_list=None):
|
||||
# 如果 genome_list 为空,就使用整个数据集
|
||||
if genome_list is None:
|
||||
test_data = self.data
|
||||
else:
|
||||
test_data = self.data[self.data['genome'].isin(genome_list)]
|
||||
|
||||
ko_token = []
|
||||
compound_token = []
|
||||
medium_token = []
|
||||
genome_token = []
|
||||
ko_mask = []
|
||||
|
||||
with h5py.File(self.embedding_h5_path, "r") as hf:
|
||||
for genome, compounds, medium in zip(test_data['genome'], test_data['compounds_list'], test_data['media_name']):
|
||||
try:
|
||||
embeddings = hf[genome][:]
|
||||
except KeyError:
|
||||
print(f"⚠️ Genome {genome} not found in HDF5, skipped")
|
||||
continue # 跳过这个 genome
|
||||
|
||||
# 处理 token 和 mask
|
||||
genome_token.append(self.genome_dict[genome])
|
||||
medium_token.append(self.medium_dict[medium])
|
||||
compound_token.append(get_input_token(inputs=compounds, vo_cab=self.vo_cab, max_len=self.compound_max_len))
|
||||
|
||||
embedding_pad, mask = ko_embedding_padding(embeddings, self.ko_max_len)
|
||||
ko_token.append(embedding_pad)
|
||||
ko_mask.append(mask)
|
||||
|
||||
# 转成 tensor
|
||||
medium_token = torch.tensor(medium_token)
|
||||
genome_token = torch.tensor(genome_token)
|
||||
ko_mask = torch.tensor(ko_mask)
|
||||
ko_token = torch.stack([torch.from_numpy(x).float() for x in ko_token])
|
||||
compound_token = torch.tensor(compound_token)
|
||||
|
||||
return ko_token, compound_token, ko_mask, genome_token, medium_token
|
||||
|
||||
|
||||
class MyDataset(Dataset):
|
||||
def __init__(self, src, trg, src_mask, genome, medium):
|
||||
self.src = src
|
||||
self.trg = trg
|
||||
self.src_mask = src_mask
|
||||
self.genome = genome
|
||||
self.medium = medium
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.src[index], self.trg[index], self.src_mask[index], self.genome[index], self.medium[index]
|
||||
|
||||
def __len__(self):
|
||||
return self.trg.size(0)
|
||||
|
||||
|
||||
class LazyDataset(Dataset):
|
||||
"""支持动态加载数据的Dataset类"""
|
||||
def __init__(self, data_processor, indices=None):
|
||||
self.data_processor = data_processor
|
||||
if indices is None:
|
||||
self.indices = list(range(data_processor.get_data_length()))
|
||||
else:
|
||||
self.indices = indices
|
||||
|
||||
def __getitem__(self, index):
|
||||
# 根据索引动态从DataProcessor获取数据
|
||||
actual_index = self.indices[index]
|
||||
sample = self.data_processor.get_sample_by_index(actual_index)
|
||||
|
||||
if sample is None:
|
||||
# 如果无法获取数据,返回下一个有效样本
|
||||
for i in range(index + 1, len(self.indices)):
|
||||
actual_index = self.indices[i]
|
||||
sample = self.data_processor.get_sample_by_index(actual_index)
|
||||
if sample is not None:
|
||||
break
|
||||
|
||||
if sample is None:
|
||||
raise RuntimeError(f"无法获取索引 {index} 的数据")
|
||||
|
||||
return (sample['ko_token'], sample['compound_token'], sample['ko_mask'],
|
||||
sample['genome_token'], sample['medium_token'])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.indices)
|
||||
|
||||
# def build_dataloaders(ko_token, compound_token, ko_mask, genome_token, medium_token, batch_size):
|
||||
# dataset = MyDataset(ko_token, compound_token, ko_mask, genome_token, medium_token)
|
||||
# n = len(dataset)
|
||||
# train_size = int(n * 10 / 10)
|
||||
# # valid_size = int((n - train_size) / 2)
|
||||
# # test_size = n - train_size - valid_size
|
||||
|
||||
# # train_ds, valid_ds, test_ds = random_split(dataset, [train_size, valid_size, test_size])
|
||||
# train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
|
||||
# # train_loader_no_shuffle = DataLoader(train_ds, batch_size=batch_size, shuffle=False)
|
||||
# # valid_loader = DataLoader(valid_ds, batch_size=batch_size, shuffle=False)
|
||||
# # test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
|
||||
|
||||
# # return train_loader, train_loader_no_shuffle, valid_loader, test_loader
|
||||
# return train_loader
|
||||
|
||||
def build_dataloaders(ko_token, compound_token, ko_mask, genome_token, medium_token, batch_size, shuffle):
|
||||
dataset = MyDataset(ko_token, compound_token, ko_mask, genome_token, medium_token)
|
||||
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
|
||||
return train_loader # 返回单个 DataLoader
|
||||
|
||||
|
||||
def build_single_lazy_dataloader(data_processor, batch_size=1, shuffle=False):
|
||||
"""为单个数据集创建懒加载数据加载器"""
|
||||
dataset = LazyDataset(data_processor)
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
|
||||
return dataloader
|
||||
|
||||
|
||||
def build_lazy_dataloaders(data_processor, batch_size=1, split_ratio=8, shuffle=True):
|
||||
"""使用LazyDataset构建数据加载器,支持动态加载(保持向后兼容)"""
|
||||
total_length = data_processor.get_data_length()
|
||||
indices = list(range(total_length))
|
||||
|
||||
# 计算分割点
|
||||
train_size = int(total_length * split_ratio / 10)
|
||||
valid_size = int((total_length - train_size) / 2)
|
||||
|
||||
# 分割索引
|
||||
if shuffle:
|
||||
import random
|
||||
random.shuffle(indices)
|
||||
|
||||
train_indices = indices[:train_size]
|
||||
valid_indices = indices[train_size:train_size + valid_size]
|
||||
test_indices = indices[train_size + valid_size:]
|
||||
|
||||
# 创建数据集和数据加载器
|
||||
train_dataset = LazyDataset(data_processor, train_indices)
|
||||
valid_dataset = LazyDataset(data_processor, valid_indices)
|
||||
test_dataset = LazyDataset(data_processor, test_indices)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
|
||||
train_loader_no_shuffle = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
|
||||
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
|
||||
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
||||
|
||||
return train_loader, train_loader_no_shuffle, valid_loader, test_loader
|
||||
|
||||
# Dataset class specifically for model inference (using unseen data)
|
||||
class InferenceDataset(Dataset):
|
||||
"""
|
||||
Dataset class for loading data during inference phase.
|
||||
Designed to handle unseen data after model training.
|
||||
"""
|
||||
def __init__(self, src, src_mask, genome):
|
||||
"""
|
||||
Initialize the inference dataset.
|
||||
|
||||
Args:
|
||||
src: Input features (e.g., embeddings) with shape [num_samples, seq_len, d_model]
|
||||
src_mask: Corresponding mask tensor with shape [num_samples, seq_len]
|
||||
"""
|
||||
self.src = src # Input features (e.g., ko_token embeddings)
|
||||
self.src_mask = src_mask # Mask for input features
|
||||
self.genome = genome
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Get a single sample from the dataset by index.
|
||||
|
||||
Args:
|
||||
index: Index of the sample to retrieve
|
||||
|
||||
Returns:
|
||||
Tuple containing (src_feature, src_mask) for the specified index
|
||||
"""
|
||||
return self.src[index], self.src_mask[index], self.genome[index]
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Return the total number of samples in the dataset.
|
||||
|
||||
Returns:
|
||||
int: Number of samples (matches length of source features)
|
||||
"""
|
||||
return len(self.src)
|
||||
|
||||
|
||||
def build_inference_dataloader(ko_token, ko_mask, genome, batch_size=1, shuffle=False):
|
||||
"""
|
||||
Build a DataLoader for model inference on unseen data.
|
||||
|
||||
Args:
|
||||
ko_token: Input features (embeddings) for inference, shape [num_samples, seq_len, d_model]
|
||||
ko_mask: Corresponding mask tensor, shape [num_samples, seq_len]
|
||||
batch_size: Batch size for inference (default: 1)
|
||||
shuffle: Whether to shuffle data (default: False for inference to preserve order)
|
||||
|
||||
Returns:
|
||||
DataLoader: Ready-to-use DataLoader for inference
|
||||
"""
|
||||
# Initialize dataset with inference data
|
||||
dataset = InferenceDataset(
|
||||
src=ko_token,
|
||||
src_mask=ko_mask,
|
||||
genome=genome
|
||||
)
|
||||
|
||||
# Create DataLoader with inference-optimized settings
|
||||
return DataLoader(
|
||||
dataset, batch_size=batch_size, shuffle=shuffle
|
||||
)
|
||||
Reference in New Issue
Block a user