199 lines
8.6 KiB
Python
199 lines
8.6 KiB
Python
"""
|
||
命令行接口,基于并行广谱抗菌预测API
|
||
"""
|
||
|
||
import argparse
|
||
import pandas as pd
|
||
from typing import List
|
||
from .broad_spectrum_api import (
|
||
ParallelBroadSpectrumPredictor,
|
||
PredictionConfig,
|
||
MoleculeInput
|
||
)
|
||
|
||
|
||
def parse_arguments():
|
||
"""
|
||
解析命令行参数
|
||
"""
|
||
parser = argparse.ArgumentParser(
|
||
prog="Prediction of antimicrobial activity.",
|
||
description="This program receives a collection of molecules as input. "
|
||
"If it receives SMILES, it first featurizes the molecules using MolE, "
|
||
"then makes predictions of antimicrobial activity. "
|
||
"By default, the program returns the antimicrobial predictive probabilities "
|
||
"for each compound-strain pair. "
|
||
"If the --aggregate_scores flag is set, then the program aggregates the predictions "
|
||
"into an antimicrobial potential score and reports the number of strains inhibited by each compound.",
|
||
usage="python cli.py input_filepath output_filepath [options]",
|
||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||
)
|
||
|
||
# 输入文件
|
||
parser.add_argument("input_filepath", help="Complete path to input file. Can be a file with SMILES "
|
||
"(make sure to set the --smiles_input flag) or a file with MolE representation.")
|
||
|
||
# 输出文件
|
||
parser.add_argument("output_filepath", help="Complete path for output file")
|
||
|
||
# 输入类型参数组
|
||
inputargs = parser.add_argument_group("Input arguments", "Arguments related to the input file")
|
||
|
||
# 如果是SMILES输入
|
||
inputargs.add_argument("-s", "--smiles_input",
|
||
help="Flag variable. Indicates if the input_filepath contains SMILES "
|
||
"that have to be first represented using a MolE pre-trained model.",
|
||
action="store_true")
|
||
|
||
# SMILES列名
|
||
inputargs.add_argument("-c", "--smiles_colname",
|
||
help="Column name in input_filepath that contains the SMILES. Only used if --smiles_input is set.",
|
||
default="smiles")
|
||
|
||
# 化合物ID列名
|
||
inputargs.add_argument("-i", "--chemid_colname",
|
||
help="Column name in smiles_filepath that contains the ID string of each chemical. "
|
||
"Only used if --smiles_input is set",
|
||
default="chem_id")
|
||
|
||
# 模型参数组
|
||
modelargs = parser.add_argument_group("Model arguments", "Arguments related to the models used for prediction")
|
||
|
||
# XGBoost模型路径
|
||
modelargs.add_argument("-x", "--xgboost_model",
|
||
help="Path to the pickled XGBoost model that makes predictions (.pkl).",
|
||
default="data/03.model_evaluation/MolE-XGBoost-08.03.2024_14.20.pkl")
|
||
|
||
# MolE模型路径
|
||
modelargs.add_argument("-m", "--mole_model",
|
||
help="Path to the directory containing the config.yaml and model.pth files "
|
||
"of the pre-trained MolE chemical representation. Only used if smiles_input is set.",
|
||
default="pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001")
|
||
|
||
# 预测参数组
|
||
predargs = parser.add_argument_group("Prediction arguments", "Arguments related to the prediction process.")
|
||
|
||
# 聚合预测结果
|
||
predargs.add_argument("-a", "--aggregate_scores",
|
||
help="Flag variable. If not set, then the prediction for each compound-strain pair is reported. "
|
||
"If set, then prediction scores of each compound is aggregated into the antimicrobial "
|
||
"potential score and the total number of strains predicted to be inhibited is reported. "
|
||
"Additionally, the broad spectrum antibiotic prediction is reported.",
|
||
action="store_true")
|
||
|
||
# 抗菌评分阈值
|
||
predargs.add_argument("-t", "--app_threshold",
|
||
help="Threshold score applied to the antimicrobial predictive probabilities "
|
||
"in order to binarize compound-microbe predictions of growth inhibition. "
|
||
"Default from original publication.",
|
||
default=0.04374140128493309, type=float)
|
||
|
||
# 广谱抗菌阈值
|
||
predargs.add_argument("-k", "--min_nkill",
|
||
help="Minimum number of microbes predicted to be inhibited "
|
||
"in order to consider the compound a broad spectrum antibiotic.",
|
||
default=10, type=int)
|
||
|
||
# 批次大小
|
||
predargs.add_argument("--batch_size",
|
||
help="Batch size for processing molecules.",
|
||
default=100, type=int)
|
||
|
||
# 工作进程数
|
||
predargs.add_argument("--n_workers",
|
||
help="Number of worker processes for parallel processing. "
|
||
"If not set, the number of CPU cores will be used.",
|
||
default=None, type=int)
|
||
|
||
# 元数据参数组
|
||
metadataargs = parser.add_argument_group("Metadata arguments", "Arguments related to the metadata used for prediction.")
|
||
|
||
# Maier菌株信息
|
||
metadataargs.add_argument("-b", "--strain_categories",
|
||
help="Path to the Maier et.al. screening results.",
|
||
default="data/01.prepare_training_data/maier_screening_results.tsv.gz")
|
||
|
||
# 细菌信息
|
||
metadataargs.add_argument("-g", "--gram_information",
|
||
help="Path to strain metadata.",
|
||
default="raw_data/maier_microbiome/strain_info_SF2.xlsx")
|
||
|
||
# 设备
|
||
parser.add_argument("-d", "--device",
|
||
help="Device where the pre-trained model is loaded. "
|
||
"Can be one of ['cpu', 'cuda', 'auto']. If 'auto' (default) "
|
||
"then cuda:0 device is selected if a GPU is detected.",
|
||
default="auto")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 给出返回信息的提示
|
||
if args.aggregate_scores:
|
||
print("Aggregating predictions of antimicrobial activity.")
|
||
else:
|
||
print("Returning predictions of antimicrobial activity for each compound-strain pair.")
|
||
|
||
return args
|
||
|
||
|
||
def main():
|
||
"""
|
||
主函数
|
||
"""
|
||
# 解析命令行参数
|
||
args = parse_arguments()
|
||
|
||
# 创建配置对象
|
||
config = PredictionConfig(
|
||
xgboost_model_path=args.xgboost_model,
|
||
mole_model_path=args.mole_model,
|
||
strain_categories_path=args.strain_categories,
|
||
gram_info_path=args.gram_information,
|
||
app_threshold=args.app_threshold,
|
||
min_nkill=args.min_nkill,
|
||
batch_size=args.batch_size,
|
||
n_workers=args.n_workers,
|
||
device=args.device
|
||
)
|
||
|
||
# 创建预测器
|
||
predictor = ParallelBroadSpectrumPredictor(config)
|
||
|
||
# 根据输入类型处理
|
||
if args.smiles_input:
|
||
# 从文件中读取SMILES
|
||
results = predictor.predict_from_file(
|
||
args.input_filepath,
|
||
smiles_column=args.smiles_colname,
|
||
id_column=args.chemid_colname
|
||
)
|
||
else:
|
||
# 从已有表示中读取 (MolE 特征向量)
|
||
# TODO: 实现 predict_from_mole_representation 方法或调用相应API
|
||
raise NotImplementedError("Processing from pre-computed representations is not yet implemented in the new API")
|
||
|
||
# 根据是否聚合结果进行输出
|
||
# 根据是否聚合结果进行输出
|
||
if args.aggregate_scores:
|
||
print("Aggregating Antimicrobial potential")
|
||
# 聚合模式:每行一个化合物,包含汇总统计
|
||
results_df = pd.DataFrame([r.to_dict() for r in results])
|
||
results_df.set_index('chem_id', inplace=True)
|
||
results_df.to_csv(args.output_filepath, sep='\t')
|
||
else:
|
||
# 非聚合模式:每行一个化合物-菌株对,输出预测概率
|
||
print("Generating non-aggregated predictions (compound-strain pairs)")
|
||
rows = []
|
||
for result in results:
|
||
base_row = {'chem_id': result.chem_id}
|
||
for strain, prob in result.predictions.items():
|
||
row = base_row.copy()
|
||
row['strain'] = strain
|
||
row['prediction_probability'] = prob
|
||
rows.append(row)
|
||
results_df = pd.DataFrame(rows)
|
||
results_df.to_csv(args.output_filepath, sep='\t', index=False)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |