120 lines
3.4 KiB
Python
120 lines
3.4 KiB
Python
"""
|
|
广谱抗菌预测API使用示例
|
|
"""
|
|
|
|
from typing import List
|
|
from broad_spectrum_api import (
|
|
ParallelBroadSpectrumPredictor,
|
|
PredictionConfig,
|
|
MoleculeInput,
|
|
BroadSpectrumResult,
|
|
predict_smiles,
|
|
predict_file
|
|
)
|
|
|
|
|
|
def example_single_prediction():
|
|
"""单分子预测示例"""
|
|
print("=== 单分子预测示例 ===")
|
|
|
|
# 创建预测器
|
|
predictor = ParallelBroadSpectrumPredictor()
|
|
|
|
# 预测单个分子
|
|
molecule = MoleculeInput(smiles="CCO", chem_id="ethanol")
|
|
result = predictor.predict_single(molecule)
|
|
|
|
print(f"化合物: {result.chem_id}")
|
|
print(f"广谱抗菌: {'是' if result.broad_spectrum else '否'}")
|
|
print(f"抑制菌株数: {result.ginhib_total}")
|
|
print(f"抗菌分数: {result.apscore_total:.3f}")
|
|
|
|
|
|
def example_batch_prediction():
|
|
"""批量预测示例"""
|
|
print("\n=== 批量预测示例 ===")
|
|
|
|
# 创建预测器
|
|
config = PredictionConfig(n_workers=4, batch_size=50)
|
|
predictor = ParallelBroadSpectrumPredictor(config)
|
|
|
|
# 准备多个分子
|
|
molecules = [
|
|
MoleculeInput(smiles="CCO", chem_id="ethanol"),
|
|
MoleculeInput(smiles="CCN", chem_id="ethylamine"),
|
|
MoleculeInput(smiles="CC(=O)O", chem_id="acetic_acid"),
|
|
]
|
|
|
|
# 批量预测
|
|
results = predictor.predict_batch(molecules)
|
|
|
|
# 输出结果
|
|
for result in results:
|
|
print(f"{result.chem_id}: 广谱={result.broad_spectrum}, 抑制数={result.ginhib_total}")
|
|
|
|
|
|
def example_smiles_list_prediction():
|
|
"""SMILES列表预测示例"""
|
|
print("\n=== SMILES列表预测示例 ===")
|
|
|
|
smiles_list = ["CCO", "CCN", "CC(=O)O"]
|
|
chem_ids = ["ethanol", "ethylamine", "acetic_acid"]
|
|
|
|
# 使用便捷函数
|
|
results = predict_smiles(smiles_list, chem_ids)
|
|
|
|
# 统计广谱抗菌化合物
|
|
broad_spectrum_count = sum(1 for r in results if r.broad_spectrum)
|
|
print(f"广谱抗菌化合物: {broad_spectrum_count}/{len(results)}")
|
|
|
|
|
|
def example_file_prediction():
|
|
"""文件预测示例"""
|
|
print("\n=== 文件预测示例 ===")
|
|
|
|
# 假设有输入文件 molecules.tsv
|
|
try:
|
|
results = predict_file(
|
|
"molecules.tsv",
|
|
smiles_column="smiles",
|
|
id_column="compound_id"
|
|
)
|
|
|
|
# 保存结果
|
|
import pandas as pd
|
|
results_df = pd.DataFrame([r.to_dict() for r in results])
|
|
results_df.to_csv("broad_spectrum_results.csv", index=False)
|
|
print(f"预测完成,结果保存到 broad_spectrum_results.csv")
|
|
|
|
except FileNotFoundError:
|
|
print("输入文件不存在,跳过文件预测示例")
|
|
|
|
|
|
def example_custom_config():
|
|
"""自定义配置示例"""
|
|
print("\n=== 自定义配置示例 ===")
|
|
|
|
# 自定义配置
|
|
config = PredictionConfig(
|
|
app_threshold=0.1, # 更严格的抑制阈值
|
|
min_nkill=15, # 更高的广谱标准
|
|
n_workers=8, # 更多并行进程
|
|
batch_size=200 # 更大的批次
|
|
)
|
|
|
|
predictor = ParallelBroadSpectrumPredictor(config)
|
|
|
|
# 预测
|
|
molecules = [MoleculeInput(smiles="CCO", chem_id="ethanol")]
|
|
results = predictor.predict_batch(molecules)
|
|
|
|
print(f"使用自定义配置预测结果: {results[0].to_dict()}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# 运行所有示例
|
|
example_single_prediction()
|
|
example_batch_prediction()
|
|
example_smiles_list_prediction()
|
|
example_file_prediction()
|
|
example_custom_config() |