first add
This commit is contained in:
120
example_usage.py
Normal file
120
example_usage.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
广谱抗菌预测API使用示例
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from broad_spectrum_api import (
|
||||
ParallelBroadSpectrumPredictor,
|
||||
PredictionConfig,
|
||||
MoleculeInput,
|
||||
BroadSpectrumResult,
|
||||
predict_smiles,
|
||||
predict_file
|
||||
)
|
||||
|
||||
|
||||
def example_single_prediction():
|
||||
"""单分子预测示例"""
|
||||
print("=== 单分子预测示例 ===")
|
||||
|
||||
# 创建预测器
|
||||
predictor = ParallelBroadSpectrumPredictor()
|
||||
|
||||
# 预测单个分子
|
||||
molecule = MoleculeInput(smiles="CCO", chem_id="ethanol")
|
||||
result = predictor.predict_single(molecule)
|
||||
|
||||
print(f"化合物: {result.chem_id}")
|
||||
print(f"广谱抗菌: {'是' if result.broad_spectrum else '否'}")
|
||||
print(f"抑制菌株数: {result.ginhib_total}")
|
||||
print(f"抗菌分数: {result.apscore_total:.3f}")
|
||||
|
||||
|
||||
def example_batch_prediction():
|
||||
"""批量预测示例"""
|
||||
print("\n=== 批量预测示例 ===")
|
||||
|
||||
# 创建预测器
|
||||
config = PredictionConfig(n_workers=4, batch_size=50)
|
||||
predictor = ParallelBroadSpectrumPredictor(config)
|
||||
|
||||
# 准备多个分子
|
||||
molecules = [
|
||||
MoleculeInput(smiles="CCO", chem_id="ethanol"),
|
||||
MoleculeInput(smiles="CCN", chem_id="ethylamine"),
|
||||
MoleculeInput(smiles="CC(=O)O", chem_id="acetic_acid"),
|
||||
]
|
||||
|
||||
# 批量预测
|
||||
results = predictor.predict_batch(molecules)
|
||||
|
||||
# 输出结果
|
||||
for result in results:
|
||||
print(f"{result.chem_id}: 广谱={result.broad_spectrum}, 抑制数={result.ginhib_total}")
|
||||
|
||||
|
||||
def example_smiles_list_prediction():
|
||||
"""SMILES列表预测示例"""
|
||||
print("\n=== SMILES列表预测示例 ===")
|
||||
|
||||
smiles_list = ["CCO", "CCN", "CC(=O)O"]
|
||||
chem_ids = ["ethanol", "ethylamine", "acetic_acid"]
|
||||
|
||||
# 使用便捷函数
|
||||
results = predict_smiles(smiles_list, chem_ids)
|
||||
|
||||
# 统计广谱抗菌化合物
|
||||
broad_spectrum_count = sum(1 for r in results if r.broad_spectrum)
|
||||
print(f"广谱抗菌化合物: {broad_spectrum_count}/{len(results)}")
|
||||
|
||||
|
||||
def example_file_prediction():
|
||||
"""文件预测示例"""
|
||||
print("\n=== 文件预测示例 ===")
|
||||
|
||||
# 假设有输入文件 molecules.tsv
|
||||
try:
|
||||
results = predict_file(
|
||||
"molecules.tsv",
|
||||
smiles_column="smiles",
|
||||
id_column="compound_id"
|
||||
)
|
||||
|
||||
# 保存结果
|
||||
import pandas as pd
|
||||
results_df = pd.DataFrame([r.to_dict() for r in results])
|
||||
results_df.to_csv("broad_spectrum_results.csv", index=False)
|
||||
print(f"预测完成,结果保存到 broad_spectrum_results.csv")
|
||||
|
||||
except FileNotFoundError:
|
||||
print("输入文件不存在,跳过文件预测示例")
|
||||
|
||||
|
||||
def example_custom_config():
|
||||
"""自定义配置示例"""
|
||||
print("\n=== 自定义配置示例 ===")
|
||||
|
||||
# 自定义配置
|
||||
config = PredictionConfig(
|
||||
app_threshold=0.1, # 更严格的抑制阈值
|
||||
min_nkill=15, # 更高的广谱标准
|
||||
n_workers=8, # 更多并行进程
|
||||
batch_size=200 # 更大的批次
|
||||
)
|
||||
|
||||
predictor = ParallelBroadSpectrumPredictor(config)
|
||||
|
||||
# 预测
|
||||
molecules = [MoleculeInput(smiles="CCO", chem_id="ethanol")]
|
||||
results = predictor.predict_batch(molecules)
|
||||
|
||||
print(f"使用自定义配置预测结果: {results[0].to_dict()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行所有示例
|
||||
example_single_prediction()
|
||||
example_batch_prediction()
|
||||
example_smiles_list_prediction()
|
||||
example_file_prediction()
|
||||
example_custom_config()
|
||||
Reference in New Issue
Block a user