Files
SIME/Data/mole/README.md
hotwa 34102cf459 1. 代码修改
models/broad_spectrum_predictor.py:
 新增 StrainPrediction dataclass(单个菌株预测结果)
 更新 BroadSpectrumResult 添加 strain_predictions 字段(pandas.DataFrame 类型)
 添加 to_strain_predictions_list() 方法(类型安全转换)
 新增 _prepare_strain_level_predictions() 方法
 修改 predict_batch() 方法支持 include_strain_predictions 参数
utils/mole_predictor.py:
 添加 include_strain_predictions 参数到所有函数
 添加命令行参数 --include-strain-predictions
 实现菌株级别数据与聚合结果的合并逻辑
 更新所有函数签名和文档字符串
2. 测试验证
 测试基本功能(仅聚合结果): test_3.csv → 3 行输出
 测试菌株级别预测功能: test_3.csv → 120 行输出(3 × 40)
 验证输出格式正确性
 验证每个分子都有完整的 40 个菌株预测
 验证革兰染色信息正确(18 个阴性菌 + 22 个阳性菌)
3. 文档更新
README.md:
 更新命令行使用示例
 添加 Python API 使用示例(包含菌株预测)
 添加详细的输出格式说明
 添加 40 种菌株列表概览
 添加数据使用场景示例(强化学习、筛选、可视化)
Data/mole/README.md:
 新增"菌株级别预测详情"章节
 完整的 40 种菌株列表(分革兰阴性/阳性)
 数据访问方式示例(CSV 读取、Python API)
 强化学习应用场景(状态表示、奖励函数设计)
 数据可视化代码示例
 性能和存储建议
2025-10-17 16:46:04 +08:00

