add mole predcit module

This commit is contained in:
2025-10-17 15:54:00 +08:00
parent ea218a3a39
commit 336bfe4b65
14 changed files with 1559 additions and 0 deletions

26
models/__init__.py Normal file
View File

@@ -0,0 +1,26 @@
"""
SIME Models Package
This package contains models for antimicrobial activity prediction.
"""
from .broad_spectrum_predictor import (
ParallelBroadSpectrumPredictor,
PredictionConfig,
MoleculeInput,
BroadSpectrumResult,
create_predictor,
predict_smiles,
predict_file
)
__all__ = [
'ParallelBroadSpectrumPredictor',
'PredictionConfig',
'MoleculeInput',
'BroadSpectrumResult',
'create_predictor',
'predict_smiles',
'predict_file'
]

View File

@@ -0,0 +1,567 @@
"""
并行广谱抗菌预测器模块
提供高性能的分子广谱抗菌活性预测功能,支持批量处理和多进程并行计算。
基于MolE分子表示和XGBoost模型进行预测。
"""
import os
import re
import pickle
import torch
import numpy as np
import pandas as pd
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import List, Dict, Union, Optional, Tuple, Any
from dataclasses import dataclass
from pathlib import Path
from scipy.stats.mstats import gmean
from sklearn.preprocessing import OneHotEncoder
from .mole_representation import process_representation
@dataclass
class PredictionConfig:
"""预测配置参数"""
xgboost_model_path: str = None
mole_model_path: str = None
strain_categories_path: str = None
gram_info_path: str = None
app_threshold: float = 0.04374140128493309
min_nkill: int = 10
batch_size: int = 100
n_workers: Optional[int] = None
device: str = "auto"
def __post_init__(self):
"""设置默认路径"""
from pathlib import Path
# 获取当前文件所在目录
current_file = Path(__file__).resolve()
project_root = current_file.parent.parent # models -> 项目根
data_dir = project_root / "Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001"
# 设置所有路径
if self.mole_model_path is None:
self.mole_model_path = str(data_dir)
if self.xgboost_model_path is None:
self.xgboost_model_path = str(data_dir / "MolE-XGBoost-08.03.2025_10.17.pkl")
if self.strain_categories_path is None:
self.strain_categories_path = str(data_dir / "maier_screening_results.tsv.gz")
if self.gram_info_path is None:
self.gram_info_path = str(data_dir / "strain_info_SF2.xlsx")
@dataclass
class MoleculeInput:
"""分子输入数据结构"""
smiles: str
chem_id: Optional[str] = None
@dataclass
class BroadSpectrumResult:
"""广谱抗菌预测结果"""
chem_id: str
apscore_total: float
apscore_gnegative: float
apscore_gpositive: float
ginhib_total: int
ginhib_gnegative: int
ginhib_gpositive: int
broad_spectrum: int
def to_dict(self) -> Dict[str, Union[str, float, int]]:
"""转换为字典格式"""
return {
'chem_id': self.chem_id,
'apscore_total': self.apscore_total,
'apscore_gnegative': self.apscore_gnegative,
'apscore_gpositive': self.apscore_gpositive,
'ginhib_total': self.ginhib_total,
'ginhib_gnegative': self.ginhib_gnegative,
'ginhib_gpositive': self.ginhib_gpositive,
'broad_spectrum': self.broad_spectrum
}
class BroadSpectrumPredictor:
"""
广谱抗菌预测器
基于MolE分子表示和XGBoost模型预测分子的广谱抗菌活性。
支持单分子和批量预测,提供详细的抗菌潜力分析。
"""
def __init__(self, config: Optional[PredictionConfig] = None) -> None:
"""
初始化预测器
Args:
config: 预测配置参数如果为None则使用默认配置
"""
self.config = config or PredictionConfig()
self.n_workers = self.config.n_workers or mp.cpu_count()
# 验证文件路径
self._validate_paths()
# 预加载共享数据
self._load_shared_data()
def _validate_paths(self) -> None:
"""验证必要文件路径是否存在"""
required_files = {
"mole_model": self.config.mole_model_path,
"xgboost_model": self.config.xgboost_model_path,
"strain_categories": self.config.strain_categories_path,
"gram_info": self.config.gram_info_path,
}
for name, file_path in required_files.items():
if file_path is None:
raise ValueError(f"{name} is None! Check __post_init__ configuration")
if not Path(file_path).exists():
raise FileNotFoundError(f"Required {name} not found: {file_path}")
def _load_shared_data(self) -> None:
"""加载共享数据(菌株信息、革兰染色信息等)"""
try:
# 加载菌株筛选数据
self.maier_screen: pd.DataFrame = pd.read_csv(
self.config.strain_categories_path, sep='\t', index_col=0
)
# 准备菌株独热编码
self.strain_ohe: pd.DataFrame = self._prep_ohe(self.maier_screen.columns)
# 加载革兰染色信息
self.maier_strains: pd.DataFrame = pd.read_excel(
self.config.gram_info_path,
skiprows=[0, 1, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54],
index_col="NT data base"
)
except Exception as e:
raise RuntimeError(f"Failed to load shared data: {str(e)}")
def _prep_ohe(self, categories: pd.Index) -> pd.DataFrame:
"""
准备菌株的独热编码
Args:
categories: 菌株类别索引
Returns:
独热编码后的DataFrame
"""
try:
# 新版本 sklearn 使用 sparse_output
ohe = OneHotEncoder(sparse_output=False)
except TypeError:
# 旧版本 sklearn 使用 sparse
ohe = OneHotEncoder(sparse=False)
ohe.fit(pd.DataFrame(categories))
cat_ohe = pd.DataFrame(
ohe.transform(pd.DataFrame(categories)),
columns=categories,
index=categories
)
return cat_ohe
def _get_mole_representation(self, molecules: List[MoleculeInput]) -> pd.DataFrame:
"""
获取分子的MolE表示
Args:
molecules: 分子输入列表
Returns:
MolE特征表示DataFrame
"""
# 准备输入数据
df_data = []
for i, mol in enumerate(molecules):
chem_id = mol.chem_id or f"mol{i+1}"
df_data.append({"smiles": mol.smiles, "chem_id": chem_id})
df = pd.DataFrame(df_data)
# 确定设备
device = self.config.device
if device == "auto":
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# 获取MolE表示
return process_representation(
dataset_path=df,
smile_column_str="smiles",
id_column_str="chem_id",
pretrained_dir=self.config.mole_model_path,
device=device
)
def _add_strains(self, chemfeats_df: pd.DataFrame) -> pd.DataFrame:
"""
添加菌株信息到化学特征(笛卡尔积)
Args:
chemfeats_df: 化学特征DataFrame
Returns:
包含菌株信息的特征DataFrame
"""
# 准备化学特征
chemfe = chemfeats_df.reset_index().rename(columns={"index": "chem_id"})
chemfe["chem_id"] = chemfe["chem_id"].astype(str)
# 准备独热编码
sohe = self.strain_ohe.reset_index().rename(columns={"index": "strain_name"})
# 笛卡尔积合并
xpred = chemfe.merge(sohe, how="cross")
xpred["pred_id"] = xpred["chem_id"].str.cat(xpred["strain_name"], sep=":")
xpred = xpred.set_index("pred_id")
xpred = xpred.drop(columns=["chem_id", "strain_name"])
return xpred
def _gram_stain(self, label_df: pd.DataFrame) -> pd.DataFrame:
"""
添加革兰染色信息
Args:
label_df: 包含菌株名称的DataFrame
Returns:
添加革兰染色信息后的DataFrame
"""
df_label = label_df.copy()
# 提取NT编号
df_label["nt_number"] = df_label["strain_name"].apply(
lambda x: re.search(r".*?\((NT\d+)\)", x).group(1) if re.search(r".*?\((NT\d+)\)", x) else None
)
# 创建革兰染色字典
gram_dict = self.maier_strains[["Gram stain"]].to_dict()["Gram stain"]
# 添加染色信息
df_label["gram_stain"] = df_label["nt_number"].apply(gram_dict.get)
return df_label
def _antimicrobial_potential(self, score_df: pd.DataFrame) -> pd.DataFrame:
"""
计算抗菌潜力分数
Args:
score_df: 预测分数DataFrame
Returns:
聚合后的抗菌潜力DataFrame
"""
# 分离化合物ID和菌株名
score_df["chem_id"] = score_df["pred_id"].str.split(":", expand=True)[0]
score_df["strain_name"] = score_df["pred_id"].str.split(":", expand=True)[1]
# 添加革兰染色信息
pred_df = self._gram_stain(score_df)
# 计算抗菌潜力分数(几何平均数的对数)
apscore_total = pred_df.groupby("chem_id")["1"].apply(gmean).to_frame().rename(
columns={"1": "apscore_total"}
)
apscore_total["apscore_total"] = np.log(apscore_total["apscore_total"])
# 按革兰染色分组的抗菌分数
apscore_gram = pred_df.groupby(["chem_id", "gram_stain"])["1"].apply(gmean).unstack().rename(
columns={"negative": "apscore_gnegative", "positive": "apscore_gpositive"}
)
apscore_gram["apscore_gnegative"] = np.log(apscore_gram["apscore_gnegative"])
apscore_gram["apscore_gpositive"] = np.log(apscore_gram["apscore_gpositive"])
# 被抑制菌株数统计
inhibted_total = pred_df.groupby("chem_id")["growth_inhibition"].sum().to_frame().rename(
columns={"growth_inhibition": "ginhib_total"}
)
# 按革兰染色分组的被抑制菌株数
inhibted_gram = pred_df.groupby(["chem_id", "gram_stain"])["growth_inhibition"].sum().unstack().rename(
columns={"negative": "ginhib_gnegative", "positive": "ginhib_gpositive"}
)
# 合并所有结果
agg_pred = apscore_total.join(apscore_gram).join(inhibted_total).join(inhibted_gram)
# 填充NaN值
agg_pred = agg_pred.fillna(0)
return agg_pred
def _predict_batch_worker(batch_data: Tuple[pd.DataFrame, int],
model_path: str,
app_threshold: float) -> Tuple[int, pd.DataFrame]:
"""
批次预测工作函数(用于多进程)
Args:
batch_data: (特征数据, 批次ID)
model_path: XGBoost模型路径
app_threshold: 抑制阈值
Returns:
(批次ID, 预测结果DataFrame)
"""
import warnings
# 忽略所有XGBoost版本相关的警告
warnings.filterwarnings("ignore", category=UserWarning, module="xgboost")
X_input, batch_id = batch_data
# 加载模型
with open(model_path, "rb") as file:
model = pickle.load(file)
# 修复特征名称兼容性问题
# 原因:模型使用旧版 XGBoost 保存时,特征列为元组格式(如 "('bacteria_name',)"
# 新版 XGBoost 严格检查特征名称匹配,导致预测失败。
# 解决:清除 XGBoost 内部的特征名称验证,直接使用输入特征进行预测
# 注意:此操作不改变模型权重和预测逻辑,只禁用格式检查,预测结果保持一致
if hasattr(model, 'get_booster'):
model.get_booster().feature_names = None
# 进行预测
y_pred = model.predict_proba(X_input)
pred_df = pd.DataFrame(y_pred, columns=["0", "1"], index=X_input.index)
# 二值化预测结果
pred_df["growth_inhibition"] = pred_df["1"].apply(
lambda x: 1 if x >= app_threshold else 0
)
return batch_id, pred_df
class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
"""
并行广谱抗菌预测器
继承自BroadSpectrumPredictor添加了多进程并行处理能力
适用于大规模分子批量预测。
"""
def predict_single(self, molecule: MoleculeInput) -> BroadSpectrumResult:
"""
预测单个分子的广谱抗菌活性
Args:
molecule: 分子输入数据
Returns:
广谱抗菌预测结果
"""
results = self.predict_batch([molecule])
return results[0]
def predict_batch(self, molecules: List[MoleculeInput]) -> List[BroadSpectrumResult]:
"""
批量预测分子的广谱抗菌活性
Args:
molecules: 分子输入列表
Returns:
广谱抗菌预测结果列表
"""
if not molecules:
return []
# 获取MolE表示
print(f"Processing {len(molecules)} molecules...")
mole_representation = self._get_mole_representation(molecules)
# 添加菌株信息
print("Preparing strain-level features...")
X_input = self._add_strains(mole_representation)
# 分批处理
print(f"Starting parallel prediction with {self.n_workers} workers...")
batches = []
for i in range(0, len(X_input), self.config.batch_size):
batch = X_input.iloc[i:i+self.config.batch_size]
batches.append((batch, i // self.config.batch_size))
# 并行预测
results = {}
with ProcessPoolExecutor(max_workers=self.n_workers) as executor:
futures = {
executor.submit(_predict_batch_worker, (batch_data, batch_id),
self.config.xgboost_model_path,
self.config.app_threshold): batch_id
for batch_data, batch_id in batches
}
for future in as_completed(futures):
batch_id, pred_df = future.result()
results[batch_id] = pred_df
print(f"Batch {batch_id} completed")
# 合并结果
print("Merging prediction results...")
all_pred_df = pd.concat([results[i] for i in sorted(results.keys())])
# 计算抗菌潜力
print("Calculating antimicrobial potential scores...")
all_pred_df = all_pred_df.reset_index()
agg_df = self._antimicrobial_potential(all_pred_df)
# 判断广谱抗菌
agg_df["broad_spectrum"] = agg_df["ginhib_total"].apply(
lambda x: 1 if x >= self.config.min_nkill else 0
)
# 转换为结果对象
results_list = []
for _, row in agg_df.iterrows():
result = BroadSpectrumResult(
chem_id=row.name,
apscore_total=row["apscore_total"],
apscore_gnegative=row["apscore_gnegative"],
apscore_gpositive=row["apscore_gpositive"],
ginhib_total=int(row["ginhib_total"]),
ginhib_gnegative=int(row["ginhib_gnegative"]),
ginhib_gpositive=int(row["ginhib_gpositive"]),
broad_spectrum=int(row["broad_spectrum"])
)
results_list.append(result)
return results_list
def predict_from_smiles(self,
smiles_list: List[str],
chem_ids: Optional[List[str]] = None) -> List[BroadSpectrumResult]:
"""
从SMILES字符串列表预测广谱抗菌活性
Args:
smiles_list: SMILES字符串列表
chem_ids: 化合物ID列表如果为None则自动生成
Returns:
广谱抗菌预测结果列表
"""
if chem_ids is None:
chem_ids = [f"mol{i+1}" for i in range(len(smiles_list))]
if len(smiles_list) != len(chem_ids):
raise ValueError("smiles_list and chem_ids must have the same length")
molecules = [
MoleculeInput(smiles=smiles, chem_id=chem_id)
for smiles, chem_id in zip(smiles_list, chem_ids)
]
return self.predict_batch(molecules)
def predict_from_file(self,
file_path: str,
smiles_column: str = "smiles",
id_column: str = "chem_id") -> List[BroadSpectrumResult]:
"""
从文件预测广谱抗菌活性
Args:
file_path: 输入文件路径支持CSV/TSV
smiles_column: SMILES列名
id_column: 化合物ID列名
Returns:
广谱抗菌预测结果列表
"""
# 读取文件
if file_path.endswith('.tsv'):
df = pd.read_csv(file_path, sep='\t')
else:
df = pd.read_csv(file_path)
# 验证列存在(大小写不敏感)
columns_lower = {col.lower(): col for col in df.columns}
smiles_col_actual = columns_lower.get(smiles_column.lower())
if smiles_col_actual is None:
raise ValueError(f"Column '{smiles_column}' not found in file. Available columns: {list(df.columns)}")
# 处理ID列
id_col_actual = columns_lower.get(id_column.lower())
if id_col_actual is None:
df[id_column] = [f"mol{i+1}" for i in range(len(df))]
id_col_actual = id_column
# 创建分子输入
molecules = [
MoleculeInput(smiles=row[smiles_col_actual], chem_id=str(row[id_col_actual]))
for _, row in df.iterrows()
]
return self.predict_batch(molecules)
def create_predictor(config: Optional[PredictionConfig] = None) -> ParallelBroadSpectrumPredictor:
"""
创建并行广谱抗菌预测器实例
Args:
config: 预测配置参数
Returns:
预测器实例
"""
return ParallelBroadSpectrumPredictor(config)
# 便捷函数
def predict_smiles(smiles_list: List[str],
chem_ids: Optional[List[str]] = None,
config: Optional[PredictionConfig] = None) -> List[BroadSpectrumResult]:
"""
便捷函数直接从SMILES列表预测广谱抗菌活性
Args:
smiles_list: SMILES字符串列表
chem_ids: 化合物ID列表
config: 预测配置
Returns:
预测结果列表
"""
predictor = create_predictor(config)
return predictor.predict_from_smiles(smiles_list, chem_ids)
def predict_file(file_path: str,
smiles_column: str = "smiles",
id_column: str = "chem_id",
config: Optional[PredictionConfig] = None) -> List[BroadSpectrumResult]:
"""
便捷函数:从文件预测广谱抗菌活性
Args:
file_path: 输入文件路径
smiles_column: SMILES列名
id_column: ID列名
config: 预测配置
Returns:
预测结果列表
"""
predictor = create_predictor(config)
return predictor.predict_from_file(file_path, smiles_column, id_column)

View File

@@ -0,0 +1,179 @@
import os
import yaml
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data, Dataset, Batch
from rdkit import Chem
from rdkit.Chem.rdchem import BondType as BT
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
ATOM_LIST = list(range(1,119))
CHIRALITY_LIST = [
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
Chem.rdchem.ChiralType.CHI_OTHER
]
BOND_LIST = [
BT.SINGLE,
BT.DOUBLE,
BT.TRIPLE,
BT.AROMATIC
]
BONDDIR_LIST = [
Chem.rdchem.BondDir.NONE,
Chem.rdchem.BondDir.ENDUPRIGHT,
Chem.rdchem.BondDir.ENDDOWNRIGHT
]
class MoleculeDataset(Dataset):
"""
Dataset class for creating molecular graphs.
Attributes:
- smile_df (pandas.DataFrame): DataFrame containing SMILES data.
- smile_column (str): Name of the column containing SMILES strings.
- id_column (str): Name of the column containing molecule IDs.
"""
def __init__(self, smile_df, smile_column, id_column):
super(Dataset, self).__init__()
# Gather the SMILES and the corresponding IDs
self.smiles_data = smile_df[smile_column].tolist()
self.id_data = smile_df[id_column].tolist()
def __getitem__(self, index):
# Get the molecule
mol = Chem.MolFromSmiles(self.smiles_data[index])
mol = Chem.AddHs(mol)
#########################
# Get the molecule info #
#########################
type_idx = []
chirality_idx = []
atomic_number = []
# Roberto: Might want to add more features later on. Such as atomic spin
for atom in mol.GetAtoms():
if atom.GetAtomicNum() == 0:
print(self.id_data[index])
type_idx.append(ATOM_LIST.index(atom.GetAtomicNum()))
chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag()))
atomic_number.append(atom.GetAtomicNum())
x1 = torch.tensor(type_idx, dtype=torch.long).view(-1,1)
x2 = torch.tensor(chirality_idx, dtype=torch.long).view(-1,1)
x = torch.cat([x1, x2], dim=-1)
row, col, edge_feat = [], [], []
for bond in mol.GetBonds():
start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
row += [start, end]
col += [end, start]
edge_feat.append([
BOND_LIST.index(bond.GetBondType()),
BONDDIR_LIST.index(bond.GetBondDir())
])
edge_feat.append([
BOND_LIST.index(bond.GetBondType()),
BONDDIR_LIST.index(bond.GetBondDir())
])
edge_index = torch.tensor([row, col], dtype=torch.long)
edge_attr = torch.tensor(np.array(edge_feat), dtype=torch.long)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr,
chem_id=self.id_data[index])
return data
def __len__(self):
return len(self.smiles_data)
def get(self, index):
return self.__getitem__(index)
def len(self):
return self.__len__()
def batch_representation(smile_df, dl_model, column_str, id_str, batch_size=10_000, id_is_str=True, device="cuda:0"):
"""
Generate molecular representations using a Deep Learning model.
Parameters:
- smile_df (pandas.DataFrame): DataFrame containing SMILES data.
- dl_model: Deep Learning model for molecular representation.
- column_str (str): Name of the column containing SMILES strings.
- id_str (str): Name of the column containing molecule IDs.
- batch_size (int, optional): Batch size for processing (default is 10,000).
- id_is_str (bool, optional): Whether IDs are strings (default is True).
- device (str, optional): Device for computation (default is "cuda:0").
Returns:
- chem_representation (pandas.DataFrame): DataFrame containing molecular representations.
"""
# First we create a list of graphs
molecular_graph_dataset = MoleculeDataset(smile_df, column_str, id_str)
graph_list = [g for g in molecular_graph_dataset]
# Determine number of loops to do given the batch size
n_batches = len(graph_list) // batch_size
# Are all molecules accounted for?
remaining_molecules = len(graph_list) % batch_size
# Starting indices
start, end = 0, batch_size
# Determine number of iterations
if remaining_molecules == 0:
n_iter = n_batches
elif remaining_molecules > 0:
n_iter = n_batches + 1
# A list to store the batch dataframes
batch_dataframes = []
# Iterate over the batches
for i in range(n_iter):
# Start batch object
batch_obj = Batch()
graph_batch = batch_obj.from_data_list(graph_list[start:end])
graph_batch = graph_batch.to(device)
# Gather the representation
with torch.no_grad():
dl_model.eval()
h_representation, _ = dl_model(graph_batch)
chem_ids = graph_batch.chem_id
batch_df = pd.DataFrame(h_representation.cpu().numpy(), index=chem_ids)
batch_dataframes.append(batch_df)
# Get the next batch
## In the final iteration we want to get all the remaining molecules
if i == n_iter - 2:
start = end
end = len(graph_list)
else:
start = end
end = end + batch_size
# Concatenate the dataframes
chem_representation = pd.concat(batch_dataframes)
return chem_representation

