add mole predcit module

This commit is contained in:
2025-10-17 15:54:00 +08:00
parent ea218a3a39
commit 336bfe4b65
14 changed files with 1559 additions and 0 deletions

163
Data/mole/README.md Normal file
View File

@@ -0,0 +1,163 @@
## convert old xgboots pickle format
```bash
cd Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001
ipython
```
```python
import xgboost as xgb
import pickle
from pathlib import Path
ckpt = Path('MolE-XGBoost-08.03.2024_14.20.pkl')
out_ckpt = Path('./')
# 加载旧模型
with open(ckpt, 'rb') as f:
model = pickle.load(f)
# 用新格式保存(推荐)
model.get_booster().save_model(out_ckpt.joinpath('MolE-XGBoost-08.03.2025_10.17.json'))
# 或者继续用pickle但清晰格式
booster = model.get_booster()
booster.feature_names = None
with open(out_ckpt.joinpath('MolE-XGBoost-08.03.2025_10.17.pkl'), 'wb') as f:
pickle.dump(model, f)
```
## 完整预测流程
```mermaid
SMILES 分子输入CSV文件
[MolE 模型]
├── config.yaml模型配置
└── model.pth模型权重
分子特征表示1000维向量
构建"分子-菌株对"(笛卡尔积)
└── maier_screening_results.tsv.gz菌株列表
[XGBoost 模型]
└── MolE-XGBoost-08.03.2025_10.17.json或.pkl
对每一对预测:是否抑制生长
获得原始预测结果(对每个菌株的预测)
[聚合分析]
├── maier_screening_results.tsv.gz菌株列表
└── strain_info_SF2.xlsx革兰染色信息
最终预测结果
输出CSV文件
```
## 所需文件清单
| 步骤 | 文件名 | 用途 | 备注 |
|------|--------|------|------|
| **MolE 模型** | `config.yaml` | 定义MolE网络结构 | YAML配置文件 |
| | `model.pth` | MolE模型权重 | PyTorch格式 |
| **构建菌株对** | `maier_screening_results.tsv.gz` | 提供40个菌株列表 | 压缩的TSV文件 |
| **XGBoost 预测** | `MolE-XGBoost-08.03.2025_10.17.json` | 预测分子-菌株对 | JSON格式或PKL格式 |
| **聚合分析** | `maier_screening_results.tsv.gz` | 菌株名称和统计 | 复用(与构建菌株对同一文件) |
| | `strain_info_SF2.xlsx` | 革兰染色分类信息 | Excel格式 |
## 文件存放位置
所有文件应位于:
```
Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001/
├── config.yaml
├── model.pth
├── MolE-XGBoost-08.03.2025_10.17.json
├── maier_screening_results.tsv.gz
└── strain_info_SF2.xlsx
```
## 代码中的对应关系
```python
# PredictionConfig 中的配置
@dataclass
class PredictionConfig:
xgboost_model_path = "MolE-XGBoost-08.03.2025_10.17.json"
mole_model_path = "model_ginconcat_btwin_100k_d8000_l0.0001" # 目录包含config.yaml + model.pth
strain_categories_path = "maier_screening_results.tsv.gz"
gram_info_path = "strain_info_SF2.xlsx"
```
## 数据流向总结
1. **输入**CSV文件中的SMILES分子
2. **MolE处理**:分子 → 1000维特征向量
3. **菌株配对**1个分子 × 40个菌株 = 40对
4. **XGBoost预测**:每对 → 抑制概率
5. **聚合分析**:统计和分类(按革兰染色)
6. **输出**CSV文件中的预测结果包含8个指标
## 参考文件
1. `maier_screening_results.tsv.gz` - 菌株列表和筛选数据
```python
self.maier_screen = pd.read_csv(
self.config.strain_categories_path, sep='\t', index_col=0
)
self.strain_ohe = self._prep_ohe(self.maier_screen.columns) # 独热编码
```
包含所有已知菌株的名称40个菌株
用于与每个分子做笛卡尔积(分子×菌株),生成所有"分子-菌株对"
XGBoost为每一对预测是否能抑制该菌株的生长
2. `strain_info_SF2.xlsx` - 革兰染色信息
```python
self.maier_strains = pd.read_excel(self.config.gram_info_path, ...)
gram_dict = self.maier_strains[["Gram stain"]].to_dict()["Gram stain"]
```
记录每个菌株的革兰染色属性:阳性(positive) 或 阴性(negative)
用于将预测结果按革兰染色分类统计
预测结果示例:
某分子 mol1 的预测结果会包括:
```python
BroadSpectrumResult(
chem_id='mol1',
apscore_total=2.5, # 对所有菌株的抗菌分数
apscore_gnegative=2.1, # 仅对革兰阴性菌的分数
apscore_gpositive=2.8, # 仅对革兰阳性菌的分数
ginhib_total=25, # 抑制的菌株总数
ginhib_gnegative=12, # 抑制的革兰阴性菌数
ginhib_gpositive=13, # 抑制的革兰阳性菌数
broad_spectrum=1 # 是否广谱≥10个菌株
)
```
结果解读:
## BroadSpectrumResult 字段说明表
| 字段名 | 数据类型 | 计算方法 | 含义说明 |
|--------|----------|----------|---------|
| `chem_id` | 字符串 | 输入的化合物标识符 | 化合物的唯一标识,如 "mol1"、"compound_001" 等 |
| `apscore_total` | 浮点数 | `log(gmean(所有40个菌株的预测概率))` | 总体抗菌潜力分数:所有菌株预测概率的几何平均数的对数。值越高表示抗菌活性越强;负值表示整体抑制概率较低 |
| `apscore_gnegative` | 浮点数 | `log(gmean(革兰阴性菌株的预测概率))` | 革兰阴性菌抗菌潜力分数:仅针对革兰阴性菌株计算的抗菌分数。用于判断对阴性菌的特异性 |
| `apscore_gpositive` | 浮点数 | `log(gmean(革兰阳性菌株的预测概率))` | 革兰阳性菌抗菌潜力分数:仅针对革兰阳性菌株计算的抗菌分数。用于判断对阳性菌的特异性 |
| `ginhib_total` | 整数 | `sum(所有菌株的二值化预测)` | 总抑制菌株数:预测被抑制的菌株总数(概率 ≥ 0.04374 的菌株数量)。范围 0-40 |
| `ginhib_gnegative` | 整数 | `sum(革兰阴性菌株的二值化预测)` | 革兰阴性菌抑制数:预测被抑制的革兰阴性菌株数量。范围 0-20 |
| `ginhib_gpositive` | 整数 | `sum(革兰阳性菌株的二值化预测)` | 革兰阳性菌抑制数:预测被抑制的革兰阳性菌株数量。范围 0-20 |
| `broad_spectrum` | 整数 (0/1) | `1 if ginhib_total >= 10 else 0` | 广谱抗菌标志:如果抑制菌株数 ≥ 10判定为广谱抗菌药物1否则为窄谱0 |
说明
- **apscore_* 类字段**:基于预测概率的连续评分,反映抗菌活性强度
- **ginhib_* 类字段**:基于二值化预测的离散计数,反映抑制范围
- **broad_spectrum**:基于 ginhib_total 的布尔判定,快速标识广谱特性

