models/broad_spectrum_predictor.py: ✅ 新增 StrainPrediction dataclass(单个菌株预测结果) ✅ 更新 BroadSpectrumResult 添加 strain_predictions 字段(pandas.DataFrame 类型) ✅ 添加 to_strain_predictions_list() 方法(类型安全转换) ✅ 新增 _prepare_strain_level_predictions() 方法 ✅ 修改 predict_batch() 方法支持 include_strain_predictions 参数 utils/mole_predictor.py: ✅ 添加 include_strain_predictions 参数到所有函数 ✅ 添加命令行参数 --include-strain-predictions ✅ 实现菌株级别数据与聚合结果的合并逻辑 ✅ 更新所有函数签名和文档字符串 2. 测试验证 ✅ 测试基本功能(仅聚合结果): test_3.csv → 3 行输出 ✅ 测试菌株级别预测功能: test_3.csv → 120 行输出(3 × 40) ✅ 验证输出格式正确性 ✅ 验证每个分子都有完整的 40 个菌株预测 ✅ 验证革兰染色信息正确(18 个阴性菌 + 22 个阳性菌) 3. 文档更新 README.md: ✅ 更新命令行使用示例 ✅ 添加 Python API 使用示例(包含菌株预测) ✅ 添加详细的输出格式说明 ✅ 添加 40 种菌株列表概览 ✅ 添加数据使用场景示例(强化学习、筛选、可视化) Data/mole/README.md: ✅ 新增"菌株级别预测详情"章节 ✅ 完整的 40 种菌株列表(分革兰阴性/阳性) ✅ 数据访问方式示例(CSV 读取、Python API) ✅ 强化学习应用场景(状态表示、奖励函数设计) ✅ 数据可视化代码示例 ✅ 性能和存储建议
209 lines
7.9 KiB
Python
209 lines
7.9 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
MolE 抗菌活性预测 Python API 示例
|
||
|
||
演示两种预测模式:
|
||
1. 聚合结果模式(默认)
|
||
2. 菌株级别预测模式
|
||
"""
|
||
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
# 添加项目根目录到 Python 路径(使用 pathlib)
|
||
project_root = Path(__file__).resolve().parent.parent
|
||
sys.path.insert(0, str(project_root))
|
||
|
||
from models import (
|
||
ParallelBroadSpectrumPredictor,
|
||
PredictionConfig,
|
||
MoleculeInput
|
||
)
|
||
|
||
|
||
def print_separator(title):
|
||
"""打印分隔线"""
|
||
print("\n" + "=" * 70)
|
||
print(f" {title}")
|
||
print("=" * 70 + "\n")
|
||
|
||
|
||
def demo_aggregated_mode():
|
||
"""演示聚合结果模式(默认)"""
|
||
print_separator("模式 1: 聚合结果模式(默认)")
|
||
|
||
# 创建配置
|
||
config = PredictionConfig(
|
||
batch_size=10,
|
||
device="auto" # 自动检测 CUDA
|
||
)
|
||
|
||
# 创建预测器
|
||
predictor = ParallelBroadSpectrumPredictor(config)
|
||
|
||
# 准备测试分子
|
||
test_molecules = [
|
||
MoleculeInput(smiles="CCO", chem_id="ethanol"),
|
||
MoleculeInput(smiles="c1ccccc1", chem_id="benzene"),
|
||
MoleculeInput(smiles="CC(=O)O", chem_id="acetic_acid"),
|
||
]
|
||
|
||
print(f"测试分子数: {len(test_molecules)}")
|
||
print(f"SMILES 示例: {test_molecules[0].smiles}")
|
||
|
||
# 执行预测(不包含菌株级别预测)
|
||
print("\n开始预测(聚合模式)...")
|
||
results = predictor.predict_batch(test_molecules, include_strain_predictions=False)
|
||
|
||
# 打印结果
|
||
print(f"\n预测完成!共 {len(results)} 个结果\n")
|
||
|
||
for result in results:
|
||
print(f"化合物ID: {result.chem_id}")
|
||
print(f" - 广谱抗菌: {'是' if result.broad_spectrum else '否'}")
|
||
print(f" - 总体抗菌得分: {result.apscore_total:.4f}")
|
||
print(f" - 革兰阴性菌得分: {result.apscore_gnegative:.4f}")
|
||
print(f" - 革兰阳性菌得分: {result.apscore_gpositive:.4f}")
|
||
print(f" - 抑制菌株总数: {result.ginhib_total} / 40")
|
||
print(f" - 抑制革兰阴性菌数: {result.ginhib_gnegative}")
|
||
print(f" - 抑制革兰阳性菌数: {result.ginhib_gpositive}")
|
||
print(f" - strain_predictions: {result.strain_predictions}") # 应该是 None
|
||
print()
|
||
|
||
# 返回结果供后续使用
|
||
return results
|
||
|
||
|
||
def demo_strain_level_mode():
|
||
"""演示菌株级别预测模式"""
|
||
print_separator("模式 2: 菌株级别预测模式")
|
||
|
||
# 创建预测器
|
||
predictor = ParallelBroadSpectrumPredictor()
|
||
|
||
# 使用一个有趣的抗菌分子进行测试(氟喹诺酮类似物)
|
||
test_molecule = MoleculeInput(
|
||
smiles="FC1=CC=C(CN2C[C@@H]3C[C@H]2CN3C2CC2)N=C1",
|
||
chem_id="test_antibacterial"
|
||
)
|
||
|
||
print(f"测试分子: {test_molecule.chem_id}")
|
||
print(f"SMILES: {test_molecule.smiles}")
|
||
|
||
# 执行预测(包含菌株级别预测)
|
||
print("\n开始预测(菌株级别模式)...")
|
||
results = predictor.predict_batch([test_molecule], include_strain_predictions=True)
|
||
result = results[0]
|
||
|
||
# 1. 打印聚合结果
|
||
print(f"\n聚合结果:")
|
||
print(f" - 化合物ID: {result.chem_id}")
|
||
print(f" - 广谱抗菌: {'是' if result.broad_spectrum else '否'}")
|
||
print(f" - 总体抗菌得分: {result.apscore_total:.4f}")
|
||
print(f" - 抑制菌株总数: {result.ginhib_total} / 40")
|
||
|
||
# 2. 打印菌株级别预测数据结构
|
||
print(f"\n菌株级别预测数据:")
|
||
if result.strain_predictions is not None:
|
||
strain_df = result.strain_predictions
|
||
print(f" - 数据类型: {type(strain_df)}")
|
||
print(f" - 数据形状: {strain_df.shape} (行, 列)")
|
||
print(f" - 列名: {list(strain_df.columns)}")
|
||
print(f" - 内存占用: {strain_df.memory_usage(deep=True).sum() / 1024:.2f} KB")
|
||
|
||
# 3. 展示前 5 个菌株的预测
|
||
print(f"\n前 5 个菌株的预测:")
|
||
print(strain_df.head(5).to_string(index=False))
|
||
|
||
# 4. 统计信息
|
||
print(f"\n统计信息:")
|
||
print(f" - 预测概率范围: [{strain_df['antimicrobial_predictive_probability'].min():.6f}, "
|
||
f"{strain_df['antimicrobial_predictive_probability'].max():.6f}]")
|
||
print(f" - 预测概率平均值: {strain_df['antimicrobial_predictive_probability'].mean():.6f}")
|
||
print(f" - 被抑制菌株数: {strain_df['growth_inhibition'].sum()}")
|
||
print(f" - 革兰阴性菌数: {len(strain_df[strain_df['gram_stain'] == 'negative'])}")
|
||
print(f" - 革兰阳性菌数: {len(strain_df[strain_df['gram_stain'] == 'positive'])}")
|
||
|
||
# 5. 展示被抑制的菌株(如果有)
|
||
inhibited = strain_df[strain_df['growth_inhibition'] == 1]
|
||
if len(inhibited) > 0:
|
||
print(f"\n被抑制的菌株 ({len(inhibited)} 个):")
|
||
print(inhibited[['strain_name', 'antimicrobial_predictive_probability', 'gram_stain']].to_string(index=False))
|
||
else:
|
||
print(f"\n该分子未预测抑制任何菌株")
|
||
|
||
# 6. 强化学习应用示例
|
||
print(f"\n强化学习应用示例:")
|
||
|
||
# 提取预测概率作为状态向量
|
||
state_vector = strain_df['antimicrobial_predictive_probability'].values
|
||
print(f" - 状态向量形状: {state_vector.shape}")
|
||
print(f" - 状态向量类型: {type(state_vector)}")
|
||
print(f" - 前 10 个值: {state_vector[:10]}")
|
||
|
||
# 提取多维特征
|
||
state_features = strain_df[[
|
||
'antimicrobial_predictive_probability',
|
||
'growth_inhibition'
|
||
]].values
|
||
print(f" - 多维特征形状: {state_features.shape}")
|
||
|
||
# 按革兰染色分组
|
||
gram_negative_probs = strain_df[
|
||
strain_df['gram_stain'] == 'negative'
|
||
]['antimicrobial_predictive_probability'].values
|
||
print(f" - 革兰阴性菌概率向量形状: {gram_negative_probs.shape}")
|
||
|
||
# 7. 转换为类型安全的列表(可选)
|
||
print(f"\n转换为 StrainPrediction 列表:")
|
||
strain_list = result.to_strain_predictions_list()
|
||
print(f" - 列表长度: {len(strain_list)}")
|
||
print(f" - 元素类型: {type(strain_list[0])}")
|
||
print(f" - 第一个元素:")
|
||
first_strain = strain_list[0]
|
||
print(f" * pred_id: {first_strain.pred_id}")
|
||
print(f" * strain_name: {first_strain.strain_name}")
|
||
print(f" * antimicrobial_predictive_probability: {first_strain.antimicrobial_predictive_probability:.6f}")
|
||
print(f" * growth_inhibition: {first_strain.growth_inhibition}")
|
||
print(f" * gram_stain: {first_strain.gram_stain}")
|
||
else:
|
||
print(" 警告: strain_predictions 为 None(未启用菌株级别预测)")
|
||
|
||
return result
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
print("\n" + "🧪" * 35)
|
||
print(" MolE 抗菌活性预测 Python API 示例")
|
||
print("🧪" * 35)
|
||
|
||
# 演示模式 1: 聚合结果
|
||
aggregated_results = demo_aggregated_mode()
|
||
|
||
# 演示模式 2: 菌株级别预测
|
||
strain_level_result = demo_strain_level_mode()
|
||
|
||
# 总结
|
||
print_separator("总结")
|
||
print("✅ 模式 1 (聚合结果): 适合快速筛选大量分子")
|
||
print(" - 每个分子返回 1 个 BroadSpectrumResult 对象")
|
||
print(" - 包含 8 个聚合指标")
|
||
print(" - strain_predictions = None")
|
||
print()
|
||
print("✅ 模式 2 (菌株级别): 适合详细分析和强化学习")
|
||
print(" - 每个分子返回 1 个 BroadSpectrumResult 对象")
|
||
print(" - 包含 8 个聚合指标 + 40 行菌株预测数据")
|
||
print(" - strain_predictions = DataFrame (40 rows × 7 columns)")
|
||
print(" - 可直接提取为 numpy array 用于 RL")
|
||
print()
|
||
print("🎯 推荐使用场景:")
|
||
print(" - 初筛: 模式 1")
|
||
print(" - 详细分析/RL 训练: 模式 2")
|
||
print()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|