Files
labweb/models/datasets.py
2025-12-16 11:39:15 +08:00

342 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
import torch
import h5py
import ast
import numpy as np
import pickle
import os
def ko_embedding_padding(embedding, max_len):
pad_rows = max_len - embedding.shape[0]
padded_embeddings = np.pad(
embedding,
pad_width=((0, pad_rows), (0, 0)),
mode="constant",
constant_values=0,
)
# mask = np.full(max_len, -np.inf, dtype=np.float32)
# mask[:embedding.shape[0]] = 0.0
mask = np.ones(max_len, dtype=bool)
mask[:embedding.shape[0]] = False
return padded_embeddings, mask
def get_input_token(inputs, vo_cab, max_len):
inputs = sorted(inputs)
input_token = [vo_cab['</s>']]
for j in inputs:
input_token.append(vo_cab[j])
input_token.append(vo_cab['<s>'])
padding_len = max_len - len(input_token)
if padding_len > 0:
input_token += [vo_cab['blank']] * padding_len
return input_token
def get_dict(data, colum):
all_list = list(data[colum])
all_list = (pd.DataFrame(all_list, columns=[colum])
.drop_duplicates([colum])
.sort_values(by=colum)
.reset_index(drop=True))
return dict(zip(all_list[colum], all_list.index))
class DataProcessor:
def __init__(self, ko_count_path, data_path, embedding_h5_path, vo_cab_pkl_path, model_type, ko_max_len=4500, compound_max_len=98):
self.ko_count_path = ko_count_path
self.data_path = data_path
self.embedding_h5_path = embedding_h5_path
self.vo_cab_pkl_path = vo_cab_pkl_path
self.model_type = model_type
self.ko_max_len = ko_max_len
self.compound_max_len = compound_max_len
self.KO_count = None
self.data = None
self.vo_cab = None
self.genome_dict = None
self.medium_dict = None
def load_data(self, count_min=800, count_max=4500):
self.KO_count = pd.read_csv(self.ko_count_path)
self.KO_count = self.KO_count[(self.KO_count['count'] >= count_min) & (self.KO_count['count'] <= count_max)]
self.data = pd.read_csv(self.data_path)
self.data = pd.merge(self.KO_count, self.data, how='inner', on='genome').drop_duplicates(['genome', 'compounds_list'])
self.data['compounds_list'] = self.data['compounds_list'].apply(lambda x: list(ast.literal_eval(x)))
# Load vo_cab: use pkl file
if self.vo_cab_pkl_path and os.path.exists(self.vo_cab_pkl_path):
print(f"Loading predefined compound_cab from {self.vo_cab_pkl_path}")
with open(self.vo_cab_pkl_path, 'rb') as f:
data_pkl = pickle.load(f)
if isinstance(data_pkl, dict) and 'compound_cab' in data_pkl:
self.vo_cab = data_pkl['compound_cab']
print(f"Successfully loaded compound_cab, vocabulary size: {len(self.vo_cab)}")
# genome and medium dictionaries
self.genome_dict = get_dict(self.data, 'genome')
self.medium_dict = get_dict(self.data, 'media_name')
# 保存 genome_dict 和 medium_dict 到文件
genome_filename = f"{self.model_type}_genome_dict.pkl"
medium_filename = f"{self.model_type}_medium_dict.pkl"
# 保存 genome_dict
with open(genome_filename, 'wb') as f:
pickle.dump(self.genome_dict, f)
# 保存 medium_dict
with open(medium_filename, 'wb') as f:
pickle.dump(self.medium_dict, f)
def get_sample_by_index(self, index):
"""根据索引获取单个样本的数据"""
if self.data is None:
raise ValueError("请先调用load_data()方法加载数据")
if index >= len(self.data):
raise IndexError("索引超出数据范围")
row = self.data.iloc[index]
genome = row['genome']
compounds = row['compounds_list']
medium = row['media_name']
# 从HDF5文件读取embedding
with h5py.File(self.embedding_h5_path, "r") as hf:
try:
embeddings = hf[genome][:]
except KeyError:
print(f"⚠️ Genome {genome} not found in HDF5")
return None
# 处理token和mask
genome_token = self.genome_dict[genome]
medium_token = self.medium_dict[medium]
compound_token = get_input_token(inputs=compounds, vo_cab=self.vo_cab, max_len=self.compound_max_len)
embedding_pad, mask = ko_embedding_padding(embeddings, self.ko_max_len)
return {
'ko_token': torch.from_numpy(embedding_pad).float(),
'compound_token': torch.tensor(compound_token),
'ko_mask': torch.tensor(mask),
'genome_token': torch.tensor(genome_token),
'medium_token': torch.tensor(medium_token)
}
def get_data_length(self):
"""获取数据集长度"""
if self.data is None:
raise ValueError("请先调用load_data()方法加载数据")
return len(self.data)
def process_test_data(self, genome_list=None):
# 如果 genome_list 为空,就使用整个数据集
if genome_list is None:
test_data = self.data
else:
test_data = self.data[self.data['genome'].isin(genome_list)]
ko_token = []
compound_token = []
medium_token = []
genome_token = []
ko_mask = []
with h5py.File(self.embedding_h5_path, "r") as hf:
for genome, compounds, medium in zip(test_data['genome'], test_data['compounds_list'], test_data['media_name']):
try:
embeddings = hf[genome][:]
except KeyError:
print(f"⚠️ Genome {genome} not found in HDF5, skipped")
continue # 跳过这个 genome
# 处理 token 和 mask
genome_token.append(self.genome_dict[genome])
medium_token.append(self.medium_dict[medium])
compound_token.append(get_input_token(inputs=compounds, vo_cab=self.vo_cab, max_len=self.compound_max_len))
embedding_pad, mask = ko_embedding_padding(embeddings, self.ko_max_len)
ko_token.append(embedding_pad)
ko_mask.append(mask)
# 转成 tensor
medium_token = torch.tensor(medium_token)
genome_token = torch.tensor(genome_token)
ko_mask = torch.tensor(ko_mask)
ko_token = torch.stack([torch.from_numpy(x).float() for x in ko_token])
compound_token = torch.tensor(compound_token)
return ko_token, compound_token, ko_mask, genome_token, medium_token
class MyDataset(Dataset):
def __init__(self, src, trg, src_mask, genome, medium):
self.src = src
self.trg = trg
self.src_mask = src_mask
self.genome = genome
self.medium = medium
def __getitem__(self, index):
return self.src[index], self.trg[index], self.src_mask[index], self.genome[index], self.medium[index]
def __len__(self):
return self.trg.size(0)
class LazyDataset(Dataset):
"""支持动态加载数据的Dataset类"""
def __init__(self, data_processor, indices=None):
self.data_processor = data_processor
if indices is None:
self.indices = list(range(data_processor.get_data_length()))
else:
self.indices = indices
def __getitem__(self, index):
# 根据索引动态从DataProcessor获取数据
actual_index = self.indices[index]
sample = self.data_processor.get_sample_by_index(actual_index)
if sample is None:
# 如果无法获取数据,返回下一个有效样本
for i in range(index + 1, len(self.indices)):
actual_index = self.indices[i]
sample = self.data_processor.get_sample_by_index(actual_index)
if sample is not None:
break
if sample is None:
raise RuntimeError(f"无法获取索引 {index} 的数据")
return (sample['ko_token'], sample['compound_token'], sample['ko_mask'],
sample['genome_token'], sample['medium_token'])
def __len__(self):
return len(self.indices)
# def build_dataloaders(ko_token, compound_token, ko_mask, genome_token, medium_token, batch_size):
# dataset = MyDataset(ko_token, compound_token, ko_mask, genome_token, medium_token)
# n = len(dataset)
# train_size = int(n * 10 / 10)
# # valid_size = int((n - train_size) / 2)
# # test_size = n - train_size - valid_size
# # train_ds, valid_ds, test_ds = random_split(dataset, [train_size, valid_size, test_size])
# train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
# # train_loader_no_shuffle = DataLoader(train_ds, batch_size=batch_size, shuffle=False)
# # valid_loader = DataLoader(valid_ds, batch_size=batch_size, shuffle=False)
# # test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
# # return train_loader, train_loader_no_shuffle, valid_loader, test_loader
# return train_loader
def build_dataloaders(ko_token, compound_token, ko_mask, genome_token, medium_token, batch_size, shuffle):
dataset = MyDataset(ko_token, compound_token, ko_mask, genome_token, medium_token)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
return train_loader # 返回单个 DataLoader
def build_single_lazy_dataloader(data_processor, batch_size=1, shuffle=False):
"""为单个数据集创建懒加载数据加载器"""
dataset = LazyDataset(data_processor)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
return dataloader
def build_lazy_dataloaders(data_processor, batch_size=1, split_ratio=8, shuffle=True):
"""使用LazyDataset构建数据加载器支持动态加载保持向后兼容"""
total_length = data_processor.get_data_length()
indices = list(range(total_length))
# 计算分割点
train_size = int(total_length * split_ratio / 10)
valid_size = int((total_length - train_size) / 2)
# 分割索引
if shuffle:
import random
random.shuffle(indices)
train_indices = indices[:train_size]
valid_indices = indices[train_size:train_size + valid_size]
test_indices = indices[train_size + valid_size:]
# 创建数据集和数据加载器
train_dataset = LazyDataset(data_processor, train_indices)
valid_dataset = LazyDataset(data_processor, valid_indices)
test_dataset = LazyDataset(data_processor, test_indices)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
train_loader_no_shuffle = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
return train_loader, train_loader_no_shuffle, valid_loader, test_loader
# Dataset class specifically for model inference (using unseen data)
class InferenceDataset(Dataset):
"""
Dataset class for loading data during inference phase.
Designed to handle unseen data after model training.
"""
def __init__(self, src, src_mask, genome):
"""
Initialize the inference dataset.
Args:
src: Input features (e.g., embeddings) with shape [num_samples, seq_len, d_model]
src_mask: Corresponding mask tensor with shape [num_samples, seq_len]
"""
self.src = src # Input features (e.g., ko_token embeddings)
self.src_mask = src_mask # Mask for input features
self.genome = genome
def __getitem__(self, index):
"""
Get a single sample from the dataset by index.
Args:
index: Index of the sample to retrieve
Returns:
Tuple containing (src_feature, src_mask) for the specified index
"""
return self.src[index], self.src_mask[index], self.genome[index]
def __len__(self):
"""
Return the total number of samples in the dataset.
Returns:
int: Number of samples (matches length of source features)
"""
return len(self.src)
def build_inference_dataloader(ko_token, ko_mask, genome, batch_size=1, shuffle=False):
"""
Build a DataLoader for model inference on unseen data.
Args:
ko_token: Input features (embeddings) for inference, shape [num_samples, seq_len, d_model]
ko_mask: Corresponding mask tensor, shape [num_samples, seq_len]
batch_size: Batch size for inference (default: 1)
shuffle: Whether to shuffle data (default: False for inference to preserve order)
Returns:
DataLoader: Ready-to-use DataLoader for inference
"""
# Initialize dataset with inference data
dataset = InferenceDataset(
src=ko_token,
src_mask=ko_mask,
genome=genome
)
# Create DataLoader with inference-optimized settings
return DataLoader(
dataset, batch_size=batch_size, shuffle=shuffle
)