View File

@@ -0,0 +1,28 @@
batch_size: 1000 # batch size
warm_up: 10 # warm-up epochs
epochs: 1000 # total number of epochs
load_model: None # resume training
eval_every_n_epochs: 1 # validation frequency
save_every_n_epochs: 5 # automatic model saving frequecy
fp16_precision: False # float precision 16 (i.e. True/False)
init_lr: 0.0005 # initial learning rate for Adam
weight_decay: 1e-5 # weight decay for Adam
gpu: cuda:0 # training GPU
model_type: gin_concat # GNN backbone (i.e., gin/gcn)
model:
num_layer: 5 # number of graph conv layers
emb_dim: 200 # embedding dimension in graph conv layers
feat_dim: 8000 # output feature dimention
drop_ratio: 0.0 # dropout ratio
pool: add # readout pooling (i.e., mean/max/add)
dataset:
num_workers: 50 # dataloader number of workers
valid_size: 0.1 # ratio of validation data
data_path: data/pubchem_data/pubchem_100k_random.txt # path of pre-training data
loss:
l: 0.0001 # Lambda parameter

26
models/__init__.py Normal file
View File

@@ -0,0 +1,26 @@
"""
SIME Models Package
This package contains models for antimicrobial activity prediction.
"""
from .broad_spectrum_predictor import (
ParallelBroadSpectrumPredictor,
PredictionConfig,
MoleculeInput,
BroadSpectrumResult,
create_predictor,
predict_smiles,
predict_file
)
__all__ = [
'ParallelBroadSpectrumPredictor',
'PredictionConfig',
'MoleculeInput',
'BroadSpectrumResult',
'create_predictor',
'predict_smiles',
'predict_file'
]

View File

