Files
SIME/Data/mole/README.md
hotwa 34102cf459 1. 代码修改
models/broad_spectrum_predictor.py:
 新增 StrainPrediction dataclass(单个菌株预测结果)
 更新 BroadSpectrumResult 添加 strain_predictions 字段(pandas.DataFrame 类型)
 添加 to_strain_predictions_list() 方法(类型安全转换)
 新增 _prepare_strain_level_predictions() 方法
 修改 predict_batch() 方法支持 include_strain_predictions 参数
utils/mole_predictor.py:
 添加 include_strain_predictions 参数到所有函数
 添加命令行参数 --include-strain-predictions
 实现菌株级别数据与聚合结果的合并逻辑
 更新所有函数签名和文档字符串
2. 测试验证
 测试基本功能(仅聚合结果): test_3.csv → 3 行输出
 测试菌株级别预测功能: test_3.csv → 120 行输出(3 × 40)
 验证输出格式正确性
 验证每个分子都有完整的 40 个菌株预测
 验证革兰染色信息正确(18 个阴性菌 + 22 个阳性菌)
3. 文档更新
README.md:
 更新命令行使用示例
 添加 Python API 使用示例(包含菌株预测)
 添加详细的输出格式说明
 添加 40 种菌株列表概览
 添加数据使用场景示例(强化学习、筛选、可视化)
Data/mole/README.md:
 新增"菌株级别预测详情"章节
 完整的 40 种菌株列表(分革兰阴性/阳性)
 数据访问方式示例(CSV 读取、Python API)
 强化学习应用场景(状态表示、奖励函数设计)
 数据可视化代码示例
 性能和存储建议
2025-10-17 16:46:04 +08:00

13 KiB
Raw Permalink Blame History

convert old xgboots pickle format

cd Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001
ipython
import xgboost as xgb
import pickle
from pathlib import Path
ckpt = Path('MolE-XGBoost-08.03.2024_14.20.pkl')
out_ckpt = Path('./')

# 加载旧模型
with open(ckpt, 'rb') as f:
    model = pickle.load(f)

# 用新格式保存(推荐)
model.get_booster().save_model(out_ckpt.joinpath('MolE-XGBoost-08.03.2025_10.17.json'))

# 或者继续用pickle但清晰格式
booster = model.get_booster()
booster.feature_names = None
with open(out_ckpt.joinpath('MolE-XGBoost-08.03.2025_10.17.pkl'), 'wb') as f:
    pickle.dump(model, f)

完整预测流程

SMILES 分子输入CSV文件
    ↓
[MolE 模型]
    ├── config.yaml模型配置
    └── model.pth模型权重
    ↓
分子特征表示1000维向量
    ↓
构建"分子-菌株对"(笛卡尔积)
    └── maier_screening_results.tsv.gz菌株列表
    ↓
[XGBoost 模型]
    └── MolE-XGBoost-08.03.2025_10.17.json或.pkl
    ↓
对每一对预测:是否抑制生长
    ↓
获得原始预测结果(对每个菌株的预测)
    ↓
[聚合分析]
    ├── maier_screening_results.tsv.gz菌株列表
    └── strain_info_SF2.xlsx革兰染色信息
    ↓
最终预测结果
    ↓
输出CSV文件

所需文件清单

步骤 文件名 用途 备注
MolE 模型 config.yaml 定义MolE网络结构 YAML配置文件
model.pth MolE模型权重 PyTorch格式
构建菌株对 maier_screening_results.tsv.gz 提供40个菌株列表 压缩的TSV文件
XGBoost 预测 MolE-XGBoost-08.03.2025_10.17.json 预测分子-菌株对 JSON格式或PKL格式
聚合分析 maier_screening_results.tsv.gz 菌株名称和统计 复用(与构建菌株对同一文件)
strain_info_SF2.xlsx 革兰染色分类信息 Excel格式

文件存放位置

所有文件应位于:

Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001/
├── config.yaml
├── model.pth
├── MolE-XGBoost-08.03.2025_10.17.json
├── maier_screening_results.tsv.gz
└── strain_info_SF2.xlsx

