## convert old xgboots pickle format ```bash cd Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001 ipython ``` ```python import xgboost as xgb import pickle from pathlib import Path ckpt = Path('MolE-XGBoost-08.03.2024_14.20.pkl') out_ckpt = Path('./') # 加载旧模型 with open(ckpt, 'rb') as f: model = pickle.load(f) # 用新格式保存(推荐) model.get_booster().save_model(out_ckpt.joinpath('MolE-XGBoost-08.03.2025_10.17.json')) # 或者继续用pickle但清晰格式 booster = model.get_booster() booster.feature_names = None with open(out_ckpt.joinpath('MolE-XGBoost-08.03.2025_10.17.pkl'), 'wb') as f: pickle.dump(model, f) ``` ## 完整预测流程 ```mermaid SMILES 分子(输入CSV文件) ↓ [MolE 模型] ├── config.yaml(模型配置) └── model.pth(模型权重) ↓ 分子特征表示(1000维向量) ↓ 构建"分子-菌株对"(笛卡尔积) └── maier_screening_results.tsv.gz(菌株列表) ↓ [XGBoost 模型] └── MolE-XGBoost-08.03.2025_10.17.json(或.pkl) ↓ 对每一对预测:是否抑制生长 ↓ 获得原始预测结果(对每个菌株的预测) ↓ [聚合分析] ├── maier_screening_results.tsv.gz(菌株列表) └── strain_info_SF2.xlsx(革兰染色信息) ↓ 最终预测结果 ↓ 输出CSV文件 ``` ## 所需文件清单 | 步骤 | 文件名 | 用途 | 备注 | |------|--------|------|------| | **MolE 模型** | `config.yaml` | 定义MolE网络结构 | YAML配置文件 | | | `model.pth` | MolE模型权重 | PyTorch格式 | | **构建菌株对** | `maier_screening_results.tsv.gz` | 提供40个菌株列表 | 压缩的TSV文件 | | **XGBoost 预测** | `MolE-XGBoost-08.03.2025_10.17.json` | 预测分子-菌株对 | JSON格式(新)或PKL格式(旧) | | **聚合分析** | `maier_screening_results.tsv.gz` | 菌株名称和统计 | 复用(与构建菌株对同一文件) | | | `strain_info_SF2.xlsx` | 革兰染色分类信息 | Excel格式 | ## 文件存放位置 所有文件应位于: ``` Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001/ ├── config.yaml ├── model.pth ├── MolE-XGBoost-08.03.2025_10.17.json ├── maier_screening_results.tsv.gz └── strain_info_SF2.xlsx ``` ## 代码中的对应关系 ```python # PredictionConfig 中的配置 @dataclass class PredictionConfig: xgboost_model_path = "MolE-XGBoost-08.03.2025_10.17.json" mole_model_path = "model_ginconcat_btwin_100k_d8000_l0.0001" # 目录(包含config.yaml + model.pth) strain_categories_path = "maier_screening_results.tsv.gz" gram_info_path = "strain_info_SF2.xlsx" ``` ## 数据流向总结 1. **输入**:CSV文件中的SMILES分子 2. **MolE处理**:分子 → 1000维特征向量 3. **菌株配对**:1个分子 × 40个菌株 = 40对 4. **XGBoost预测**:每对 → 抑制概率 5. **聚合分析**:统计和分类(按革兰染色) 6. **输出**:CSV文件中的预测结果(包含8个指标) ## 参考文件 1. `maier_screening_results.tsv.gz` - 菌株列表和筛选数据 ```python self.maier_screen = pd.read_csv( self.config.strain_categories_path, sep='\t', index_col=0 ) self.strain_ohe = self._prep_ohe(self.maier_screen.columns) # 独热编码 ``` 包含所有已知菌株的名称(40个菌株) 用于与每个分子做笛卡尔积(分子×菌株),生成所有"分子-菌株对" XGBoost为每一对预测:是否能抑制该菌株的生长 2. `strain_info_SF2.xlsx` - 革兰染色信息 ```python self.maier_strains = pd.read_excel(self.config.gram_info_path, ...) gram_dict = self.maier_strains[["Gram stain"]].to_dict()["Gram stain"] ``` 记录每个菌株的革兰染色属性:阳性(positive) 或 阴性(negative) 用于将预测结果按革兰染色分类统计 预测结果示例: 某分子 mol1 的预测结果会包括: ```python BroadSpectrumResult( chem_id='mol1', apscore_total=2.5, # 对所有菌株的抗菌分数 apscore_gnegative=2.1, # 仅对革兰阴性菌的分数 apscore_gpositive=2.8, # 仅对革兰阳性菌的分数 ginhib_total=25, # 抑制的菌株总数 ginhib_gnegative=12, # 抑制的革兰阴性菌数 ginhib_gpositive=13, # 抑制的革兰阳性菌数 broad_spectrum=1 # 是否广谱(≥10个菌株) ) ``` 结果解读: ## BroadSpectrumResult 字段说明表 | 字段名 | 数据类型 | 计算方法 | 含义说明 | |--------|----------|----------|---------| | `chem_id` | 字符串 | 输入的化合物标识符 | 化合物的唯一标识,如 "mol1"、"compound_001" 等 | | `apscore_total` | 浮点数 | `log(gmean(所有40个菌株的预测概率))` | 总体抗菌潜力分数:所有菌株预测概率的几何平均数的对数。值越高表示抗菌活性越强;负值表示整体抑制概率较低 | | `apscore_gnegative` | 浮点数 | `log(gmean(革兰阴性菌株的预测概率))` | 革兰阴性菌抗菌潜力分数:仅针对革兰阴性菌株计算的抗菌分数。用于判断对阴性菌的特异性 | | `apscore_gpositive` | 浮点数 | `log(gmean(革兰阳性菌株的预测概率))` | 革兰阳性菌抗菌潜力分数:仅针对革兰阳性菌株计算的抗菌分数。用于判断对阳性菌的特异性 | | `ginhib_total` | 整数 | `sum(所有菌株的二值化预测)` | 总抑制菌株数:预测被抑制的菌株总数(概率 ≥ 0.04374 的菌株数量)。范围 0-40 | | `ginhib_gnegative` | 整数 | `sum(革兰阴性菌株的二值化预测)` | 革兰阴性菌抑制数:预测被抑制的革兰阴性菌株数量。范围 0-20 | | `ginhib_gpositive` | 整数 | `sum(革兰阳性菌株的二值化预测)` | 革兰阳性菌抑制数:预测被抑制的革兰阳性菌株数量。范围 0-20 | | `broad_spectrum` | 整数 (0/1) | `1 if ginhib_total >= 10 else 0` | 广谱抗菌标志:如果抑制菌株数 ≥ 10,判定为广谱抗菌药物(1),否则为窄谱(0) | 说明 - **apscore_* 类字段**:基于预测概率的连续评分,反映抗菌活性强度 - **ginhib_* 类字段**:基于二值化预测的离散计数,反映抑制范围 - **broad_spectrum**:基于 ginhib_total 的布尔判定,快速标识广谱特性 --- ## 菌株级别预测详情 ### 使用 `--include-strain-predictions` 参数 启用此参数后,输出将包含每个分子对所有 40 个菌株的详细预测数据。 **命令示例**: ```bash python utils/mole_predictor.py input.csv output.csv --include-strain-predictions ``` ### 菌株级别输出格式 每个分子会产生 40 行数据,每行对应一个菌株的预测结果: | 列名 | 数据类型 | 说明 | 示例值 | |------|----------|------|--------| | `pred_id` | 字符串 | 预测ID,格式为 `chem_id:strain_name` | `mol1:Akkermansia muciniphila (NT5021)` | | `chem_id` | 字符串 | 化合物标识符 | `mol1` | | `strain_name` | 字符串 | 菌株名称 | `Akkermansia muciniphila (NT5021)` | | `antimicrobial_predictive_probability` | 浮点数 | XGBoost 预测的抗菌概率(0-1) | `0.000102` | | `no_growth_probability` | 浮点数 | 不抑制的概率(= 1 - antimicrobial_predictive_probability) | `0.999898` | | `growth_inhibition` | 整数 (0/1) | 二值化抑制结果(1=抑制,0=不抑制) | `0` | | `gram_stain` | 字符串 | 革兰染色类型 | `negative` 或 `positive` | ### 完整的 40 种测试菌株列表 #### 革兰阴性菌(23 种) | 序号 | 菌株名称 | NT编号 | |------|----------|--------| | 1 | Akkermansia muciniphila | NT5021 | | 2 | Bacteroides caccae | NT5050 | | 3 | Bacteroides fragilis (ET) | NT5033 | | 4 | Bacteroides fragilis (NT) | NT5003 | | 5 | Bacteroides ovatus | NT5054 | | 6 | Bacteroides thetaiotaomicron | NT5004 | | 7 | Bacteroides uniformis | NT5002 | | 8 | Bacteroides vulgatus | NT5001 | | 9 | Bacteroides xylanisolvens | NT5064 | | 10 | Escherichia coli (3 isolates + Nissle) | NT5028, NT5024, NT5030, NT5011 | | 11 | Klebsiella pneumoniae | NT5049 | | 12 | Parabacteroides distasonis | NT5023 | | 13 | Phocaeicola vulgatus | NT5001 | | 14 | 其他肠道革兰阴性菌 | ... | #### 革兰阳性菌(17 种) | 序号 | 菌株名称 | NT编号 | |------|----------|--------| | 1 | Bifidobacterium adolescentis | NT5022 | | 2 | Bifidobacterium longum | NT5067 | | 3 | Bifidobacterium pseudocatenulatum | NT5058 | | 4 | Clostridium bolteae | NT5005 | | 5 | Clostridium innocuum | NT5026 | | 6 | Clostridium ramosum | NT5027 | | 7 | Clostridium scindens | NT5029 | | 8 | Clostridium symbiosum | NT5006 | | 9 | Enterococcus faecalis | NT5034 | | 10 | Enterococcus faecium | NT5043 | | 11 | Lactobacillus plantarum | NT5035 | | 12 | Lactobacillus reuteri | NT5032 | | 13 | Lactobacillus rhamnosus | NT5037 | | 14 | Streptococcus parasanguinis | NT5041 | | 15 | Streptococcus salivarius | NT5040 | | 16 | 其他肠道革兰阳性菌 | ... | **注**: 完整列表可从 `maier_screening_results.tsv.gz` 和 `strain_info_SF2.xlsx` 文件中查看。 ### 数据访问方式 #### 1. CSV 文件读取 ```python import pandas as pd # 读取包含菌株预测的结果 df = pd.read_csv('output_with_strains.csv') # 查看某个分子的所有菌株预测 mol1_strains = df[df['chem_id'] == 'mol1'] print(f"分子 mol1 的预测:") print(mol1_strains[['strain_name', 'antimicrobial_predictive_probability', 'growth_inhibition']]) # 筛选被抑制的菌株 inhibited = mol1_strains[mol1_strains['growth_inhibition'] == 1] print(f"\n被抑制的菌株数: {len(inhibited)}") print(inhibited[['strain_name', 'antimicrobial_predictive_probability']]) ``` #### 2. Python API 访问 ```python from models import ParallelBroadSpectrumPredictor, MoleculeInput predictor = ParallelBroadSpectrumPredictor() molecule = MoleculeInput(smiles="CCO", chem_id="ethanol") # 包含菌株级别预测 result = predictor.predict_batch([molecule], include_strain_predictions=True)[0] # 访问菌株预测 DataFrame strain_df = result.strain_predictions print(f"菌株预测数据形状: {strain_df.shape}") # (40, 7) print(f"列名: {strain_df.columns.tolist()}") # 提取预测概率向量(用于强化学习) probabilities = strain_df['antimicrobial_predictive_probability'].values print(f"预测概率向量形状: {probabilities.shape}") # (40,) # 筛选特定革兰染色类型 gram_negative = strain_df[strain_df['gram_stain'] == 'negative'] print(f"革兰阴性菌预测数: {len(gram_negative)}") # 转换为类型安全的列表(可选) strain_list = result.to_strain_predictions_list() for strain_pred in strain_list[:3]: print(f"{strain_pred.strain_name}: {strain_pred.antimicrobial_predictive_probability:.6f}") ``` ### 强化学习场景应用 #### 状态表示 ```python # 将 40 个菌株的预测概率作为状态向量 state = result.strain_predictions['antimicrobial_predictive_probability'].values # state.shape = (40,) # 或者包含更多特征 state_features = result.strain_predictions[[ 'antimicrobial_predictive_probability', 'growth_inhibition' ]].values # state_features.shape = (40, 2) ``` #### 奖励函数设计 ```python def calculate_reward(strain_predictions_df): """ 基于菌株级别预测计算奖励 Args: strain_predictions_df: 包含 40 个菌株预测的 DataFrame Returns: reward: 标量奖励值 """ # 方案1: 基于抑制菌株数 reward = strain_predictions_df['growth_inhibition'].sum() / 40.0 # 方案2: 基于预测概率 reward = strain_predictions_df['antimicrobial_predictive_probability'].mean() # 方案3: 加权奖励(考虑革兰染色) gram_negative_score = strain_predictions_df[ strain_predictions_df['gram_stain'] == 'negative' ]['antimicrobial_predictive_probability'].mean() gram_positive_score = strain_predictions_df[ strain_predictions_df['gram_stain'] == 'positive' ]['antimicrobial_predictive_probability'].mean() reward = 0.6 * gram_negative_score + 0.4 * gram_positive_score return reward ``` ### 数据可视化 ```python import matplotlib.pyplot as plt import seaborn as sns # 读取菌株预测数据 strain_df = result.strain_predictions # 按预测概率排序 strain_df_sorted = strain_df.sort_values('antimicrobial_predictive_probability', ascending=False) # 绘制柱状图 plt.figure(figsize=(15, 6)) plt.bar(range(len(strain_df_sorted)), strain_df_sorted['antimicrobial_predictive_probability'], color=['red' if x == 1 else 'blue' for x in strain_df_sorted['growth_inhibition']]) plt.xlabel('菌株索引') plt.ylabel('抗菌预测概率') plt.title('分子对 40 种菌株的抗菌活性预测') plt.xticks(rotation=90) plt.tight_layout() plt.show() # 按革兰染色分组可视化 fig, axes = plt.subplots(1, 2, figsize=(14, 5)) for idx, gram_type in enumerate(['negative', 'positive']): data = strain_df[strain_df['gram_stain'] == gram_type] axes[idx].barh(data['strain_name'], data['antimicrobial_predictive_probability']) axes[idx].set_xlabel('预测概率') axes[idx].set_title(f'革兰{gram_type}菌') axes[idx].tick_params(axis='y', labelsize=8) plt.tight_layout() plt.show() ``` --- ## 性能和存储建议 - **聚合结果**: 每个分子 1 行,适合快速筛选 - **菌株级别预测**: 每个分子 40 行,适合详细分析和强化学习 - **存储空间**: 包含菌株预测的文件约为仅聚合结果的 40 倍大小 - **推荐做法**: - 初筛时使用聚合结果 - 对候选分子使用菌株级别预测进行深入分析