@@ -0,0 +1,567 @@
"""
并行广谱抗菌预测器模块
提供高性能的分子广谱抗菌活性预测功能,支持批量处理和多进程并行计算。
基于MolE分子表示和XGBoost模型进行预测。
"""
import os
import re
import pickle
import torch
import numpy as np
import pandas as pd
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import List, Dict, Union, Optional, Tuple, Any
from dataclasses import dataclass
from pathlib import Path
from scipy.stats.mstats import gmean
from sklearn.preprocessing import OneHotEncoder
from .mole_representation import process_representation
@dataclass
class PredictionConfig:
"""预测配置参数"""
xgboost_model_path: str = None
mole_model_path: str = None
strain_categories_path: str = None
gram_info_path: str = None
app_threshold: float = 0.04374140128493309
min_nkill: int = 10
batch_size: int = 100
n_workers: Optional[int] = None
device: str = "auto"
def __post_init__(self):
"""设置默认路径"""
from pathlib import Path
# 获取当前文件所在目录
current_file = Path(__file__).resolve()
project_root = current_file.parent.parent # models -> 项目根
data_dir = project_root / "Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001"
# 设置所有路径
if self.mole_model_path is None:
self.mole_model_path = str(data_dir)
if self.xgboost_model_path is None:
self.xgboost_model_path = str(data_dir / "MolE-XGBoost-08.03.2025_10.17.pkl")
if self.strain_categories_path is None:
self.strain_categories_path = str(data_dir / "maier_screening_results.tsv.gz")
if self.gram_info_path is None:
self.gram_info_path = str(data_dir / "strain_info_SF2.xlsx")
@dataclass
class MoleculeInput:
"""分子输入数据结构"""
smiles: str
chem_id: Optional[str] = None
@dataclass
class BroadSpectrumResult:
"""广谱抗菌预测结果"""
chem_id: str
apscore_total: float
apscore_gnegative: float
apscore_gpositive: float
ginhib_total: int
ginhib_gnegative: int
ginhib_gpositive: int
broad_spectrum: int
def to_dict(self) -> Dict[str, Union[str, float, int]]:
"""转换为字典格式"""
return {
'chem_id': self.chem_id,
'apscore_total': self.apscore_total,
'apscore_gnegative': self.apscore_gnegative,
'apscore_gpositive': self.apscore_gpositive,
'ginhib_total': self.ginhib_total,
'ginhib_gnegative': self.ginhib_gnegative,
'ginhib_gpositive': self.ginhib_gpositive,
'broad_spectrum': self.broad_spectrum
}
class BroadSpectrumPredictor:
"""
广谱抗菌预测器
基于MolE分子表示和XGBoost模型预测分子的广谱抗菌活性。
支持单分子和批量预测,提供详细的抗菌潜力分析。
"""
def __init__(self, config: Optional[PredictionConfig] = None) -> None:
"""
初始化预测器
Args:
config: 预测配置参数如果为None则使用默认配置
"""
self.config = config or PredictionConfig()
self.n_workers = self.config.n_workers or mp.cpu_count()
# 验证文件路径
self._validate_paths()
# 预加载共享数据
self._load_shared_data()
def _validate_paths(self) -> None:
"""验证必要文件路径是否存在"""
required_files = {
"mole_model": self.config.mole_model_path,
"xgboost_model": self.config.xgboost_model_path,
"strain_categories": self.config.strain_categories_path,
"gram_info": self.config.gram_info_path,
}
for name, file_path in required_files.items():
if file_path is None:
raise ValueError(f"{name} is None! Check __post_init__ configuration")
if not Path(file_path).exists():
raise FileNotFoundError(f"Required {name} not found: {file_path}")
def _load_shared_data(self) -> None:
"""加载共享数据(菌株信息、革兰染色信息等)"""
try:
# 加载菌株筛选数据
self.maier_screen: pd.DataFrame = pd.read_csv(
self.config.strain_categories_path, sep='\t', index_col=0
)
# 准备菌株独热编码
self.strain_ohe: pd.DataFrame = self._prep_ohe(self.maier_screen.columns)
# 加载革兰染色信息
self.maier_strains: pd.DataFrame = pd.read_excel(
self.config.gram_info_path,
skiprows=[0, 1, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54],
index_col="NT data base"
)
except Exception as e:
raise RuntimeError(f"Failed to load shared data: {str(e)}")
def _prep_ohe(self, categories: pd.Index) -> pd.DataFrame:
"""
准备菌株的独热编码
Args:
categories: 菌株类别索引
Returns:
独热编码后的DataFrame
"""
try:
# 新版本 sklearn 使用 sparse_output
ohe = OneHotEncoder(sparse_output=False)
except TypeError:
# 旧版本 sklearn 使用 sparse
ohe = OneHotEncoder(sparse=False)
ohe.fit(pd.DataFrame(categories))
cat_ohe = pd.DataFrame(
ohe.transform(pd.DataFrame(categories)),
columns=categories,
index=categories
)
return cat_ohe
def _get_mole_representation(self, molecules: List[MoleculeInput]) -> pd.DataFrame:
"""
获取分子的MolE表示
Args:
molecules: 分子输入列表
Returns:
MolE特征表示DataFrame
"""
# 准备输入数据
df_data = []
for i, mol in enumerate(molecules):
chem_id = mol.chem_id or f"mol{i+1}"
df_data.append({"smiles": mol.smiles, "chem_id": chem_id})
df = pd.DataFrame(df_data)
# 确定设备
device = self.config.device
if device == "auto":
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# 获取MolE表示
return process_representation(
dataset_path=df,
smile_column_str="smiles",
id_column_str="chem_id",
pretrained_dir=self.config.mole_model_path,
device=device
)
def _add_strains(self, chemfeats_df: pd.DataFrame) -> pd.DataFrame:
"""
添加菌株信息到化学特征(笛卡尔积)
Args:
chemfeats_df: 化学特征DataFrame
Returns:
包含菌株信息的特征DataFrame
"""
# 准备化学特征
chemfe = chemfeats_df.reset_index().rename(columns={"index": "chem_id"})
chemfe["chem_id"] = chemfe["chem_id"].astype(str)
# 准备独热编码
sohe = self.strain_ohe.reset_index().rename(columns={"index": "strain_name"})
# 笛卡尔积合并
xpred = chemfe.merge(sohe, how="cross")
xpred["pred_id"] = xpred["chem_id"].str.cat(xpred["strain_name"], sep=":")
xpred = xpred.set_index("pred_id")
xpred = xpred.drop(columns=["chem_id", "strain_name"])
return xpred
def _gram_stain(self, label_df: pd.DataFrame) -> pd.DataFrame:
"""
添加革兰染色信息
Args:
label_df: 包含菌株名称的DataFrame
Returns:
添加革兰染色信息后的DataFrame
"""
df_label = label_df.copy()
# 提取NT编号
df_label["nt_number"] = df_label["strain_name"].apply(
lambda x: re.search(r".*?\((NT\d+)\)", x).group(1) if re.search(r".*?\((NT\d+)\)", x) else None
)
# 创建革兰染色字典
gram_dict = self.maier_strains[["Gram stain"]].to_dict()["Gram stain"]
# 添加染色信息
df_label["gram_stain"] = df_label["nt_number"].apply(gram_dict.get)
return df_label
def _antimicrobial_potential(self, score_df: pd.DataFrame) -> pd.DataFrame:
"""
计算抗菌潜力分数
Args:
score_df: 预测分数DataFrame
Returns:
聚合后的抗菌潜力DataFrame
"""
# 分离化合物ID和菌株名
score_df["chem_id"] = score_df["pred_id"].str.split(":", expand=True)[0]
score_df["strain_name"] = score_df["pred_id"].str.split(":", expand=True)[1]
# 添加革兰染色信息
pred_df = self._gram_stain(score_df)
# 计算抗菌潜力分数(几何平均数的对数)
apscore_total = pred_df.groupby("chem_id")["1"].apply(gmean).to_frame().rename(
columns={"1": "apscore_total"}
)
apscore_total["apscore_total"] = np.log(apscore_total["apscore_total"])
# 按革兰染色分组的抗菌分数
apscore_gram = pred_df.groupby(["chem_id", "gram_stain"])["1"].apply(gmean).unstack().rename(
columns={"negative": "apscore_gnegative", "positive": "apscore_gpositive"}
)
apscore_gram["apscore_gnegative"] = np.log(apscore_gram["apscore_gnegative"])
apscore_gram["apscore_gpositive"] = np.log(apscore_gram["apscore_gpositive"])
# 被抑制菌株数统计
inhibted_total = pred_df.groupby("chem_id")["growth_inhibition"].sum().to_frame().rename(
columns={"growth_inhibition": "ginhib_total"}
)
# 按革兰染色分组的被抑制菌株数
inhibted_gram = pred_df.groupby(["chem_id", "gram_stain"])["growth_inhibition"].sum().unstack().rename(
columns={"negative": "ginhib_gnegative", "positive": "ginhib_gpositive"}
)
# 合并所有结果
agg_pred = apscore_total.join(apscore_gram).join(inhibted_total).join(inhibted_gram)
# 填充NaN值
agg_pred = agg_pred.fillna(0)
return agg_pred
def _predict_batch_worker(batch_data: Tuple[pd.DataFrame, int],
model_path: str,
app_threshold: float) -> Tuple[int, pd.DataFrame]:
"""
批次预测工作函数(用于多进程)
Args:
batch_data: (特征数据, 批次ID)
model_path: XGBoost模型路径
app_threshold: 抑制阈值
Returns:
(批次ID, 预测结果DataFrame)
"""
import warnings
# 忽略所有XGBoost版本相关的警告
warnings.filterwarnings("ignore", category=UserWarning, module="xgboost")
X_input, batch_id = batch_data
# 加载模型
with open(model_path, "rb") as file:
model = pickle.load(file)
# 修复特征名称兼容性问题
# 原因:模型使用旧版 XGBoost 保存时,特征列为元组格式(如 "('bacteria_name',)"
# 新版 XGBoost 严格检查特征名称匹配,导致预测失败。
# 解决:清除 XGBoost 内部的特征名称验证,直接使用输入特征进行预测
# 注意:此操作不改变模型权重和预测逻辑,只禁用格式检查,预测结果保持一致
if hasattr(model, 'get_booster'):
model.get_booster().feature_names = None
# 进行预测
y_pred = model.predict_proba(X_input)
pred_df = pd.DataFrame(y_pred, columns=["0", "1"], index=X_input.index)
# 二值化预测结果
pred_df["growth_inhibition"] = pred_df["1"].apply(
lambda x: 1 if x >= app_threshold else 0
)
return batch_id, pred_df
class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
"""
并行广谱抗菌预测器
继承自BroadSpectrumPredictor添加了多进程并行处理能力
适用于大规模分子批量预测。
"""
def predict_single(self, molecule: MoleculeInput) -> BroadSpectrumResult:
"""
预测单个分子的广谱抗菌活性
Args:
molecule: 分子输入数据
Returns:
广谱抗菌预测结果
"""
results = self.predict_batch([molecule])
return results[0]
def predict_batch(self, molecules: List[MoleculeInput]) -> List[BroadSpectrumResult]:
"""
批量预测分子的广谱抗菌活性
Args:
molecules: 分子输入列表
Returns:
广谱抗菌预测结果列表
"""
if not molecules:
return []
# 获取MolE表示
print(f"Processing {len(molecules)} molecules...")
mole_representation = self._get_mole_representation(molecules)
# 添加菌株信息
print("Preparing strain-level features...")
X_input = self._add_strains(mole_representation)
# 分批处理
print(f"Starting parallel prediction with {self.n_workers} workers...")
batches = []
for i in range(0, len(X_input), self.config.batch_size):
batch = X_input.iloc[i:i+self.config.batch_size]
batches.append((batch, i // self.config.batch_size))
# 并行预测
results = {}
with ProcessPoolExecutor(max_workers=self.n_workers) as executor:
futures = {
executor.submit(_predict_batch_worker, (batch_data, batch_id),
self.config.xgboost_model_path,
self.config.app_threshold): batch_id
for batch_data, batch_id in batches
}
for future in as_completed(futures):
batch_id, pred_df = future.result()
results[batch_id] = pred_df
print(f"Batch {batch_id} completed")
# 合并结果
print("Merging prediction results...")
all_pred_df = pd.concat([results[i] for i in sorted(results.keys())])
# 计算抗菌潜力
print("Calculating antimicrobial potential scores...")
all_pred_df = all_pred_df.reset_index()
agg_df = self._antimicrobial_potential(all_pred_df)
# 判断广谱抗菌
agg_df["broad_spectrum"] = agg_df["ginhib_total"].apply(
lambda x: 1 if x >= self.config.min_nkill else 0
)
# 转换为结果对象
results_list = []
for _, row in agg_df.iterrows():
result = BroadSpectrumResult(
chem_id=row.name,
apscore_total=row["apscore_total"],
apscore_gnegative=row["apscore_gnegative"],
apscore_gpositive=row["apscore_gpositive"],
ginhib_total=int(row["ginhib_total"]),
ginhib_gnegative=int(row["ginhib_gnegative"]),
ginhib_gpositive=int(row["ginhib_gpositive"]),
broad_spectrum=int(row["broad_spectrum"])
)
results_list.append(result)
return results_list
def predict_from_smiles(self,
smiles_list: List[str],
chem_ids: Optional[List[str]] = None) -> List[BroadSpectrumResult]:
"""
从SMILES字符串列表预测广谱抗菌活性
Args:
smiles_list: SMILES字符串列表
chem_ids: 化合物ID列表如果为None则自动生成
Returns:
广谱抗菌预测结果列表
"""
if chem_ids is None:
chem_ids = [f"mol{i+1}" for i in range(len(smiles_list))]
if len(smiles_list) != len(chem_ids):
raise ValueError("smiles_list and chem_ids must have the same length")
molecules = [
MoleculeInput(smiles=smiles, chem_id=chem_id)
for smiles, chem_id in zip(smiles_list, chem_ids)
]
return self.predict_batch(molecules)
def predict_from_file(self,
file_path: str,
smiles_column: str = "smiles",
id_column: str = "chem_id") -> List[BroadSpectrumResult]:
"""
从文件预测广谱抗菌活性
Args:
file_path: 输入文件路径支持CSV/TSV
smiles_column: SMILES列名
id_column: 化合物ID列名
Returns:
广谱抗菌预测结果列表
"""
# 读取文件
if file_path.endswith('.tsv'):
df = pd.read_csv(file_path, sep='\t')
else:
df = pd.read_csv(file_path)
# 验证列存在(大小写不敏感)
columns_lower = {col.lower(): col for col in df.columns}
smiles_col_actual = columns_lower.get(smiles_column.lower())
if smiles_col_actual is None:
raise ValueError(f"Column '{smiles_column}' not found in file. Available columns: {list(df.columns)}")
# 处理ID列
id_col_actual = columns_lower.get(id_column.lower())
if id_col_actual is None:
df[id_column] = [f"mol{i+1}" for i in range(len(df))]
id_col_actual = id_column
# 创建分子输入
molecules = [
MoleculeInput(smiles=row[smiles_col_actual], chem_id=str(row[id_col_actual]))
for _, row in df.iterrows()
]
return self.predict_batch(molecules)
def create_predictor(config: Optional[PredictionConfig] = None) -> ParallelBroadSpectrumPredictor:
"""
创建并行广谱抗菌预测器实例
Args:
config: 预测配置参数
Returns:
预测器实例
"""
return ParallelBroadSpectrumPredictor(config)
# 便捷函数
def predict_smiles(smiles_list: List[str],
chem_ids: Optional[List[str]] = None,
config: Optional[PredictionConfig] = None) -> List[BroadSpectrumResult]:
"""
便捷函数直接从SMILES列表预测广谱抗菌活性
Args:
smiles_list: SMILES字符串列表
chem_ids: 化合物ID列表
config: 预测配置
Returns:
预测结果列表
"""
predictor = create_predictor(config)
return predictor.predict_from_smiles(smiles_list, chem_ids)
def predict_file(file_path: str,
smiles_column: str = "smiles",
id_column: str = "chem_id",
config: Optional[PredictionConfig] = None) -> List[BroadSpectrumResult]:
"""
便捷函数:从文件预测广谱抗菌活性
Args:
file_path: 输入文件路径
smiles_column: SMILES列名
id_column: ID列名
config: 预测配置
Returns:
预测结果列表
"""
predictor = create_predictor(config)
return predictor.predict_from_file(file_path, smiles_column, id_column)

View File

@@ -0,0 +1,179 @@
import os
import yaml
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data, Dataset, Batch
from rdkit import Chem
from rdkit.Chem.rdchem import BondType as BT
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
ATOM_LIST = list(range(1,119))
CHIRALITY_LIST = [
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
Chem.rdchem.ChiralType.CHI_OTHER
]
BOND_LIST = [
BT.SINGLE,
BT.DOUBLE,
BT.TRIPLE,
BT.AROMATIC
]
BONDDIR_LIST = [
Chem.rdchem.BondDir.NONE,
Chem.rdchem.BondDir.ENDUPRIGHT,
Chem.rdchem.BondDir.ENDDOWNRIGHT
]
class MoleculeDataset(Dataset):
"""
Dataset class for creating molecular graphs.
Attributes:
- smile_df (pandas.DataFrame): DataFrame containing SMILES data.
- smile_column (str): Name of the column containing SMILES strings.
- id_column (str): Name of the column containing molecule IDs.
"""
def __init__(self, smile_df, smile_column, id_column):
super(Dataset, self).__init__()
# Gather the SMILES and the corresponding IDs
self.smiles_data = smile_df[smile_column].tolist()
self.id_data = smile_df[id_column].tolist()
def __getitem__(self, index):
# Get the molecule
mol = Chem.MolFromSmiles(self.smiles_data[index])
mol = Chem.AddHs(mol)
#########################
# Get the molecule info #
#########################
type_idx = []
chirality_idx = []
atomic_number = []
# Roberto: Might want to add more features later on. Such as atomic spin
for atom in mol.GetAtoms():
if atom.GetAtomicNum() == 0:
print(self.id_data[index])
type_idx.append(ATOM_LIST.index(atom.GetAtomicNum()))
chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag()))
atomic_number.append(atom.GetAtomicNum())
x1 = torch.tensor(type_idx, dtype=torch.long).view(-1,1)
x2 = torch.tensor(chirality_idx, dtype=torch.long).view(-1,1)
x = torch.cat([x1, x2], dim=-1)
row, col, edge_feat = [], [], []
for bond in mol.GetBonds():
start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
row += [start, end]
col += [end, start]
edge_feat.append([
BOND_LIST.index(bond.GetBondType()),
BONDDIR_LIST.index(bond.GetBondDir())
])
edge_feat.append([
BOND_LIST.index(bond.GetBondType()),
BONDDIR_LIST.index(bond.GetBondDir())
])
edge_index = torch.tensor([row, col], dtype=torch.long)
edge_attr = torch.tensor(np.array(edge_feat), dtype=torch.long)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr,
chem_id=self.id_data[index])
return data
def __len__(self):
return len(self.smiles_data)
def get(self, index):
return self.__getitem__(index)
def len(self):
return self.__len__()
def batch_representation(smile_df, dl_model, column_str, id_str, batch_size=10_000, id_is_str=True, device="cuda:0"):
"""
Generate molecular representations using a Deep Learning model.
Parameters:
- smile_df (pandas.DataFrame): DataFrame containing SMILES data.
- dl_model: Deep Learning model for molecular representation.
- column_str (str): Name of the column containing SMILES strings.
- id_str (str): Name of the column containing molecule IDs.
- batch_size (int, optional): Batch size for processing (default is 10,000).
- id_is_str (bool, optional): Whether IDs are strings (default is True).
- device (str, optional): Device for computation (default is "cuda:0").
Returns:
- chem_representation (pandas.DataFrame): DataFrame containing molecular representations.
"""
# First we create a list of graphs
molecular_graph_dataset = MoleculeDataset(smile_df, column_str, id_str)
graph_list = [g for g in molecular_graph_dataset]
# Determine number of loops to do given the batch size
n_batches = len(graph_list) // batch_size
# Are all molecules accounted for?
remaining_molecules = len(graph_list) % batch_size
# Starting indices
start, end = 0, batch_size
# Determine number of iterations
if remaining_molecules == 0:
n_iter = n_batches
elif remaining_molecules > 0:
n_iter = n_batches + 1
# A list to store the batch dataframes
batch_dataframes = []
# Iterate over the batches
for i in range(n_iter):
# Start batch object
batch_obj = Batch()
graph_batch = batch_obj.from_data_list(graph_list[start:end])
graph_batch = graph_batch.to(device)
# Gather the representation
with torch.no_grad():
dl_model.eval()
h_representation, _ = dl_model(graph_batch)
chem_ids = graph_batch.chem_id
batch_df = pd.DataFrame(h_representation.cpu().numpy(), index=chem_ids)
batch_dataframes.append(batch_df)
# Get the next batch
## In the final iteration we want to get all the remaining molecules
if i == n_iter - 2:
start = end
end = len(graph_list)
else:
start = end
end = end + batch_size
# Concatenate the dataframes
chem_representation = pd.concat(batch_dataframes)
return chem_representation

