1. 代码修改

models/broad_spectrum_predictor.py:
 新增 StrainPrediction dataclass(单个菌株预测结果)
 更新 BroadSpectrumResult 添加 strain_predictions 字段(pandas.DataFrame 类型)
 添加 to_strain_predictions_list() 方法(类型安全转换)
 新增 _prepare_strain_level_predictions() 方法
 修改 predict_batch() 方法支持 include_strain_predictions 参数
utils/mole_predictor.py:
 添加 include_strain_predictions 参数到所有函数
 添加命令行参数 --include-strain-predictions
 实现菌株级别数据与聚合结果的合并逻辑
 更新所有函数签名和文档字符串
2. 测试验证
 测试基本功能(仅聚合结果): test_3.csv → 3 行输出
 测试菌株级别预测功能: test_3.csv → 120 行输出(3 × 40)
 验证输出格式正确性
 验证每个分子都有完整的 40 个菌株预测
 验证革兰染色信息正确(18 个阴性菌 + 22 个阳性菌)
3. 文档更新
README.md:
 更新命令行使用示例
 添加 Python API 使用示例(包含菌株预测)
 添加详细的输出格式说明
 添加 40 种菌株列表概览
 添加数据使用场景示例(强化学习、筛选、可视化)
Data/mole/README.md:
 新增"菌株级别预测详情"章节
 完整的 40 种菌株列表(分革兰阴性/阳性)
 数据访问方式示例(CSV 读取、Python API)
 强化学习应用场景(状态表示、奖励函数设计)
 数据可视化代码示例
 性能和存储建议
This commit is contained in:
2025-10-17 16:46:04 +08:00
parent 62e0f3d6aa
commit 34102cf459
5 changed files with 716 additions and 21 deletions

View File

