Files
mole_broad_spectrum_parallel/example_usage.py
mm644706215 a56e60e9a3 first add
2025-10-16 17:21:48 +08:00

120 lines
3.4 KiB
Python

"""
广谱抗菌预测API使用示例
"""
from typing import List
from broad_spectrum_api import (
ParallelBroadSpectrumPredictor,
PredictionConfig,
MoleculeInput,
BroadSpectrumResult,
predict_smiles,
predict_file
)
def example_single_prediction():
"""单分子预测示例"""
print("=== 单分子预测示例 ===")
# 创建预测器
predictor = ParallelBroadSpectrumPredictor()
# 预测单个分子
molecule = MoleculeInput(smiles="CCO", chem_id="ethanol")
result = predictor.predict_single(molecule)
print(f"化合物: {result.chem_id}")
print(f"广谱抗菌: {'' if result.broad_spectrum else ''}")
print(f"抑制菌株数: {result.ginhib_total}")
print(f"抗菌分数: {result.apscore_total:.3f}")
def example_batch_prediction():
"""批量预测示例"""
print("\n=== 批量预测示例 ===")
# 创建预测器
config = PredictionConfig(n_workers=4, batch_size=50)
predictor = ParallelBroadSpectrumPredictor(config)
# 准备多个分子
molecules = [
MoleculeInput(smiles="CCO", chem_id="ethanol"),
MoleculeInput(smiles="CCN", chem_id="ethylamine"),
MoleculeInput(smiles="CC(=O)O", chem_id="acetic_acid"),
]
# 批量预测
results = predictor.predict_batch(molecules)
# 输出结果
for result in results:
print(f"{result.chem_id}: 广谱={result.broad_spectrum}, 抑制数={result.ginhib_total}")
def example_smiles_list_prediction():
"""SMILES列表预测示例"""
print("\n=== SMILES列表预测示例 ===")
smiles_list = ["CCO", "CCN", "CC(=O)O"]
chem_ids = ["ethanol", "ethylamine", "acetic_acid"]
# 使用便捷函数
results = predict_smiles(smiles_list, chem_ids)
# 统计广谱抗菌化合物
broad_spectrum_count = sum(1 for r in results if r.broad_spectrum)
print(f"广谱抗菌化合物: {broad_spectrum_count}/{len(results)}")
def example_file_prediction():
"""文件预测示例"""
print("\n=== 文件预测示例 ===")
# 假设有输入文件 molecules.tsv
try:
results = predict_file(
"molecules.tsv",
smiles_column="smiles",
id_column="compound_id"
)
# 保存结果
import pandas as pd
results_df = pd.DataFrame([r.to_dict() for r in results])
results_df.to_csv("broad_spectrum_results.csv", index=False)
print(f"预测完成,结果保存到 broad_spectrum_results.csv")
except FileNotFoundError:
print("输入文件不存在,跳过文件预测示例")
def example_custom_config():
"""自定义配置示例"""
print("\n=== 自定义配置示例 ===")
# 自定义配置
config = PredictionConfig(
app_threshold=0.1, # 更严格的抑制阈值
min_nkill=15, # 更高的广谱标准
n_workers=8, # 更多并行进程
batch_size=200 # 更大的批次
)
predictor = ParallelBroadSpectrumPredictor(config)
# 预测
molecules = [MoleculeInput(smiles="CCO", chem_id="ethanol")]
results = predictor.predict_batch(molecules)
print(f"使用自定义配置预测结果: {results[0].to_dict()}")
if __name__ == "__main__":
# 运行所有示例
example_single_prediction()
example_batch_prediction()
example_smiles_list_prediction()
example_file_prediction()
example_custom_config()