164
models/ginet_concat.py Normal file
View File

@@ -0,0 +1,164 @@
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool
num_atom_type = 119 # including the extra mask tokens
num_chirality_tag = 3
num_bond_type = 5 # including aromatic and self-loop edge
num_bond_direction = 3
class GINEConv(MessagePassing):
def __init__(self, emb_dim):
super(GINEConv, self).__init__()
self.mlp = nn.Sequential(
nn.Linear(emb_dim, 2*emb_dim),
nn.BatchNorm1d(2*emb_dim),
nn.ReLU(),
nn.Linear(2*emb_dim, emb_dim),
nn.ReLU()
)
self.edge_embedding1 = nn.Embedding(num_bond_type, emb_dim)
self.edge_embedding2 = nn.Embedding(num_bond_direction, emb_dim)
nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
def forward(self, x, edge_index, edge_attr):
# add self loops in the edge space
edge_index = add_self_loops(edge_index, num_nodes=x.size(0))[0]
# add features corresponding to self-loop edges.
self_loop_attr = torch.zeros(x.size(0), 2)
self_loop_attr[:,0] = 4 #bond type for self-loop edge
self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0)
edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
return self.propagate(edge_index, x=x, edge_attr=edge_embeddings)
def message(self, x_j, edge_attr):
return x_j + edge_attr
def update(self, aggr_out):
return self.mlp(aggr_out)
class GINet(nn.Module):
"""
GIN encoder from MolE.
Args:
num_layer (int): Number of GNN layers.
emb_dim (int): Dimensionality of embeddings for each graph layer.
feat_dim (int): Dimensionality of embedding vector.
drop_ratio (float): Dropout rate.
pool (str): Pooling method for neighbor aggregation ('mean', 'max', or 'add').
Output:
h_global_embedding: Graph-level representation
out: Final embedding vector
"""
def __init__(self, num_layer=5, emb_dim=300, feat_dim=256, drop_ratio=0, pool='mean'):
super(GINet, self).__init__()
self.num_layer = num_layer
self.emb_dim = emb_dim
self.feat_dim = feat_dim
self.drop_ratio = drop_ratio
self.concat_dim = num_layer * emb_dim
if self.concat_dim != self.feat_dim:
print(f"Representation dimension ({self.concat_dim}) - Embedding dimension ({self.feat_dim})")
self.x_embedding1 = nn.Embedding(num_atom_type, emb_dim)
self.x_embedding2 = nn.Embedding(num_chirality_tag, emb_dim)
nn.init.xavier_uniform_(self.x_embedding1.weight.data)
nn.init.xavier_uniform_(self.x_embedding2.weight.data)
# List of MLPs
self.gnns = nn.ModuleList()
for layer in range(num_layer):
self.gnns.append(GINEConv(emb_dim))
# List of batchnorms
self.batch_norms = nn.ModuleList()
for layer in range(num_layer):
self.batch_norms.append(nn.BatchNorm1d(emb_dim))
if pool == 'mean':
self.pool = global_mean_pool
elif pool == 'max':
self.pool = global_max_pool
elif pool == 'add':
self.pool = global_add_pool
self.feat_lin = nn.Linear(self.concat_dim, self.feat_dim)
self.out_lin = nn.Sequential(
nn.Linear(self.feat_dim, self.feat_dim),
nn.BatchNorm1d(self.feat_dim),
nn.ReLU(inplace=True),
nn.Linear(self.feat_dim, self.feat_dim), # Is not reduced to half size!
nn.BatchNorm1d(self.feat_dim),
nn.ReLU(inplace=True),
nn.Linear(self.feat_dim, self.feat_dim)
)
def forward(self, data):
x = data.x
edge_index = data.edge_index
edge_attr = data.edge_attr
h_init = self.x_embedding1(x[:,0]) + self.x_embedding2(x[:,1])
# Perform the convolutions
h_dict = {}
for layer in range(self.num_layer):
if layer == self.num_layer - 1:
tmp_h = self.gnns[layer](h_dict[f"h_{layer - 1}"], edge_index, edge_attr)
tmp_h = self.batch_norms[layer](tmp_h)
h_dict[f"h_{layer}"] = F.dropout(tmp_h, self.drop_ratio, training=self.training)
else:
if layer == 0:
tmp_h = self.gnns[layer](h_init, edge_index, edge_attr)
tmp_h = self.batch_norms[layer](tmp_h)
h_dict[f"h_{layer}"] = F.dropout(F.relu(tmp_h), self.drop_ratio, training=self.training)
else:
tmp_h = self.gnns[layer](h_dict[f"h_{layer - 1}"], edge_index, edge_attr)
tmp_h = self.batch_norms[layer](tmp_h)
h_dict[f"h_{layer}"] = F.dropout(F.relu(tmp_h), self.drop_ratio, training=self.training)
# Graph representation
h_list_pooled = [self.pool(h_dict[f"h_{layer}"], data.batch) for layer in range(self.num_layer)]
h_global_embedding = torch.cat(h_list_pooled, dim=1)
assert h_global_embedding.shape[1] == self.concat_dim
# Projection
h_expansion = self.feat_lin(h_global_embedding)
out = self.out_lin(h_expansion)
return h_global_embedding, out
def load_my_state_dict(self, state_dict):
own_state = self.state_dict()
for name, param in state_dict.items():
if name not in own_state:
continue
if isinstance(param, nn.parameter.Parameter):
# backwards compatibility for serialized parameters
param = param.data
print(name)
own_state[name].copy_(param)