代码中的对应关系

# PredictionConfig 中的配置
@dataclass
class PredictionConfig:
    xgboost_model_path = "MolE-XGBoost-08.03.2025_10.17.json"
    mole_model_path = "model_ginconcat_btwin_100k_d8000_l0.0001"  # 目录包含config.yaml + model.pth
    strain_categories_path = "maier_screening_results.tsv.gz"
    gram_info_path = "strain_info_SF2.xlsx"

数据流向总结

  1. 输入CSV文件中的SMILES分子
  2. MolE处理:分子 → 1000维特征向量
  3. 菌株配对1个分子 × 40个菌株 = 40对
  4. XGBoost预测:每对 → 抑制概率
  5. 聚合分析:统计和分类(按革兰染色)
  6. 输出CSV文件中的预测结果包含8个指标

参考文件

  1. maier_screening_results.tsv.gz - 菌株列表和筛选数据
self.maier_screen = pd.read_csv(
    self.config.strain_categories_path, sep='\t', index_col=0
)
self.strain_ohe = self._prep_ohe(self.maier_screen.columns)  # 独热编码

包含所有已知菌株的名称40个菌株 用于与每个分子做笛卡尔积(分子×菌株),生成所有"分子-菌株对" XGBoost为每一对预测是否能抑制该菌株的生长

  1. strain_info_SF2.xlsx - 革兰染色信息
self.maier_strains = pd.read_excel(self.config.gram_info_path, ...)
gram_dict = self.maier_strains[["Gram stain"]].to_dict()["Gram stain"]

记录每个菌株的革兰染色属性:阳性(positive) 或 阴性(negative) 用于将预测结果按革兰染色分类统计

预测结果示例: 某分子 mol1 的预测结果会包括:

BroadSpectrumResult(
    chem_id='mol1',
    apscore_total=2.5,              # 对所有菌株的抗菌分数
    apscore_gnegative=2.1,          # 仅对革兰阴性菌的分数
    apscore_gpositive=2.8,          # 仅对革兰阳性菌的分数
    ginhib_total=25,                # 抑制的菌株总数
    ginhib_gnegative=12,            # 抑制的革兰阴性菌数
    ginhib_gpositive=13,            # 抑制的革兰阳性菌数
    broad_spectrum=1                # 是否广谱≥10个菌株
)

结果解读:

BroadSpectrumResult 字段说明表

字段名 数据类型 计算方法 含义说明
chem_id 字符串 输入的化合物标识符 化合物的唯一标识,如 "mol1"、"compound_001" 等
apscore_total 浮点数 log(gmean(所有40个菌株的预测概率)) 总体抗菌潜力分数:所有菌株预测概率的几何平均数的对数。值越高表示抗菌活性越强;负值表示整体抑制概率较低
apscore_gnegative 浮点数 log(gmean(革兰阴性菌株的预测概率)) 革兰阴性菌抗菌潜力分数:仅针对革兰阴性菌株计算的抗菌分数。用于判断对阴性菌的特异性
apscore_gpositive 浮点数 log(gmean(革兰阳性菌株的预测概率)) 革兰阳性菌抗菌潜力分数:仅针对革兰阳性菌株计算的抗菌分数。用于判断对阳性菌的特异性
ginhib_total 整数 sum(所有菌株的二值化预测) 总抑制菌株数:预测被抑制的菌株总数(概率 ≥ 0.04374 的菌株数量)。范围 0-40
ginhib_gnegative 整数 sum(革兰阴性菌株的二值化预测) 革兰阴性菌抑制数:预测被抑制的革兰阴性菌株数量。范围 0-20
ginhib_gpositive 整数 sum(革兰阳性菌株的二值化预测) 革兰阳性菌抑制数:预测被抑制的革兰阳性菌株数量。范围 0-20
broad_spectrum 整数 (0/1) 1 if ginhib_total >= 10 else 0 广谱抗菌标志:如果抑制菌株数 ≥ 10判定为广谱抗菌药物1否则为窄谱0