164
models/ginet_concat.py Normal file
View File

@@ -0,0 +1,164 @@
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool
num_atom_type = 119 # including the extra mask tokens
num_chirality_tag = 3
num_bond_type = 5 # including aromatic and self-loop edge
num_bond_direction = 3
class GINEConv(MessagePassing):
def __init__(self, emb_dim):
super(GINEConv, self).__init__()
self.mlp = nn.Sequential(
nn.Linear(emb_dim, 2*emb_dim),
nn.BatchNorm1d(2*emb_dim),
nn.ReLU(),
nn.Linear(2*emb_dim, emb_dim),
nn.ReLU()
)
self.edge_embedding1 = nn.Embedding(num_bond_type, emb_dim)
self.edge_embedding2 = nn.Embedding(num_bond_direction, emb_dim)
nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
def forward(self, x, edge_index, edge_attr):
# add self loops in the edge space
edge_index = add_self_loops(edge_index, num_nodes=x.size(0))[0]
# add features corresponding to self-loop edges.
self_loop_attr = torch.zeros(x.size(0), 2)
self_loop_attr[:,0] = 4 #bond type for self-loop edge
self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0)
edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
return self.propagate(edge_index, x=x, edge_attr=edge_embeddings)
def message(self, x_j, edge_attr):
return x_j + edge_attr
def update(self, aggr_out):
return self.mlp(aggr_out)
class GINet(nn.Module):
"""
GIN encoder from MolE.
Args:
num_layer (int): Number of GNN layers.
emb_dim (int): Dimensionality of embeddings for each graph layer.
feat_dim (int): Dimensionality of embedding vector.
drop_ratio (float): Dropout rate.
pool (str): Pooling method for neighbor aggregation ('mean', 'max', or 'add').
Output:
h_global_embedding: Graph-level representation
out: Final embedding vector
"""
def __init__(self, num_layer=5, emb_dim=300, feat_dim=256, drop_ratio=0, pool='mean'):
super(GINet, self).__init__()
self.num_layer = num_layer
self.emb_dim = emb_dim
self.feat_dim = feat_dim
self.drop_ratio = drop_ratio
self.concat_dim = num_layer * emb_dim
if self.concat_dim != self.feat_dim:
print(f"Representation dimension ({self.concat_dim}) - Embedding dimension ({self.feat_dim})")
self.x_embedding1 = nn.Embedding(num_atom_type, emb_dim)
self.x_embedding2 = nn.Embedding(num_chirality_tag, emb_dim)
nn.init.xavier_uniform_(self.x_embedding1.weight.data)
nn.init.xavier_uniform_(self.x_embedding2.weight.data)
# List of MLPs
self.gnns = nn.ModuleList()
for layer in range(num_layer):
self.gnns.append(GINEConv(emb_dim))
# List of batchnorms
self.batch_norms = nn.ModuleList()
for layer in range(num_layer):
self.batch_norms.append(nn.BatchNorm1d(emb_dim))
if pool == 'mean':
self.pool = global_mean_pool
elif pool == 'max':
self.pool = global_max_pool
elif pool == 'add':
self.pool = global_add_pool
self.feat_lin = nn.Linear(self.concat_dim, self.feat_dim)
self.out_lin = nn.Sequential(
nn.Linear(self.feat_dim, self.feat_dim),
nn.BatchNorm1d(self.feat_dim),
nn.ReLU(inplace=True),
nn.Linear(self.feat_dim, self.feat_dim), # Is not reduced to half size!
nn.BatchNorm1d(self.feat_dim),
nn.ReLU(inplace=True),
nn.Linear(self.feat_dim, self.feat_dim)
)
def forward(self, data):
x = data.x
edge_index = data.edge_index
edge_attr = data.edge_attr
h_init = self.x_embedding1(x[:,0]) + self.x_embedding2(x[:,1])
# Perform the convolutions
h_dict = {}
for layer in range(self.num_layer):
if layer == self.num_layer - 1:
tmp_h = self.gnns[layer](h_dict[f"h_{layer - 1}"], edge_index, edge_attr)
tmp_h = self.batch_norms[layer](tmp_h)
h_dict[f"h_{layer}"] = F.dropout(tmp_h, self.drop_ratio, training=self.training)
else:
if layer == 0:
tmp_h = self.gnns[layer](h_init, edge_index, edge_attr)
tmp_h = self.batch_norms[layer](tmp_h)
h_dict[f"h_{layer}"] = F.dropout(F.relu(tmp_h), self.drop_ratio, training=self.training)
else:
tmp_h = self.gnns[layer](h_dict[f"h_{layer - 1}"], edge_index, edge_attr)
tmp_h = self.batch_norms[layer](tmp_h)
h_dict[f"h_{layer}"] = F.dropout(F.relu(tmp_h), self.drop_ratio, training=self.training)
# Graph representation
h_list_pooled = [self.pool(h_dict[f"h_{layer}"], data.batch) for layer in range(self.num_layer)]
h_global_embedding = torch.cat(h_list_pooled, dim=1)
assert h_global_embedding.shape[1] == self.concat_dim
# Projection
h_expansion = self.feat_lin(h_global_embedding)
out = self.out_lin(h_expansion)
return h_global_embedding, out
def load_my_state_dict(self, state_dict):
own_state = self.state_dict()
for name, param in state_dict.items():
if name not in own_state:
continue
if isinstance(param, nn.parameter.Parameter):
# backwards compatibility for serialized parameters
param = param.data
print(name)
own_state[name].copy_(param)

