Files
mole_broad_spectrum_parallel/detailed_analysis.py
mm644706215 a56e60e9a3 first add
2025-10-16 17:21:48 +08:00

132 lines
5.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
详细分析广谱抗菌预测的计算过程
"""
import numpy as np
import pandas as pd
from scipy.stats.mstats import gmean
from broad_spectrum_api import ParallelBroadSpectrumPredictor, MoleculeInput
def analyze_prediction_process():
"""详细分析预测过程"""
print("=== 广谱抗菌预测详细分析 ===\n")
# 创建预测器
predictor = ParallelBroadSpectrumPredictor()
# 测试分子
molecule = MoleculeInput(smiles="CCO", chem_id="ethanol")
print("1. 基本信息:")
print(f" - 总菌株数量: {len(predictor.maier_screen.columns)}")
print(f" - 革兰阳性菌: {(predictor.maier_strains['Gram stain'] == 'positive').sum()}")
print(f" - 革兰阴性菌: {(predictor.maier_strains['Gram stain'] == 'negative').sum()}")
print(f" - 抑制阈值: {predictor.config.app_threshold}")
print(f" - 广谱标准: 抑制≥{predictor.config.min_nkill}个菌株")
# 获取MolE表示
print("\n2. 获取分子表示...")
mole_representation = predictor._get_mole_representation([molecule])
print(f" - MolE特征维度: {mole_representation.shape}")
# 添加菌株信息
print("\n3. 构建预测特征...")
X_input = predictor._add_strains(mole_representation)
print(f" - 预测样本数: {len(X_input)} (1个分子 × {len(predictor.maier_screen.columns)}个菌株)")
print(f" - 特征维度: {X_input.shape[1]}")
# 进行预测
print("\n4. 模型预测...")
import pickle
with open(predictor.config.xgboost_model_path, "rb") as file:
model = pickle.load(file)
y_pred = model.predict_proba(X_input)
pred_df = pd.DataFrame(y_pred, columns=["0", "1"], index=X_input.index)
# 显示预测概率统计
print(f" - 抑制概率范围: {pred_df['1'].min():.6f} - {pred_df['1'].max():.6f}")
print(f" - 抑制概率均值: {pred_df['1'].mean():.6f}")
print(f" - 抑制概率中位数: {pred_df['1'].median():.6f}")
# 二值化预测
pred_df["growth_inhibition"] = pred_df["1"].apply(
lambda x: 1 if x >= predictor.config.app_threshold else 0
)
inhibited_count = pred_df["growth_inhibition"].sum()
print(f" - 超过阈值的菌株数: {inhibited_count}")
# 分析抗菌分数计算
print("\n5. 抗菌分数计算:")
# 计算几何平均数
geometric_mean = gmean(pred_df["1"])
log_geometric_mean = np.log(geometric_mean)
print(f" - 所有概率的几何平均数: {geometric_mean:.10f}")
print(f" - 几何平均数的对数: {log_geometric_mean:.6f}")
# 显示概率分布
print(f"\n6. 概率分布分析:")
print(f" - 概率 < 0.001: {(pred_df['1'] < 0.001).sum()} 个菌株")
print(f" - 概率 0.001-0.01: {((pred_df['1'] >= 0.001) & (pred_df['1'] < 0.01)).sum()} 个菌株")
print(f" - 概率 0.01-0.1: {((pred_df['1'] >= 0.01) & (pred_df['1'] < 0.1)).sum()} 个菌株")
print(f" - 概率 ≥ 0.1: {(pred_df['1'] >= 0.1).sum()} 个菌株")
# 显示最高和最低的几个预测
print(f"\n7. 预测详情 (前5高和前5低):")
sorted_pred = pred_df.sort_values("1", ascending=False)
print(" 最高抑制概率:")
for i, (idx, row) in enumerate(sorted_pred.head().iterrows()):
strain_name = idx.split(":")[1]
print(f" {i+1}. {strain_name}: {row['1']:.6f}")
print(" 最低抑制概率:")
for i, (idx, row) in enumerate(sorted_pred.tail().iterrows()):
strain_name = idx.split(":")[1]
print(f" {i+1}. {strain_name}: {row['1']:.6f}")
# 完整预测结果
print(f"\n8. 最终结果:")
result = predictor.predict_single(molecule)
print(f" - 总抗菌分数: {result.apscore_total:.6f}")
print(f" - 革兰阴性菌分数: {result.apscore_gnegative:.6f}")
print(f" - 革兰阳性菌分数: {result.apscore_gpositive:.6f}")
print(f" - 抑制菌株总数: {result.ginhib_total}")
print(f" - 抑制革兰阴性菌: {result.ginhib_gnegative}")
print(f" - 抑制革兰阳性菌: {result.ginhib_gpositive}")
print(f" - 广谱抗菌: {'' if result.broad_spectrum else ''}")
def compare_different_molecules():
"""比较不同分子的预测结果"""
print("\n\n=== 不同分子对比分析 ===\n")
predictor = ParallelBroadSpectrumPredictor()
# 测试不同类型的分子
molecules = [
MoleculeInput(smiles="CCO", chem_id="ethanol"),
MoleculeInput(smiles="CC(=O)O", chem_id="acetic_acid"),
MoleculeInput(smiles="CCN", chem_id="ethylamine"),
MoleculeInput(smiles="c1ccccc1", chem_id="benzene"),
MoleculeInput(smiles="CC(C)O", chem_id="isopropanol"),
]
results = predictor.predict_batch(molecules)
print("分子对比结果:")
print("-" * 80)
print(f"{'分子':<15} {'SMILES':<12} {'抗菌分数':<10} {'抑制数':<8} {'广谱':<6}")
print("-" * 80)
for result in results:
mol_info = next(m for m in molecules if m.chem_id == result.chem_id)
print(f"{result.chem_id:<15} {mol_info.smiles:<12} {result.apscore_total:<10.3f} "
f"{result.ginhib_total:<8} {'' if result.broad_spectrum else '':<6}")
if __name__ == "__main__":
analyze_prediction_process()
compare_different_molecules()