说明

  • apscore_ 类字段*:基于预测概率的连续评分,反映抗菌活性强度
  • ginhib_ 类字段*:基于二值化预测的离散计数,反映抑制范围
  • broad_spectrum:基于 ginhib_total 的布尔判定,快速标识广谱特性

菌株级别预测详情

使用 --include-strain-predictions 参数

启用此参数后,输出将包含每个分子对所有 40 个菌株的详细预测数据。

命令示例

python utils/mole_predictor.py input.csv output.csv --include-strain-predictions

菌株级别输出格式

每个分子会产生 40 行数据,每行对应一个菌株的预测结果:

列名 数据类型 说明 示例值
pred_id 字符串 预测ID格式为 chem_id:strain_name mol1:Akkermansia muciniphila (NT5021)
chem_id 字符串 化合物标识符 mol1
strain_name 字符串 菌株名称 Akkermansia muciniphila (NT5021)
antimicrobial_predictive_probability 浮点数 XGBoost 预测的抗菌概率0-1 0.000102
no_growth_probability 浮点数 不抑制的概率(= 1 - antimicrobial_predictive_probability 0.999898
growth_inhibition 整数 (0/1) 二值化抑制结果1=抑制0=不抑制) 0
gram_stain 字符串 革兰染色类型 negativepositive

完整的 40 种测试菌株列表

革兰阴性菌23 种)

序号 菌株名称 NT编号
1 Akkermansia muciniphila NT5021
2 Bacteroides caccae NT5050
3 Bacteroides fragilis (ET) NT5033
4 Bacteroides fragilis (NT) NT5003
5 Bacteroides ovatus NT5054
6 Bacteroides thetaiotaomicron NT5004
7 Bacteroides uniformis NT5002
8 Bacteroides vulgatus NT5001
9 Bacteroides xylanisolvens NT5064
10 Escherichia coli (3 isolates + Nissle) NT5028, NT5024, NT5030, NT5011
11 Klebsiella pneumoniae NT5049
12 Parabacteroides distasonis NT5023
13 Phocaeicola vulgatus NT5001
14 其他肠道革兰阴性菌 ...

革兰阳性菌17 种)

序号 菌株名称 NT编号
1 Bifidobacterium adolescentis NT5022
2 Bifidobacterium longum NT5067
3 Bifidobacterium pseudocatenulatum NT5058
4 Clostridium bolteae NT5005
5 Clostridium innocuum NT5026
6 Clostridium ramosum NT5027
7 Clostridium scindens NT5029
8 Clostridium symbiosum NT5006
9 Enterococcus faecalis NT5034
10 Enterococcus faecium NT5043
11 Lactobacillus plantarum NT5035
12 Lactobacillus reuteri NT5032
13 Lactobacillus rhamnosus NT5037
14 Streptococcus parasanguinis NT5041
15 Streptococcus salivarius NT5040
16 其他肠道革兰阳性菌 ...

: 完整列表可从 maier_screening_results.tsv.gzstrain_info_SF2.xlsx 文件中查看。

数据访问方式

1. CSV 文件读取

import pandas as pd

# 读取包含菌株预测的结果
df = pd.read_csv('output_with_strains.csv')

# 查看某个分子的所有菌株预测
mol1_strains = df[df['chem_id'] == 'mol1']
print(f"分子 mol1 的预测:")
print(mol1_strains[['strain_name', 'antimicrobial_predictive_probability', 'growth_inhibition']])

# 筛选被抑制的菌株
inhibited = mol1_strains[mol1_strains['growth_inhibition'] == 1]
print(f"\n被抑制的菌株数: {len(inhibited)}")
print(inhibited[['strain_name', 'antimicrobial_predictive_probability']])

2. Python API 访问

from models import ParallelBroadSpectrumPredictor, MoleculeInput

predictor = ParallelBroadSpectrumPredictor()
molecule = MoleculeInput(smiles="CCO", chem_id="ethanol")