26
models/mole.yaml Normal file
View File

@@ -0,0 +1,26 @@
name: mole
channels:
- pytorch
- nvidia
- rmg
- conda-forge
- rdkit
- defaults
dependencies:
- python=3.8
- pytorch=2.2.1
- pytorch-cuda=11.8
- rdkit=2022.3.3
- pip
- pip:
- xgboost==1.6.2
- pandas==2.0.3
- PyYAML==6.0.1
- torch_geometric==2.5.0
- openpyxl
- pubchempy==1.0.4
- matplotlib==3.7.5
- seaborn==0.13.2
- tqdm
- scikit-learn==1.0.2
- umap-learn==0.5.5

View File

@@ -0,0 +1,128 @@
"""
MolE Representation Module
This module provides functions to generate MolE molecular representations.
"""
import os
import yaml
import torch
import pandas as pd
from rdkit import Chem
from rdkit import RDLogger
from .dataset_representation import batch_representation
from .ginet_concat import GINet
RDLogger.DisableLog('rdApp.*')
def read_smiles(data_path, smile_col="smiles", id_col="chem_id"):
"""
Read SMILES data from a file or DataFrame and remove invalid SMILES.
Parameters:
- data_path (str or pd.DataFrame): Path to the file or a DataFrame containing SMILES data.
- smile_col (str, optional): Name of the column containing SMILES strings.
- id_col (str, optional): Name of the column containing molecule IDs.
Returns:
- smile_df (pandas.DataFrame): DataFrame containing SMILES data with specified columns.
"""
# Read the data
if isinstance(data_path, pd.DataFrame):
smile_df = data_path.copy()
else:
# Try to read with different separators
try:
smile_df = pd.read_csv(data_path, sep='\t')
except:
smile_df = pd.read_csv(data_path)
# Check if columns exist, handle case-insensitive matching
columns_lower = {col.lower(): col for col in smile_df.columns}
smile_col_actual = columns_lower.get(smile_col.lower(), smile_col)
id_col_actual = columns_lower.get(id_col.lower(), id_col)
if smile_col_actual not in smile_df.columns:
raise ValueError(f"Column '{smile_col}' not found in data. Available columns: {list(smile_df.columns)}")
# Select columns
if id_col_actual in smile_df.columns:
smile_df = smile_df[[smile_col_actual, id_col_actual]]
smile_df.columns = [smile_col, id_col]
else:
# Create ID column if not exists
smile_df = smile_df[[smile_col_actual]]
smile_df.columns = [smile_col]
smile_df[id_col] = [f"mol{i+1}" for i in range(len(smile_df))]
# Make sure ID column is interpreted as str
smile_df[id_col] = smile_df[id_col].astype(str)
# Remove NaN
smile_df = smile_df.dropna()
# Remove invalid smiles
smile_df = smile_df[smile_df[smile_col].apply(lambda x: Chem.MolFromSmiles(x) is not None)]
return smile_df
def load_pretrained_model(pretrained_model_dir, device="cuda:0"):
"""
Load a pre-trained MolE model.
Parameters:
- pretrained_model_dir (str): Path to the pre-trained MolE model directory.
- device (str, optional): Device for computation (default is "cuda:0").
Returns:
- model: Loaded pre-trained model.
"""
# Read model configuration
config = yaml.load(open(os.path.join(pretrained_model_dir, "config.yaml"), "r"), Loader=yaml.FullLoader)
model_config = config["model"]
# Instantiate model
model = GINet(**model_config).to(device)
# Load pre-trained weights
model_pth_path = os.path.join(pretrained_model_dir, "model.pth")
print(f"Loading model from: {model_pth_path}")
state_dict = torch.load(model_pth_path, map_location=device)
model.load_my_state_dict(state_dict)
return model
def process_representation(dataset_path, smile_column_str, id_column_str, pretrained_dir, device):
"""
Process the dataset to generate molecular representations.
Parameters:
- dataset_path (str or pd.DataFrame): Path to the dataset file or DataFrame.
- pretrained_dir (str): Path to the pre-trained model directory.
- smile_column_str (str): Name of the column containing SMILES strings.
- id_column_str (str): Name of the column containing molecule IDs.
- device (str): Device to use for computation. Can be "cpu", "cuda:0", etc.
Returns:
- udl_representation (pandas.DataFrame): DataFrame containing molecular representations.
"""
# First we read the SMILES dataframe
smiles_df = read_smiles(dataset_path, smile_col=smile_column_str, id_col=id_column_str)
# Load the pre-trained model
pmodel = load_pretrained_model(pretrained_model_dir=pretrained_dir, device=device)
# Gather pre-trained representation
udl_representation = batch_representation(smiles_df, pmodel, smile_column_str, id_column_str, device=device)
return udl_representation