@@ -0,0 +1,208 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
MolE 抗菌活性预测 Python API 示例
演示两种预测模式:
1. 聚合结果模式(默认)
2. 菌株级别预测模式
"""
import sys
from pathlib import Path
# 添加项目根目录到 Python 路径(使用 pathlib
project_root = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(project_root))
from models import (
ParallelBroadSpectrumPredictor,
PredictionConfig,
MoleculeInput
)
def print_separator(title):
"""打印分隔线"""
print("\n" + "=" * 70)
print(f" {title}")
print("=" * 70 + "\n")
def demo_aggregated_mode():
"""演示聚合结果模式(默认)"""
print_separator("模式 1: 聚合结果模式(默认)")
# 创建配置
config = PredictionConfig(
batch_size=10,
device="auto" # 自动检测 CUDA
)
# 创建预测器
predictor = ParallelBroadSpectrumPredictor(config)
# 准备测试分子
test_molecules = [
MoleculeInput(smiles="CCO", chem_id="ethanol"),
MoleculeInput(smiles="c1ccccc1", chem_id="benzene"),
MoleculeInput(smiles="CC(=O)O", chem_id="acetic_acid"),
]
print(f"测试分子数: {len(test_molecules)}")
print(f"SMILES 示例: {test_molecules[0].smiles}")
# 执行预测(不包含菌株级别预测)
print("\n开始预测(聚合模式)...")
results = predictor.predict_batch(test_molecules, include_strain_predictions=False)
# 打印结果
print(f"\n预测完成!共 {len(results)} 个结果\n")
for result in results:
print(f"化合物ID: {result.chem_id}")
print(f" - 广谱抗菌: {'' if result.broad_spectrum else ''}")
print(f" - 总体抗菌得分: {result.apscore_total:.4f}")
print(f" - 革兰阴性菌得分: {result.apscore_gnegative:.4f}")
print(f" - 革兰阳性菌得分: {result.apscore_gpositive:.4f}")
print(f" - 抑制菌株总数: {result.ginhib_total} / 40")
print(f" - 抑制革兰阴性菌数: {result.ginhib_gnegative}")
print(f" - 抑制革兰阳性菌数: {result.ginhib_gpositive}")
print(f" - strain_predictions: {result.strain_predictions}") # 应该是 None
print()
# 返回结果供后续使用
return results
def demo_strain_level_mode():
"""演示菌株级别预测模式"""
print_separator("模式 2: 菌株级别预测模式")
# 创建预测器
predictor = ParallelBroadSpectrumPredictor()
# 使用一个有趣的抗菌分子进行测试(氟喹诺酮类似物)
test_molecule = MoleculeInput(
smiles="FC1=CC=C(CN2C[C@@H]3C[C@H]2CN3C2CC2)N=C1",
chem_id="test_antibacterial"
)
print(f"测试分子: {test_molecule.chem_id}")
print(f"SMILES: {test_molecule.smiles}")
# 执行预测(包含菌株级别预测)
print("\n开始预测(菌株级别模式)...")
results = predictor.predict_batch([test_molecule], include_strain_predictions=True)
result = results[0]
# 1. 打印聚合结果
print(f"\n聚合结果:")
print(f" - 化合物ID: {result.chem_id}")
print(f" - 广谱抗菌: {'' if result.broad_spectrum else ''}")
print(f" - 总体抗菌得分: {result.apscore_total:.4f}")
print(f" - 抑制菌株总数: {result.ginhib_total} / 40")
# 2. 打印菌株级别预测数据结构
print(f"\n菌株级别预测数据:")
if result.strain_predictions is not None:
strain_df = result.strain_predictions
print(f" - 数据类型: {type(strain_df)}")
print(f" - 数据形状: {strain_df.shape} (行, 列)")
print(f" - 列名: {list(strain_df.columns)}")
print(f" - 内存占用: {strain_df.memory_usage(deep=True).sum() / 1024:.2f} KB")
# 3. 展示前 5 个菌株的预测
print(f"\n前 5 个菌株的预测:")
print(strain_df.head(5).to_string(index=False))
# 4. 统计信息
print(f"\n统计信息:")
print(f" - 预测概率范围: [{strain_df['antimicrobial_predictive_probability'].min():.6f}, "
f"{strain_df['antimicrobial_predictive_probability'].max():.6f}]")
print(f" - 预测概率平均值: {strain_df['antimicrobial_predictive_probability'].mean():.6f}")
print(f" - 被抑制菌株数: {strain_df['growth_inhibition'].sum()}")
print(f" - 革兰阴性菌数: {len(strain_df[strain_df['gram_stain'] == 'negative'])}")
print(f" - 革兰阳性菌数: {len(strain_df[strain_df['gram_stain'] == 'positive'])}")
# 5. 展示被抑制的菌株(如果有)
inhibited = strain_df[strain_df['growth_inhibition'] == 1]
if len(inhibited) > 0:
print(f"\n被抑制的菌株 ({len(inhibited)} 个):")
print(inhibited[['strain_name', 'antimicrobial_predictive_probability', 'gram_stain']].to_string(index=False))
else:
print(f"\n该分子未预测抑制任何菌株")
# 6. 强化学习应用示例
print(f"\n强化学习应用示例:")
# 提取预测概率作为状态向量
state_vector = strain_df['antimicrobial_predictive_probability'].values
print(f" - 状态向量形状: {state_vector.shape}")
print(f" - 状态向量类型: {type(state_vector)}")
print(f" - 前 10 个值: {state_vector[:10]}")
# 提取多维特征
state_features = strain_df[[
'antimicrobial_predictive_probability',
'growth_inhibition'
]].values
print(f" - 多维特征形状: {state_features.shape}")
# 按革兰染色分组
gram_negative_probs = strain_df[
strain_df['gram_stain'] == 'negative'
]['antimicrobial_predictive_probability'].values
print(f" - 革兰阴性菌概率向量形状: {gram_negative_probs.shape}")
# 7. 转换为类型安全的列表(可选)
print(f"\n转换为 StrainPrediction 列表:")
strain_list = result.to_strain_predictions_list()
print(f" - 列表长度: {len(strain_list)}")
print(f" - 元素类型: {type(strain_list[0])}")
print(f" - 第一个元素:")
first_strain = strain_list[0]
print(f" * pred_id: {first_strain.pred_id}")
print(f" * strain_name: {first_strain.strain_name}")
print(f" * antimicrobial_predictive_probability: {first_strain.antimicrobial_predictive_probability:.6f}")
print(f" * growth_inhibition: {first_strain.growth_inhibition}")
print(f" * gram_stain: {first_strain.gram_stain}")
else:
print(" 警告: strain_predictions 为 None未启用菌株级别预测")
return result
def main():
"""主函数"""
print("\n" + "🧪" * 35)
print(" MolE 抗菌活性预测 Python API 示例")
print("🧪" * 35)
# 演示模式 1: 聚合结果
aggregated_results = demo_aggregated_mode()
# 演示模式 2: 菌株级别预测
strain_level_result = demo_strain_level_mode()
# 总结
print_separator("总结")
print("✅ 模式 1 (聚合结果): 适合快速筛选大量分子")
print(" - 每个分子返回 1 个 BroadSpectrumResult 对象")
print(" - 包含 8 个聚合指标")
print(" - strain_predictions = None")
print()
print("✅ 模式 2 (菌株级别): 适合详细分析和强化学习")
print(" - 每个分子返回 1 个 BroadSpectrumResult 对象")
print(" - 包含 8 个聚合指标 + 40 行菌株预测数据")
print(" - strain_predictions = DataFrame (40 rows × 7 columns)")
print(" - 可直接提取为 numpy array 用于 RL")
print()
print("🎯 推荐使用场景:")
print(" - 初筛: 模式 1")
print(" - 详细分析/RL 训练: 模式 2")
print()
if __name__ == "__main__":
main()