Files
mole_broad_spectrum_parallel/cli.py
mm644706215 a56e60e9a3 first add
2025-10-16 17:21:48 +08:00

199 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
命令行接口基于并行广谱抗菌预测API
"""
import argparse
import pandas as pd
from typing import List
from .broad_spectrum_api import (
ParallelBroadSpectrumPredictor,
PredictionConfig,
MoleculeInput
)
def parse_arguments():
"""
解析命令行参数
"""
parser = argparse.ArgumentParser(
prog="Prediction of antimicrobial activity.",
description="This program receives a collection of molecules as input. "
"If it receives SMILES, it first featurizes the molecules using MolE, "
"then makes predictions of antimicrobial activity. "
"By default, the program returns the antimicrobial predictive probabilities "
"for each compound-strain pair. "
"If the --aggregate_scores flag is set, then the program aggregates the predictions "
"into an antimicrobial potential score and reports the number of strains inhibited by each compound.",
usage="python cli.py input_filepath output_filepath [options]",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
# 输入文件
parser.add_argument("input_filepath", help="Complete path to input file. Can be a file with SMILES "
"(make sure to set the --smiles_input flag) or a file with MolE representation.")
# 输出文件
parser.add_argument("output_filepath", help="Complete path for output file")
# 输入类型参数组
inputargs = parser.add_argument_group("Input arguments", "Arguments related to the input file")
# 如果是SMILES输入
inputargs.add_argument("-s", "--smiles_input",
help="Flag variable. Indicates if the input_filepath contains SMILES "
"that have to be first represented using a MolE pre-trained model.",
action="store_true")
# SMILES列名
inputargs.add_argument("-c", "--smiles_colname",
help="Column name in input_filepath that contains the SMILES. Only used if --smiles_input is set.",
default="smiles")
# 化合物ID列名
inputargs.add_argument("-i", "--chemid_colname",
help="Column name in smiles_filepath that contains the ID string of each chemical. "
"Only used if --smiles_input is set",
default="chem_id")
# 模型参数组
modelargs = parser.add_argument_group("Model arguments", "Arguments related to the models used for prediction")
# XGBoost模型路径
modelargs.add_argument("-x", "--xgboost_model",
help="Path to the pickled XGBoost model that makes predictions (.pkl).",
default="data/03.model_evaluation/MolE-XGBoost-08.03.2024_14.20.pkl")
# MolE模型路径
modelargs.add_argument("-m", "--mole_model",
help="Path to the directory containing the config.yaml and model.pth files "
"of the pre-trained MolE chemical representation. Only used if smiles_input is set.",
default="pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001")
# 预测参数组
predargs = parser.add_argument_group("Prediction arguments", "Arguments related to the prediction process.")
# 聚合预测结果
predargs.add_argument("-a", "--aggregate_scores",
help="Flag variable. If not set, then the prediction for each compound-strain pair is reported. "
"If set, then prediction scores of each compound is aggregated into the antimicrobial "
"potential score and the total number of strains predicted to be inhibited is reported. "
"Additionally, the broad spectrum antibiotic prediction is reported.",
action="store_true")
# 抗菌评分阈值
predargs.add_argument("-t", "--app_threshold",
help="Threshold score applied to the antimicrobial predictive probabilities "
"in order to binarize compound-microbe predictions of growth inhibition. "
"Default from original publication.",
default=0.04374140128493309, type=float)
# 广谱抗菌阈值
predargs.add_argument("-k", "--min_nkill",
help="Minimum number of microbes predicted to be inhibited "
"in order to consider the compound a broad spectrum antibiotic.",
default=10, type=int)
# 批次大小
predargs.add_argument("--batch_size",
help="Batch size for processing molecules.",
default=100, type=int)
# 工作进程数
predargs.add_argument("--n_workers",
help="Number of worker processes for parallel processing. "
"If not set, the number of CPU cores will be used.",
default=None, type=int)
# 元数据参数组
metadataargs = parser.add_argument_group("Metadata arguments", "Arguments related to the metadata used for prediction.")
# Maier菌株信息
metadataargs.add_argument("-b", "--strain_categories",
help="Path to the Maier et.al. screening results.",
default="data/01.prepare_training_data/maier_screening_results.tsv.gz")
# 细菌信息
metadataargs.add_argument("-g", "--gram_information",
help="Path to strain metadata.",
default="raw_data/maier_microbiome/strain_info_SF2.xlsx")
# 设备
parser.add_argument("-d", "--device",
help="Device where the pre-trained model is loaded. "
"Can be one of ['cpu', 'cuda', 'auto']. If 'auto' (default) "
"then cuda:0 device is selected if a GPU is detected.",
default="auto")
args = parser.parse_args()
# 给出返回信息的提示
if args.aggregate_scores:
print("Aggregating predictions of antimicrobial activity.")
else:
print("Returning predictions of antimicrobial activity for each compound-strain pair.")
return args
def main():
"""
主函数
"""
# 解析命令行参数
args = parse_arguments()
# 创建配置对象
config = PredictionConfig(
xgboost_model_path=args.xgboost_model,
mole_model_path=args.mole_model,
strain_categories_path=args.strain_categories,
gram_info_path=args.gram_information,
app_threshold=args.app_threshold,
min_nkill=args.min_nkill,
batch_size=args.batch_size,
n_workers=args.n_workers,
device=args.device
)
# 创建预测器
predictor = ParallelBroadSpectrumPredictor(config)
# 根据输入类型处理
if args.smiles_input:
# 从文件中读取SMILES
results = predictor.predict_from_file(
args.input_filepath,
smiles_column=args.smiles_colname,
id_column=args.chemid_colname
)
else:
# 从已有表示中读取 (MolE 特征向量)
# TODO: 实现 predict_from_mole_representation 方法或调用相应API
raise NotImplementedError("Processing from pre-computed representations is not yet implemented in the new API")
# 根据是否聚合结果进行输出
# 根据是否聚合结果进行输出
if args.aggregate_scores:
print("Aggregating Antimicrobial potential")
# 聚合模式:每行一个化合物,包含汇总统计
results_df = pd.DataFrame([r.to_dict() for r in results])
results_df.set_index('chem_id', inplace=True)
results_df.to_csv(args.output_filepath, sep='\t')
else:
# 非聚合模式:每行一个化合物-菌株对,输出预测概率
print("Generating non-aggregated predictions (compound-strain pairs)")
rows = []
for result in results:
base_row = {'chem_id': result.chem_id}
for strain, prob in result.predictions.items():
row = base_row.copy()
row['strain'] = strain
row['prediction_probability'] = prob
rows.append(row)
results_df = pd.DataFrame(rows)
results_df.to_csv(args.output_filepath, sep='\t', index=False)
if __name__ == "__main__":
main()