# 包含菌株级别预测
result = predictor.predict_batch([molecule], include_strain_predictions=True)[0]

# 访问菌株预测 DataFrame
strain_df = result.strain_predictions
print(f"菌株预测数据形状: {strain_df.shape}")  # (40, 7)
print(f"列名: {strain_df.columns.tolist()}")

# 提取预测概率向量(用于强化学习)
probabilities = strain_df['antimicrobial_predictive_probability'].values
print(f"预测概率向量形状: {probabilities.shape}")  # (40,)

# 筛选特定革兰染色类型
gram_negative = strain_df[strain_df['gram_stain'] == 'negative']
print(f"革兰阴性菌预测数: {len(gram_negative)}")

# 转换为类型安全的列表(可选)
strain_list = result.to_strain_predictions_list()
for strain_pred in strain_list[:3]:
    print(f"{strain_pred.strain_name}: {strain_pred.antimicrobial_predictive_probability:.6f}")

强化学习场景应用

状态表示

# 将 40 个菌株的预测概率作为状态向量
state = result.strain_predictions['antimicrobial_predictive_probability'].values
# state.shape = (40,)

# 或者包含更多特征
state_features = result.strain_predictions[[
    'antimicrobial_predictive_probability',
    'growth_inhibition'
]].values
# state_features.shape = (40, 2)

奖励函数设计

def calculate_reward(strain_predictions_df):
    """
    基于菌株级别预测计算奖励
    
    Args:
        strain_predictions_df: 包含 40 个菌株预测的 DataFrame
        
    Returns:
        reward: 标量奖励值
    """
    # 方案1: 基于抑制菌株数
    reward = strain_predictions_df['growth_inhibition'].sum() / 40.0
    
    # 方案2: 基于预测概率
    reward = strain_predictions_df['antimicrobial_predictive_probability'].mean()
    
    # 方案3: 加权奖励(考虑革兰染色)
    gram_negative_score = strain_predictions_df[
        strain_predictions_df['gram_stain'] == 'negative'
    ]['antimicrobial_predictive_probability'].mean()
    
    gram_positive_score = strain_predictions_df[
        strain_predictions_df['gram_stain'] == 'positive'
    ]['antimicrobial_predictive_probability'].mean()
    
    reward = 0.6 * gram_negative_score + 0.4 * gram_positive_score
    
    return reward

数据可视化

import matplotlib.pyplot as plt
import seaborn as sns

# 读取菌株预测数据
strain_df = result.strain_predictions

# 按预测概率排序
strain_df_sorted = strain_df.sort_values('antimicrobial_predictive_probability', ascending=False)

# 绘制柱状图
plt.figure(figsize=(15, 6))
plt.bar(range(len(strain_df_sorted)), 
        strain_df_sorted['antimicrobial_predictive_probability'],
        color=['red' if x == 1 else 'blue' for x in strain_df_sorted['growth_inhibition']])
plt.xlabel('菌株索引')
plt.ylabel('抗菌预测概率')
plt.title('分子对 40 种菌株的抗菌活性预测')
plt.xticks(rotation=90)
plt.tight_layout()
plt.show()

# 按革兰染色分组可视化
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for idx, gram_type in enumerate(['negative', 'positive']):
    data = strain_df[strain_df['gram_stain'] == gram_type]
    axes[idx].barh(data['strain_name'], data['antimicrobial_predictive_probability'])
    axes[idx].set_xlabel('预测概率')
    axes[idx].set_title(f'革兰{gram_type}菌')
    axes[idx].tick_params(axis='y', labelsize=8)

plt.tight_layout()
plt.show()

性能和存储建议

  • 聚合结果: 每个分子 1 行,适合快速筛选
  • 菌株级别预测: 每个分子 40 行,适合详细分析和强化学习
  • 存储空间: 包含菌株预测的文件约为仅聚合结果的 40 倍大小
  • 推荐做法:
    • 初筛时使用聚合结果
    • 对候选分子使用菌株级别预测进行深入分析