132 lines
5.3 KiB
Python
132 lines
5.3 KiB
Python
"""
|
||
详细分析广谱抗菌预测的计算过程
|
||
"""
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
from scipy.stats.mstats import gmean
|
||
from broad_spectrum_api import ParallelBroadSpectrumPredictor, MoleculeInput
|
||
|
||
def analyze_prediction_process():
|
||
"""详细分析预测过程"""
|
||
print("=== 广谱抗菌预测详细分析 ===\n")
|
||
|
||
# 创建预测器
|
||
predictor = ParallelBroadSpectrumPredictor()
|
||
|
||
# 测试分子
|
||
molecule = MoleculeInput(smiles="CCO", chem_id="ethanol")
|
||
|
||
print("1. 基本信息:")
|
||
print(f" - 总菌株数量: {len(predictor.maier_screen.columns)}")
|
||
print(f" - 革兰阳性菌: {(predictor.maier_strains['Gram stain'] == 'positive').sum()}")
|
||
print(f" - 革兰阴性菌: {(predictor.maier_strains['Gram stain'] == 'negative').sum()}")
|
||
print(f" - 抑制阈值: {predictor.config.app_threshold}")
|
||
print(f" - 广谱标准: 抑制≥{predictor.config.min_nkill}个菌株")
|
||
|
||
# 获取MolE表示
|
||
print("\n2. 获取分子表示...")
|
||
mole_representation = predictor._get_mole_representation([molecule])
|
||
print(f" - MolE特征维度: {mole_representation.shape}")
|
||
|
||
# 添加菌株信息
|
||
print("\n3. 构建预测特征...")
|
||
X_input = predictor._add_strains(mole_representation)
|
||
print(f" - 预测样本数: {len(X_input)} (1个分子 × {len(predictor.maier_screen.columns)}个菌株)")
|
||
print(f" - 特征维度: {X_input.shape[1]}")
|
||
|
||
# 进行预测
|
||
print("\n4. 模型预测...")
|
||
import pickle
|
||
with open(predictor.config.xgboost_model_path, "rb") as file:
|
||
model = pickle.load(file)
|
||
|
||
y_pred = model.predict_proba(X_input)
|
||
pred_df = pd.DataFrame(y_pred, columns=["0", "1"], index=X_input.index)
|
||
|
||
# 显示预测概率统计
|
||
print(f" - 抑制概率范围: {pred_df['1'].min():.6f} - {pred_df['1'].max():.6f}")
|
||
print(f" - 抑制概率均值: {pred_df['1'].mean():.6f}")
|
||
print(f" - 抑制概率中位数: {pred_df['1'].median():.6f}")
|
||
|
||
# 二值化预测
|
||
pred_df["growth_inhibition"] = pred_df["1"].apply(
|
||
lambda x: 1 if x >= predictor.config.app_threshold else 0
|
||
)
|
||
|
||
inhibited_count = pred_df["growth_inhibition"].sum()
|
||
print(f" - 超过阈值的菌株数: {inhibited_count}")
|
||
|
||
# 分析抗菌分数计算
|
||
print("\n5. 抗菌分数计算:")
|
||
|
||
# 计算几何平均数
|
||
geometric_mean = gmean(pred_df["1"])
|
||
log_geometric_mean = np.log(geometric_mean)
|
||
|
||
print(f" - 所有概率的几何平均数: {geometric_mean:.10f}")
|
||
print(f" - 几何平均数的对数: {log_geometric_mean:.6f}")
|
||
|
||
# 显示概率分布
|
||
print(f"\n6. 概率分布分析:")
|
||
print(f" - 概率 < 0.001: {(pred_df['1'] < 0.001).sum()} 个菌株")
|
||
print(f" - 概率 0.001-0.01: {((pred_df['1'] >= 0.001) & (pred_df['1'] < 0.01)).sum()} 个菌株")
|
||
print(f" - 概率 0.01-0.1: {((pred_df['1'] >= 0.01) & (pred_df['1'] < 0.1)).sum()} 个菌株")
|
||
print(f" - 概率 ≥ 0.1: {(pred_df['1'] >= 0.1).sum()} 个菌株")
|
||
|
||
# 显示最高和最低的几个预测
|
||
print(f"\n7. 预测详情 (前5高和前5低):")
|
||
sorted_pred = pred_df.sort_values("1", ascending=False)
|
||
|
||
print(" 最高抑制概率:")
|
||
for i, (idx, row) in enumerate(sorted_pred.head().iterrows()):
|
||
strain_name = idx.split(":")[1]
|
||
print(f" {i+1}. {strain_name}: {row['1']:.6f}")
|
||
|
||
print(" 最低抑制概率:")
|
||
for i, (idx, row) in enumerate(sorted_pred.tail().iterrows()):
|
||
strain_name = idx.split(":")[1]
|
||
print(f" {i+1}. {strain_name}: {row['1']:.6f}")
|
||
|
||
# 完整预测结果
|
||
print(f"\n8. 最终结果:")
|
||
result = predictor.predict_single(molecule)
|
||
print(f" - 总抗菌分数: {result.apscore_total:.6f}")
|
||
print(f" - 革兰阴性菌分数: {result.apscore_gnegative:.6f}")
|
||
print(f" - 革兰阳性菌分数: {result.apscore_gpositive:.6f}")
|
||
print(f" - 抑制菌株总数: {result.ginhib_total}")
|
||
print(f" - 抑制革兰阴性菌: {result.ginhib_gnegative}")
|
||
print(f" - 抑制革兰阳性菌: {result.ginhib_gpositive}")
|
||
print(f" - 广谱抗菌: {'是' if result.broad_spectrum else '否'}")
|
||
|
||
def compare_different_molecules():
|
||
"""比较不同分子的预测结果"""
|
||
print("\n\n=== 不同分子对比分析 ===\n")
|
||
|
||
predictor = ParallelBroadSpectrumPredictor()
|
||
|
||
# 测试不同类型的分子
|
||
molecules = [
|
||
MoleculeInput(smiles="CCO", chem_id="ethanol"),
|
||
MoleculeInput(smiles="CC(=O)O", chem_id="acetic_acid"),
|
||
MoleculeInput(smiles="CCN", chem_id="ethylamine"),
|
||
MoleculeInput(smiles="c1ccccc1", chem_id="benzene"),
|
||
MoleculeInput(smiles="CC(C)O", chem_id="isopropanol"),
|
||
]
|
||
|
||
results = predictor.predict_batch(molecules)
|
||
|
||
print("分子对比结果:")
|
||
print("-" * 80)
|
||
print(f"{'分子':<15} {'SMILES':<12} {'抗菌分数':<10} {'抑制数':<8} {'广谱':<6}")
|
||
print("-" * 80)
|
||
|
||
for result in results:
|
||
mol_info = next(m for m in molecules if m.chem_id == result.chem_id)
|
||
print(f"{result.chem_id:<15} {mol_info.smiles:<12} {result.apscore_total:<10.3f} "
|
||
f"{result.ginhib_total:<8} {'是' if result.broad_spectrum else '否':<6}")
|
||
|
||
if __name__ == "__main__":
|
||
analyze_prediction_process()
|
||
compare_different_molecules()
|