26
models/mole.yaml Normal file
View File

@@ -0,0 +1,26 @@
name: mole
channels:
- pytorch
- nvidia
- rmg
- conda-forge
- rdkit
- defaults
dependencies:
- python=3.8
- pytorch=2.2.1
- pytorch-cuda=11.8
- rdkit=2022.3.3
- pip
- pip:
- xgboost==1.6.2
- pandas==2.0.3
- PyYAML==6.0.1
- torch_geometric==2.5.0
- openpyxl
- pubchempy==1.0.4
- matplotlib==3.7.5
- seaborn==0.13.2
- tqdm
- scikit-learn==1.0.2
- umap-learn==0.5.5

View File

@@ -0,0 +1,128 @@
"""
MolE Representation Module
This module provides functions to generate MolE molecular representations.
"""
import os
import yaml
import torch
import pandas as pd
from rdkit import Chem
from rdkit import RDLogger
from .dataset_representation import batch_representation
from .ginet_concat import GINet
RDLogger.DisableLog('rdApp.*')
def read_smiles(data_path, smile_col="smiles", id_col="chem_id"):
"""
Read SMILES data from a file or DataFrame and remove invalid SMILES.
Parameters:
- data_path (str or pd.DataFrame): Path to the file or a DataFrame containing SMILES data.
- smile_col (str, optional): Name of the column containing SMILES strings.
- id_col (str, optional): Name of the column containing molecule IDs.
Returns:
- smile_df (pandas.DataFrame): DataFrame containing SMILES data with specified columns.
"""
# Read the data
if isinstance(data_path, pd.DataFrame):
smile_df = data_path.copy()
else:
# Try to read with different separators
try:
smile_df = pd.read_csv(data_path, sep='\t')
except:
smile_df = pd.read_csv(data_path)
# Check if columns exist, handle case-insensitive matching
columns_lower = {col.lower(): col for col in smile_df.columns}
smile_col_actual = columns_lower.get(smile_col.lower(), smile_col)
id_col_actual = columns_lower.get(id_col.lower(), id_col)
if smile_col_actual not in smile_df.columns:
raise ValueError(f"Column '{smile_col}' not found in data. Available columns: {list(smile_df.columns)}")
# Select columns
if id_col_actual in smile_df.columns:
smile_df = smile_df[[smile_col_actual, id_col_actual]]
smile_df.columns = [smile_col, id_col]
else:
# Create ID column if not exists
smile_df = smile_df[[smile_col_actual]]
smile_df.columns = [smile_col]
smile_df[id_col] = [f"mol{i+1}" for i in range(len(smile_df))]
# Make sure ID column is interpreted as str
smile_df[id_col] = smile_df[id_col].astype(str)
# Remove NaN
smile_df = smile_df.dropna()
# Remove invalid smiles
smile_df = smile_df[smile_df[smile_col].apply(lambda x: Chem.MolFromSmiles(x) is not None)]
return smile_df
def load_pretrained_model(pretrained_model_dir, device="cuda:0"):
"""
Load a pre-trained MolE model.
Parameters:
- pretrained_model_dir (str): Path to the pre-trained MolE model directory.
- device (str, optional): Device for computation (default is "cuda:0").
Returns:
- model: Loaded pre-trained model.
"""
# Read model configuration
config = yaml.load(open(os.path.join(pretrained_model_dir, "config.yaml"), "r"), Loader=yaml.FullLoader)
model_config = config["model"]
# Instantiate model
model = GINet(**model_config).to(device)
# Load pre-trained weights
model_pth_path = os.path.join(pretrained_model_dir, "model.pth")
print(f"Loading model from: {model_pth_path}")
state_dict = torch.load(model_pth_path, map_location=device)
model.load_my_state_dict(state_dict)
return model
def process_representation(dataset_path, smile_column_str, id_column_str, pretrained_dir, device):
"""
Process the dataset to generate molecular representations.
Parameters:
- dataset_path (str or pd.DataFrame): Path to the dataset file or DataFrame.
- pretrained_dir (str): Path to the pre-trained model directory.
- smile_column_str (str): Name of the column containing SMILES strings.
- id_column_str (str): Name of the column containing molecule IDs.
- device (str): Device to use for computation. Can be "cpu", "cuda:0", etc.
Returns:
- udl_representation (pandas.DataFrame): DataFrame containing molecular representations.
"""
# First we read the SMILES dataframe
smiles_df = read_smiles(dataset_path, smile_col=smile_column_str, id_col=id_column_str)
# Load the pre-trained model
pmodel = load_pretrained_model(pretrained_model_dir=pretrained_dir, device=device)
# Gather pre-trained representation
udl_representation = batch_representation(smiles_df, pmodel, smile_column_str, id_column_str, device=device)
return udl_representation