add mole predcit module
This commit is contained in:
163
Data/mole/README.md
Normal file
163
Data/mole/README.md
Normal file
@@ -0,0 +1,163 @@
|
||||
## convert old xgboots pickle format
|
||||
|
||||
```bash
|
||||
cd Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001
|
||||
ipython
|
||||
```
|
||||
|
||||
```python
|
||||
import xgboost as xgb
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
ckpt = Path('MolE-XGBoost-08.03.2024_14.20.pkl')
|
||||
out_ckpt = Path('./')
|
||||
|
||||
# 加载旧模型
|
||||
with open(ckpt, 'rb') as f:
|
||||
model = pickle.load(f)
|
||||
|
||||
# 用新格式保存(推荐)
|
||||
model.get_booster().save_model(out_ckpt.joinpath('MolE-XGBoost-08.03.2025_10.17.json'))
|
||||
|
||||
# 或者继续用pickle但清晰格式
|
||||
booster = model.get_booster()
|
||||
booster.feature_names = None
|
||||
with open(out_ckpt.joinpath('MolE-XGBoost-08.03.2025_10.17.pkl'), 'wb') as f:
|
||||
pickle.dump(model, f)
|
||||
```
|
||||
|
||||
## 完整预测流程
|
||||
|
||||
```mermaid
|
||||
SMILES 分子(输入CSV文件)
|
||||
↓
|
||||
[MolE 模型]
|
||||
├── config.yaml(模型配置)
|
||||
└── model.pth(模型权重)
|
||||
↓
|
||||
分子特征表示(1000维向量)
|
||||
↓
|
||||
构建"分子-菌株对"(笛卡尔积)
|
||||
└── maier_screening_results.tsv.gz(菌株列表)
|
||||
↓
|
||||
[XGBoost 模型]
|
||||
└── MolE-XGBoost-08.03.2025_10.17.json(或.pkl)
|
||||
↓
|
||||
对每一对预测:是否抑制生长
|
||||
↓
|
||||
获得原始预测结果(对每个菌株的预测)
|
||||
↓
|
||||
[聚合分析]
|
||||
├── maier_screening_results.tsv.gz(菌株列表)
|
||||
└── strain_info_SF2.xlsx(革兰染色信息)
|
||||
↓
|
||||
最终预测结果
|
||||
↓
|
||||
输出CSV文件
|
||||
```
|
||||
|
||||
## 所需文件清单
|
||||
|
||||
| 步骤 | 文件名 | 用途 | 备注 |
|
||||
|------|--------|------|------|
|
||||
| **MolE 模型** | `config.yaml` | 定义MolE网络结构 | YAML配置文件 |
|
||||
| | `model.pth` | MolE模型权重 | PyTorch格式 |
|
||||
| **构建菌株对** | `maier_screening_results.tsv.gz` | 提供40个菌株列表 | 压缩的TSV文件 |
|
||||
| **XGBoost 预测** | `MolE-XGBoost-08.03.2025_10.17.json` | 预测分子-菌株对 | JSON格式(新)或PKL格式(旧) |
|
||||
| **聚合分析** | `maier_screening_results.tsv.gz` | 菌株名称和统计 | 复用(与构建菌株对同一文件) |
|
||||
| | `strain_info_SF2.xlsx` | 革兰染色分类信息 | Excel格式 |
|
||||
|
||||
## 文件存放位置
|
||||
|
||||
所有文件应位于:
|
||||
```
|
||||
Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001/
|
||||
├── config.yaml
|
||||
├── model.pth
|
||||
├── MolE-XGBoost-08.03.2025_10.17.json
|
||||
├── maier_screening_results.tsv.gz
|
||||
└── strain_info_SF2.xlsx
|
||||
```
|
||||
|
||||
## 代码中的对应关系
|
||||
|
||||
```python
|
||||
# PredictionConfig 中的配置
|
||||
@dataclass
|
||||
class PredictionConfig:
|
||||
xgboost_model_path = "MolE-XGBoost-08.03.2025_10.17.json"
|
||||
mole_model_path = "model_ginconcat_btwin_100k_d8000_l0.0001" # 目录(包含config.yaml + model.pth)
|
||||
strain_categories_path = "maier_screening_results.tsv.gz"
|
||||
gram_info_path = "strain_info_SF2.xlsx"
|
||||
```
|
||||
|
||||
## 数据流向总结
|
||||
|
||||
1. **输入**:CSV文件中的SMILES分子
|
||||
2. **MolE处理**:分子 → 1000维特征向量
|
||||
3. **菌株配对**:1个分子 × 40个菌株 = 40对
|
||||
4. **XGBoost预测**:每对 → 抑制概率
|
||||
5. **聚合分析**:统计和分类(按革兰染色)
|
||||
6. **输出**:CSV文件中的预测结果(包含8个指标)
|
||||
|
||||
## 参考文件
|
||||
|
||||
1. `maier_screening_results.tsv.gz` - 菌株列表和筛选数据
|
||||
|
||||
```python
|
||||
self.maier_screen = pd.read_csv(
|
||||
self.config.strain_categories_path, sep='\t', index_col=0
|
||||
)
|
||||
self.strain_ohe = self._prep_ohe(self.maier_screen.columns) # 独热编码
|
||||
```
|
||||
|
||||
包含所有已知菌株的名称(40个菌株)
|
||||
用于与每个分子做笛卡尔积(分子×菌株),生成所有"分子-菌株对"
|
||||
XGBoost为每一对预测:是否能抑制该菌株的生长
|
||||
|
||||
2. `strain_info_SF2.xlsx` - 革兰染色信息
|
||||
|
||||
```python
|
||||
self.maier_strains = pd.read_excel(self.config.gram_info_path, ...)
|
||||
gram_dict = self.maier_strains[["Gram stain"]].to_dict()["Gram stain"]
|
||||
```
|
||||
|
||||
记录每个菌株的革兰染色属性:阳性(positive) 或 阴性(negative)
|
||||
用于将预测结果按革兰染色分类统计
|
||||
|
||||
预测结果示例:
|
||||
某分子 mol1 的预测结果会包括:
|
||||
|
||||
```python
|
||||
BroadSpectrumResult(
|
||||
chem_id='mol1',
|
||||
apscore_total=2.5, # 对所有菌株的抗菌分数
|
||||
apscore_gnegative=2.1, # 仅对革兰阴性菌的分数
|
||||
apscore_gpositive=2.8, # 仅对革兰阳性菌的分数
|
||||
ginhib_total=25, # 抑制的菌株总数
|
||||
ginhib_gnegative=12, # 抑制的革兰阴性菌数
|
||||
ginhib_gpositive=13, # 抑制的革兰阳性菌数
|
||||
broad_spectrum=1 # 是否广谱(≥10个菌株)
|
||||
)
|
||||
```
|
||||
|
||||
结果解读:
|
||||
|
||||
## BroadSpectrumResult 字段说明表
|
||||
|
||||
| 字段名 | 数据类型 | 计算方法 | 含义说明 |
|
||||
|--------|----------|----------|---------|
|
||||
| `chem_id` | 字符串 | 输入的化合物标识符 | 化合物的唯一标识,如 "mol1"、"compound_001" 等 |
|
||||
| `apscore_total` | 浮点数 | `log(gmean(所有40个菌株的预测概率))` | 总体抗菌潜力分数:所有菌株预测概率的几何平均数的对数。值越高表示抗菌活性越强;负值表示整体抑制概率较低 |
|
||||
| `apscore_gnegative` | 浮点数 | `log(gmean(革兰阴性菌株的预测概率))` | 革兰阴性菌抗菌潜力分数:仅针对革兰阴性菌株计算的抗菌分数。用于判断对阴性菌的特异性 |
|
||||
| `apscore_gpositive` | 浮点数 | `log(gmean(革兰阳性菌株的预测概率))` | 革兰阳性菌抗菌潜力分数:仅针对革兰阳性菌株计算的抗菌分数。用于判断对阳性菌的特异性 |
|
||||
| `ginhib_total` | 整数 | `sum(所有菌株的二值化预测)` | 总抑制菌株数:预测被抑制的菌株总数(概率 ≥ 0.04374 的菌株数量)。范围 0-40 |
|
||||
| `ginhib_gnegative` | 整数 | `sum(革兰阴性菌株的二值化预测)` | 革兰阴性菌抑制数:预测被抑制的革兰阴性菌株数量。范围 0-20 |
|
||||
| `ginhib_gpositive` | 整数 | `sum(革兰阳性菌株的二值化预测)` | 革兰阳性菌抑制数:预测被抑制的革兰阳性菌株数量。范围 0-20 |
|
||||
| `broad_spectrum` | 整数 (0/1) | `1 if ginhib_total >= 10 else 0` | 广谱抗菌标志:如果抑制菌株数 ≥ 10,判定为广谱抗菌药物(1),否则为窄谱(0) |
|
||||
|
||||
说明
|
||||
|
||||
- **apscore_* 类字段**:基于预测概率的连续评分,反映抗菌活性强度
|
||||
- **ginhib_* 类字段**:基于二值化预测的离散计数,反映抑制范围
|
||||
- **broad_spectrum**:基于 ginhib_total 的布尔判定,快速标识广谱特性
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,28 @@
|
||||
batch_size: 1000 # batch size
|
||||
warm_up: 10 # warm-up epochs
|
||||
epochs: 1000 # total number of epochs
|
||||
|
||||
load_model: None # resume training
|
||||
eval_every_n_epochs: 1 # validation frequency
|
||||
save_every_n_epochs: 5 # automatic model saving frequecy
|
||||
|
||||
fp16_precision: False # float precision 16 (i.e. True/False)
|
||||
init_lr: 0.0005 # initial learning rate for Adam
|
||||
weight_decay: 1e-5 # weight decay for Adam
|
||||
gpu: cuda:0 # training GPU
|
||||
|
||||
model_type: gin_concat # GNN backbone (i.e., gin/gcn)
|
||||
model:
|
||||
num_layer: 5 # number of graph conv layers
|
||||
emb_dim: 200 # embedding dimension in graph conv layers
|
||||
feat_dim: 8000 # output feature dimention
|
||||
drop_ratio: 0.0 # dropout ratio
|
||||
pool: add # readout pooling (i.e., mean/max/add)
|
||||
|
||||
dataset:
|
||||
num_workers: 50 # dataloader number of workers
|
||||
valid_size: 0.1 # ratio of validation data
|
||||
data_path: data/pubchem_data/pubchem_100k_random.txt # path of pre-training data
|
||||
|
||||
loss:
|
||||
l: 0.0001 # Lambda parameter
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user