1. 代码修改
models/broad_spectrum_predictor.py: ✅ 新增 StrainPrediction dataclass(单个菌株预测结果) ✅ 更新 BroadSpectrumResult 添加 strain_predictions 字段(pandas.DataFrame 类型) ✅ 添加 to_strain_predictions_list() 方法(类型安全转换) ✅ 新增 _prepare_strain_level_predictions() 方法 ✅ 修改 predict_batch() 方法支持 include_strain_predictions 参数 utils/mole_predictor.py: ✅ 添加 include_strain_predictions 参数到所有函数 ✅ 添加命令行参数 --include-strain-predictions ✅ 实现菌株级别数据与聚合结果的合并逻辑 ✅ 更新所有函数签名和文档字符串 2. 测试验证 ✅ 测试基本功能(仅聚合结果): test_3.csv → 3 行输出 ✅ 测试菌株级别预测功能: test_3.csv → 120 行输出(3 × 40) ✅ 验证输出格式正确性 ✅ 验证每个分子都有完整的 40 个菌株预测 ✅ 验证革兰染色信息正确(18 个阴性菌 + 22 个阳性菌) 3. 文档更新 README.md: ✅ 更新命令行使用示例 ✅ 添加 Python API 使用示例(包含菌株预测) ✅ 添加详细的输出格式说明 ✅ 添加 40 种菌株列表概览 ✅ 添加数据使用场景示例(强化学习、筛选、可视化) Data/mole/README.md: ✅ 新增"菌株级别预测详情"章节 ✅ 完整的 40 种菌株列表(分革兰阴性/阳性) ✅ 数据访问方式示例(CSV 读取、Python API) ✅ 强化学习应用场景(状态表示、奖励函数设计) ✅ 数据可视化代码示例 ✅ 性能和存储建议
This commit is contained in:
208
test/mole_predict_singal_mole.py
Normal file
208
test/mole_predict_singal_mole.py
Normal file
@@ -0,0 +1,208 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
MolE 抗菌活性预测 Python API 示例
|
||||
|
||||
演示两种预测模式:
|
||||
1. 聚合结果模式(默认)
|
||||
2. 菌株级别预测模式
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到 Python 路径(使用 pathlib)
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from models import (
|
||||
ParallelBroadSpectrumPredictor,
|
||||
PredictionConfig,
|
||||
MoleculeInput
|
||||
)
|
||||
|
||||
|
||||
def print_separator(title):
|
||||
"""打印分隔线"""
|
||||
print("\n" + "=" * 70)
|
||||
print(f" {title}")
|
||||
print("=" * 70 + "\n")
|
||||
|
||||
|
||||
def demo_aggregated_mode():
|
||||
"""演示聚合结果模式(默认)"""
|
||||
print_separator("模式 1: 聚合结果模式(默认)")
|
||||
|
||||
# 创建配置
|
||||
config = PredictionConfig(
|
||||
batch_size=10,
|
||||
device="auto" # 自动检测 CUDA
|
||||
)
|
||||
|
||||
# 创建预测器
|
||||
predictor = ParallelBroadSpectrumPredictor(config)
|
||||
|
||||
# 准备测试分子
|
||||
test_molecules = [
|
||||
MoleculeInput(smiles="CCO", chem_id="ethanol"),
|
||||
MoleculeInput(smiles="c1ccccc1", chem_id="benzene"),
|
||||
MoleculeInput(smiles="CC(=O)O", chem_id="acetic_acid"),
|
||||
]
|
||||
|
||||
print(f"测试分子数: {len(test_molecules)}")
|
||||
print(f"SMILES 示例: {test_molecules[0].smiles}")
|
||||
|
||||
# 执行预测(不包含菌株级别预测)
|
||||
print("\n开始预测(聚合模式)...")
|
||||
results = predictor.predict_batch(test_molecules, include_strain_predictions=False)
|
||||
|
||||
# 打印结果
|
||||
print(f"\n预测完成!共 {len(results)} 个结果\n")
|
||||
|
||||
for result in results:
|
||||
print(f"化合物ID: {result.chem_id}")
|
||||
print(f" - 广谱抗菌: {'是' if result.broad_spectrum else '否'}")
|
||||
print(f" - 总体抗菌得分: {result.apscore_total:.4f}")
|
||||
print(f" - 革兰阴性菌得分: {result.apscore_gnegative:.4f}")
|
||||
print(f" - 革兰阳性菌得分: {result.apscore_gpositive:.4f}")
|
||||
print(f" - 抑制菌株总数: {result.ginhib_total} / 40")
|
||||
print(f" - 抑制革兰阴性菌数: {result.ginhib_gnegative}")
|
||||
print(f" - 抑制革兰阳性菌数: {result.ginhib_gpositive}")
|
||||
print(f" - strain_predictions: {result.strain_predictions}") # 应该是 None
|
||||
print()
|
||||
|
||||
# 返回结果供后续使用
|
||||
return results
|
||||
|
||||
|
||||
def demo_strain_level_mode():
|
||||
"""演示菌株级别预测模式"""
|
||||
print_separator("模式 2: 菌株级别预测模式")
|
||||
|
||||
# 创建预测器
|
||||
predictor = ParallelBroadSpectrumPredictor()
|
||||
|
||||
# 使用一个有趣的抗菌分子进行测试(氟喹诺酮类似物)
|
||||
test_molecule = MoleculeInput(
|
||||
smiles="FC1=CC=C(CN2C[C@@H]3C[C@H]2CN3C2CC2)N=C1",
|
||||
chem_id="test_antibacterial"
|
||||
)
|
||||
|
||||
print(f"测试分子: {test_molecule.chem_id}")
|
||||
print(f"SMILES: {test_molecule.smiles}")
|
||||
|
||||
# 执行预测(包含菌株级别预测)
|
||||
print("\n开始预测(菌株级别模式)...")
|
||||
results = predictor.predict_batch([test_molecule], include_strain_predictions=True)
|
||||
result = results[0]
|
||||
|
||||
# 1. 打印聚合结果
|
||||
print(f"\n聚合结果:")
|
||||
print(f" - 化合物ID: {result.chem_id}")
|
||||
print(f" - 广谱抗菌: {'是' if result.broad_spectrum else '否'}")
|
||||
print(f" - 总体抗菌得分: {result.apscore_total:.4f}")
|
||||
print(f" - 抑制菌株总数: {result.ginhib_total} / 40")
|
||||
|
||||
# 2. 打印菌株级别预测数据结构
|
||||
print(f"\n菌株级别预测数据:")
|
||||
if result.strain_predictions is not None:
|
||||
strain_df = result.strain_predictions
|
||||
print(f" - 数据类型: {type(strain_df)}")
|
||||
print(f" - 数据形状: {strain_df.shape} (行, 列)")
|
||||
print(f" - 列名: {list(strain_df.columns)}")
|
||||
print(f" - 内存占用: {strain_df.memory_usage(deep=True).sum() / 1024:.2f} KB")
|
||||
|
||||
# 3. 展示前 5 个菌株的预测
|
||||
print(f"\n前 5 个菌株的预测:")
|
||||
print(strain_df.head(5).to_string(index=False))
|
||||
|
||||
# 4. 统计信息
|
||||
print(f"\n统计信息:")
|
||||
print(f" - 预测概率范围: [{strain_df['antimicrobial_predictive_probability'].min():.6f}, "
|
||||
f"{strain_df['antimicrobial_predictive_probability'].max():.6f}]")
|
||||
print(f" - 预测概率平均值: {strain_df['antimicrobial_predictive_probability'].mean():.6f}")
|
||||
print(f" - 被抑制菌株数: {strain_df['growth_inhibition'].sum()}")
|
||||
print(f" - 革兰阴性菌数: {len(strain_df[strain_df['gram_stain'] == 'negative'])}")
|
||||
print(f" - 革兰阳性菌数: {len(strain_df[strain_df['gram_stain'] == 'positive'])}")
|
||||
|
||||
# 5. 展示被抑制的菌株(如果有)
|
||||
inhibited = strain_df[strain_df['growth_inhibition'] == 1]
|
||||
if len(inhibited) > 0:
|
||||
print(f"\n被抑制的菌株 ({len(inhibited)} 个):")
|
||||
print(inhibited[['strain_name', 'antimicrobial_predictive_probability', 'gram_stain']].to_string(index=False))
|
||||
else:
|
||||
print(f"\n该分子未预测抑制任何菌株")
|
||||
|
||||
# 6. 强化学习应用示例
|
||||
print(f"\n强化学习应用示例:")
|
||||
|
||||
# 提取预测概率作为状态向量
|
||||
state_vector = strain_df['antimicrobial_predictive_probability'].values
|
||||
print(f" - 状态向量形状: {state_vector.shape}")
|
||||
print(f" - 状态向量类型: {type(state_vector)}")
|
||||
print(f" - 前 10 个值: {state_vector[:10]}")
|
||||
|
||||
# 提取多维特征
|
||||
state_features = strain_df[[
|
||||
'antimicrobial_predictive_probability',
|
||||
'growth_inhibition'
|
||||
]].values
|
||||
print(f" - 多维特征形状: {state_features.shape}")
|
||||
|
||||
# 按革兰染色分组
|
||||
gram_negative_probs = strain_df[
|
||||
strain_df['gram_stain'] == 'negative'
|
||||
]['antimicrobial_predictive_probability'].values
|
||||
print(f" - 革兰阴性菌概率向量形状: {gram_negative_probs.shape}")
|
||||
|
||||
# 7. 转换为类型安全的列表(可选)
|
||||
print(f"\n转换为 StrainPrediction 列表:")
|
||||
strain_list = result.to_strain_predictions_list()
|
||||
print(f" - 列表长度: {len(strain_list)}")
|
||||
print(f" - 元素类型: {type(strain_list[0])}")
|
||||
print(f" - 第一个元素:")
|
||||
first_strain = strain_list[0]
|
||||
print(f" * pred_id: {first_strain.pred_id}")
|
||||
print(f" * strain_name: {first_strain.strain_name}")
|
||||
print(f" * antimicrobial_predictive_probability: {first_strain.antimicrobial_predictive_probability:.6f}")
|
||||
print(f" * growth_inhibition: {first_strain.growth_inhibition}")
|
||||
print(f" * gram_stain: {first_strain.gram_stain}")
|
||||
else:
|
||||
print(" 警告: strain_predictions 为 None(未启用菌株级别预测)")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("\n" + "🧪" * 35)
|
||||
print(" MolE 抗菌活性预测 Python API 示例")
|
||||
print("🧪" * 35)
|
||||
|
||||
# 演示模式 1: 聚合结果
|
||||
aggregated_results = demo_aggregated_mode()
|
||||
|
||||
# 演示模式 2: 菌株级别预测
|
||||
strain_level_result = demo_strain_level_mode()
|
||||
|
||||
# 总结
|
||||
print_separator("总结")
|
||||
print("✅ 模式 1 (聚合结果): 适合快速筛选大量分子")
|
||||
print(" - 每个分子返回 1 个 BroadSpectrumResult 对象")
|
||||
print(" - 包含 8 个聚合指标")
|
||||
print(" - strain_predictions = None")
|
||||
print()
|
||||
print("✅ 模式 2 (菌株级别): 适合详细分析和强化学习")
|
||||
print(" - 每个分子返回 1 个 BroadSpectrumResult 对象")
|
||||
print(" - 包含 8 个聚合指标 + 40 行菌株预测数据")
|
||||
print(" - strain_predictions = DataFrame (40 rows × 7 columns)")
|
||||
print(" - 可直接提取为 numpy array 用于 RL")
|
||||
print()
|
||||
print("🎯 推荐使用场景:")
|
||||
print(" - 初筛: 模式 1")
|
||||
print(" - 详细分析/RL 训练: 模式 2")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user