first commit

This commit is contained in:
gzy
2025-12-16 11:39:15 +08:00
commit a3bdbee7c2
118 changed files with 34631 additions and 0 deletions

342
models/datasets.py Normal file
View File

@@ -0,0 +1,342 @@
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
import torch
import h5py
import ast
import numpy as np
import pickle
import os
def ko_embedding_padding(embedding, max_len):
pad_rows = max_len - embedding.shape[0]
padded_embeddings = np.pad(
embedding,
pad_width=((0, pad_rows), (0, 0)),
mode="constant",
constant_values=0,
)
# mask = np.full(max_len, -np.inf, dtype=np.float32)
# mask[:embedding.shape[0]] = 0.0
mask = np.ones(max_len, dtype=bool)
mask[:embedding.shape[0]] = False
return padded_embeddings, mask
def get_input_token(inputs, vo_cab, max_len):
inputs = sorted(inputs)
input_token = [vo_cab['</s>']]
for j in inputs:
input_token.append(vo_cab[j])
input_token.append(vo_cab['<s>'])
padding_len = max_len - len(input_token)
if padding_len > 0:
input_token += [vo_cab['blank']] * padding_len
return input_token
def get_dict(data, colum):
all_list = list(data[colum])
all_list = (pd.DataFrame(all_list, columns=[colum])
.drop_duplicates([colum])
.sort_values(by=colum)
.reset_index(drop=True))
return dict(zip(all_list[colum], all_list.index))
class DataProcessor:
def __init__(self, ko_count_path, data_path, embedding_h5_path, vo_cab_pkl_path, model_type, ko_max_len=4500, compound_max_len=98):
self.ko_count_path = ko_count_path
self.data_path = data_path
self.embedding_h5_path = embedding_h5_path
self.vo_cab_pkl_path = vo_cab_pkl_path
self.model_type = model_type
self.ko_max_len = ko_max_len
self.compound_max_len = compound_max_len
self.KO_count = None
self.data = None
self.vo_cab = None
self.genome_dict = None
self.medium_dict = None
def load_data(self, count_min=800, count_max=4500):
self.KO_count = pd.read_csv(self.ko_count_path)
self.KO_count = self.KO_count[(self.KO_count['count'] >= count_min) & (self.KO_count['count'] <= count_max)]
self.data = pd.read_csv(self.data_path)
self.data = pd.merge(self.KO_count, self.data, how='inner', on='genome').drop_duplicates(['genome', 'compounds_list'])
self.data['compounds_list'] = self.data['compounds_list'].apply(lambda x: list(ast.literal_eval(x)))
# Load vo_cab: use pkl file
if self.vo_cab_pkl_path and os.path.exists(self.vo_cab_pkl_path):
print(f"Loading predefined compound_cab from {self.vo_cab_pkl_path}")
with open(self.vo_cab_pkl_path, 'rb') as f:
data_pkl = pickle.load(f)
if isinstance(data_pkl, dict) and 'compound_cab' in data_pkl:
self.vo_cab = data_pkl['compound_cab']
print(f"Successfully loaded compound_cab, vocabulary size: {len(self.vo_cab)}")
# genome and medium dictionaries
self.genome_dict = get_dict(self.data, 'genome')
self.medium_dict = get_dict(self.data, 'media_name')
# 保存 genome_dict 和 medium_dict 到文件
genome_filename = f"{self.model_type}_genome_dict.pkl"
medium_filename = f"{self.model_type}_medium_dict.pkl"
# 保存 genome_dict
with open(genome_filename, 'wb') as f:
pickle.dump(self.genome_dict, f)
# 保存 medium_dict
with open(medium_filename, 'wb') as f:
pickle.dump(self.medium_dict, f)
def get_sample_by_index(self, index):
"""根据索引获取单个样本的数据"""
if self.data is None:
raise ValueError("请先调用load_data()方法加载数据")
if index >= len(self.data):
raise IndexError("索引超出数据范围")
row = self.data.iloc[index]
genome = row['genome']
compounds = row['compounds_list']
medium = row['media_name']
# 从HDF5文件读取embedding
with h5py.File(self.embedding_h5_path, "r") as hf:
try:
embeddings = hf[genome][:]
except KeyError:
print(f"⚠️ Genome {genome} not found in HDF5")
return None
# 处理token和mask
genome_token = self.genome_dict[genome]
medium_token = self.medium_dict[medium]
compound_token = get_input_token(inputs=compounds, vo_cab=self.vo_cab, max_len=self.compound_max_len)
embedding_pad, mask = ko_embedding_padding(embeddings, self.ko_max_len)
return {
'ko_token': torch.from_numpy(embedding_pad).float(),
'compound_token': torch.tensor(compound_token),
'ko_mask': torch.tensor(mask),
'genome_token': torch.tensor(genome_token),
'medium_token': torch.tensor(medium_token)
}
def get_data_length(self):
"""获取数据集长度"""
if self.data is None:
raise ValueError("请先调用load_data()方法加载数据")
return len(self.data)
def process_test_data(self, genome_list=None):
# 如果 genome_list 为空,就使用整个数据集
if genome_list is None:
test_data = self.data
else:
test_data = self.data[self.data['genome'].isin(genome_list)]
ko_token = []
compound_token = []
medium_token = []
genome_token = []
ko_mask = []
with h5py.File(self.embedding_h5_path, "r") as hf:
for genome, compounds, medium in zip(test_data['genome'], test_data['compounds_list'], test_data['media_name']):
try:
embeddings = hf[genome][:]
except KeyError:
print(f"⚠️ Genome {genome} not found in HDF5, skipped")
continue # 跳过这个 genome
# 处理 token 和 mask
genome_token.append(self.genome_dict[genome])
medium_token.append(self.medium_dict[medium])
compound_token.append(get_input_token(inputs=compounds, vo_cab=self.vo_cab, max_len=self.compound_max_len))
embedding_pad, mask = ko_embedding_padding(embeddings, self.ko_max_len)
ko_token.append(embedding_pad)
ko_mask.append(mask)
# 转成 tensor
medium_token = torch.tensor(medium_token)
genome_token = torch.tensor(genome_token)
ko_mask = torch.tensor(ko_mask)
ko_token = torch.stack([torch.from_numpy(x).float() for x in ko_token])
compound_token = torch.tensor(compound_token)
return ko_token, compound_token, ko_mask, genome_token, medium_token
class MyDataset(Dataset):
def __init__(self, src, trg, src_mask, genome, medium):
self.src = src
self.trg = trg
self.src_mask = src_mask
self.genome = genome
self.medium = medium
def __getitem__(self, index):
return self.src[index], self.trg[index], self.src_mask[index], self.genome[index], self.medium[index]
def __len__(self):
return self.trg.size(0)
class LazyDataset(Dataset):
"""支持动态加载数据的Dataset类"""
def __init__(self, data_processor, indices=None):
self.data_processor = data_processor
if indices is None:
self.indices = list(range(data_processor.get_data_length()))
else:
self.indices = indices
def __getitem__(self, index):
# 根据索引动态从DataProcessor获取数据
actual_index = self.indices[index]
sample = self.data_processor.get_sample_by_index(actual_index)
if sample is None:
# 如果无法获取数据,返回下一个有效样本
for i in range(index + 1, len(self.indices)):
actual_index = self.indices[i]
sample = self.data_processor.get_sample_by_index(actual_index)
if sample is not None:
break
if sample is None:
raise RuntimeError(f"无法获取索引 {index} 的数据")
return (sample['ko_token'], sample['compound_token'], sample['ko_mask'],
sample['genome_token'], sample['medium_token'])
def __len__(self):
return len(self.indices)
# def build_dataloaders(ko_token, compound_token, ko_mask, genome_token, medium_token, batch_size):
# dataset = MyDataset(ko_token, compound_token, ko_mask, genome_token, medium_token)
# n = len(dataset)
# train_size = int(n * 10 / 10)
# # valid_size = int((n - train_size) / 2)
# # test_size = n - train_size - valid_size
# # train_ds, valid_ds, test_ds = random_split(dataset, [train_size, valid_size, test_size])
# train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
# # train_loader_no_shuffle = DataLoader(train_ds, batch_size=batch_size, shuffle=False)
# # valid_loader = DataLoader(valid_ds, batch_size=batch_size, shuffle=False)
# # test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
# # return train_loader, train_loader_no_shuffle, valid_loader, test_loader
# return train_loader
def build_dataloaders(ko_token, compound_token, ko_mask, genome_token, medium_token, batch_size, shuffle):
dataset = MyDataset(ko_token, compound_token, ko_mask, genome_token, medium_token)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
return train_loader # 返回单个 DataLoader
def build_single_lazy_dataloader(data_processor, batch_size=1, shuffle=False):
"""为单个数据集创建懒加载数据加载器"""
dataset = LazyDataset(data_processor)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
return dataloader
def build_lazy_dataloaders(data_processor, batch_size=1, split_ratio=8, shuffle=True):
"""使用LazyDataset构建数据加载器支持动态加载保持向后兼容"""
total_length = data_processor.get_data_length()
indices = list(range(total_length))
# 计算分割点
train_size = int(total_length * split_ratio / 10)
valid_size = int((total_length - train_size) / 2)
# 分割索引
if shuffle:
import random
random.shuffle(indices)
train_indices = indices[:train_size]
valid_indices = indices[train_size:train_size + valid_size]
test_indices = indices[train_size + valid_size:]
# 创建数据集和数据加载器
train_dataset = LazyDataset(data_processor, train_indices)
valid_dataset = LazyDataset(data_processor, valid_indices)
test_dataset = LazyDataset(data_processor, test_indices)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
train_loader_no_shuffle = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
return train_loader, train_loader_no_shuffle, valid_loader, test_loader
# Dataset class specifically for model inference (using unseen data)
class InferenceDataset(Dataset):
"""
Dataset class for loading data during inference phase.
Designed to handle unseen data after model training.
"""
def __init__(self, src, src_mask, genome):
"""
Initialize the inference dataset.
Args:
src: Input features (e.g., embeddings) with shape [num_samples, seq_len, d_model]
src_mask: Corresponding mask tensor with shape [num_samples, seq_len]
"""
self.src = src # Input features (e.g., ko_token embeddings)
self.src_mask = src_mask # Mask for input features
self.genome = genome
def __getitem__(self, index):
"""
Get a single sample from the dataset by index.
Args:
index: Index of the sample to retrieve
Returns:
Tuple containing (src_feature, src_mask) for the specified index
"""
return self.src[index], self.src_mask[index], self.genome[index]
def __len__(self):
"""
Return the total number of samples in the dataset.
Returns:
int: Number of samples (matches length of source features)
"""
return len(self.src)
def build_inference_dataloader(ko_token, ko_mask, genome, batch_size=1, shuffle=False):
"""
Build a DataLoader for model inference on unseen data.
Args:
ko_token: Input features (embeddings) for inference, shape [num_samples, seq_len, d_model]
ko_mask: Corresponding mask tensor, shape [num_samples, seq_len]
batch_size: Batch size for inference (default: 1)
shuffle: Whether to shuffle data (default: False for inference to preserve order)
Returns:
DataLoader: Ready-to-use DataLoader for inference
"""
# Initialize dataset with inference data
dataset = InferenceDataset(
src=ko_token,
src_mask=ko_mask,
genome=genome
)
# Create DataLoader with inference-optimized settings
return DataLoader(
dataset, batch_size=batch_size, shuffle=shuffle
)