1. 代码修改
models/broad_spectrum_predictor.py: ✅ 新增 StrainPrediction dataclass(单个菌株预测结果) ✅ 更新 BroadSpectrumResult 添加 strain_predictions 字段(pandas.DataFrame 类型) ✅ 添加 to_strain_predictions_list() 方法(类型安全转换) ✅ 新增 _prepare_strain_level_predictions() 方法 ✅ 修改 predict_batch() 方法支持 include_strain_predictions 参数 utils/mole_predictor.py: ✅ 添加 include_strain_predictions 参数到所有函数 ✅ 添加命令行参数 --include-strain-predictions ✅ 实现菌株级别数据与聚合结果的合并逻辑 ✅ 更新所有函数签名和文档字符串 2. 测试验证 ✅ 测试基本功能(仅聚合结果): test_3.csv → 3 行输出 ✅ 测试菌株级别预测功能: test_3.csv → 120 行输出(3 × 40) ✅ 验证输出格式正确性 ✅ 验证每个分子都有完整的 40 个菌株预测 ✅ 验证革兰染色信息正确(18 个阴性菌 + 22 个阳性菌) 3. 文档更新 README.md: ✅ 更新命令行使用示例 ✅ 添加 Python API 使用示例(包含菌株预测) ✅ 添加详细的输出格式说明 ✅ 添加 40 种菌株列表概览 ✅ 添加数据使用场景示例(强化学习、筛选、可视化) Data/mole/README.md: ✅ 新增"菌株级别预测详情"章节 ✅ 完整的 40 种菌株列表(分革兰阴性/阳性) ✅ 数据访问方式示例(CSV 读取、Python API) ✅ 强化学习应用场景(状态表示、奖励函数设计) ✅ 数据可视化代码示例 ✅ 性能和存储建议
This commit is contained in:
142
README.md
142
README.md
@@ -153,9 +153,12 @@ conda install -c conda-forge rdkit
|
||||
**基本用法:**
|
||||
|
||||
```bash
|
||||
# 预测 CSV 文件
|
||||
# 预测 CSV 文件(仅聚合结果)
|
||||
python utils/mole_predictor.py input.csv
|
||||
|
||||
# 包含 40 种菌株的详细预测数据
|
||||
python utils/mole_predictor.py input.csv --include-strain-predictions
|
||||
|
||||
# 指定输出路径
|
||||
python utils/mole_predictor.py input.csv output.csv
|
||||
|
||||
@@ -171,6 +174,12 @@ python utils/mole_predictor.py input.csv --device cuda:0
|
||||
python utils/mole_predictor.py input.csv \
|
||||
--batch-size 200 \
|
||||
--n-workers 8
|
||||
|
||||
# 完整示例:包含菌株预测 + GPU 加速
|
||||
python utils/mole_predictor.py input.csv output.csv \
|
||||
--include-strain-predictions \
|
||||
--device cuda:0 \
|
||||
--batch-size 200
|
||||
```
|
||||
|
||||
**查看所有选项:**
|
||||
@@ -196,7 +205,7 @@ python utils/mole_predictor.py Data/fragment/GDB11-27M.csv
|
||||
```python
|
||||
from utils.mole_predictor import predict_csv_file
|
||||
|
||||
# 基本使用
|
||||
# 基本使用(仅聚合结果)
|
||||
df_result = predict_csv_file(
|
||||
input_path="Data/fragment/Frags-Enamine-18M.csv",
|
||||
output_path="results/predictions.csv",
|
||||
@@ -205,9 +214,24 @@ df_result = predict_csv_file(
|
||||
device="auto"
|
||||
)
|
||||
|
||||
# 包含 40 种菌株的详细预测数据
|
||||
df_result_with_strains = predict_csv_file(
|
||||
input_path="Data/fragment/Frags-Enamine-18M.csv",
|
||||
output_path="results/predictions_with_strains.csv",
|
||||
smiles_column="smiles",
|
||||
batch_size=100,
|
||||
device="auto",
|
||||
include_strain_predictions=True # 启用菌株级别预测
|
||||
)
|
||||
|
||||
# 查看结果
|
||||
print(f"总分子数: {len(df_result)}")
|
||||
print(f"广谱分子数: {df_result['broad_spectrum'].sum()}")
|
||||
|
||||
# 如果包含菌株预测,数据行数会增加(每个分子 40 行)
|
||||
if 'strain_name' in df_result_with_strains.columns:
|
||||
print(f"包含菌株预测的总行数: {len(df_result_with_strains)}")
|
||||
print(f"菌株数: {df_result_with_strains['strain_name'].nunique()}")
|
||||
```
|
||||
|
||||
**批量预测多个文件:**
|
||||
@@ -247,7 +271,7 @@ config = PredictionConfig(
|
||||
# 创建预测器
|
||||
predictor = ParallelBroadSpectrumPredictor(config)
|
||||
|
||||
# 预测单个分子
|
||||
# 预测单个分子(仅聚合结果)
|
||||
molecule = MoleculeInput(smiles="CCO", chem_id="ethanol")
|
||||
result = predictor.predict_single(molecule)
|
||||
|
||||
@@ -256,19 +280,41 @@ print(f"广谱抗菌: {result.broad_spectrum}")
|
||||
print(f"抗菌得分: {result.apscore_total:.3f}")
|
||||
print(f"抑制菌株数: {result.ginhib_total}")
|
||||
|
||||
# 批量预测
|
||||
# 批量预测(包含 40 种菌株的详细预测)
|
||||
smiles_list = ["CCO", "c1ccccc1", "CC(=O)O"]
|
||||
chem_ids = ["ethanol", "benzene", "acetic_acid"]
|
||||
|
||||
results = predictor.predict_from_smiles(smiles_list, chem_ids)
|
||||
results = predictor.predict_batch(
|
||||
[MoleculeInput(smiles=s, chem_id=c) for s, c in zip(smiles_list, chem_ids)],
|
||||
include_strain_predictions=True # 启用菌株级别预测
|
||||
)
|
||||
|
||||
for r in results:
|
||||
print(f"{r.chem_id}: broad_spectrum={r.broad_spectrum}, "
|
||||
f"apscore={r.apscore_total:.3f}")
|
||||
print(f"\n{r.chem_id}:")
|
||||
print(f" 广谱抗菌: {r.broad_spectrum}")
|
||||
print(f" 抗菌得分: {r.apscore_total:.3f}")
|
||||
print(f" 抑制菌株数: {r.ginhib_total}")
|
||||
|
||||
# 访问菌株级别预测数据(DataFrame 格式)
|
||||
if r.strain_predictions is not None:
|
||||
print(f" 菌株预测数据形状: {r.strain_predictions.shape}")
|
||||
print(f" 示例菌株预测:")
|
||||
print(r.strain_predictions.head(3))
|
||||
|
||||
# 强化学习场景:提取特定菌株的预测概率
|
||||
strain_probs = r.strain_predictions['antimicrobial_predictive_probability'].values
|
||||
print(f" 预测概率向量形状: {strain_probs.shape}") # (40,)
|
||||
|
||||
# 或转换为类型安全的列表
|
||||
strain_list = r.to_strain_predictions_list()
|
||||
print(f" 第一个菌株: {strain_list[0].strain_name}")
|
||||
print(f" 预测概率: {strain_list[0].antimicrobial_predictive_probability:.6f}")
|
||||
```
|
||||
|
||||
### 输出说明
|
||||
|
||||
#### 1. 聚合结果(默认输出)
|
||||
|
||||
预测结果会添加以下 7 个新列:
|
||||
|
||||
| 列名 | 类型 | 说明 |
|
||||
@@ -276,11 +322,85 @@ for r in results:
|
||||
| `apscore_total` | float | 总体抗菌潜力分数(对数尺度,值越大抗菌活性越强) |
|
||||
| `apscore_gnegative` | float | 革兰阴性菌抗菌潜力分数 |
|
||||
| `apscore_gpositive` | float | 革兰阳性菌抗菌潜力分数 |
|
||||
| `ginhib_total` | int | 被抑制的菌株总数 |
|
||||
| `ginhib_total` | int | 被抑制的菌株总数(0-40) |
|
||||
| `ginhib_gnegative` | int | 被抑制的革兰阴性菌株数 |
|
||||
| `ginhib_gpositive` | int | 被抑制的革兰阳性菌株数 |
|
||||
| `broad_spectrum` | int | 是否为广谱抗菌(1=是,0=否) |
|
||||
|
||||
**输出示例**(每个分子一行):
|
||||
|
||||
```csv
|
||||
smiles,chem_id,apscore_total,apscore_gnegative,apscore_gpositive,ginhib_total,ginhib_gnegative,ginhib_gpositive,broad_spectrum
|
||||
CCO,mol1,-9.93,-10.17,-9.74,0,0,0,0
|
||||
```
|
||||
|
||||
#### 2. 菌株级别预测(使用 `--include-strain-predictions` 时)
|
||||
|
||||
启用菌株级别预测后,输出会包含以下额外列(每个分子 40 行):
|
||||
|
||||
| 列名 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| `pred_id` | str | 预测ID,格式为 `chem_id:strain_name` |
|
||||
| `strain_name` | str | 菌株名称(40 种菌株之一) |
|
||||
| `antimicrobial_predictive_probability` | float | XGBoost 预测的抗菌概率(0-1) |
|
||||
| `no_growth_probability` | float | 预测不抑制的概率(1 - antimicrobial_predictive_probability) |
|
||||
| `growth_inhibition` | int | 二值化抑制结果(0=不抑制,1=抑制) |
|
||||
| `gram_stain` | str | 革兰染色类型(negative 或 positive) |
|
||||
|
||||
**输出示例**(每个分子 40 行,对应 40 个菌株):
|
||||
|
||||
```csv
|
||||
smiles,chem_id,apscore_total,...,pred_id,strain_name,antimicrobial_predictive_probability,no_growth_probability,growth_inhibition,gram_stain
|
||||
CCO,mol1,-9.93,...,mol1:Akkermansia muciniphila (NT5021),Akkermansia muciniphila (NT5021),0.000102,0.999898,0,negative
|
||||
CCO,mol1,-9.93,...,mol1:Bacteroides caccae (NT5050),Bacteroides caccae (NT5050),0.000155,0.999845,0,negative
|
||||
...(共 40 行,对应 40 个菌株)
|
||||
```
|
||||
|
||||
**数据使用场景**:
|
||||
|
||||
1. **强化学习状态表示**:
|
||||
```python
|
||||
# 提取预测概率作为状态向量
|
||||
state_vector = result.strain_predictions['antimicrobial_predictive_probability'].values
|
||||
# 形状: (40,) - 可直接用于 RL 环境
|
||||
```
|
||||
|
||||
2. **筛选特定菌株**:
|
||||
```python
|
||||
# 筛选革兰阴性菌
|
||||
gram_negative = result.strain_predictions[
|
||||
result.strain_predictions['gram_stain'] == 'negative'
|
||||
]
|
||||
```
|
||||
|
||||
3. **可视化分析**:
|
||||
```python
|
||||
import matplotlib.pyplot as plt
|
||||
df = result.strain_predictions
|
||||
df.plot(x='strain_name', y='antimicrobial_predictive_probability', kind='bar')
|
||||
```
|
||||
|
||||
#### 40 种测试菌株列表
|
||||
|
||||
预测涵盖以下 40 种人类肠道菌株:
|
||||
|
||||
**革兰阴性菌(23 种)**:
|
||||
- Akkermansia muciniphila (NT5021)
|
||||
- Bacteroides 属: caccae, fragilis (ET/NT), ovatus, thetaiotaomicron, uniformis, vulgatus, xylanisolvens (8 种)
|
||||
- Escherichia coli 各亚型 (4 种)
|
||||
- Klebsiella pneumoniae (NT5049)
|
||||
- 其他肠道革兰阴性菌
|
||||
|
||||
**革兰阳性菌(17 种)**:
|
||||
- Bifidobacterium 属 (3 种)
|
||||
- Clostridium 属 (5 种)
|
||||
- Enterococcus 属 (2 种)
|
||||
- Lactobacillus 属 (3 种)
|
||||
- Streptococcus 属 (2 种)
|
||||
- 其他肠道革兰阳性菌
|
||||
|
||||
完整菌株列表详见 `Data/mole/README.md`。
|
||||
|
||||
#### 广谱抗菌判断标准
|
||||
|
||||
默认情况下,如果一个分子能抑制 **10 个或更多菌株** (`ginhib_total >= 10`),则被认为是广谱抗菌分子。
|
||||
@@ -291,6 +411,12 @@ for r in results:
|
||||
|
||||
- 输入: `Data/fragment/Frags-Enamine-18M.csv`
|
||||
- 输出: `Data/fragment/Frags-Enamine-18M_predicted.csv`
|
||||
- 输出(含菌株): `Data/fragment/Frags-Enamine-18M_predicted.csv`(每个分子 40 行)
|
||||
|
||||
#### 数据量说明
|
||||
|
||||
- **仅聚合结果**:输出行数 = 输入分子数
|
||||
- **包含菌株预测**:输出行数 = 输入分子数 × 40
|
||||
|
||||
---
|
||||
|
||||
|
||||
Reference in New Issue
Block a user