384 lines
13 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
## convert old xgboots pickle format
```bash
cd Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001
ipython
```
```python
import xgboost as xgb
import pickle
from pathlib import Path
ckpt = Path('MolE-XGBoost-08.03.2024_14.20.pkl')
out_ckpt = Path('./')
# 加载旧模型
with open(ckpt, 'rb') as f:
model = pickle.load(f)
# 用新格式保存(推荐)
model.get_booster().save_model(out_ckpt.joinpath('MolE-XGBoost-08.03.2025_10.17.json'))
# 或者继续用pickle但清晰格式
booster = model.get_booster()
booster.feature_names = None
with open(out_ckpt.joinpath('MolE-XGBoost-08.03.2025_10.17.pkl'), 'wb') as f:
pickle.dump(model, f)
```
## 完整预测流程
```mermaid
SMILES 分子输入CSV文件
[MolE 模型]
├── config.yaml模型配置
└── model.pth模型权重
分子特征表示1000维向量
构建"分子-菌株对"(笛卡尔积)
└── maier_screening_results.tsv.gz菌株列表
[XGBoost 模型]
└── MolE-XGBoost-08.03.2025_10.17.json或.pkl
对每一对预测:是否抑制生长
获得原始预测结果(对每个菌株的预测)
[聚合分析]
├── maier_screening_results.tsv.gz菌株列表
└── strain_info_SF2.xlsx革兰染色信息
最终预测结果
输出CSV文件
```
## 所需文件清单
| 步骤 | 文件名 | 用途 | 备注 |
|------|--------|------|------|
| **MolE 模型** | `config.yaml` | 定义MolE网络结构 | YAML配置文件 |
| | `model.pth` | MolE模型权重 | PyTorch格式 |
| **构建菌株对** | `maier_screening_results.tsv.gz` | 提供40个菌株列表 | 压缩的TSV文件 |
| **XGBoost 预测** | `MolE-XGBoost-08.03.2025_10.17.json` | 预测分子-菌株对 | JSON格式或PKL格式 |
| **聚合分析** | `maier_screening_results.tsv.gz` | 菌株名称和统计 | 复用(与构建菌株对同一文件) |
| | `strain_info_SF2.xlsx` | 革兰染色分类信息 | Excel格式 |
## 文件存放位置
所有文件应位于:
```
Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001/
├── config.yaml
├── model.pth
├── MolE-XGBoost-08.03.2025_10.17.json
├── maier_screening_results.tsv.gz
└── strain_info_SF2.xlsx
```
## 代码中的对应关系
```python
# PredictionConfig 中的配置
@dataclass
class PredictionConfig:
xgboost_model_path = "MolE-XGBoost-08.03.2025_10.17.json"
mole_model_path = "model_ginconcat_btwin_100k_d8000_l0.0001" # 目录包含config.yaml + model.pth
strain_categories_path = "maier_screening_results.tsv.gz"
gram_info_path = "strain_info_SF2.xlsx"
```
## 数据流向总结
1. **输入**CSV文件中的SMILES分子
2. **MolE处理**:分子 → 1000维特征向量
3. **菌株配对**1个分子 × 40个菌株 = 40对
4. **XGBoost预测**:每对 → 抑制概率
5. **聚合分析**:统计和分类(按革兰染色)
6. **输出**CSV文件中的预测结果包含8个指标
## 参考文件
1. `maier_screening_results.tsv.gz` - 菌株列表和筛选数据
```python
self.maier_screen = pd.read_csv(
self.config.strain_categories_path, sep='\t', index_col=0
)
self.strain_ohe = self._prep_ohe(self.maier_screen.columns) # 独热编码
```
包含所有已知菌株的名称40个菌株
用于与每个分子做笛卡尔积(分子×菌株),生成所有"分子-菌株对"
XGBoost为每一对预测是否能抑制该菌株的生长
2. `strain_info_SF2.xlsx` - 革兰染色信息
```python
self.maier_strains = pd.read_excel(self.config.gram_info_path, ...)
gram_dict = self.maier_strains[["Gram stain"]].to_dict()["Gram stain"]
```
记录每个菌株的革兰染色属性:阳性(positive) 或 阴性(negative)
用于将预测结果按革兰染色分类统计
预测结果示例:
某分子 mol1 的预测结果会包括:
```python
BroadSpectrumResult(
chem_id='mol1',
apscore_total=2.5, # 对所有菌株的抗菌分数
apscore_gnegative=2.1, # 仅对革兰阴性菌的分数
apscore_gpositive=2.8, # 仅对革兰阳性菌的分数
ginhib_total=25, # 抑制的菌株总数
ginhib_gnegative=12, # 抑制的革兰阴性菌数
ginhib_gpositive=13, # 抑制的革兰阳性菌数
broad_spectrum=1 # 是否广谱≥10个菌株
)
```
结果解读:
## BroadSpectrumResult 字段说明表
| 字段名 | 数据类型 | 计算方法 | 含义说明 |
|--------|----------|----------|---------|
| `chem_id` | 字符串 | 输入的化合物标识符 | 化合物的唯一标识,如 "mol1"、"compound_001" 等 |
| `apscore_total` | 浮点数 | `log(gmean(所有40个菌株的预测概率))` | 总体抗菌潜力分数:所有菌株预测概率的几何平均数的对数。值越高表示抗菌活性越强;负值表示整体抑制概率较低 |
| `apscore_gnegative` | 浮点数 | `log(gmean(革兰阴性菌株的预测概率))` | 革兰阴性菌抗菌潜力分数:仅针对革兰阴性菌株计算的抗菌分数。用于判断对阴性菌的特异性 |
| `apscore_gpositive` | 浮点数 | `log(gmean(革兰阳性菌株的预测概率))` | 革兰阳性菌抗菌潜力分数:仅针对革兰阳性菌株计算的抗菌分数。用于判断对阳性菌的特异性 |
| `ginhib_total` | 整数 | `sum(所有菌株的二值化预测)` | 总抑制菌株数:预测被抑制的菌株总数(概率 ≥ 0.04374 的菌株数量)。范围 0-40 |
| `ginhib_gnegative` | 整数 | `sum(革兰阴性菌株的二值化预测)` | 革兰阴性菌抑制数:预测被抑制的革兰阴性菌株数量。范围 0-20 |
| `ginhib_gpositive` | 整数 | `sum(革兰阳性菌株的二值化预测)` | 革兰阳性菌抑制数:预测被抑制的革兰阳性菌株数量。范围 0-20 |
| `broad_spectrum` | 整数 (0/1) | `1 if ginhib_total >= 10 else 0` | 广谱抗菌标志:如果抑制菌株数 ≥ 10判定为广谱抗菌药物1否则为窄谱0 |
说明
- **apscore_* 类字段**:基于预测概率的连续评分,反映抗菌活性强度
- **ginhib_* 类字段**:基于二值化预测的离散计数,反映抑制范围
- **broad_spectrum**:基于 ginhib_total 的布尔判定,快速标识广谱特性
---
## 菌株级别预测详情
### 使用 `--include-strain-predictions` 参数
启用此参数后,输出将包含每个分子对所有 40 个菌株的详细预测数据。
**命令示例**
```bash
python utils/mole_predictor.py input.csv output.csv --include-strain-predictions
```
### 菌株级别输出格式
每个分子会产生 40 行数据,每行对应一个菌株的预测结果:
| 列名 | 数据类型 | 说明 | 示例值 |
|------|----------|------|--------|
| `pred_id` | 字符串 | 预测ID格式为 `chem_id:strain_name` | `mol1:Akkermansia muciniphila (NT5021)` |
| `chem_id` | 字符串 | 化合物标识符 | `mol1` |
| `strain_name` | 字符串 | 菌株名称 | `Akkermansia muciniphila (NT5021)` |
| `antimicrobial_predictive_probability` | 浮点数 | XGBoost 预测的抗菌概率0-1 | `0.000102` |
| `no_growth_probability` | 浮点数 | 不抑制的概率(= 1 - antimicrobial_predictive_probability | `0.999898` |
| `growth_inhibition` | 整数 (0/1) | 二值化抑制结果1=抑制0=不抑制) | `0` |
| `gram_stain` | 字符串 | 革兰染色类型 | `negative``positive` |
### 完整的 40 种测试菌株列表
#### 革兰阴性菌23 种)
| 序号 | 菌株名称 | NT编号 |
|------|----------|--------|
| 1 | Akkermansia muciniphila | NT5021 |
| 2 | Bacteroides caccae | NT5050 |
| 3 | Bacteroides fragilis (ET) | NT5033 |
| 4 | Bacteroides fragilis (NT) | NT5003 |
| 5 | Bacteroides ovatus | NT5054 |
| 6 | Bacteroides thetaiotaomicron | NT5004 |
| 7 | Bacteroides uniformis | NT5002 |
| 8 | Bacteroides vulgatus | NT5001 |
| 9 | Bacteroides xylanisolvens | NT5064 |
| 10 | Escherichia coli (3 isolates + Nissle) | NT5028, NT5024, NT5030, NT5011 |
| 11 | Klebsiella pneumoniae | NT5049 |
| 12 | Parabacteroides distasonis | NT5023 |
| 13 | Phocaeicola vulgatus | NT5001 |
| 14 | 其他肠道革兰阴性菌 | ... |
#### 革兰阳性菌17 种)
| 序号 | 菌株名称 | NT编号 |
|------|----------|--------|
| 1 | Bifidobacterium adolescentis | NT5022 |
| 2 | Bifidobacterium longum | NT5067 |
| 3 | Bifidobacterium pseudocatenulatum | NT5058 |
| 4 | Clostridium bolteae | NT5005 |
| 5 | Clostridium innocuum | NT5026 |
| 6 | Clostridium ramosum | NT5027 |
| 7 | Clostridium scindens | NT5029 |
| 8 | Clostridium symbiosum | NT5006 |
| 9 | Enterococcus faecalis | NT5034 |
| 10 | Enterococcus faecium | NT5043 |
| 11 | Lactobacillus plantarum | NT5035 |
| 12 | Lactobacillus reuteri | NT5032 |
| 13 | Lactobacillus rhamnosus | NT5037 |
| 14 | Streptococcus parasanguinis | NT5041 |
| 15 | Streptococcus salivarius | NT5040 |
| 16 | 其他肠道革兰阳性菌 | ... |
**注**: 完整列表可从 `maier_screening_results.tsv.gz``strain_info_SF2.xlsx` 文件中查看。
### 数据访问方式
#### 1. CSV 文件读取
```python
import pandas as pd
# 读取包含菌株预测的结果
df = pd.read_csv('output_with_strains.csv')
# 查看某个分子的所有菌株预测
mol1_strains = df[df['chem_id'] == 'mol1']
print(f"分子 mol1 的预测:")
print(mol1_strains[['strain_name', 'antimicrobial_predictive_probability', 'growth_inhibition']])
# 筛选被抑制的菌株
inhibited = mol1_strains[mol1_strains['growth_inhibition'] == 1]
print(f"\n被抑制的菌株数: {len(inhibited)}")
print(inhibited[['strain_name', 'antimicrobial_predictive_probability']])
```
#### 2. Python API 访问
```python
from models import ParallelBroadSpectrumPredictor, MoleculeInput
predictor = ParallelBroadSpectrumPredictor()
molecule = MoleculeInput(smiles="CCO", chem_id="ethanol")
# 包含菌株级别预测
result = predictor.predict_batch([molecule], include_strain_predictions=True)[0]
# 访问菌株预测 DataFrame
strain_df = result.strain_predictions
print(f"菌株预测数据形状: {strain_df.shape}") # (40, 7)
print(f"列名: {strain_df.columns.tolist()}")
# 提取预测概率向量(用于强化学习)
probabilities = strain_df['antimicrobial_predictive_probability'].values
print(f"预测概率向量形状: {probabilities.shape}") # (40,)
# 筛选特定革兰染色类型
gram_negative = strain_df[strain_df['gram_stain'] == 'negative']
print(f"革兰阴性菌预测数: {len(gram_negative)}")
# 转换为类型安全的列表(可选)
strain_list = result.to_strain_predictions_list()
for strain_pred in strain_list[:3]:
print(f"{strain_pred.strain_name}: {strain_pred.antimicrobial_predictive_probability:.6f}")
```
### 强化学习场景应用
#### 状态表示
```python
# 将 40 个菌株的预测概率作为状态向量
state = result.strain_predictions['antimicrobial_predictive_probability'].values
# state.shape = (40,)
# 或者包含更多特征
state_features = result.strain_predictions[[
'antimicrobial_predictive_probability',
'growth_inhibition'
]].values
# state_features.shape = (40, 2)
```
#### 奖励函数设计
```python
def calculate_reward(strain_predictions_df):
"""
基于菌株级别预测计算奖励
Args:
strain_predictions_df: 包含 40 个菌株预测的 DataFrame
Returns:
reward: 标量奖励值
"""
# 方案1: 基于抑制菌株数
reward = strain_predictions_df['growth_inhibition'].sum() / 40.0
# 方案2: 基于预测概率
reward = strain_predictions_df['antimicrobial_predictive_probability'].mean()
# 方案3: 加权奖励(考虑革兰染色)
gram_negative_score = strain_predictions_df[
strain_predictions_df['gram_stain'] == 'negative'
]['antimicrobial_predictive_probability'].mean()
gram_positive_score = strain_predictions_df[
strain_predictions_df['gram_stain'] == 'positive'
]['antimicrobial_predictive_probability'].mean()
reward = 0.6 * gram_negative_score + 0.4 * gram_positive_score
return reward
```
### 数据可视化
```python
import matplotlib.pyplot as plt
import seaborn as sns
# 读取菌株预测数据
strain_df = result.strain_predictions
# 按预测概率排序
strain_df_sorted = strain_df.sort_values('antimicrobial_predictive_probability', ascending=False)
# 绘制柱状图
plt.figure(figsize=(15, 6))
plt.bar(range(len(strain_df_sorted)),
strain_df_sorted['antimicrobial_predictive_probability'],
color=['red' if x == 1 else 'blue' for x in strain_df_sorted['growth_inhibition']])
plt.xlabel('菌株索引')
plt.ylabel('抗菌预测概率')
plt.title('分子对 40 种菌株的抗菌活性预测')
plt.xticks(rotation=90)
plt.tight_layout()
plt.show()
# 按革兰染色分组可视化
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for idx, gram_type in enumerate(['negative', 'positive']):
data = strain_df[strain_df['gram_stain'] == gram_type]
axes[idx].barh(data['strain_name'], data['antimicrobial_predictive_probability'])
axes[idx].set_xlabel('预测概率')
axes[idx].set_title(f'革兰{gram_type}')
axes[idx].tick_params(axis='y', labelsize=8)
plt.tight_layout()
plt.show()
```
---
## 性能和存储建议
- **聚合结果**: 每个分子 1 行,适合快速筛选
- **菌株级别预测**: 每个分子 40 行,适合详细分析和强化学习
- **存储空间**: 包含菌株预测的文件约为仅聚合结果的 40 倍大小
- **推荐做法**:
- 初筛时使用聚合结果
- 对候选分子使用菌株级别预测进行深入分析