Files
SIME/test/mole_predict_singal_mole.py
hotwa 34102cf459 1. 代码修改
models/broad_spectrum_predictor.py:
 新增 StrainPrediction dataclass(单个菌株预测结果)
 更新 BroadSpectrumResult 添加 strain_predictions 字段(pandas.DataFrame 类型)
 添加 to_strain_predictions_list() 方法(类型安全转换)
 新增 _prepare_strain_level_predictions() 方法
 修改 predict_batch() 方法支持 include_strain_predictions 参数
utils/mole_predictor.py:
 添加 include_strain_predictions 参数到所有函数
 添加命令行参数 --include-strain-predictions
 实现菌株级别数据与聚合结果的合并逻辑
 更新所有函数签名和文档字符串
2. 测试验证
 测试基本功能(仅聚合结果): test_3.csv → 3 行输出
 测试菌株级别预测功能: test_3.csv → 120 行输出(3 × 40)
 验证输出格式正确性
 验证每个分子都有完整的 40 个菌株预测
 验证革兰染色信息正确(18 个阴性菌 + 22 个阳性菌)
3. 文档更新
README.md:
 更新命令行使用示例
 添加 Python API 使用示例(包含菌株预测)
 添加详细的输出格式说明
 添加 40 种菌株列表概览
 添加数据使用场景示例(强化学习、筛选、可视化)
Data/mole/README.md:
 新增"菌株级别预测详情"章节
 完整的 40 种菌株列表(分革兰阴性/阳性)
 数据访问方式示例(CSV 读取、Python API)
 强化学习应用场景(状态表示、奖励函数设计)
 数据可视化代码示例
 性能和存储建议
2025-10-17 16:46:04 +08:00

