#!/usr/bin/env python # -*- coding: utf-8 -*- """ MolE 抗菌活性预测 Python API 示例 演示两种预测模式: 1. 聚合结果模式(默认) 2. 菌株级别预测模式 """ import sys from pathlib import Path # 添加项目根目录到 Python 路径(使用 pathlib) project_root = Path(__file__).resolve().parent.parent sys.path.insert(0, str(project_root)) from models import ( ParallelBroadSpectrumPredictor, PredictionConfig, MoleculeInput ) def print_separator(title): """打印分隔线""" print("\n" + "=" * 70) print(f" {title}") print("=" * 70 + "\n") def demo_aggregated_mode(): """演示聚合结果模式(默认)""" print_separator("模式 1: 聚合结果模式(默认)") # 创建配置 config = PredictionConfig( batch_size=10, device="auto" # 自动检测 CUDA ) # 创建预测器 predictor = ParallelBroadSpectrumPredictor(config) # 准备测试分子 test_molecules = [ MoleculeInput(smiles="CCO", chem_id="ethanol"), MoleculeInput(smiles="c1ccccc1", chem_id="benzene"), MoleculeInput(smiles="CC(=O)O", chem_id="acetic_acid"), ] print(f"测试分子数: {len(test_molecules)}") print(f"SMILES 示例: {test_molecules[0].smiles}") # 执行预测(不包含菌株级别预测) print("\n开始预测(聚合模式)...") results = predictor.predict_batch(test_molecules, include_strain_predictions=False) # 打印结果 print(f"\n预测完成!共 {len(results)} 个结果\n") for result in results: print(f"化合物ID: {result.chem_id}") print(f" - 广谱抗菌: {'是' if result.broad_spectrum else '否'}") print(f" - 总体抗菌得分: {result.apscore_total:.4f}") print(f" - 革兰阴性菌得分: {result.apscore_gnegative:.4f}") print(f" - 革兰阳性菌得分: {result.apscore_gpositive:.4f}") print(f" - 抑制菌株总数: {result.ginhib_total} / 40") print(f" - 抑制革兰阴性菌数: {result.ginhib_gnegative}") print(f" - 抑制革兰阳性菌数: {result.ginhib_gpositive}") print(f" - strain_predictions: {result.strain_predictions}") # 应该是 None print() # 返回结果供后续使用 return results def demo_strain_level_mode(): """演示菌株级别预测模式""" print_separator("模式 2: 菌株级别预测模式") # 创建预测器 predictor = ParallelBroadSpectrumPredictor() # 使用一个有趣的抗菌分子进行测试(氟喹诺酮类似物) test_molecule = MoleculeInput( smiles="FC1=CC=C(CN2C[C@@H]3C[C@H]2CN3C2CC2)N=C1", chem_id="test_antibacterial" ) print(f"测试分子: {test_molecule.chem_id}") print(f"SMILES: {test_molecule.smiles}") # 执行预测(包含菌株级别预测) print("\n开始预测(菌株级别模式)...") results = predictor.predict_batch([test_molecule], include_strain_predictions=True) result = results[0] # 1. 打印聚合结果 print(f"\n聚合结果:") print(f" - 化合物ID: {result.chem_id}") print(f" - 广谱抗菌: {'是' if result.broad_spectrum else '否'}") print(f" - 总体抗菌得分: {result.apscore_total:.4f}") print(f" - 抑制菌株总数: {result.ginhib_total} / 40") # 2. 打印菌株级别预测数据结构 print(f"\n菌株级别预测数据:") if result.strain_predictions is not None: strain_df = result.strain_predictions print(f" - 数据类型: {type(strain_df)}") print(f" - 数据形状: {strain_df.shape} (行, 列)") print(f" - 列名: {list(strain_df.columns)}") print(f" - 内存占用: {strain_df.memory_usage(deep=True).sum() / 1024:.2f} KB") # 3. 展示前 5 个菌株的预测 print(f"\n前 5 个菌株的预测:") print(strain_df.head(5).to_string(index=False)) # 4. 统计信息 print(f"\n统计信息:") print(f" - 预测概率范围: [{strain_df['antimicrobial_predictive_probability'].min():.6f}, " f"{strain_df['antimicrobial_predictive_probability'].max():.6f}]") print(f" - 预测概率平均值: {strain_df['antimicrobial_predictive_probability'].mean():.6f}") print(f" - 被抑制菌株数: {strain_df['growth_inhibition'].sum()}") print(f" - 革兰阴性菌数: {len(strain_df[strain_df['gram_stain'] == 'negative'])}") print(f" - 革兰阳性菌数: {len(strain_df[strain_df['gram_stain'] == 'positive'])}") # 5. 展示被抑制的菌株(如果有) inhibited = strain_df[strain_df['growth_inhibition'] == 1] if len(inhibited) > 0: print(f"\n被抑制的菌株 ({len(inhibited)} 个):") print(inhibited[['strain_name', 'antimicrobial_predictive_probability', 'gram_stain']].to_string(index=False)) else: print(f"\n该分子未预测抑制任何菌株") # 6. 强化学习应用示例 print(f"\n强化学习应用示例:") # 提取预测概率作为状态向量 state_vector = strain_df['antimicrobial_predictive_probability'].values print(f" - 状态向量形状: {state_vector.shape}") print(f" - 状态向量类型: {type(state_vector)}") print(f" - 前 10 个值: {state_vector[:10]}") # 提取多维特征 state_features = strain_df[[ 'antimicrobial_predictive_probability', 'growth_inhibition' ]].values print(f" - 多维特征形状: {state_features.shape}") # 按革兰染色分组 gram_negative_probs = strain_df[ strain_df['gram_stain'] == 'negative' ]['antimicrobial_predictive_probability'].values print(f" - 革兰阴性菌概率向量形状: {gram_negative_probs.shape}") # 7. 转换为类型安全的列表(可选) print(f"\n转换为 StrainPrediction 列表:") strain_list = result.to_strain_predictions_list() print(f" - 列表长度: {len(strain_list)}") print(f" - 元素类型: {type(strain_list[0])}") print(f" - 第一个元素:") first_strain = strain_list[0] print(f" * pred_id: {first_strain.pred_id}") print(f" * strain_name: {first_strain.strain_name}") print(f" * antimicrobial_predictive_probability: {first_strain.antimicrobial_predictive_probability:.6f}") print(f" * growth_inhibition: {first_strain.growth_inhibition}") print(f" * gram_stain: {first_strain.gram_stain}") else: print(" 警告: strain_predictions 为 None(未启用菌株级别预测)") return result def main(): """主函数""" print("\n" + "🧪" * 35) print(" MolE 抗菌活性预测 Python API 示例") print("🧪" * 35) # 演示模式 1: 聚合结果 aggregated_results = demo_aggregated_mode() # 演示模式 2: 菌株级别预测 strain_level_result = demo_strain_level_mode() # 总结 print_separator("总结") print("✅ 模式 1 (聚合结果): 适合快速筛选大量分子") print(" - 每个分子返回 1 个 BroadSpectrumResult 对象") print(" - 包含 8 个聚合指标") print(" - strain_predictions = None") print() print("✅ 模式 2 (菌株级别): 适合详细分析和强化学习") print(" - 每个分子返回 1 个 BroadSpectrumResult 对象") print(" - 包含 8 个聚合指标 + 40 行菌株预测数据") print(" - strain_predictions = DataFrame (40 rows × 7 columns)") print(" - 可直接提取为 numpy array 用于 RL") print() print("🎯 推荐使用场景:") print(" - 初筛: 模式 1") print(" - 详细分析/RL 训练: 模式 2") print() if __name__ == "__main__": main()