新增功能: - 新增统一批量预测工具 utils/batch_predictor.py * 支持单进程/多进程并行模式 * 灵活的 GPU 配置和显存自动计算 * 自动临时文件管理和断点续传 * 完整的 CLI 参数支持(Click 框架) - 新增 Shell 脚本集合 scripts/ * run_parallel_predict.sh - 并行预测脚本 * run_single_predict.sh - 单进程预测脚本 * merge_results.sh - 结果合并脚本 性能优化: - 解决 CUDA + multiprocessing fork 死锁问题 * 使用 spawn 模式替代 fork * 文件描述符级别的输出重定向 - 优化预测性能 * XGBoost OpenMP 多线程(利用所有 CPU 核心) * 预加载模型减少重复加载 * 大批量处理降低函数调用开销 * 实际加速比:2-3x(12进程 vs 单进程) - 优化输出显示 * 抑制模型加载时的权重信息 * 只显示进度条和关键统计 * 临时文件自动保存到专门目录 文档更新: - README.md 新增"大规模并行预测"章节 - README.md 新增"性能优化说明"章节 - 添加详细的使用示例和参数说明 - 更新项目结构和版本信息 技术细节: - 每个模型实例约占用 2.5GB GPU 显存 - 显存计算公式:建议进程数 = GPU显存(GB) / 2.5 - GPU 瓶颈占比:MolE 表示生成 94% - 非 GIL 问题:计算密集任务在 C/CUDA 层 Breaking Changes: - 废弃旧的独立预测脚本,统一使用新工具 相关 Issue: 解决 #并行预测卡死问题 测试平台: Linux, 256 CPU cores, NVIDIA RTX 5090 32GB
1000 lines
29 KiB
Markdown
Executable File
1000 lines
29 KiB
Markdown
Executable File
# SIME - Structure-Informed Macrolide Expansion
|
||
|
||
SIME 是一个用于大环内酯类化合物结构扩展和抗菌活性预测的工具。
|
||
|
||
## 目录
|
||
|
||
- [原有功能](#原有功能)
|
||
- [MolE 抗菌活性预测](#mole-抗菌活性预测)
|
||
- [快速开始](#快速开始)
|
||
- [🚀 大规模并行预测(推荐)](#大规模并行预测推荐)
|
||
- [安装依赖](#安装依赖)
|
||
- [使用方法](#使用方法)
|
||
- [输出说明](#输出说明)
|
||
- [项目结构](#项目结构)
|
||
- [性能优化说明](#性能优化说明)
|
||
- [常见问题](#常见问题)
|
||
|
||
---
|
||
|
||
## 原有功能
|
||
|
||
SIME 提供大环内酯类化合物的结构设计和合成路径分析功能。
|
||
|
||
---
|
||
|
||
## MolE 抗菌活性预测
|
||
|
||
本工具集成了 MolE(Molecular Embeddings)模型,可以预测小分子的广谱抗菌活性。
|
||
|
||
### 快速开始
|
||
|
||
#### 使用 uv(推荐)
|
||
|
||
```bash
|
||
# 1. 创建虚拟环境(Python 3.12)
|
||
uv venv --python 3.12 --seed .venv
|
||
|
||
# 2. 激活环境
|
||
source .venv/bin/activate # Linux/Mac
|
||
# 或
|
||
.venv\Scripts\activate # Windows
|
||
|
||
# 3. 使用 uv 安装依赖
|
||
uv pip install -r requirements-mole.txt
|
||
|
||
# 4. 验证安装
|
||
python verify_setup.py
|
||
|
||
# 5. 运行预测
|
||
python utils/mole_predictor.py Data/fragment/Frags-Enamine-18M.csv
|
||
```
|
||
|
||
---
|
||
|
||
## 🚀 大规模并行预测(推荐)
|
||
|
||
对于大型数据集(如 18M 分子),我们提供了优化的并行预测工具。
|
||
|
||
### 核心工具:`utils/batch_predictor.py`
|
||
|
||
这是一个统一的批量预测脚本,支持:
|
||
- ✅ 单进程或多进程并行
|
||
- ✅ 灵活的 GPU 配置
|
||
- ✅ 自动临时文件管理
|
||
- ✅ 断点续传功能
|
||
|
||
### 快速使用
|
||
|
||
#### 1. 单进程预测(最稳定)
|
||
|
||
```bash
|
||
# 基本用法
|
||
pixi run python utils/batch_predictor.py -i input.csv -o output.csv
|
||
|
||
# 使用 shell 脚本
|
||
bash scripts/run_single_predict.sh input.csv output.csv cuda:0
|
||
```
|
||
|
||
#### 2. 多进程并行(推荐)
|
||
|
||
```bash
|
||
# 基本用法(4 个进程)
|
||
pixi run python utils/batch_predictor.py \
|
||
-i Data/fragment/Frags-Enamine-18M.csv \
|
||
-o output.csv \
|
||
-n 4
|
||
|
||
# 使用 shell 脚本(更方便)
|
||
bash scripts/run_parallel_predict.sh \
|
||
Data/fragment/Frags-Enamine-18M.csv \
|
||
output.csv \
|
||
4 \
|
||
cuda:0
|
||
```
|
||
|
||
### 显存计算
|
||
|
||
每个模型实例约占用 **2.5GB GPU 显存**,建议并行进程数计算公式:
|
||
|
||
```
|
||
建议进程数 = GPU显存(GB) / 2.5
|
||
```
|
||
|
||
示例:
|
||
- **12GB 显存** → 4 个进程
|
||
- **24GB 显存** → 9 个进程
|
||
- **32GB 显存** → 12 个进程
|
||
- **48GB 显存** → 19 个进程
|
||
|
||
### 完整参数说明
|
||
|
||
```bash
|
||
pixi run python utils/batch_predictor.py --help
|
||
```
|
||
|
||
主要参数:
|
||
- `-i, --input`: 输入 CSV 文件路径
|
||
- `-o, --output`: 输出 CSV 文件路径
|
||
- `-s, --smiles-column`: SMILES 列名(默认: smiles)
|
||
- `-d, --id-column`: ID 列名(默认: chem_id)
|
||
- `-g, --device`: GPU 设备(默认: cuda:0)
|
||
- `-n, --n-processes`: 并行进程数(默认: 1)
|
||
- `-b, --batch-size`: 批处理大小(默认: 1000)
|
||
- `--start-from`: 断点续传起始行
|
||
- `-m, --max-molecules`: 限制处理分子数
|
||
- `--temp-dir`: 临时文件目录
|
||
- `--verbose/--quiet`: 详细/安静模式
|
||
|
||
### 使用示例
|
||
|
||
```bash
|
||
# 1. 测试小数据集(前 1000 个分子)
|
||
pixi run python utils/batch_predictor.py \
|
||
-i Data/fragment/Frags-Enamine-18M.csv \
|
||
-o test_1k.csv \
|
||
-m 1000 \
|
||
--verbose
|
||
|
||
# 2. 并行预测(32GB 显存,12 个进程)
|
||
pixi run python utils/batch_predictor.py \
|
||
-i Data/fragment/Frags-Enamine-18M.csv \
|
||
-o predicted.csv \
|
||
-g cuda:0 \
|
||
-n 12 \
|
||
--verbose
|
||
|
||
# 3. 指定自定义列名和 GPU
|
||
pixi run python utils/batch_predictor.py \
|
||
-i data.csv \
|
||
-o output.csv \
|
||
-s SMILES \
|
||
-d compound_id \
|
||
-g cuda:1 \
|
||
-n 8
|
||
|
||
# 4. 断点续传(从第 100000 行开始)
|
||
pixi run python utils/batch_predictor.py \
|
||
-i Data/fragment/Frags-Enamine-18M.csv \
|
||
-o output.csv \
|
||
--start-from 100000 \
|
||
-n 4
|
||
|
||
# 5. 后台运行(使用 nohup)
|
||
nohup pixi run python utils/batch_predictor.py \
|
||
-i Data/fragment/Frags-Enamine-18M.csv \
|
||
-o output.csv \
|
||
-n 4 \
|
||
> prediction.log 2>&1 &
|
||
|
||
# 查看日志
|
||
tail -f prediction.log
|
||
```
|
||
|
||
### 性能预估
|
||
|
||
基于 256 CPU 核心 + RTX 5090 32GB 的测试:
|
||
|
||
| 分子数 | 单进程 | 4 进程并行 | 12 进程并行 |
|
||
|--------|--------|-----------|------------|
|
||
| 1,000 | ~30秒 | ~15秒 | ~12秒 |
|
||
| 10,000 | ~5分钟 | ~2分钟 | ~1.5分钟 |
|
||
| 100,000 | ~50分钟 | ~20分钟 | ~15分钟 |
|
||
| 18M | ~7天 | ~3天 | ~2天 |
|
||
|
||
**说明**:
|
||
- GPU 是瓶颈(MolE 表示生成占 94%)
|
||
- 多进程在单 GPU 上串行使用
|
||
- 实际加速比:2-3x(而非线性)
|
||
- XGBoost 已经使用 OpenMP 并行(256 CPU 核心)
|
||
|
||
### 临时文件管理
|
||
|
||
临时文件自动保存到:`{输入文件名}_temp/` 目录
|
||
|
||
```
|
||
Data/fragment/Frags-Enamine-18M_temp/
|
||
├── part_0.csv # 进程 0 的结果
|
||
├── part_1.csv # 进程 1 的结果
|
||
├── part_2.csv # 进程 2 的结果
|
||
├── part_3.csv # 进程 3 的结果
|
||
├── batch_10.csv # 每 10 批的临时保存
|
||
└── chunk_* # 其他临时文件
|
||
```
|
||
|
||
临时文件默认保留,可用于断点续传。删除临时文件:
|
||
|
||
```bash
|
||
rm -rf Data/fragment/Frags-Enamine-18M_temp/
|
||
```
|
||
|
||
---
|
||
|
||
#### 使用 pyproject.toml 配置(uv 推荐)
|
||
|
||
项目提供了两个环境配置:
|
||
|
||
1. **SIME 原始环境** - 用于大环内酯结构设计
|
||
|
||
```bash
|
||
# 使用 uv 创建默认环境
|
||
uv sync
|
||
```
|
||
|
||
2. **MolE 预测环境** - 用于抗菌活性预测
|
||
|
||
```bash
|
||
# 使用 uv 创建 MolE 环境
|
||
uv sync --extra mole
|
||
```
|
||
|
||
#### 使用 pixi 配置(conda 用户推荐)
|
||
|
||
如果你使用 conda 或需要更好的包管理,可以使用 pixi:
|
||
|
||
```bash
|
||
# 安装 pixi(如果还没有)
|
||
curl -fsSL https://pixi.sh/install.sh | bash
|
||
|
||
# 创建 SIME 原始环境
|
||
pixi install
|
||
|
||
# 创建 MolE 预测环境
|
||
pixi install -e mole
|
||
|
||
# 激活 MolE 环境
|
||
pixi shell -e mole
|
||
|
||
# 在 pixi 环境中运行预测
|
||
pixi run -e mole predict Data/fragment/test_100.csv
|
||
```
|
||
|
||
### 安装依赖
|
||
|
||
#### 方法 1: 使用 uv(推荐)
|
||
|
||
```bash
|
||
# 创建虚拟环境
|
||
uv venv --python 3.12 .venv
|
||
source .venv/bin/activate
|
||
|
||
# 安装依赖
|
||
uv pip install -r requirements-mole.txt
|
||
```
|
||
|
||
#### 方法 2: 使用 pixi
|
||
|
||
```bash
|
||
# 创建虚拟环境
|
||
pixi init
|
||
|
||
# 基础环境
|
||
pixi add python=3.12
|
||
|
||
# nvidia cuda工具链
|
||
pixi workspace channel add nvidia
|
||
pixi add nvidia::cuda-toolkit=12.8
|
||
|
||
# 科学计算 安装 pandas 会自动安装上 numpy
|
||
pixi add
|
||
|
||
# torch-geometric
|
||
pixi add conda-forge::pandas conda-forge::torch-geometric conda-forge::xgboost conda-forge::pyyaml conda-forge::rdkit conda-forge::pip conda-forge::click conda-forge::openpyxl
|
||
|
||
# PyTorch相关(指定通道)
|
||
# 1. 添加 pytorch 频道 conda 太旧改为使用 pypi
|
||
# pixi workspace channel add pytorch
|
||
# pixi add pytorch::pytorch=2.6 pytorch::pytorch-cuda=12.4
|
||
pixi add --pypi torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0
|
||
|
||
# 然后在 pixi.toml 中手动编辑为:
|
||
[pypi-dependencies]
|
||
torch = { version = "==2.8.0", index = "https://download.pytorch.org/whl/cu128" }
|
||
torchvision = { version = "==0.23.0", index = "https://download.pytorch.org/whl/cu128" }
|
||
torchaudio = { version = "==2.8.0", index = "https://download.pytorch.org/whl/cu128" }
|
||
|
||
# 安装依赖
|
||
pixi install
|
||
|
||
# 激活
|
||
pixi shell
|
||
```
|
||
|
||
不同机器使用配置方式:
|
||
|
||
```bash
|
||
# 在 Linux GPU 机器上安装和运行
|
||
pixi install --environment gpu
|
||
pixi run --environment gpu <your-task>
|
||
|
||
# 在 macOS 或 CPU 机器上安装和运行
|
||
pixi install --environment cpu
|
||
pixi run --environment cpu <your-task>
|
||
|
||
# 或使用默认环境(CPU)
|
||
pixi install
|
||
pixi run <your-task>
|
||
```
|
||
|
||
#### RDKit 安装建议
|
||
|
||
RDKit 推荐使用 conda 安装:
|
||
|
||
```bash
|
||
conda install -c conda-forge rdkit
|
||
```
|
||
|
||
### 使用方法
|
||
|
||
#### 1. 命令行使用
|
||
|
||
**基本用法:**
|
||
|
||
```bash
|
||
# 预测 CSV 文件(仅聚合结果)
|
||
python utils/mole_predictor.py input.csv
|
||
|
||
# 包含 40 种菌株的详细预测数据
|
||
python utils/mole_predictor.py input.csv --include-strain-predictions
|
||
|
||
# 指定输出路径
|
||
python utils/mole_predictor.py input.csv output.csv
|
||
|
||
# 自定义列名
|
||
python utils/mole_predictor.py input.csv output.csv \
|
||
--smiles-column SMILES \
|
||
--id-column compound_id
|
||
|
||
# 使用 GPU 加速
|
||
python utils/mole_predictor.py input.csv --device cuda:0
|
||
|
||
# 调整批次大小和工作进程
|
||
python utils/mole_predictor.py input.csv \
|
||
--batch-size 200 \
|
||
--n-workers 8
|
||
|
||
# 完整示例:包含菌株预测 + GPU 加速
|
||
python utils/mole_predictor.py input.csv output.csv \
|
||
--include-strain-predictions \
|
||
--device cuda:0 \
|
||
--batch-size 200
|
||
```
|
||
|
||
**查看所有选项:**
|
||
|
||
```bash
|
||
python utils/mole_predictor.py --help
|
||
```
|
||
|
||
**预测项目数据:**
|
||
|
||
```bash
|
||
# 预测 Frags-Enamine-18M.csv
|
||
# 创建测试文件(前 1001 行,包含表头)
|
||
head -1001 Data/fragment/Frags-Enamine-18M.csv > Data/fragment/test_1000.csv
|
||
|
||
# 测试命令 - 保守参数
|
||
nohup pixi run python utils/mole_predictor.py \
|
||
Data/fragment/test_1000.csv \
|
||
Data/fragment/test_1000_predicted.csv \
|
||
--device cuda:0 \
|
||
--batch-size 100 \
|
||
--n-workers 8 \
|
||
> Data/fragment/mole_test_1000.log 2>&1 &
|
||
|
||
# 查看日志
|
||
tail -f Data/fragment/mole_test_1000.log
|
||
|
||
# 预测 GDB11-27M.csv
|
||
python utils/mole_predictor.py Data/fragment/GDB11-27M.csv
|
||
```
|
||
|
||
#### 2. Python API 使用
|
||
|
||
**预测单个文件:**
|
||
|
||
```python
|
||
from utils.mole_predictor import predict_csv_file
|
||
|
||
# 基本使用(仅聚合结果)
|
||
df_result = predict_csv_file(
|
||
input_path="Data/fragment/Frags-Enamine-18M.csv",
|
||
output_path="results/predictions.csv",
|
||
smiles_column="smiles",
|
||
batch_size=100,
|
||
device="auto"
|
||
)
|
||
|
||
# 包含 40 种菌株的详细预测数据
|
||
df_result_with_strains = predict_csv_file(
|
||
input_path="Data/fragment/Frags-Enamine-18M.csv",
|
||
output_path="results/predictions_with_strains.csv",
|
||
smiles_column="smiles",
|
||
batch_size=100,
|
||
device="auto",
|
||
include_strain_predictions=True # 启用菌株级别预测
|
||
)
|
||
|
||
# 查看结果
|
||
print(f"总分子数: {len(df_result)}")
|
||
print(f"广谱分子数: {df_result['broad_spectrum'].sum()}")
|
||
|
||
# 如果包含菌株预测,数据行数会增加(每个分子 40 行)
|
||
if 'strain_name' in df_result_with_strains.columns:
|
||
print(f"包含菌株预测的总行数: {len(df_result_with_strains)}")
|
||
print(f"菌株数: {df_result_with_strains['strain_name'].nunique()}")
|
||
```
|
||
|
||
**批量预测多个文件:**
|
||
|
||
```python
|
||
from utils.mole_predictor import predict_multiple_files
|
||
|
||
input_files = [
|
||
"Data/fragment/Frags-Enamine-18M.csv",
|
||
"Data/fragment/GDB11-27M.csv"
|
||
]
|
||
|
||
results = predict_multiple_files(
|
||
input_paths=input_files,
|
||
output_dir="results/",
|
||
smiles_column="smiles",
|
||
batch_size=100,
|
||
device="auto"
|
||
)
|
||
```
|
||
|
||
**直接使用预测器:**
|
||
|
||
```python
|
||
from models import (
|
||
ParallelBroadSpectrumPredictor,
|
||
PredictionConfig,
|
||
MoleculeInput
|
||
)
|
||
|
||
# 创建配置
|
||
config = PredictionConfig(
|
||
batch_size=100,
|
||
device="auto" # 或 "cpu", "cuda:0"
|
||
)
|
||
|
||
# 创建预测器
|
||
predictor = ParallelBroadSpectrumPredictor(config)
|
||
|
||
# 预测单个分子(仅聚合结果)
|
||
molecule = MoleculeInput(smiles="CCO", chem_id="ethanol")
|
||
result = predictor.predict_single(molecule)
|
||
|
||
print(f"化合物ID: {result.chem_id}")
|
||
print(f"广谱抗菌: {result.broad_spectrum}")
|
||
print(f"抗菌得分: {result.apscore_total:.3f}")
|
||
print(f"抑制菌株数: {result.ginhib_total}")
|
||
|
||
# 批量预测(包含 40 种菌株的详细预测)
|
||
smiles_list = ["CCO", "c1ccccc1", "CC(=O)O"]
|
||
chem_ids = ["ethanol", "benzene", "acetic_acid"]
|
||
|
||
results = predictor.predict_batch(
|
||
[MoleculeInput(smiles=s, chem_id=c) for s, c in zip(smiles_list, chem_ids)],
|
||
include_strain_predictions=True # 启用菌株级别预测
|
||
)
|
||
|
||
for r in results:
|
||
print(f"\n{r.chem_id}:")
|
||
print(f" 广谱抗菌: {r.broad_spectrum}")
|
||
print(f" 抗菌得分: {r.apscore_total:.3f}")
|
||
print(f" 抑制菌株数: {r.ginhib_total}")
|
||
|
||
# 访问菌株级别预测数据(DataFrame 格式)
|
||
if r.strain_predictions is not None:
|
||
print(f" 菌株预测数据形状: {r.strain_predictions.shape}")
|
||
print(f" 示例菌株预测:")
|
||
print(r.strain_predictions.head(3))
|
||
|
||
# 强化学习场景:提取特定菌株的预测概率
|
||
strain_probs = r.strain_predictions['antimicrobial_predictive_probability'].values
|
||
print(f" 预测概率向量形状: {strain_probs.shape}") # (40,)
|
||
|
||
# 或转换为类型安全的列表
|
||
strain_list = r.to_strain_predictions_list()
|
||
print(f" 第一个菌株: {strain_list[0].strain_name}")
|
||
print(f" 预测概率: {strain_list[0].antimicrobial_predictive_probability:.6f}")
|
||
```
|
||
|
||
### 输出说明
|
||
|
||
#### 1. 聚合结果(默认输出)
|
||
|
||
预测结果会添加以下 7 个新列:
|
||
|
||
| 列名 | 类型 | 说明 |
|
||
|------|------|------|
|
||
| `apscore_total` | float | 总体抗菌潜力分数(对数尺度,值越大抗菌活性越强) |
|
||
| `apscore_gnegative` | float | 革兰阴性菌抗菌潜力分数 |
|
||
| `apscore_gpositive` | float | 革兰阳性菌抗菌潜力分数 |
|
||
| `ginhib_total` | int | 被抑制的菌株总数(0-40) |
|
||
| `ginhib_gnegative` | int | 被抑制的革兰阴性菌株数 |
|
||
| `ginhib_gpositive` | int | 被抑制的革兰阳性菌株数 |
|
||
| `broad_spectrum` | int | 是否为广谱抗菌(1=是,0=否) |
|
||
|
||
**输出示例**(每个分子一行):
|
||
|
||
```csv
|
||
smiles,chem_id,apscore_total,apscore_gnegative,apscore_gpositive,ginhib_total,ginhib_gnegative,ginhib_gpositive,broad_spectrum
|
||
CCO,mol1,-9.93,-10.17,-9.74,0,0,0,0
|
||
```
|
||
|
||
#### 2. 菌株级别预测(使用 `--include-strain-predictions` 时)
|
||
|
||
启用菌株级别预测后,输出会包含以下额外列(每个分子 40 行):
|
||
|
||
| 列名 | 类型 | 说明 |
|
||
|------|------|------|
|
||
| `pred_id` | str | 预测ID,格式为 `chem_id:strain_name` |
|
||
| `strain_name` | str | 菌株名称(40 种菌株之一) |
|
||
| `antimicrobial_predictive_probability` | float | XGBoost 预测的抗菌概率(0-1) |
|
||
| `no_growth_probability` | float | 预测不抑制的概率(1 - antimicrobial_predictive_probability) |
|
||
| `growth_inhibition` | int | 二值化抑制结果(0=不抑制,1=抑制) |
|
||
| `gram_stain` | str | 革兰染色类型(negative 或 positive) |
|
||
|
||
**输出示例**(每个分子 40 行,对应 40 个菌株):
|
||
|
||
```csv
|
||
smiles,chem_id,apscore_total,...,pred_id,strain_name,antimicrobial_predictive_probability,no_growth_probability,growth_inhibition,gram_stain
|
||
CCO,mol1,-9.93,...,mol1:Akkermansia muciniphila (NT5021),Akkermansia muciniphila (NT5021),0.000102,0.999898,0,negative
|
||
CCO,mol1,-9.93,...,mol1:Bacteroides caccae (NT5050),Bacteroides caccae (NT5050),0.000155,0.999845,0,negative
|
||
...(共 40 行,对应 40 个菌株)
|
||
```
|
||
|
||
**数据使用场景**:
|
||
|
||
1. **强化学习状态表示**:
|
||
```python
|
||
# 提取预测概率作为状态向量
|
||
state_vector = result.strain_predictions['antimicrobial_predictive_probability'].values
|
||
# 形状: (40,) - 可直接用于 RL 环境
|
||
```
|
||
|
||
2. **筛选特定菌株**:
|
||
```python
|
||
# 筛选革兰阴性菌
|
||
gram_negative = result.strain_predictions[
|
||
result.strain_predictions['gram_stain'] == 'negative'
|
||
]
|
||
```
|
||
|
||
3. **可视化分析**:
|
||
```python
|
||
import matplotlib.pyplot as plt
|
||
df = result.strain_predictions
|
||
df.plot(x='strain_name', y='antimicrobial_predictive_probability', kind='bar')
|
||
```
|
||
|
||
#### 40 种测试菌株列表
|
||
|
||
预测涵盖以下 40 种人类肠道菌株:
|
||
|
||
**革兰阴性菌(23 种)**:
|
||
- Akkermansia muciniphila (NT5021)
|
||
- Bacteroides 属: caccae, fragilis (ET/NT), ovatus, thetaiotaomicron, uniformis, vulgatus, xylanisolvens (8 种)
|
||
- Escherichia coli 各亚型 (4 种)
|
||
- Klebsiella pneumoniae (NT5049)
|
||
- 其他肠道革兰阴性菌
|
||
|
||
**革兰阳性菌(17 种)**:
|
||
- Bifidobacterium 属 (3 种)
|
||
- Clostridium 属 (5 种)
|
||
- Enterococcus 属 (2 种)
|
||
- Lactobacillus 属 (3 种)
|
||
- Streptococcus 属 (2 种)
|
||
- 其他肠道革兰阳性菌
|
||
|
||
完整菌株列表详见 `Data/mole/README.md`。
|
||
|
||
#### 广谱抗菌判断标准
|
||
|
||
默认情况下,如果一个分子能抑制 **10 个或更多菌株** (`ginhib_total >= 10`),则被认为是广谱抗菌分子。
|
||
|
||
#### 输出文件位置
|
||
|
||
默认情况下,输出文件会添加 `_predicted` 后缀:
|
||
|
||
- 输入: `Data/fragment/Frags-Enamine-18M.csv`
|
||
- 输出: `Data/fragment/Frags-Enamine-18M_predicted.csv`
|
||
- 输出(含菌株): `Data/fragment/Frags-Enamine-18M_predicted.csv`(每个分子 40 行)
|
||
|
||
#### 数据量说明
|
||
|
||
- **仅聚合结果**:输出行数 = 输入分子数
|
||
- **包含菌株预测**:输出行数 = 输入分子数 × 40
|
||
|
||
## 抗菌预测模型输出格式字段解释
|
||
|
||
### 完整输出字段解释表
|
||
|
||
#### 基础信息
|
||
|
||
| 字段名 | 数据类型 | 来源 | 计算方法 | 含义说明 |
|
||
|--------|--------|------|--------|---------|
|
||
| SMILES | 字符串 | 输入数据 | 直接复制 | 分子的 SMILES 结构表示 |
|
||
| chem_id | 字符串 | 输入数据 | 直接复制或自动生成 | 化合物的唯一标识符(如 "mol1") |
|
||
|
||
#### 聚合预测结果(每个分子一组值)
|
||
|
||
| 字段名 | 数据类型 | 来源 | 计算方法 | 含义说明 |
|
||
|--------|--------|------|--------|---------|
|
||
| apscore_total | 浮点数 | 聚合计算 | log(gmean(所有40个菌株的预测概率)) | 总体抗菌潜力分数:所有菌株预测概率的几何平均数的对数。值越高表示抗菌活性越强,负值表示整体抑制概率较低 |
|
||
| apscore_gnegative | 浮点数 | 聚合计算 | log(gmean(革兰阴性菌株的预测概率)) | 革兰阴性菌抗菌潜力分数:仅针对革兰阴性菌株(23种)计算的抗菌分数 |
|
||
| apscore_gpositive | 浮点数 | 聚合计算 | log(gmean(革兰阳性菌株的预测概率)) | 革兰阳性菌抗菌潜力分数:仅针对革兰阳性菌株(17种)计算的抗菌分数 |
|
||
| ginhib_total | 整数 | 聚合计算 | sum(所有菌株的二值化预测) | 总抑制菌株数:预测被抑制的菌株总数(概率 ≥ 0.04374 的菌株数量),范围 0-40 |
|
||
| ginhib_gnegative | 整数 | 聚合计算 | sum(革兰阴性菌株的二值化预测) | 革兰阴性菌抑制数:预测被抑制的革兰阴性菌株数量,范围 0-23 |
|
||
| ginhib_gpositive | 整数 | 聚合计算 | sum(革兰阳性菌株的二值化预测) | 革兰阳性菌抑制数:预测被抑制的革兰阳性菌株数量,范围 0-17 |
|
||
| broad_spectrum | 整数 (0/1) | 聚合计算 | 1 if ginhib_total >= 10 else 0 | 广谱抗菌标志:如果抑制菌株数 ≥ 10,则判定为广谱抗菌药物(1),否则为 0 |
|
||
|
||
#### 菌株级别预测结果(每个分子 40 行,每行对应一个菌株)
|
||
|
||
| 字段名 | 数据类型 | 来源 | 计算方法 | 含义说明 |
|
||
|--------|--------|------|--------|---------|
|
||
| pred_id | 字符串 | 组合生成 | chem_id + ":" + strain_name | 预测组合ID:格式为 "化合物ID:菌株名称",如 "mol1:Akkermansia muciniphila (NT5021)" |
|
||
| strain_name | 字符串 | 菌株元数据 | 从 40 个菌株列表中提取 | 菌株名称:包含菌株学名和 NT 编号,如 "Akkermansia muciniphila (NT5021)" |
|
||
| antimicrobial_predictive_probability | 浮点数 | XGBoost 预测 | model.predict_proba(X)[:, 1] | 抗菌预测概率:XGBoost 模型预测该化合物抑制该菌株生长的概率,范围 0-1。这是模型的原始输出概率 |
|
||
| no_growth_probability | 浮点数 | XGBoost 预测 | model.predict_proba(X)[:, 0] | 不抑制概率:预测该化合物不抑制该菌株生长的概率,等于 1 - antimicrobial_predictive_probability |
|
||
| growth_inhibition | 整数 (0/1) | 阈值二值化 | 1 if antimicrobial_predictive_probability >= 0.04374 else 0 | 生长抑制标签:二值化的抑制结果。1 表示预测抑制,0 表示预测不抑制。阈值 0.04374 是通过验证集优化得到的 |
|
||
| gram_stain | 字符串 | 菌株元数据 | 从 strain_info_SF2.xlsx 中查找 | 革兰染色类型:该菌株的革兰染色分类,值为 "negative"(革兰阴性)或 "positive"(革兰阳性) |
|
||
|
||
---
|
||
|
||
### 数据结构说明
|
||
|
||
#### 输出格式特点
|
||
|
||
- **前 8 列**(SMILES 到 broad_spectrum):每个分子的聚合结果,这些值在该分子的 40 行中保持不变
|
||
- **后 6 列**(pred_id 到 gram_stain):每个分子-菌株对的具体预测,每行对应不同的菌株
|
||
|
||
#### 示例数据
|
||
|
||
```csv
|
||
SMILES,chem_id,apscore_total,apscore_gnegative,apscore_gpositive,ginhib_total,ginhib_gnegative,ginhib_gpositive,broad_spectrum,pred_id,strain_name,antimicrobial_predictive_probability,no_growth_probability,growth_inhibition,gram_stain
|
||
CCO,mol1,-9.93,-10.17,-9.74,0,0,0,0,mol1:Akkermansia muciniphila (NT5021),Akkermansia muciniphila (NT5021),0.000102,0.999898,0,negative
|
||
CCO,mol1,-9.93,-10.17,-9.74,0,0,0,0,mol1:Bacteroides caccae (NT5050),Bacteroides caccae (NT5050),0.000155,0.999845,0,negative
|
||
...(共 40 行,前 8 列相同,后 6 列不同)
|
||
```
|
||
|
||
---
|
||
|
||
### 关键说明
|
||
|
||
| 项目 | 说明 |
|
||
|------|------|
|
||
| 数据量 | 每个输入分子会生成 40 行输出(对应 40 个菌株),因此总行数 = 输入分子数 × 40 |
|
||
| 阈值优化 | 默认阈值 0.04374 是通过最大化验证集 F1 分数得到的最优值 |
|
||
| 革兰染色分布 | 40 个菌株中,23 个为革兰阴性菌,17 个为革兰阳性菌 |
|
||
| 概率解释 | antimicrobial_predictive_probability 越接近 1,表示模型越确信该化合物会抑制该菌株 |
|
||
| 应用场景 | 这种格式特别适合强化学习场景,可以直接提取 40 维的预测概率向量作为状态表示 |
|
||
|
||
---
|
||
|
||
## 项目结构
|
||
|
||
```
|
||
SIME/
|
||
├── models/ # MolE 预测模型
|
||
│ ├── __init__.py
|
||
│ ├── broad_spectrum_predictor.py # 核心预测器
|
||
│ ├── dataset_representation.py # 数据集表示
|
||
│ ├── ginet_concat.py # GIN 神经网络
|
||
│ └── mole_representation.py # MolE 表示生成
|
||
│
|
||
├── utils/
|
||
│ ├── batch_predictor.py # 🆕 统一批量预测工具
|
||
│ ├── mole_predictor.py # 原预测工具(保留兼容性)
|
||
│ └── ... (其他工具)
|
||
│
|
||
├── scripts/ # 🆕 Shell 脚本
|
||
│ ├── run_parallel_predict.sh # 并行预测脚本
|
||
│ ├── run_single_predict.sh # 单进程预测脚本
|
||
│ └── merge_results.sh # 结果合并脚本
|
||
│
|
||
├── Data/
|
||
│ └── fragment/ # 待预测数据
|
||
│ ├── Frags-Enamine-18M.csv
|
||
│ └── GDB11-27M.csv
|
||
│
|
||
├── pyproject.toml # uv 项目配置
|
||
├── requirements.txt # SIME 原始依赖
|
||
├── requirements-mole.txt # MolE 预测依赖
|
||
│
|
||
├── verify_setup.py # 设置验证工具
|
||
├── check_mole_dependencies.py # 依赖检查工具
|
||
└── test_mole_predictor.py # 功能测试
|
||
```
|
||
|
||
---
|
||
|
||
## 依赖说明
|
||
|
||
### SIME 原始依赖 (requirements.txt)
|
||
|
||
用于大环内酯结构设计功能。
|
||
|
||
### MolE 预测依赖 (requirements-mole.txt)
|
||
|
||
用于抗菌活性预测,主要包括:
|
||
|
||
- **深度学习**: torch, torch-geometric
|
||
- **科学计算**: numpy, pandas, scipy
|
||
- **机器学习**: scikit-learn, xgboost
|
||
- **化学信息**: rdkit
|
||
- **其他**: openpyxl, pyyaml, click
|
||
|
||
---
|
||
|
||
## 验证和测试
|
||
|
||
### 验证安装
|
||
|
||
```bash
|
||
# 检查 Python 依赖
|
||
python verify_setup.py
|
||
|
||
# 检查模型文件
|
||
python check_mole_dependencies.py
|
||
```
|
||
|
||
### 运行测试
|
||
|
||
```bash
|
||
# 功能测试(使用小规模测试数据)
|
||
python test_mole_predictor.py
|
||
```
|
||
|
||
---
|
||
|
||
## 性能优化说明
|
||
|
||
### 为什么并行速度没有线性提升?
|
||
|
||
**问题根源**:
|
||
- ✅ **已解决 CUDA fork 死锁**:使用 `spawn` 模式而非 `fork`
|
||
- ✅ **单进程 + OpenMP**:XGBoost 已使用所有 CPU 核心(256核)
|
||
- ⚠️ **GPU 是瓶颈**:MolE 表示生成占 94% 时间
|
||
|
||
**技术细节**:
|
||
|
||
1. **单 GPU 串行化**
|
||
- 多进程共享一个 GPU
|
||
- GPU 只能串行处理
|
||
- 预期加速:2-3x(非线性)
|
||
|
||
2. **性能分布**(100个分子)
|
||
```
|
||
MolE representation: 93.6% ← GPU瓶颈
|
||
Feature preparation: 0.7%
|
||
XGBoost prediction: 1.9% ← 已经很快
|
||
Post-processing: 3.8%
|
||
```
|
||
|
||
3. **不是 GIL 问题**
|
||
- XGBoost: C++ OpenMP(不受 GIL 影响)
|
||
- PyTorch CUDA: 不受 GIL 影响
|
||
- NumPy/Pandas: C 实现(不受 GIL 影响)
|
||
|
||
### 如何进一步提速?
|
||
|
||
**方案 1:多 GPU 并行**(最有效)
|
||
|
||
如果有多个 GPU,可以真正并行:
|
||
|
||
```bash
|
||
# GPU 0: 处理前 1/4
|
||
pixi run python utils/batch_predictor.py \
|
||
-i data.csv -o output_0.csv \
|
||
-g cuda:0 -n 4 --start-from 0 --max-molecules 4500000 &
|
||
|
||
# GPU 1: 处理第 2/4
|
||
pixi run python utils/batch_predictor.py \
|
||
-i data.csv -o output_1.csv \
|
||
-g cuda:1 -n 4 --start-from 4500000 --max-molecules 4500000 &
|
||
|
||
# GPU 2, 3 类似...
|
||
```
|
||
|
||
**方案 2:增加并行进程数**
|
||
|
||
在显存允许的情况下:
|
||
|
||
```bash
|
||
# 32GB 显存 → 12 个进程
|
||
pixi run python utils/batch_predictor.py \
|
||
-i data.csv -o output.csv -n 12
|
||
```
|
||
|
||
**方案 3:优化 MolE 推理**(需要修改代码)
|
||
|
||
- 使用 TorchScript JIT 编译
|
||
- 使用 FP16 混合精度
|
||
- 缓存常见分子的 MolE 表示
|
||
|
||
### 系统要求
|
||
|
||
**推荐配置**:
|
||
- CPU: 多核心(推荐 64+ 核心)
|
||
- GPU: NVIDIA GPU with CUDA(32GB+ 显存)
|
||
- 内存: 64GB+ RAM
|
||
- 磁盘: 100GB+ 可用空间(用于临时文件)
|
||
|
||
**最低配置**:
|
||
- CPU: 8 核心
|
||
- GPU: 可选(CPU 模式会很慢)
|
||
- 内存: 16GB RAM
|
||
- 磁盘: 20GB 可用空间
|
||
|
||
---
|
||
|
||
## 常见问题
|
||
|
||
### Q1: 如何处理大文件?
|
||
|
||
**方案 1:** 增加批次大小和工作进程数
|
||
|
||
```bash
|
||
python utils/mole_predictor.py large_file.csv \
|
||
--batch-size 500 \
|
||
--n-workers 8
|
||
```
|
||
|
||
**方案 2:** 先提取部分数据测试
|
||
|
||
```bash
|
||
# 提取前 1000 行
|
||
head -1001 large_file.csv > test_1000.csv
|
||
python utils/mole_predictor.py test_1000.csv
|
||
```
|
||
|
||
### Q2: 如何只使用 CPU?
|
||
|
||
```bash
|
||
python utils/mole_predictor.py input.csv --device cpu
|
||
```
|
||
|
||
### Q3: 列名大小写问题?
|
||
|
||
工具会自动进行大小写不敏感的列名匹配,所以 `SMILES`、`smiles`、`Smiles` 都可以识别。
|
||
|
||
### Q4: ModuleNotFoundError 错误?
|
||
|
||
确保已安装依赖:
|
||
|
||
```bash
|
||
uv pip install -r requirements-mole.txt
|
||
```
|
||
|
||
对于 RDKit,推荐使用 conda:
|
||
|
||
```bash
|
||
conda install -c conda-forge rdkit
|
||
```
|
||
|
||
### Q5: 如何自定义模型路径?
|
||
|
||
```python
|
||
from models import PredictionConfig, ParallelBroadSpectrumPredictor
|
||
|
||
config = PredictionConfig(
|
||
xgboost_model_path="/path/to/model.pkl",
|
||
mole_model_path="/path/to/mole_model",
|
||
strain_categories_path="/path/to/strain_data.tsv.gz",
|
||
gram_info_path="/path/to/gram_info.xlsx",
|
||
app_threshold=0.044,
|
||
min_nkill=10,
|
||
batch_size=100,
|
||
device="auto"
|
||
)
|
||
|
||
predictor = ParallelBroadSpectrumPredictor(config)
|
||
```
|
||
|
||
### Q6: GPU 内存不足?
|
||
|
||
减小批次大小:
|
||
|
||
```bash
|
||
python utils/mole_predictor.py input.csv --batch-size 50
|
||
```
|
||
|
||
### Q7: 模型文件在哪里?
|
||
|
||
模型文件位于相邻的 `mole_broad_spectrum_parallel` 项目中:
|
||
|
||
```
|
||
../mole_broad_spectrum_parallel/
|
||
├── pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001/
|
||
│ ├── config.yaml
|
||
│ └── model.pth
|
||
├── data/03.model_evaluation/MolE-XGBoost-08.03.2024_14.20.pkl
|
||
└── ...
|
||
```
|
||
|
||
运行 `python check_mole_dependencies.py` 检查文件是否存在。
|
||
|
||
---
|
||
|
||
## 性能建议
|
||
|
||
- **使用 GPU**: 设置 `--device cuda:0` 可大幅加速(需要 CUDA)
|
||
- **调整批次**: 较大的批次(100-500)通常更快
|
||
- **多进程**: 使用 `--n-workers` 指定工作进程数
|
||
- **首次加载**: 首次运行需要加载模型(~30秒),后续会更快
|
||
|
||
### 性能参考
|
||
|
||
| 分子数量 | CPU (8核) | GPU (CUDA) |
|
||
|---------|----------|------------|
|
||
| 100 | ~30秒 | ~10秒 |
|
||
| 1,000 | ~5分钟 | ~1分钟 |
|
||
| 10,000 | ~50分钟 | ~8分钟 |
|
||
|
||
---
|
||
|
||
## 系统要求
|
||
|
||
- **Python**: 3.7 或更高版本(推荐 3.12)
|
||
- **内存**: 最低 8 GB RAM
|
||
- **存储**: 至少 2 GB 可用空间
|
||
- **GPU**: 可选,但强烈推荐(需要 CUDA 支持)
|
||
|
||
---
|
||
|
||
## 技术支持
|
||
|
||
如有问题:
|
||
|
||
1. 查看验证结果: `python verify_setup.py`
|
||
2. 检查模型文件: `python check_mole_dependencies.py`
|
||
3. 运行功能测试: `python test_mole_predictor.py`
|
||
|
||
---
|
||
|
||
## 许可
|
||
|
||
详见 LICENSE 文件。
|
||
|
||
## 引用
|
||
|
||
如果使用本工具,请引用相关论文。
|
||
|
||
---
|
||
|
||
## 更新日志
|
||
|
||
### v2.0.0 (2025-10-17) - 并行预测重大更新
|
||
|
||
**新增功能**:
|
||
- ✅ 统一的批量预测工具 `utils/batch_predictor.py`(支持单进程/多进程)
|
||
- ✅ Shell 脚本集合(`scripts/` 目录)
|
||
- ✅ 自动临时文件管理和断点续传
|
||
- ✅ 显存自动计算(GPU显存(GB) / 2.5)
|
||
- ✅ 抑制模型加载输出,只显示关键信息
|
||
|
||
**性能改进**:
|
||
- ✅ 解决 CUDA + multiprocessing fork 死锁问题
|
||
- ✅ XGBoost OpenMP 多线程(利用所有 CPU 核心)
|
||
- ✅ 实际加速比:2-3x(12进程 vs 单进程)
|
||
|
||
**使用方式**:
|
||
```bash
|
||
# 并行预测(推荐)
|
||
pixi run python utils/batch_predictor.py -i input.csv -o output.csv -n 12
|
||
|
||
# 或使用 Shell 脚本
|
||
bash scripts/run_parallel_predict.sh input.csv output.csv 12 cuda:0
|
||
```
|
||
|
||
查看完整帮助:`pixi run python utils/batch_predictor.py --help`
|
||
|
||
---
|
||
|
||
**更新日期**: 2025-10-17
|
||
**版本**: 2.0.0
|