209 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
MolE 抗菌活性预测 Python API 示例
演示两种预测模式:
1. 聚合结果模式(默认)
2. 菌株级别预测模式
"""
import sys
from pathlib import Path
# 添加项目根目录到 Python 路径(使用 pathlib
project_root = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(project_root))
from models import (
ParallelBroadSpectrumPredictor,
PredictionConfig,
MoleculeInput
)
def print_separator(title):
"""打印分隔线"""
print("\n" + "=" * 70)
print(f" {title}")
print("=" * 70 + "\n")
def demo_aggregated_mode():
"""演示聚合结果模式(默认)"""
print_separator("模式 1: 聚合结果模式(默认)")
# 创建配置
config = PredictionConfig(
batch_size=10,
device="auto" # 自动检测 CUDA
)
# 创建预测器
predictor = ParallelBroadSpectrumPredictor(config)
# 准备测试分子
test_molecules = [
MoleculeInput(smiles="CCO", chem_id="ethanol"),
MoleculeInput(smiles="c1ccccc1", chem_id="benzene"),
MoleculeInput(smiles="CC(=O)O", chem_id="acetic_acid"),
]
print(f"测试分子数: {len(test_molecules)}")
print(f"SMILES 示例: {test_molecules[0].smiles}")
# 执行预测(不包含菌株级别预测)
print("\n开始预测(聚合模式)...")
results = predictor.predict_batch(test_molecules, include_strain_predictions=False)
# 打印结果
print(f"\n预测完成!共 {len(results)} 个结果\n")
for result in results:
print(f"化合物ID: {result.chem_id}")
print(f" - 广谱抗菌: {'' if result.broad_spectrum else ''}")
print(f" - 总体抗菌得分: {result.apscore_total:.4f}")
print(f" - 革兰阴性菌得分: {result.apscore_gnegative:.4f}")
print(f" - 革兰阳性菌得分: {result.apscore_gpositive:.4f}")
print(f" - 抑制菌株总数: {result.ginhib_total} / 40")
print(f" - 抑制革兰阴性菌数: {result.ginhib_gnegative}")
print(f" - 抑制革兰阳性菌数: {result.ginhib_gpositive}")
print(f" - strain_predictions: {result.strain_predictions}") # 应该是 None
print()
# 返回结果供后续使用
return results
def demo_strain_level_mode():
"""演示菌株级别预测模式"""
print_separator("模式 2: 菌株级别预测模式")
# 创建预测器
predictor = ParallelBroadSpectrumPredictor()
# 使用一个有趣的抗菌分子进行测试(氟喹诺酮类似物)
test_molecule = MoleculeInput(
smiles="FC1=CC=C(CN2C[C@@H]3C[C@H]2CN3C2CC2)N=C1",
chem_id="test_antibacterial"
)
print(f"测试分子: {test_molecule.chem_id}")
print(f"SMILES: {test_molecule.smiles}")
# 执行预测(包含菌株级别预测)
print("\n开始预测(菌株级别模式)...")
results = predictor.predict_batch([test_molecule], include_strain_predictions=True)
result = results[0]
# 1. 打印聚合结果
print(f"\n聚合结果:")
print(f" - 化合物ID: {result.chem_id}")
print(f" - 广谱抗菌: {'' if result.broad_spectrum else ''}")
print(f" - 总体抗菌得分: {result.apscore_total:.4f}")
print(f" - 抑制菌株总数: {result.ginhib_total} / 40")
# 2. 打印菌株级别预测数据结构
print(f"\n菌株级别预测数据:")
if result.strain_predictions is not None:
strain_df = result.strain_predictions
print(f" - 数据类型: {type(strain_df)}")
print(f" - 数据形状: {strain_df.shape} (行, 列)")
print(f" - 列名: {list(strain_df.columns)}")
print(f" - 内存占用: {strain_df.memory_usage(deep=True).sum() / 1024:.2f} KB")
# 3. 展示前 5 个菌株的预测
print(f"\n前 5 个菌株的预测:")
print(strain_df.head(5).to_string(index=False))
# 4. 统计信息
print(f"\n统计信息:")
print(f" - 预测概率范围: [{strain_df['antimicrobial_predictive_probability'].min():.6f}, "
f"{strain_df['antimicrobial_predictive_probability'].max():.6f}]")
print(f" - 预测概率平均值: {strain_df['antimicrobial_predictive_probability'].mean():.6f}")
print(f" - 被抑制菌株数: {strain_df['growth_inhibition'].sum()}")
print(f" - 革兰阴性菌数: {len(strain_df[strain_df['gram_stain'] == 'negative'])}")
print(f" - 革兰阳性菌数: {len(strain_df[strain_df['gram_stain'] == 'positive'])}")
# 5. 展示被抑制的菌株(如果有)
inhibited = strain_df[strain_df['growth_inhibition'] == 1]
if len(inhibited) > 0:
print(f"\n被抑制的菌株 ({len(inhibited)} 个):")
print(inhibited[['strain_name', 'antimicrobial_predictive_probability', 'gram_stain']].to_string(index=False))
else:
print(f"\n该分子未预测抑制任何菌株")
# 6. 强化学习应用示例
print(f"\n强化学习应用示例:")
# 提取预测概率作为状态向量
state_vector = strain_df['antimicrobial_predictive_probability'].values
print(f" - 状态向量形状: {state_vector.shape}")
print(f" - 状态向量类型: {type(state_vector)}")
print(f" - 前 10 个值: {state_vector[:10]}")
# 提取多维特征
state_features = strain_df[[
'antimicrobial_predictive_probability',
'growth_inhibition'
]].values
print(f" - 多维特征形状: {state_features.shape}")
# 按革兰染色分组
gram_negative_probs = strain_df[
strain_df['gram_stain'] == 'negative'
]['antimicrobial_predictive_probability'].values
print(f" - 革兰阴性菌概率向量形状: {gram_negative_probs.shape}")
# 7. 转换为类型安全的列表(可选)
print(f"\n转换为 StrainPrediction 列表:")
strain_list = result.to_strain_predictions_list()
print(f" - 列表长度: {len(strain_list)}")
print(f" - 元素类型: {type(strain_list[0])}")
print(f" - 第一个元素:")
first_strain = strain_list[0]
print(f" * pred_id: {first_strain.pred_id}")
print(f" * strain_name: {first_strain.strain_name}")
print(f" * antimicrobial_predictive_probability: {first_strain.antimicrobial_predictive_probability:.6f}")
print(f" * growth_inhibition: {first_strain.growth_inhibition}")
print(f" * gram_stain: {first_strain.gram_stain}")
else:
print(" 警告: strain_predictions 为 None未启用菌株级别预测")
return result
def main():
"""主函数"""
print("\n" + "🧪" * 35)
print(" MolE 抗菌活性预测 Python API 示例")
print("🧪" * 35)
# 演示模式 1: 聚合结果
aggregated_results = demo_aggregated_mode()
# 演示模式 2: 菌株级别预测
strain_level_result = demo_strain_level_mode()
# 总结
print_separator("总结")
print("✅ 模式 1 (聚合结果): 适合快速筛选大量分子")
print(" - 每个分子返回 1 个 BroadSpectrumResult 对象")
print(" - 包含 8 个聚合指标")
print(" - strain_predictions = None")
print()
print("✅ 模式 2 (菌株级别): 适合详细分析和强化学习")
print(" - 每个分子返回 1 个 BroadSpectrumResult 对象")
print(" - 包含 8 个聚合指标 + 40 行菌株预测数据")
print(" - strain_predictions = DataFrame (40 rows × 7 columns)")
print(" - 可直接提取为 numpy array 用于 RL")
print()
print("🎯 推荐使用场景:")
print(" - 初筛: 模式 1")
print(" - 详细分析/RL 训练: 模式 2")
print()
if __name__ == "__main__":
main()