add mole predcit module
This commit is contained in:
163
Data/mole/README.md
Normal file
163
Data/mole/README.md
Normal file
@@ -0,0 +1,163 @@
|
||||
## convert old xgboots pickle format
|
||||
|
||||
```bash
|
||||
cd Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001
|
||||
ipython
|
||||
```
|
||||
|
||||
```python
|
||||
import xgboost as xgb
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
ckpt = Path('MolE-XGBoost-08.03.2024_14.20.pkl')
|
||||
out_ckpt = Path('./')
|
||||
|
||||
# 加载旧模型
|
||||
with open(ckpt, 'rb') as f:
|
||||
model = pickle.load(f)
|
||||
|
||||
# 用新格式保存(推荐)
|
||||
model.get_booster().save_model(out_ckpt.joinpath('MolE-XGBoost-08.03.2025_10.17.json'))
|
||||
|
||||
# 或者继续用pickle但清晰格式
|
||||
booster = model.get_booster()
|
||||
booster.feature_names = None
|
||||
with open(out_ckpt.joinpath('MolE-XGBoost-08.03.2025_10.17.pkl'), 'wb') as f:
|
||||
pickle.dump(model, f)
|
||||
```
|
||||
|
||||
## 完整预测流程
|
||||
|
||||
```mermaid
|
||||
SMILES 分子(输入CSV文件)
|
||||
↓
|
||||
[MolE 模型]
|
||||
├── config.yaml(模型配置)
|
||||
└── model.pth(模型权重)
|
||||
↓
|
||||
分子特征表示(1000维向量)
|
||||
↓
|
||||
构建"分子-菌株对"(笛卡尔积)
|
||||
└── maier_screening_results.tsv.gz(菌株列表)
|
||||
↓
|
||||
[XGBoost 模型]
|
||||
└── MolE-XGBoost-08.03.2025_10.17.json(或.pkl)
|
||||
↓
|
||||
对每一对预测:是否抑制生长
|
||||
↓
|
||||
获得原始预测结果(对每个菌株的预测)
|
||||
↓
|
||||
[聚合分析]
|
||||
├── maier_screening_results.tsv.gz(菌株列表)
|
||||
└── strain_info_SF2.xlsx(革兰染色信息)
|
||||
↓
|
||||
最终预测结果
|
||||
↓
|
||||
输出CSV文件
|
||||
```
|
||||
|
||||
## 所需文件清单
|
||||
|
||||
| 步骤 | 文件名 | 用途 | 备注 |
|
||||
|------|--------|------|------|
|
||||
| **MolE 模型** | `config.yaml` | 定义MolE网络结构 | YAML配置文件 |
|
||||
| | `model.pth` | MolE模型权重 | PyTorch格式 |
|
||||
| **构建菌株对** | `maier_screening_results.tsv.gz` | 提供40个菌株列表 | 压缩的TSV文件 |
|
||||
| **XGBoost 预测** | `MolE-XGBoost-08.03.2025_10.17.json` | 预测分子-菌株对 | JSON格式(新)或PKL格式(旧) |
|
||||
| **聚合分析** | `maier_screening_results.tsv.gz` | 菌株名称和统计 | 复用(与构建菌株对同一文件) |
|
||||
| | `strain_info_SF2.xlsx` | 革兰染色分类信息 | Excel格式 |
|
||||
|
||||
## 文件存放位置
|
||||
|
||||
所有文件应位于:
|
||||
```
|
||||
Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001/
|
||||
├── config.yaml
|
||||
├── model.pth
|
||||
├── MolE-XGBoost-08.03.2025_10.17.json
|
||||
├── maier_screening_results.tsv.gz
|
||||
└── strain_info_SF2.xlsx
|
||||
```
|
||||
|
||||
## 代码中的对应关系
|
||||
|
||||
```python
|
||||
# PredictionConfig 中的配置
|
||||
@dataclass
|
||||
class PredictionConfig:
|
||||
xgboost_model_path = "MolE-XGBoost-08.03.2025_10.17.json"
|
||||
mole_model_path = "model_ginconcat_btwin_100k_d8000_l0.0001" # 目录(包含config.yaml + model.pth)
|
||||
strain_categories_path = "maier_screening_results.tsv.gz"
|
||||
gram_info_path = "strain_info_SF2.xlsx"
|
||||
```
|
||||
|
||||
## 数据流向总结
|
||||
|
||||
1. **输入**:CSV文件中的SMILES分子
|
||||
2. **MolE处理**:分子 → 1000维特征向量
|
||||
3. **菌株配对**:1个分子 × 40个菌株 = 40对
|
||||
4. **XGBoost预测**:每对 → 抑制概率
|
||||
5. **聚合分析**:统计和分类(按革兰染色)
|
||||
6. **输出**:CSV文件中的预测结果(包含8个指标)
|
||||
|
||||
## 参考文件
|
||||
|
||||
1. `maier_screening_results.tsv.gz` - 菌株列表和筛选数据
|
||||
|
||||
```python
|
||||
self.maier_screen = pd.read_csv(
|
||||
self.config.strain_categories_path, sep='\t', index_col=0
|
||||
)
|
||||
self.strain_ohe = self._prep_ohe(self.maier_screen.columns) # 独热编码
|
||||
```
|
||||
|
||||
包含所有已知菌株的名称(40个菌株)
|
||||
用于与每个分子做笛卡尔积(分子×菌株),生成所有"分子-菌株对"
|
||||
XGBoost为每一对预测:是否能抑制该菌株的生长
|
||||
|
||||
2. `strain_info_SF2.xlsx` - 革兰染色信息
|
||||
|
||||
```python
|
||||
self.maier_strains = pd.read_excel(self.config.gram_info_path, ...)
|
||||
gram_dict = self.maier_strains[["Gram stain"]].to_dict()["Gram stain"]
|
||||
```
|
||||
|
||||
记录每个菌株的革兰染色属性:阳性(positive) 或 阴性(negative)
|
||||
用于将预测结果按革兰染色分类统计
|
||||
|
||||
预测结果示例:
|
||||
某分子 mol1 的预测结果会包括:
|
||||
|
||||
```python
|
||||
BroadSpectrumResult(
|
||||
chem_id='mol1',
|
||||
apscore_total=2.5, # 对所有菌株的抗菌分数
|
||||
apscore_gnegative=2.1, # 仅对革兰阴性菌的分数
|
||||
apscore_gpositive=2.8, # 仅对革兰阳性菌的分数
|
||||
ginhib_total=25, # 抑制的菌株总数
|
||||
ginhib_gnegative=12, # 抑制的革兰阴性菌数
|
||||
ginhib_gpositive=13, # 抑制的革兰阳性菌数
|
||||
broad_spectrum=1 # 是否广谱(≥10个菌株)
|
||||
)
|
||||
```
|
||||
|
||||
结果解读:
|
||||
|
||||
## BroadSpectrumResult 字段说明表
|
||||
|
||||
| 字段名 | 数据类型 | 计算方法 | 含义说明 |
|
||||
|--------|----------|----------|---------|
|
||||
| `chem_id` | 字符串 | 输入的化合物标识符 | 化合物的唯一标识,如 "mol1"、"compound_001" 等 |
|
||||
| `apscore_total` | 浮点数 | `log(gmean(所有40个菌株的预测概率))` | 总体抗菌潜力分数:所有菌株预测概率的几何平均数的对数。值越高表示抗菌活性越强;负值表示整体抑制概率较低 |
|
||||
| `apscore_gnegative` | 浮点数 | `log(gmean(革兰阴性菌株的预测概率))` | 革兰阴性菌抗菌潜力分数:仅针对革兰阴性菌株计算的抗菌分数。用于判断对阴性菌的特异性 |
|
||||
| `apscore_gpositive` | 浮点数 | `log(gmean(革兰阳性菌株的预测概率))` | 革兰阳性菌抗菌潜力分数:仅针对革兰阳性菌株计算的抗菌分数。用于判断对阳性菌的特异性 |
|
||||
| `ginhib_total` | 整数 | `sum(所有菌株的二值化预测)` | 总抑制菌株数:预测被抑制的菌株总数(概率 ≥ 0.04374 的菌株数量)。范围 0-40 |
|
||||
| `ginhib_gnegative` | 整数 | `sum(革兰阴性菌株的二值化预测)` | 革兰阴性菌抑制数:预测被抑制的革兰阴性菌株数量。范围 0-20 |
|
||||
| `ginhib_gpositive` | 整数 | `sum(革兰阳性菌株的二值化预测)` | 革兰阳性菌抑制数:预测被抑制的革兰阳性菌株数量。范围 0-20 |
|
||||
| `broad_spectrum` | 整数 (0/1) | `1 if ginhib_total >= 10 else 0` | 广谱抗菌标志:如果抑制菌株数 ≥ 10,判定为广谱抗菌药物(1),否则为窄谱(0) |
|
||||
|
||||
说明
|
||||
|
||||
- **apscore_* 类字段**:基于预测概率的连续评分,反映抗菌活性强度
|
||||
- **ginhib_* 类字段**:基于二值化预测的离散计数,反映抑制范围
|
||||
- **broad_spectrum**:基于 ginhib_total 的布尔判定,快速标识广谱特性
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,28 @@
|
||||
batch_size: 1000 # batch size
|
||||
warm_up: 10 # warm-up epochs
|
||||
epochs: 1000 # total number of epochs
|
||||
|
||||
load_model: None # resume training
|
||||
eval_every_n_epochs: 1 # validation frequency
|
||||
save_every_n_epochs: 5 # automatic model saving frequecy
|
||||
|
||||
fp16_precision: False # float precision 16 (i.e. True/False)
|
||||
init_lr: 0.0005 # initial learning rate for Adam
|
||||
weight_decay: 1e-5 # weight decay for Adam
|
||||
gpu: cuda:0 # training GPU
|
||||
|
||||
model_type: gin_concat # GNN backbone (i.e., gin/gcn)
|
||||
model:
|
||||
num_layer: 5 # number of graph conv layers
|
||||
emb_dim: 200 # embedding dimension in graph conv layers
|
||||
feat_dim: 8000 # output feature dimention
|
||||
drop_ratio: 0.0 # dropout ratio
|
||||
pool: add # readout pooling (i.e., mean/max/add)
|
||||
|
||||
dataset:
|
||||
num_workers: 50 # dataloader number of workers
|
||||
valid_size: 0.1 # ratio of validation data
|
||||
data_path: data/pubchem_data/pubchem_100k_random.txt # path of pre-training data
|
||||
|
||||
loss:
|
||||
l: 0.0001 # Lambda parameter
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
26
models/__init__.py
Normal file
26
models/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
SIME Models Package
|
||||
|
||||
This package contains models for antimicrobial activity prediction.
|
||||
"""
|
||||
|
||||
from .broad_spectrum_predictor import (
|
||||
ParallelBroadSpectrumPredictor,
|
||||
PredictionConfig,
|
||||
MoleculeInput,
|
||||
BroadSpectrumResult,
|
||||
create_predictor,
|
||||
predict_smiles,
|
||||
predict_file
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'ParallelBroadSpectrumPredictor',
|
||||
'PredictionConfig',
|
||||
'MoleculeInput',
|
||||
'BroadSpectrumResult',
|
||||
'create_predictor',
|
||||
'predict_smiles',
|
||||
'predict_file'
|
||||
]
|
||||
|
||||
567
models/broad_spectrum_predictor.py
Normal file
567
models/broad_spectrum_predictor.py
Normal file
@@ -0,0 +1,567 @@
|
||||
"""
|
||||
并行广谱抗菌预测器模块
|
||||
|
||||
提供高性能的分子广谱抗菌活性预测功能,支持批量处理和多进程并行计算。
|
||||
基于MolE分子表示和XGBoost模型进行预测。
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import pickle
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import multiprocessing as mp
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from typing import List, Dict, Union, Optional, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from scipy.stats.mstats import gmean
|
||||
from sklearn.preprocessing import OneHotEncoder
|
||||
|
||||
from .mole_representation import process_representation
|
||||
|
||||
|
||||
@dataclass
|
||||
class PredictionConfig:
|
||||
"""预测配置参数"""
|
||||
xgboost_model_path: str = None
|
||||
mole_model_path: str = None
|
||||
strain_categories_path: str = None
|
||||
gram_info_path: str = None
|
||||
app_threshold: float = 0.04374140128493309
|
||||
min_nkill: int = 10
|
||||
batch_size: int = 100
|
||||
n_workers: Optional[int] = None
|
||||
device: str = "auto"
|
||||
|
||||
def __post_init__(self):
|
||||
"""设置默认路径"""
|
||||
from pathlib import Path
|
||||
|
||||
# 获取当前文件所在目录
|
||||
current_file = Path(__file__).resolve()
|
||||
project_root = current_file.parent.parent # models -> 项目根
|
||||
|
||||
data_dir = project_root / "Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001"
|
||||
|
||||
# 设置所有路径
|
||||
if self.mole_model_path is None:
|
||||
self.mole_model_path = str(data_dir)
|
||||
|
||||
if self.xgboost_model_path is None:
|
||||
self.xgboost_model_path = str(data_dir / "MolE-XGBoost-08.03.2025_10.17.pkl")
|
||||
|
||||
if self.strain_categories_path is None:
|
||||
self.strain_categories_path = str(data_dir / "maier_screening_results.tsv.gz")
|
||||
|
||||
if self.gram_info_path is None:
|
||||
self.gram_info_path = str(data_dir / "strain_info_SF2.xlsx")
|
||||
|
||||
@dataclass
|
||||
class MoleculeInput:
|
||||
"""分子输入数据结构"""
|
||||
smiles: str
|
||||
chem_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BroadSpectrumResult:
|
||||
"""广谱抗菌预测结果"""
|
||||
chem_id: str
|
||||
apscore_total: float
|
||||
apscore_gnegative: float
|
||||
apscore_gpositive: float
|
||||
ginhib_total: int
|
||||
ginhib_gnegative: int
|
||||
ginhib_gpositive: int
|
||||
broad_spectrum: int
|
||||
|
||||
def to_dict(self) -> Dict[str, Union[str, float, int]]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
'chem_id': self.chem_id,
|
||||
'apscore_total': self.apscore_total,
|
||||
'apscore_gnegative': self.apscore_gnegative,
|
||||
'apscore_gpositive': self.apscore_gpositive,
|
||||
'ginhib_total': self.ginhib_total,
|
||||
'ginhib_gnegative': self.ginhib_gnegative,
|
||||
'ginhib_gpositive': self.ginhib_gpositive,
|
||||
'broad_spectrum': self.broad_spectrum
|
||||
}
|
||||
|
||||
|
||||
class BroadSpectrumPredictor:
|
||||
"""
|
||||
广谱抗菌预测器
|
||||
|
||||
基于MolE分子表示和XGBoost模型预测分子的广谱抗菌活性。
|
||||
支持单分子和批量预测,提供详细的抗菌潜力分析。
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[PredictionConfig] = None) -> None:
|
||||
"""
|
||||
初始化预测器
|
||||
|
||||
Args:
|
||||
config: 预测配置参数,如果为None则使用默认配置
|
||||
"""
|
||||
self.config = config or PredictionConfig()
|
||||
self.n_workers = self.config.n_workers or mp.cpu_count()
|
||||
|
||||
# 验证文件路径
|
||||
self._validate_paths()
|
||||
|
||||
# 预加载共享数据
|
||||
self._load_shared_data()
|
||||
|
||||
def _validate_paths(self) -> None:
|
||||
"""验证必要文件路径是否存在"""
|
||||
required_files = {
|
||||
"mole_model": self.config.mole_model_path,
|
||||
"xgboost_model": self.config.xgboost_model_path,
|
||||
"strain_categories": self.config.strain_categories_path,
|
||||
"gram_info": self.config.gram_info_path,
|
||||
}
|
||||
|
||||
for name, file_path in required_files.items():
|
||||
if file_path is None:
|
||||
raise ValueError(f"{name} is None! Check __post_init__ configuration")
|
||||
if not Path(file_path).exists():
|
||||
raise FileNotFoundError(f"Required {name} not found: {file_path}")
|
||||
|
||||
def _load_shared_data(self) -> None:
|
||||
"""加载共享数据(菌株信息、革兰染色信息等)"""
|
||||
try:
|
||||
# 加载菌株筛选数据
|
||||
self.maier_screen: pd.DataFrame = pd.read_csv(
|
||||
self.config.strain_categories_path, sep='\t', index_col=0
|
||||
)
|
||||
|
||||
# 准备菌株独热编码
|
||||
self.strain_ohe: pd.DataFrame = self._prep_ohe(self.maier_screen.columns)
|
||||
|
||||
# 加载革兰染色信息
|
||||
self.maier_strains: pd.DataFrame = pd.read_excel(
|
||||
self.config.gram_info_path,
|
||||
skiprows=[0, 1, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54],
|
||||
index_col="NT data base"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load shared data: {str(e)}")
|
||||
|
||||
def _prep_ohe(self, categories: pd.Index) -> pd.DataFrame:
|
||||
"""
|
||||
准备菌株的独热编码
|
||||
|
||||
Args:
|
||||
categories: 菌株类别索引
|
||||
|
||||
Returns:
|
||||
独热编码后的DataFrame
|
||||
"""
|
||||
try:
|
||||
# 新版本 sklearn 使用 sparse_output
|
||||
ohe = OneHotEncoder(sparse_output=False)
|
||||
except TypeError:
|
||||
# 旧版本 sklearn 使用 sparse
|
||||
ohe = OneHotEncoder(sparse=False)
|
||||
|
||||
ohe.fit(pd.DataFrame(categories))
|
||||
cat_ohe = pd.DataFrame(
|
||||
ohe.transform(pd.DataFrame(categories)),
|
||||
columns=categories,
|
||||
index=categories
|
||||
)
|
||||
return cat_ohe
|
||||
|
||||
def _get_mole_representation(self, molecules: List[MoleculeInput]) -> pd.DataFrame:
|
||||
"""
|
||||
获取分子的MolE表示
|
||||
|
||||
Args:
|
||||
molecules: 分子输入列表
|
||||
|
||||
Returns:
|
||||
MolE特征表示DataFrame
|
||||
"""
|
||||
# 准备输入数据
|
||||
df_data = []
|
||||
for i, mol in enumerate(molecules):
|
||||
chem_id = mol.chem_id or f"mol{i+1}"
|
||||
df_data.append({"smiles": mol.smiles, "chem_id": chem_id})
|
||||
|
||||
df = pd.DataFrame(df_data)
|
||||
|
||||
# 确定设备
|
||||
device = self.config.device
|
||||
if device == "auto":
|
||||
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# 获取MolE表示
|
||||
return process_representation(
|
||||
dataset_path=df,
|
||||
smile_column_str="smiles",
|
||||
id_column_str="chem_id",
|
||||
pretrained_dir=self.config.mole_model_path,
|
||||
device=device
|
||||
)
|
||||
|
||||
def _add_strains(self, chemfeats_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
添加菌株信息到化学特征(笛卡尔积)
|
||||
|
||||
Args:
|
||||
chemfeats_df: 化学特征DataFrame
|
||||
|
||||
Returns:
|
||||
包含菌株信息的特征DataFrame
|
||||
"""
|
||||
# 准备化学特征
|
||||
chemfe = chemfeats_df.reset_index().rename(columns={"index": "chem_id"})
|
||||
chemfe["chem_id"] = chemfe["chem_id"].astype(str)
|
||||
|
||||
# 准备独热编码
|
||||
sohe = self.strain_ohe.reset_index().rename(columns={"index": "strain_name"})
|
||||
|
||||
# 笛卡尔积合并
|
||||
xpred = chemfe.merge(sohe, how="cross")
|
||||
xpred["pred_id"] = xpred["chem_id"].str.cat(xpred["strain_name"], sep=":")
|
||||
|
||||
xpred = xpred.set_index("pred_id")
|
||||
xpred = xpred.drop(columns=["chem_id", "strain_name"])
|
||||
|
||||
return xpred
|
||||
|
||||
def _gram_stain(self, label_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
添加革兰染色信息
|
||||
|
||||
Args:
|
||||
label_df: 包含菌株名称的DataFrame
|
||||
|
||||
Returns:
|
||||
添加革兰染色信息后的DataFrame
|
||||
"""
|
||||
df_label = label_df.copy()
|
||||
|
||||
# 提取NT编号
|
||||
df_label["nt_number"] = df_label["strain_name"].apply(
|
||||
lambda x: re.search(r".*?\((NT\d+)\)", x).group(1) if re.search(r".*?\((NT\d+)\)", x) else None
|
||||
)
|
||||
|
||||
# 创建革兰染色字典
|
||||
gram_dict = self.maier_strains[["Gram stain"]].to_dict()["Gram stain"]
|
||||
|
||||
# 添加染色信息
|
||||
df_label["gram_stain"] = df_label["nt_number"].apply(gram_dict.get)
|
||||
|
||||
return df_label
|
||||
|
||||
def _antimicrobial_potential(self, score_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
计算抗菌潜力分数
|
||||
|
||||
Args:
|
||||
score_df: 预测分数DataFrame
|
||||
|
||||
Returns:
|
||||
聚合后的抗菌潜力DataFrame
|
||||
"""
|
||||
# 分离化合物ID和菌株名
|
||||
score_df["chem_id"] = score_df["pred_id"].str.split(":", expand=True)[0]
|
||||
score_df["strain_name"] = score_df["pred_id"].str.split(":", expand=True)[1]
|
||||
|
||||
# 添加革兰染色信息
|
||||
pred_df = self._gram_stain(score_df)
|
||||
|
||||
# 计算抗菌潜力分数(几何平均数的对数)
|
||||
apscore_total = pred_df.groupby("chem_id")["1"].apply(gmean).to_frame().rename(
|
||||
columns={"1": "apscore_total"}
|
||||
)
|
||||
apscore_total["apscore_total"] = np.log(apscore_total["apscore_total"])
|
||||
|
||||
# 按革兰染色分组的抗菌分数
|
||||
apscore_gram = pred_df.groupby(["chem_id", "gram_stain"])["1"].apply(gmean).unstack().rename(
|
||||
columns={"negative": "apscore_gnegative", "positive": "apscore_gpositive"}
|
||||
)
|
||||
apscore_gram["apscore_gnegative"] = np.log(apscore_gram["apscore_gnegative"])
|
||||
apscore_gram["apscore_gpositive"] = np.log(apscore_gram["apscore_gpositive"])
|
||||
|
||||
# 被抑制菌株数统计
|
||||
inhibted_total = pred_df.groupby("chem_id")["growth_inhibition"].sum().to_frame().rename(
|
||||
columns={"growth_inhibition": "ginhib_total"}
|
||||
)
|
||||
|
||||
# 按革兰染色分组的被抑制菌株数
|
||||
inhibted_gram = pred_df.groupby(["chem_id", "gram_stain"])["growth_inhibition"].sum().unstack().rename(
|
||||
columns={"negative": "ginhib_gnegative", "positive": "ginhib_gpositive"}
|
||||
)
|
||||
|
||||
# 合并所有结果
|
||||
agg_pred = apscore_total.join(apscore_gram).join(inhibted_total).join(inhibted_gram)
|
||||
|
||||
# 填充NaN值
|
||||
agg_pred = agg_pred.fillna(0)
|
||||
|
||||
return agg_pred
|
||||
|
||||
|
||||
def _predict_batch_worker(batch_data: Tuple[pd.DataFrame, int],
|
||||
model_path: str,
|
||||
app_threshold: float) -> Tuple[int, pd.DataFrame]:
|
||||
"""
|
||||
批次预测工作函数(用于多进程)
|
||||
|
||||
Args:
|
||||
batch_data: (特征数据, 批次ID)
|
||||
model_path: XGBoost模型路径
|
||||
app_threshold: 抑制阈值
|
||||
|
||||
Returns:
|
||||
(批次ID, 预测结果DataFrame)
|
||||
"""
|
||||
import warnings
|
||||
# 忽略所有XGBoost版本相关的警告
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="xgboost")
|
||||
X_input, batch_id = batch_data
|
||||
|
||||
# 加载模型
|
||||
with open(model_path, "rb") as file:
|
||||
model = pickle.load(file)
|
||||
# 修复特征名称兼容性问题
|
||||
# 原因:模型使用旧版 XGBoost 保存时,特征列为元组格式(如 "('bacteria_name',)")
|
||||
# 新版 XGBoost 严格检查特征名称匹配,导致预测失败。
|
||||
# 解决:清除 XGBoost 内部的特征名称验证,直接使用输入特征进行预测
|
||||
# 注意:此操作不改变模型权重和预测逻辑,只禁用格式检查,预测结果保持一致
|
||||
if hasattr(model, 'get_booster'):
|
||||
model.get_booster().feature_names = None
|
||||
|
||||
# 进行预测
|
||||
y_pred = model.predict_proba(X_input)
|
||||
pred_df = pd.DataFrame(y_pred, columns=["0", "1"], index=X_input.index)
|
||||
|
||||
# 二值化预测结果
|
||||
pred_df["growth_inhibition"] = pred_df["1"].apply(
|
||||
lambda x: 1 if x >= app_threshold else 0
|
||||
)
|
||||
|
||||
return batch_id, pred_df
|
||||
|
||||
|
||||
class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
|
||||
"""
|
||||
并行广谱抗菌预测器
|
||||
|
||||
继承自BroadSpectrumPredictor,添加了多进程并行处理能力,
|
||||
适用于大规模分子批量预测。
|
||||
"""
|
||||
|
||||
def predict_single(self, molecule: MoleculeInput) -> BroadSpectrumResult:
|
||||
"""
|
||||
预测单个分子的广谱抗菌活性
|
||||
|
||||
Args:
|
||||
molecule: 分子输入数据
|
||||
|
||||
Returns:
|
||||
广谱抗菌预测结果
|
||||
"""
|
||||
results = self.predict_batch([molecule])
|
||||
return results[0]
|
||||
|
||||
def predict_batch(self, molecules: List[MoleculeInput]) -> List[BroadSpectrumResult]:
|
||||
"""
|
||||
批量预测分子的广谱抗菌活性
|
||||
|
||||
Args:
|
||||
molecules: 分子输入列表
|
||||
|
||||
Returns:
|
||||
广谱抗菌预测结果列表
|
||||
"""
|
||||
if not molecules:
|
||||
return []
|
||||
|
||||
# 获取MolE表示
|
||||
print(f"Processing {len(molecules)} molecules...")
|
||||
mole_representation = self._get_mole_representation(molecules)
|
||||
|
||||
# 添加菌株信息
|
||||
print("Preparing strain-level features...")
|
||||
X_input = self._add_strains(mole_representation)
|
||||
|
||||
# 分批处理
|
||||
print(f"Starting parallel prediction with {self.n_workers} workers...")
|
||||
batches = []
|
||||
for i in range(0, len(X_input), self.config.batch_size):
|
||||
batch = X_input.iloc[i:i+self.config.batch_size]
|
||||
batches.append((batch, i // self.config.batch_size))
|
||||
|
||||
# 并行预测
|
||||
results = {}
|
||||
with ProcessPoolExecutor(max_workers=self.n_workers) as executor:
|
||||
futures = {
|
||||
executor.submit(_predict_batch_worker, (batch_data, batch_id),
|
||||
self.config.xgboost_model_path,
|
||||
self.config.app_threshold): batch_id
|
||||
for batch_data, batch_id in batches
|
||||
}
|
||||
|
||||
for future in as_completed(futures):
|
||||
batch_id, pred_df = future.result()
|
||||
results[batch_id] = pred_df
|
||||
print(f"Batch {batch_id} completed")
|
||||
|
||||
# 合并结果
|
||||
print("Merging prediction results...")
|
||||
all_pred_df = pd.concat([results[i] for i in sorted(results.keys())])
|
||||
|
||||
# 计算抗菌潜力
|
||||
print("Calculating antimicrobial potential scores...")
|
||||
all_pred_df = all_pred_df.reset_index()
|
||||
agg_df = self._antimicrobial_potential(all_pred_df)
|
||||
|
||||
# 判断广谱抗菌
|
||||
agg_df["broad_spectrum"] = agg_df["ginhib_total"].apply(
|
||||
lambda x: 1 if x >= self.config.min_nkill else 0
|
||||
)
|
||||
|
||||
# 转换为结果对象
|
||||
results_list = []
|
||||
for _, row in agg_df.iterrows():
|
||||
result = BroadSpectrumResult(
|
||||
chem_id=row.name,
|
||||
apscore_total=row["apscore_total"],
|
||||
apscore_gnegative=row["apscore_gnegative"],
|
||||
apscore_gpositive=row["apscore_gpositive"],
|
||||
ginhib_total=int(row["ginhib_total"]),
|
||||
ginhib_gnegative=int(row["ginhib_gnegative"]),
|
||||
ginhib_gpositive=int(row["ginhib_gpositive"]),
|
||||
broad_spectrum=int(row["broad_spectrum"])
|
||||
)
|
||||
results_list.append(result)
|
||||
|
||||
return results_list
|
||||
|
||||
def predict_from_smiles(self,
|
||||
smiles_list: List[str],
|
||||
chem_ids: Optional[List[str]] = None) -> List[BroadSpectrumResult]:
|
||||
"""
|
||||
从SMILES字符串列表预测广谱抗菌活性
|
||||
|
||||
Args:
|
||||
smiles_list: SMILES字符串列表
|
||||
chem_ids: 化合物ID列表,如果为None则自动生成
|
||||
|
||||
Returns:
|
||||
广谱抗菌预测结果列表
|
||||
"""
|
||||
if chem_ids is None:
|
||||
chem_ids = [f"mol{i+1}" for i in range(len(smiles_list))]
|
||||
|
||||
if len(smiles_list) != len(chem_ids):
|
||||
raise ValueError("smiles_list and chem_ids must have the same length")
|
||||
|
||||
molecules = [
|
||||
MoleculeInput(smiles=smiles, chem_id=chem_id)
|
||||
for smiles, chem_id in zip(smiles_list, chem_ids)
|
||||
]
|
||||
|
||||
return self.predict_batch(molecules)
|
||||
|
||||
def predict_from_file(self,
|
||||
file_path: str,
|
||||
smiles_column: str = "smiles",
|
||||
id_column: str = "chem_id") -> List[BroadSpectrumResult]:
|
||||
"""
|
||||
从文件预测广谱抗菌活性
|
||||
|
||||
Args:
|
||||
file_path: 输入文件路径(支持CSV/TSV)
|
||||
smiles_column: SMILES列名
|
||||
id_column: 化合物ID列名
|
||||
|
||||
Returns:
|
||||
广谱抗菌预测结果列表
|
||||
"""
|
||||
# 读取文件
|
||||
if file_path.endswith('.tsv'):
|
||||
df = pd.read_csv(file_path, sep='\t')
|
||||
else:
|
||||
df = pd.read_csv(file_path)
|
||||
|
||||
# 验证列存在(大小写不敏感)
|
||||
columns_lower = {col.lower(): col for col in df.columns}
|
||||
|
||||
smiles_col_actual = columns_lower.get(smiles_column.lower())
|
||||
if smiles_col_actual is None:
|
||||
raise ValueError(f"Column '{smiles_column}' not found in file. Available columns: {list(df.columns)}")
|
||||
|
||||
# 处理ID列
|
||||
id_col_actual = columns_lower.get(id_column.lower())
|
||||
if id_col_actual is None:
|
||||
df[id_column] = [f"mol{i+1}" for i in range(len(df))]
|
||||
id_col_actual = id_column
|
||||
|
||||
# 创建分子输入
|
||||
molecules = [
|
||||
MoleculeInput(smiles=row[smiles_col_actual], chem_id=str(row[id_col_actual]))
|
||||
for _, row in df.iterrows()
|
||||
]
|
||||
|
||||
return self.predict_batch(molecules)
|
||||
|
||||
|
||||
def create_predictor(config: Optional[PredictionConfig] = None) -> ParallelBroadSpectrumPredictor:
|
||||
"""
|
||||
创建并行广谱抗菌预测器实例
|
||||
|
||||
Args:
|
||||
config: 预测配置参数
|
||||
|
||||
Returns:
|
||||
预测器实例
|
||||
"""
|
||||
return ParallelBroadSpectrumPredictor(config)
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def predict_smiles(smiles_list: List[str],
|
||||
chem_ids: Optional[List[str]] = None,
|
||||
config: Optional[PredictionConfig] = None) -> List[BroadSpectrumResult]:
|
||||
"""
|
||||
便捷函数:直接从SMILES列表预测广谱抗菌活性
|
||||
|
||||
Args:
|
||||
smiles_list: SMILES字符串列表
|
||||
chem_ids: 化合物ID列表
|
||||
config: 预测配置
|
||||
|
||||
Returns:
|
||||
预测结果列表
|
||||
"""
|
||||
predictor = create_predictor(config)
|
||||
return predictor.predict_from_smiles(smiles_list, chem_ids)
|
||||
|
||||
|
||||
def predict_file(file_path: str,
|
||||
smiles_column: str = "smiles",
|
||||
id_column: str = "chem_id",
|
||||
config: Optional[PredictionConfig] = None) -> List[BroadSpectrumResult]:
|
||||
"""
|
||||
便捷函数:从文件预测广谱抗菌活性
|
||||
|
||||
Args:
|
||||
file_path: 输入文件路径
|
||||
smiles_column: SMILES列名
|
||||
id_column: ID列名
|
||||
config: 预测配置
|
||||
|
||||
Returns:
|
||||
预测结果列表
|
||||
"""
|
||||
predictor = create_predictor(config)
|
||||
return predictor.predict_from_file(file_path, smiles_column, id_column)
|
||||
|
||||
179
models/dataset_representation.py
Normal file
179
models/dataset_representation.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import os
|
||||
import yaml
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
import torch
|
||||
from torch_geometric.data import Data, Dataset, Batch
|
||||
|
||||
from rdkit import Chem
|
||||
from rdkit.Chem.rdchem import BondType as BT
|
||||
from rdkit import RDLogger
|
||||
|
||||
RDLogger.DisableLog('rdApp.*')
|
||||
|
||||
|
||||
ATOM_LIST = list(range(1,119))
|
||||
CHIRALITY_LIST = [
|
||||
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
|
||||
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
|
||||
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
|
||||
Chem.rdchem.ChiralType.CHI_OTHER
|
||||
]
|
||||
BOND_LIST = [
|
||||
BT.SINGLE,
|
||||
BT.DOUBLE,
|
||||
BT.TRIPLE,
|
||||
BT.AROMATIC
|
||||
]
|
||||
BONDDIR_LIST = [
|
||||
Chem.rdchem.BondDir.NONE,
|
||||
Chem.rdchem.BondDir.ENDUPRIGHT,
|
||||
Chem.rdchem.BondDir.ENDDOWNRIGHT
|
||||
]
|
||||
|
||||
|
||||
class MoleculeDataset(Dataset):
|
||||
"""
|
||||
Dataset class for creating molecular graphs.
|
||||
|
||||
Attributes:
|
||||
- smile_df (pandas.DataFrame): DataFrame containing SMILES data.
|
||||
- smile_column (str): Name of the column containing SMILES strings.
|
||||
- id_column (str): Name of the column containing molecule IDs.
|
||||
"""
|
||||
|
||||
def __init__(self, smile_df, smile_column, id_column):
|
||||
super(Dataset, self).__init__()
|
||||
|
||||
# Gather the SMILES and the corresponding IDs
|
||||
self.smiles_data = smile_df[smile_column].tolist()
|
||||
self.id_data = smile_df[id_column].tolist()
|
||||
|
||||
def __getitem__(self, index):
|
||||
# Get the molecule
|
||||
mol = Chem.MolFromSmiles(self.smiles_data[index])
|
||||
mol = Chem.AddHs(mol)
|
||||
|
||||
#########################
|
||||
# Get the molecule info #
|
||||
#########################
|
||||
type_idx = []
|
||||
chirality_idx = []
|
||||
atomic_number = []
|
||||
|
||||
# Roberto: Might want to add more features later on. Such as atomic spin
|
||||
for atom in mol.GetAtoms():
|
||||
if atom.GetAtomicNum() == 0:
|
||||
print(self.id_data[index])
|
||||
|
||||
type_idx.append(ATOM_LIST.index(atom.GetAtomicNum()))
|
||||
chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag()))
|
||||
atomic_number.append(atom.GetAtomicNum())
|
||||
|
||||
x1 = torch.tensor(type_idx, dtype=torch.long).view(-1,1)
|
||||
x2 = torch.tensor(chirality_idx, dtype=torch.long).view(-1,1)
|
||||
x = torch.cat([x1, x2], dim=-1)
|
||||
|
||||
row, col, edge_feat = [], [], []
|
||||
for bond in mol.GetBonds():
|
||||
start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
|
||||
row += [start, end]
|
||||
col += [end, start]
|
||||
edge_feat.append([
|
||||
BOND_LIST.index(bond.GetBondType()),
|
||||
BONDDIR_LIST.index(bond.GetBondDir())
|
||||
])
|
||||
edge_feat.append([
|
||||
BOND_LIST.index(bond.GetBondType()),
|
||||
BONDDIR_LIST.index(bond.GetBondDir())
|
||||
])
|
||||
|
||||
edge_index = torch.tensor([row, col], dtype=torch.long)
|
||||
edge_attr = torch.tensor(np.array(edge_feat), dtype=torch.long)
|
||||
|
||||
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr,
|
||||
chem_id=self.id_data[index])
|
||||
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.smiles_data)
|
||||
|
||||
def get(self, index):
|
||||
return self.__getitem__(index)
|
||||
|
||||
def len(self):
|
||||
return self.__len__()
|
||||
|
||||
|
||||
def batch_representation(smile_df, dl_model, column_str, id_str, batch_size=10_000, id_is_str=True, device="cuda:0"):
|
||||
"""
|
||||
Generate molecular representations using a Deep Learning model.
|
||||
|
||||
Parameters:
|
||||
- smile_df (pandas.DataFrame): DataFrame containing SMILES data.
|
||||
- dl_model: Deep Learning model for molecular representation.
|
||||
- column_str (str): Name of the column containing SMILES strings.
|
||||
- id_str (str): Name of the column containing molecule IDs.
|
||||
- batch_size (int, optional): Batch size for processing (default is 10,000).
|
||||
- id_is_str (bool, optional): Whether IDs are strings (default is True).
|
||||
- device (str, optional): Device for computation (default is "cuda:0").
|
||||
|
||||
Returns:
|
||||
- chem_representation (pandas.DataFrame): DataFrame containing molecular representations.
|
||||
"""
|
||||
|
||||
# First we create a list of graphs
|
||||
molecular_graph_dataset = MoleculeDataset(smile_df, column_str, id_str)
|
||||
graph_list = [g for g in molecular_graph_dataset]
|
||||
|
||||
# Determine number of loops to do given the batch size
|
||||
n_batches = len(graph_list) // batch_size
|
||||
|
||||
# Are all molecules accounted for?
|
||||
remaining_molecules = len(graph_list) % batch_size
|
||||
|
||||
# Starting indices
|
||||
start, end = 0, batch_size
|
||||
|
||||
# Determine number of iterations
|
||||
if remaining_molecules == 0:
|
||||
n_iter = n_batches
|
||||
|
||||
elif remaining_molecules > 0:
|
||||
n_iter = n_batches + 1
|
||||
|
||||
# A list to store the batch dataframes
|
||||
batch_dataframes = []
|
||||
|
||||
# Iterate over the batches
|
||||
for i in range(n_iter):
|
||||
# Start batch object
|
||||
batch_obj = Batch()
|
||||
graph_batch = batch_obj.from_data_list(graph_list[start:end])
|
||||
graph_batch = graph_batch.to(device)
|
||||
|
||||
# Gather the representation
|
||||
with torch.no_grad():
|
||||
dl_model.eval()
|
||||
h_representation, _ = dl_model(graph_batch)
|
||||
chem_ids = graph_batch.chem_id
|
||||
|
||||
batch_df = pd.DataFrame(h_representation.cpu().numpy(), index=chem_ids)
|
||||
batch_dataframes.append(batch_df)
|
||||
|
||||
# Get the next batch
|
||||
## In the final iteration we want to get all the remaining molecules
|
||||
if i == n_iter - 2:
|
||||
start = end
|
||||
end = len(graph_list)
|
||||
else:
|
||||
start = end
|
||||
end = end + batch_size
|
||||
|
||||
# Concatenate the dataframes
|
||||
chem_representation = pd.concat(batch_dataframes)
|
||||
|
||||
return chem_representation
|
||||
|
||||
164
models/ginet_concat.py
Normal file
164
models/ginet_concat.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch_geometric.nn import MessagePassing
|
||||
from torch_geometric.utils import add_self_loops
|
||||
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool
|
||||
|
||||
num_atom_type = 119 # including the extra mask tokens
|
||||
num_chirality_tag = 3
|
||||
|
||||
num_bond_type = 5 # including aromatic and self-loop edge
|
||||
num_bond_direction = 3
|
||||
|
||||
|
||||
class GINEConv(MessagePassing):
|
||||
def __init__(self, emb_dim):
|
||||
super(GINEConv, self).__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(emb_dim, 2*emb_dim),
|
||||
nn.BatchNorm1d(2*emb_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(2*emb_dim, emb_dim),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.edge_embedding1 = nn.Embedding(num_bond_type, emb_dim)
|
||||
self.edge_embedding2 = nn.Embedding(num_bond_direction, emb_dim)
|
||||
nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
|
||||
nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
|
||||
|
||||
def forward(self, x, edge_index, edge_attr):
|
||||
# add self loops in the edge space
|
||||
edge_index = add_self_loops(edge_index, num_nodes=x.size(0))[0]
|
||||
|
||||
# add features corresponding to self-loop edges.
|
||||
self_loop_attr = torch.zeros(x.size(0), 2)
|
||||
self_loop_attr[:,0] = 4 #bond type for self-loop edge
|
||||
self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
|
||||
edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0)
|
||||
|
||||
edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
|
||||
|
||||
return self.propagate(edge_index, x=x, edge_attr=edge_embeddings)
|
||||
|
||||
def message(self, x_j, edge_attr):
|
||||
return x_j + edge_attr
|
||||
|
||||
def update(self, aggr_out):
|
||||
return self.mlp(aggr_out)
|
||||
|
||||
|
||||
class GINet(nn.Module):
|
||||
|
||||
"""
|
||||
GIN encoder from MolE.
|
||||
|
||||
Args:
|
||||
num_layer (int): Number of GNN layers.
|
||||
emb_dim (int): Dimensionality of embeddings for each graph layer.
|
||||
feat_dim (int): Dimensionality of embedding vector.
|
||||
drop_ratio (float): Dropout rate.
|
||||
pool (str): Pooling method for neighbor aggregation ('mean', 'max', or 'add').
|
||||
|
||||
Output:
|
||||
h_global_embedding: Graph-level representation
|
||||
out: Final embedding vector
|
||||
"""
|
||||
def __init__(self, num_layer=5, emb_dim=300, feat_dim=256, drop_ratio=0, pool='mean'):
|
||||
|
||||
super(GINet, self).__init__()
|
||||
self.num_layer = num_layer
|
||||
self.emb_dim = emb_dim
|
||||
self.feat_dim = feat_dim
|
||||
self.drop_ratio = drop_ratio
|
||||
|
||||
self.concat_dim = num_layer * emb_dim
|
||||
|
||||
if self.concat_dim != self.feat_dim:
|
||||
print(f"Representation dimension ({self.concat_dim}) - Embedding dimension ({self.feat_dim})")
|
||||
|
||||
self.x_embedding1 = nn.Embedding(num_atom_type, emb_dim)
|
||||
self.x_embedding2 = nn.Embedding(num_chirality_tag, emb_dim)
|
||||
nn.init.xavier_uniform_(self.x_embedding1.weight.data)
|
||||
nn.init.xavier_uniform_(self.x_embedding2.weight.data)
|
||||
|
||||
# List of MLPs
|
||||
self.gnns = nn.ModuleList()
|
||||
for layer in range(num_layer):
|
||||
self.gnns.append(GINEConv(emb_dim))
|
||||
|
||||
# List of batchnorms
|
||||
self.batch_norms = nn.ModuleList()
|
||||
for layer in range(num_layer):
|
||||
self.batch_norms.append(nn.BatchNorm1d(emb_dim))
|
||||
|
||||
if pool == 'mean':
|
||||
self.pool = global_mean_pool
|
||||
elif pool == 'max':
|
||||
self.pool = global_max_pool
|
||||
elif pool == 'add':
|
||||
self.pool = global_add_pool
|
||||
|
||||
self.feat_lin = nn.Linear(self.concat_dim, self.feat_dim)
|
||||
|
||||
self.out_lin = nn.Sequential(
|
||||
nn.Linear(self.feat_dim, self.feat_dim),
|
||||
nn.BatchNorm1d(self.feat_dim),
|
||||
nn.ReLU(inplace=True),
|
||||
|
||||
nn.Linear(self.feat_dim, self.feat_dim), # Is not reduced to half size!
|
||||
nn.BatchNorm1d(self.feat_dim),
|
||||
nn.ReLU(inplace=True),
|
||||
|
||||
nn.Linear(self.feat_dim, self.feat_dim)
|
||||
)
|
||||
def forward(self, data):
|
||||
x = data.x
|
||||
edge_index = data.edge_index
|
||||
edge_attr = data.edge_attr
|
||||
|
||||
h_init = self.x_embedding1(x[:,0]) + self.x_embedding2(x[:,1])
|
||||
|
||||
# Perform the convolutions
|
||||
h_dict = {}
|
||||
|
||||
for layer in range(self.num_layer):
|
||||
if layer == self.num_layer - 1:
|
||||
tmp_h = self.gnns[layer](h_dict[f"h_{layer - 1}"], edge_index, edge_attr)
|
||||
tmp_h = self.batch_norms[layer](tmp_h)
|
||||
h_dict[f"h_{layer}"] = F.dropout(tmp_h, self.drop_ratio, training=self.training)
|
||||
|
||||
else:
|
||||
if layer == 0:
|
||||
tmp_h = self.gnns[layer](h_init, edge_index, edge_attr)
|
||||
tmp_h = self.batch_norms[layer](tmp_h)
|
||||
h_dict[f"h_{layer}"] = F.dropout(F.relu(tmp_h), self.drop_ratio, training=self.training)
|
||||
else:
|
||||
tmp_h = self.gnns[layer](h_dict[f"h_{layer - 1}"], edge_index, edge_attr)
|
||||
tmp_h = self.batch_norms[layer](tmp_h)
|
||||
h_dict[f"h_{layer}"] = F.dropout(F.relu(tmp_h), self.drop_ratio, training=self.training)
|
||||
|
||||
# Graph representation
|
||||
h_list_pooled = [self.pool(h_dict[f"h_{layer}"], data.batch) for layer in range(self.num_layer)]
|
||||
h_global_embedding = torch.cat(h_list_pooled, dim=1)
|
||||
|
||||
assert h_global_embedding.shape[1] == self.concat_dim
|
||||
|
||||
# Projection
|
||||
h_expansion = self.feat_lin(h_global_embedding)
|
||||
out = self.out_lin(h_expansion)
|
||||
|
||||
return h_global_embedding, out
|
||||
|
||||
def load_my_state_dict(self, state_dict):
|
||||
own_state = self.state_dict()
|
||||
for name, param in state_dict.items():
|
||||
if name not in own_state:
|
||||
continue
|
||||
if isinstance(param, nn.parameter.Parameter):
|
||||
# backwards compatibility for serialized parameters
|
||||
param = param.data
|
||||
print(name)
|
||||
own_state[name].copy_(param)
|
||||
|
||||
26
models/mole.yaml
Normal file
26
models/mole.yaml
Normal file
@@ -0,0 +1,26 @@
|
||||
name: mole
|
||||
channels:
|
||||
- pytorch
|
||||
- nvidia
|
||||
- rmg
|
||||
- conda-forge
|
||||
- rdkit
|
||||
- defaults
|
||||
dependencies:
|
||||
- python=3.8
|
||||
- pytorch=2.2.1
|
||||
- pytorch-cuda=11.8
|
||||
- rdkit=2022.3.3
|
||||
- pip
|
||||
- pip:
|
||||
- xgboost==1.6.2
|
||||
- pandas==2.0.3
|
||||
- PyYAML==6.0.1
|
||||
- torch_geometric==2.5.0
|
||||
- openpyxl
|
||||
- pubchempy==1.0.4
|
||||
- matplotlib==3.7.5
|
||||
- seaborn==0.13.2
|
||||
- tqdm
|
||||
- scikit-learn==1.0.2
|
||||
- umap-learn==0.5.5
|
||||
128
models/mole_representation.py
Normal file
128
models/mole_representation.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
MolE Representation Module
|
||||
|
||||
This module provides functions to generate MolE molecular representations.
|
||||
"""
|
||||
|
||||
import os
|
||||
import yaml
|
||||
import torch
|
||||
import pandas as pd
|
||||
from rdkit import Chem
|
||||
from rdkit import RDLogger
|
||||
|
||||
from .dataset_representation import batch_representation
|
||||
from .ginet_concat import GINet
|
||||
|
||||
RDLogger.DisableLog('rdApp.*')
|
||||
|
||||
|
||||
def read_smiles(data_path, smile_col="smiles", id_col="chem_id"):
|
||||
"""
|
||||
Read SMILES data from a file or DataFrame and remove invalid SMILES.
|
||||
|
||||
Parameters:
|
||||
- data_path (str or pd.DataFrame): Path to the file or a DataFrame containing SMILES data.
|
||||
- smile_col (str, optional): Name of the column containing SMILES strings.
|
||||
- id_col (str, optional): Name of the column containing molecule IDs.
|
||||
|
||||
Returns:
|
||||
- smile_df (pandas.DataFrame): DataFrame containing SMILES data with specified columns.
|
||||
"""
|
||||
|
||||
# Read the data
|
||||
if isinstance(data_path, pd.DataFrame):
|
||||
smile_df = data_path.copy()
|
||||
else:
|
||||
# Try to read with different separators
|
||||
try:
|
||||
smile_df = pd.read_csv(data_path, sep='\t')
|
||||
except:
|
||||
smile_df = pd.read_csv(data_path)
|
||||
|
||||
# Check if columns exist, handle case-insensitive matching
|
||||
columns_lower = {col.lower(): col for col in smile_df.columns}
|
||||
|
||||
smile_col_actual = columns_lower.get(smile_col.lower(), smile_col)
|
||||
id_col_actual = columns_lower.get(id_col.lower(), id_col)
|
||||
|
||||
if smile_col_actual not in smile_df.columns:
|
||||
raise ValueError(f"Column '{smile_col}' not found in data. Available columns: {list(smile_df.columns)}")
|
||||
|
||||
# Select columns
|
||||
if id_col_actual in smile_df.columns:
|
||||
smile_df = smile_df[[smile_col_actual, id_col_actual]]
|
||||
smile_df.columns = [smile_col, id_col]
|
||||
else:
|
||||
# Create ID column if not exists
|
||||
smile_df = smile_df[[smile_col_actual]]
|
||||
smile_df.columns = [smile_col]
|
||||
smile_df[id_col] = [f"mol{i+1}" for i in range(len(smile_df))]
|
||||
|
||||
# Make sure ID column is interpreted as str
|
||||
smile_df[id_col] = smile_df[id_col].astype(str)
|
||||
|
||||
# Remove NaN
|
||||
smile_df = smile_df.dropna()
|
||||
|
||||
# Remove invalid smiles
|
||||
smile_df = smile_df[smile_df[smile_col].apply(lambda x: Chem.MolFromSmiles(x) is not None)]
|
||||
|
||||
return smile_df
|
||||
|
||||
|
||||
def load_pretrained_model(pretrained_model_dir, device="cuda:0"):
|
||||
"""
|
||||
Load a pre-trained MolE model.
|
||||
|
||||
Parameters:
|
||||
- pretrained_model_dir (str): Path to the pre-trained MolE model directory.
|
||||
- device (str, optional): Device for computation (default is "cuda:0").
|
||||
|
||||
Returns:
|
||||
- model: Loaded pre-trained model.
|
||||
"""
|
||||
|
||||
# Read model configuration
|
||||
config = yaml.load(open(os.path.join(pretrained_model_dir, "config.yaml"), "r"), Loader=yaml.FullLoader)
|
||||
model_config = config["model"]
|
||||
|
||||
# Instantiate model
|
||||
model = GINet(**model_config).to(device)
|
||||
|
||||
# Load pre-trained weights
|
||||
model_pth_path = os.path.join(pretrained_model_dir, "model.pth")
|
||||
print(f"Loading model from: {model_pth_path}")
|
||||
|
||||
state_dict = torch.load(model_pth_path, map_location=device)
|
||||
model.load_my_state_dict(state_dict)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def process_representation(dataset_path, smile_column_str, id_column_str, pretrained_dir, device):
|
||||
"""
|
||||
Process the dataset to generate molecular representations.
|
||||
|
||||
Parameters:
|
||||
- dataset_path (str or pd.DataFrame): Path to the dataset file or DataFrame.
|
||||
- pretrained_dir (str): Path to the pre-trained model directory.
|
||||
- smile_column_str (str): Name of the column containing SMILES strings.
|
||||
- id_column_str (str): Name of the column containing molecule IDs.
|
||||
- device (str): Device to use for computation. Can be "cpu", "cuda:0", etc.
|
||||
|
||||
Returns:
|
||||
- udl_representation (pandas.DataFrame): DataFrame containing molecular representations.
|
||||
"""
|
||||
|
||||
# First we read the SMILES dataframe
|
||||
smiles_df = read_smiles(dataset_path, smile_col=smile_column_str, id_col=id_column_str)
|
||||
|
||||
# Load the pre-trained model
|
||||
pmodel = load_pretrained_model(pretrained_model_dir=pretrained_dir, device=device)
|
||||
|
||||
# Gather pre-trained representation
|
||||
udl_representation = batch_representation(smiles_df, pmodel, smile_column_str, id_column_str, device=device)
|
||||
|
||||
return udl_representation
|
||||
|
||||
278
utils/mole_predictor.py
Normal file
278
utils/mole_predictor.py
Normal file
@@ -0,0 +1,278 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
MolE 抗菌活性预测工具
|
||||
|
||||
这个脚本提供了使用 MolE 模型预测小分子 SMILES 抗菌活性的功能。
|
||||
支持命令行和 Python API 调用两种方式。
|
||||
|
||||
命令行示例:
|
||||
python mole_predictor.py input.csv output.csv --smiles-column smiles --id-column chem_id
|
||||
|
||||
Python API 示例:
|
||||
from utils.mole_predictor import predict_csv_file
|
||||
|
||||
predict_csv_file(
|
||||
input_path="input.csv",
|
||||
output_path="output.csv",
|
||||
smiles_column="smiles",
|
||||
id_column="chem_id"
|
||||
)
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
import click
|
||||
import pandas as pd
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
from models.broad_spectrum_predictor import (
|
||||
ParallelBroadSpectrumPredictor,
|
||||
PredictionConfig,
|
||||
MoleculeInput,
|
||||
BroadSpectrumResult
|
||||
)
|
||||
|
||||
|
||||
def predict_csv_file(
|
||||
input_path: str,
|
||||
output_path: Optional[str] = None,
|
||||
smiles_column: str = "smiles",
|
||||
id_column: str = "chem_id",
|
||||
batch_size: int = 100,
|
||||
n_workers: Optional[int] = None,
|
||||
device: str = "auto",
|
||||
add_suffix: bool = True
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
预测 CSV 文件中的分子抗菌活性
|
||||
|
||||
Args:
|
||||
input_path: 输入 CSV 文件路径
|
||||
output_path: 输出 CSV 文件路径,如果为 None 则自动生成
|
||||
smiles_column: SMILES 列名
|
||||
id_column: 化合物 ID 列名
|
||||
batch_size: 批处理大小
|
||||
n_workers: 工作进程数
|
||||
device: 计算设备 ("auto", "cpu", "cuda:0" 等)
|
||||
add_suffix: 是否在输出文件名后添加预测后缀
|
||||
|
||||
Returns:
|
||||
包含预测结果的 DataFrame
|
||||
"""
|
||||
|
||||
print(f"开始处理文件: {input_path}")
|
||||
|
||||
# 读取输入文件
|
||||
input_path_obj = Path(input_path)
|
||||
if not input_path_obj.exists():
|
||||
raise FileNotFoundError(f"输入文件不存在: {input_path}")
|
||||
|
||||
# 读取 CSV
|
||||
try:
|
||||
df_input = pd.read_csv(input_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"读取 CSV 文件失败: {e}")
|
||||
|
||||
print(f"读取了 {len(df_input)} 条数据")
|
||||
|
||||
# 检查列是否存在(大小写不敏感)
|
||||
columns_lower = {col.lower(): col for col in df_input.columns}
|
||||
|
||||
smiles_col_actual = columns_lower.get(smiles_column.lower())
|
||||
if smiles_col_actual is None:
|
||||
raise ValueError(
|
||||
f"未找到 SMILES 列 '{smiles_column}'。可用列: {list(df_input.columns)}"
|
||||
)
|
||||
|
||||
# 处理 ID 列
|
||||
id_col_actual = columns_lower.get(id_column.lower())
|
||||
if id_col_actual is None:
|
||||
print(f"未找到 ID 列 '{id_column}',将自动生成 ID")
|
||||
df_input[id_column] = [f"mol{i+1}" for i in range(len(df_input))]
|
||||
id_col_actual = id_column
|
||||
|
||||
# 创建预测器配置
|
||||
config = PredictionConfig(
|
||||
batch_size=batch_size,
|
||||
n_workers=n_workers,
|
||||
device=device
|
||||
)
|
||||
|
||||
# 初始化预测器
|
||||
print("初始化预测器...")
|
||||
predictor = ParallelBroadSpectrumPredictor(config)
|
||||
|
||||
# 准备分子输入
|
||||
molecules = [
|
||||
MoleculeInput(smiles=row[smiles_col_actual], chem_id=str(row[id_col_actual]))
|
||||
for _, row in df_input.iterrows()
|
||||
]
|
||||
|
||||
# 执行预测
|
||||
print("开始预测...")
|
||||
results = predictor.predict_batch(molecules)
|
||||
|
||||
# 转换结果为 DataFrame
|
||||
results_dicts = [r.to_dict() for r in results]
|
||||
df_results = pd.DataFrame(results_dicts)
|
||||
|
||||
# 合并原始数据和预测结果
|
||||
# 使用 chem_id 作为键进行合并
|
||||
df_input['_merge_id'] = df_input[id_col_actual].astype(str)
|
||||
df_results['_merge_id'] = df_results['chem_id'].astype(str)
|
||||
|
||||
df_output = df_input.merge(
|
||||
df_results.drop(columns=['chem_id']),
|
||||
on='_merge_id',
|
||||
how='left'
|
||||
)
|
||||
df_output = df_output.drop(columns=['_merge_id'])
|
||||
|
||||
# 生成输出路径
|
||||
if output_path is None:
|
||||
if add_suffix:
|
||||
output_path = str(input_path_obj.parent / f"{input_path_obj.stem}_predicted{input_path_obj.suffix}")
|
||||
else:
|
||||
output_path = str(input_path_obj.parent / f"{input_path_obj.stem}_output{input_path_obj.suffix}")
|
||||
elif add_suffix:
|
||||
output_path_obj = Path(output_path)
|
||||
output_path = str(output_path_obj.parent / f"{output_path_obj.stem}_predicted{output_path_obj.suffix}")
|
||||
|
||||
# 保存结果
|
||||
print(f"保存结果到: {output_path}")
|
||||
df_output.to_csv(output_path, index=False)
|
||||
|
||||
print(f"完成! 预测了 {len(results)} 个分子")
|
||||
print(f"其中 {sum(r.broad_spectrum for r in results)} 个分子被预测为广谱抗菌")
|
||||
|
||||
return df_output
|
||||
|
||||
|
||||
def predict_multiple_files(
|
||||
input_paths: List[str],
|
||||
output_dir: Optional[str] = None,
|
||||
smiles_column: str = "smiles",
|
||||
id_column: str = "chem_id",
|
||||
batch_size: int = 100,
|
||||
n_workers: Optional[int] = None,
|
||||
device: str = "auto",
|
||||
add_suffix: bool = True
|
||||
) -> List[pd.DataFrame]:
|
||||
"""
|
||||
批量预测多个 CSV 文件
|
||||
|
||||
Args:
|
||||
input_paths: 输入 CSV 文件路径列表
|
||||
output_dir: 输出目录,如果为 None 则在原文件目录生成
|
||||
smiles_column: SMILES 列名
|
||||
id_column: 化合物 ID 列名
|
||||
batch_size: 批处理大小
|
||||
n_workers: 工作进程数
|
||||
device: 计算设备
|
||||
add_suffix: 是否在输出文件名后添加预测后缀
|
||||
|
||||
Returns:
|
||||
包含预测结果的 DataFrame 列表
|
||||
"""
|
||||
|
||||
results = []
|
||||
|
||||
for input_path in input_paths:
|
||||
input_path_obj = Path(input_path)
|
||||
|
||||
# 确定输出路径
|
||||
if output_dir is not None:
|
||||
output_dir_obj = Path(output_dir)
|
||||
output_dir_obj.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if add_suffix:
|
||||
output_path = str(output_dir_obj / f"{input_path_obj.stem}_predicted{input_path_obj.suffix}")
|
||||
else:
|
||||
output_path = str(output_dir_obj / input_path_obj.name)
|
||||
else:
|
||||
output_path = None
|
||||
|
||||
# 预测单个文件
|
||||
try:
|
||||
df_result = predict_csv_file(
|
||||
input_path=input_path,
|
||||
output_path=output_path,
|
||||
smiles_column=smiles_column,
|
||||
id_column=id_column,
|
||||
batch_size=batch_size,
|
||||
n_workers=n_workers,
|
||||
device=device,
|
||||
add_suffix=add_suffix
|
||||
)
|
||||
results.append(df_result)
|
||||
except Exception as e:
|
||||
print(f"处理文件 {input_path} 时出错: {e}")
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 命令行接口
|
||||
# ============================================================================
|
||||
|
||||
@click.command()
|
||||
@click.argument('input_path', type=click.Path(exists=True))
|
||||
@click.argument('output_path', type=click.Path(), required=False)
|
||||
@click.option('--smiles-column', '-s', default='smiles',
|
||||
help='SMILES 列名 (默认: smiles)')
|
||||
@click.option('--id-column', '-i', default='chem_id',
|
||||
help='化合物 ID 列名 (默认: chem_id)')
|
||||
@click.option('--batch-size', '-b', default=100, type=int,
|
||||
help='批处理大小 (默认: 100)')
|
||||
@click.option('--n-workers', '-w', default=None, type=int,
|
||||
help='工作进程数 (默认: CPU 核心数)')
|
||||
@click.option('--device', '-d', default='auto',
|
||||
type=click.Choice(['auto', 'cpu', 'cuda:0', 'cuda:1'], case_sensitive=False),
|
||||
help='计算设备 (默认: auto)')
|
||||
@click.option('--add-suffix/--no-add-suffix', default=True,
|
||||
help='是否在输出文件名后添加 "_predicted" 后缀 (默认: 添加)')
|
||||
def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers, device, add_suffix):
|
||||
"""
|
||||
使用 MolE 模型预测小分子 SMILES 的抗菌活性
|
||||
|
||||
INPUT_PATH: 输入 CSV 文件路径
|
||||
|
||||
OUTPUT_PATH: 输出 CSV 文件路径 (可选,默认在原文件目录生成)
|
||||
|
||||
示例:
|
||||
|
||||
python mole_predictor.py input.csv output.csv
|
||||
|
||||
python mole_predictor.py input.csv -s SMILES -i ID
|
||||
|
||||
python mole_predictor.py input.csv --device cuda:0 --batch-size 200
|
||||
"""
|
||||
|
||||
try:
|
||||
predict_csv_file(
|
||||
input_path=input_path,
|
||||
output_path=output_path,
|
||||
smiles_column=smiles_column,
|
||||
id_column=id_column,
|
||||
batch_size=batch_size,
|
||||
n_workers=n_workers,
|
||||
device=device,
|
||||
add_suffix=add_suffix
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(f"错误: {e}", err=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
cli()
|
||||
|
||||
Reference in New Issue
Block a user