278
utils/mole_predictor.py Normal file
View File

@@ -0,0 +1,278 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
MolE 抗菌活性预测工具
这个脚本提供了使用 MolE 模型预测小分子 SMILES 抗菌活性的功能。
支持命令行和 Python API 调用两种方式。
命令行示例:
python mole_predictor.py input.csv output.csv --smiles-column smiles --id-column chem_id
Python API 示例:
from utils.mole_predictor import predict_csv_file
predict_csv_file(
input_path="input.csv",
output_path="output.csv",
smiles_column="smiles",
id_column="chem_id"
)
"""
import sys
import os
from pathlib import Path
# 添加项目根目录到 Python 路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
import click
import pandas as pd
from typing import Optional, List
from datetime import datetime
from models.broad_spectrum_predictor import (
ParallelBroadSpectrumPredictor,
PredictionConfig,
MoleculeInput,
BroadSpectrumResult
)
def predict_csv_file(
input_path: str,
output_path: Optional[str] = None,
smiles_column: str = "smiles",
id_column: str = "chem_id",
batch_size: int = 100,
n_workers: Optional[int] = None,
device: str = "auto",
add_suffix: bool = True
) -> pd.DataFrame:
"""
预测 CSV 文件中的分子抗菌活性
Args:
input_path: 输入 CSV 文件路径
output_path: 输出 CSV 文件路径,如果为 None 则自动生成
smiles_column: SMILES 列名
id_column: 化合物 ID 列名
batch_size: 批处理大小
n_workers: 工作进程数
device: 计算设备 ("auto", "cpu", "cuda:0" 等)
add_suffix: 是否在输出文件名后添加预测后缀
Returns:
包含预测结果的 DataFrame
"""
print(f"开始处理文件: {input_path}")
# 读取输入文件
input_path_obj = Path(input_path)
if not input_path_obj.exists():
raise FileNotFoundError(f"输入文件不存在: {input_path}")
# 读取 CSV
try:
df_input = pd.read_csv(input_path)
except Exception as e:
raise RuntimeError(f"读取 CSV 文件失败: {e}")
print(f"读取了 {len(df_input)} 条数据")
# 检查列是否存在(大小写不敏感)
columns_lower = {col.lower(): col for col in df_input.columns}
smiles_col_actual = columns_lower.get(smiles_column.lower())
if smiles_col_actual is None:
raise ValueError(
f"未找到 SMILES 列 '{smiles_column}'。可用列: {list(df_input.columns)}"
)
# 处理 ID 列
id_col_actual = columns_lower.get(id_column.lower())
if id_col_actual is None:
print(f"未找到 ID 列 '{id_column}',将自动生成 ID")
df_input[id_column] = [f"mol{i+1}" for i in range(len(df_input))]
id_col_actual = id_column
# 创建预测器配置
config = PredictionConfig(
batch_size=batch_size,
n_workers=n_workers,
device=device
)
# 初始化预测器
print("初始化预测器...")
predictor = ParallelBroadSpectrumPredictor(config)
# 准备分子输入
molecules = [
MoleculeInput(smiles=row[smiles_col_actual], chem_id=str(row[id_col_actual]))
for _, row in df_input.iterrows()
]
# 执行预测
print("开始预测...")
results = predictor.predict_batch(molecules)
# 转换结果为 DataFrame
results_dicts = [r.to_dict() for r in results]
df_results = pd.DataFrame(results_dicts)
# 合并原始数据和预测结果
# 使用 chem_id 作为键进行合并
df_input['_merge_id'] = df_input[id_col_actual].astype(str)
df_results['_merge_id'] = df_results['chem_id'].astype(str)
df_output = df_input.merge(
df_results.drop(columns=['chem_id']),
on='_merge_id',
how='left'
)
df_output = df_output.drop(columns=['_merge_id'])
# 生成输出路径
if output_path is None:
if add_suffix:
output_path = str(input_path_obj.parent / f"{input_path_obj.stem}_predicted{input_path_obj.suffix}")
else:
output_path = str(input_path_obj.parent / f"{input_path_obj.stem}_output{input_path_obj.suffix}")
elif add_suffix:
output_path_obj = Path(output_path)
output_path = str(output_path_obj.parent / f"{output_path_obj.stem}_predicted{output_path_obj.suffix}")
# 保存结果
print(f"保存结果到: {output_path}")
df_output.to_csv(output_path, index=False)
print(f"完成! 预测了 {len(results)} 个分子")
print(f"其中 {sum(r.broad_spectrum for r in results)} 个分子被预测为广谱抗菌")
return df_output
def predict_multiple_files(
input_paths: List[str],
output_dir: Optional[str] = None,
smiles_column: str = "smiles",
id_column: str = "chem_id",
batch_size: int = 100,
n_workers: Optional[int] = None,
device: str = "auto",
add_suffix: bool = True
) -> List[pd.DataFrame]:
"""
批量预测多个 CSV 文件
Args:
input_paths: 输入 CSV 文件路径列表
output_dir: 输出目录,如果为 None 则在原文件目录生成
smiles_column: SMILES 列名
id_column: 化合物 ID 列名
batch_size: 批处理大小
n_workers: 工作进程数
device: 计算设备
add_suffix: 是否在输出文件名后添加预测后缀
Returns:
包含预测结果的 DataFrame 列表
"""
results = []
for input_path in input_paths:
input_path_obj = Path(input_path)
# 确定输出路径
if output_dir is not None:
output_dir_obj = Path(output_dir)
output_dir_obj.mkdir(parents=True, exist_ok=True)
if add_suffix:
output_path = str(output_dir_obj / f"{input_path_obj.stem}_predicted{input_path_obj.suffix}")
else:
output_path = str(output_dir_obj / input_path_obj.name)
else:
output_path = None
# 预测单个文件
try:
df_result = predict_csv_file(
input_path=input_path,
output_path=output_path,
smiles_column=smiles_column,
id_column=id_column,
batch_size=batch_size,
n_workers=n_workers,
device=device,
add_suffix=add_suffix
)
results.append(df_result)
except Exception as e:
print(f"处理文件 {input_path} 时出错: {e}")
continue
return results
# ============================================================================
# 命令行接口
# ============================================================================
@click.command()
@click.argument('input_path', type=click.Path(exists=True))
@click.argument('output_path', type=click.Path(), required=False)
@click.option('--smiles-column', '-s', default='smiles',
help='SMILES 列名 (默认: smiles)')
@click.option('--id-column', '-i', default='chem_id',
help='化合物 ID 列名 (默认: chem_id)')
@click.option('--batch-size', '-b', default=100, type=int,
help='批处理大小 (默认: 100)')
@click.option('--n-workers', '-w', default=None, type=int,
help='工作进程数 (默认: CPU 核心数)')
@click.option('--device', '-d', default='auto',
type=click.Choice(['auto', 'cpu', 'cuda:0', 'cuda:1'], case_sensitive=False),
help='计算设备 (默认: auto)')
@click.option('--add-suffix/--no-add-suffix', default=True,
help='是否在输出文件名后添加 "_predicted" 后缀 (默认: 添加)')
def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers, device, add_suffix):
"""
使用 MolE 模型预测小分子 SMILES 的抗菌活性
INPUT_PATH: 输入 CSV 文件路径
OUTPUT_PATH: 输出 CSV 文件路径 (可选,默认在原文件目录生成)
示例:
python mole_predictor.py input.csv output.csv
python mole_predictor.py input.csv -s SMILES -i ID
python mole_predictor.py input.csv --device cuda:0 --batch-size 200
"""
try:
predict_csv_file(
input_path=input_path,
output_path=output_path,
smiles_column=smiles_column,
id_column=id_column,
batch_size=batch_size,
n_workers=n_workers,
device=device,
add_suffix=add_suffix
)
except Exception as e:
click.echo(f"错误: {e}", err=True)
sys.exit(1)
if __name__ == '__main__':
cli()