""" 命令行接口,基于并行广谱抗菌预测API """ import argparse import pandas as pd from typing import List from .broad_spectrum_api import ( ParallelBroadSpectrumPredictor, PredictionConfig, MoleculeInput ) def parse_arguments(): """ 解析命令行参数 """ parser = argparse.ArgumentParser( prog="Prediction of antimicrobial activity.", description="This program receives a collection of molecules as input. " "If it receives SMILES, it first featurizes the molecules using MolE, " "then makes predictions of antimicrobial activity. " "By default, the program returns the antimicrobial predictive probabilities " "for each compound-strain pair. " "If the --aggregate_scores flag is set, then the program aggregates the predictions " "into an antimicrobial potential score and reports the number of strains inhibited by each compound.", usage="python cli.py input_filepath output_filepath [options]", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) # 输入文件 parser.add_argument("input_filepath", help="Complete path to input file. Can be a file with SMILES " "(make sure to set the --smiles_input flag) or a file with MolE representation.") # 输出文件 parser.add_argument("output_filepath", help="Complete path for output file") # 输入类型参数组 inputargs = parser.add_argument_group("Input arguments", "Arguments related to the input file") # 如果是SMILES输入 inputargs.add_argument("-s", "--smiles_input", help="Flag variable. Indicates if the input_filepath contains SMILES " "that have to be first represented using a MolE pre-trained model.", action="store_true") # SMILES列名 inputargs.add_argument("-c", "--smiles_colname", help="Column name in input_filepath that contains the SMILES. Only used if --smiles_input is set.", default="smiles") # 化合物ID列名 inputargs.add_argument("-i", "--chemid_colname", help="Column name in smiles_filepath that contains the ID string of each chemical. " "Only used if --smiles_input is set", default="chem_id") # 模型参数组 modelargs = parser.add_argument_group("Model arguments", "Arguments related to the models used for prediction") # XGBoost模型路径 modelargs.add_argument("-x", "--xgboost_model", help="Path to the pickled XGBoost model that makes predictions (.pkl).", default="data/03.model_evaluation/MolE-XGBoost-08.03.2024_14.20.pkl") # MolE模型路径 modelargs.add_argument("-m", "--mole_model", help="Path to the directory containing the config.yaml and model.pth files " "of the pre-trained MolE chemical representation. Only used if smiles_input is set.", default="pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001") # 预测参数组 predargs = parser.add_argument_group("Prediction arguments", "Arguments related to the prediction process.") # 聚合预测结果 predargs.add_argument("-a", "--aggregate_scores", help="Flag variable. If not set, then the prediction for each compound-strain pair is reported. " "If set, then prediction scores of each compound is aggregated into the antimicrobial " "potential score and the total number of strains predicted to be inhibited is reported. " "Additionally, the broad spectrum antibiotic prediction is reported.", action="store_true") # 抗菌评分阈值 predargs.add_argument("-t", "--app_threshold", help="Threshold score applied to the antimicrobial predictive probabilities " "in order to binarize compound-microbe predictions of growth inhibition. " "Default from original publication.", default=0.04374140128493309, type=float) # 广谱抗菌阈值 predargs.add_argument("-k", "--min_nkill", help="Minimum number of microbes predicted to be inhibited " "in order to consider the compound a broad spectrum antibiotic.", default=10, type=int) # 批次大小 predargs.add_argument("--batch_size", help="Batch size for processing molecules.", default=100, type=int) # 工作进程数 predargs.add_argument("--n_workers", help="Number of worker processes for parallel processing. " "If not set, the number of CPU cores will be used.", default=None, type=int) # 元数据参数组 metadataargs = parser.add_argument_group("Metadata arguments", "Arguments related to the metadata used for prediction.") # Maier菌株信息 metadataargs.add_argument("-b", "--strain_categories", help="Path to the Maier et.al. screening results.", default="data/01.prepare_training_data/maier_screening_results.tsv.gz") # 细菌信息 metadataargs.add_argument("-g", "--gram_information", help="Path to strain metadata.", default="raw_data/maier_microbiome/strain_info_SF2.xlsx") # 设备 parser.add_argument("-d", "--device", help="Device where the pre-trained model is loaded. " "Can be one of ['cpu', 'cuda', 'auto']. If 'auto' (default) " "then cuda:0 device is selected if a GPU is detected.", default="auto") args = parser.parse_args() # 给出返回信息的提示 if args.aggregate_scores: print("Aggregating predictions of antimicrobial activity.") else: print("Returning predictions of antimicrobial activity for each compound-strain pair.") return args def main(): """ 主函数 """ # 解析命令行参数 args = parse_arguments() # 创建配置对象 config = PredictionConfig( xgboost_model_path=args.xgboost_model, mole_model_path=args.mole_model, strain_categories_path=args.strain_categories, gram_info_path=args.gram_information, app_threshold=args.app_threshold, min_nkill=args.min_nkill, batch_size=args.batch_size, n_workers=args.n_workers, device=args.device ) # 创建预测器 predictor = ParallelBroadSpectrumPredictor(config) # 根据输入类型处理 if args.smiles_input: # 从文件中读取SMILES results = predictor.predict_from_file( args.input_filepath, smiles_column=args.smiles_colname, id_column=args.chemid_colname ) else: # 从已有表示中读取 (MolE 特征向量) # TODO: 实现 predict_from_mole_representation 方法或调用相应API raise NotImplementedError("Processing from pre-computed representations is not yet implemented in the new API") # 根据是否聚合结果进行输出 # 根据是否聚合结果进行输出 if args.aggregate_scores: print("Aggregating Antimicrobial potential") # 聚合模式:每行一个化合物,包含汇总统计 results_df = pd.DataFrame([r.to_dict() for r in results]) results_df.set_index('chem_id', inplace=True) results_df.to_csv(args.output_filepath, sep='\t') else: # 非聚合模式:每行一个化合物-菌株对,输出预测概率 print("Generating non-aggregated predictions (compound-strain pairs)") rows = [] for result in results: base_row = {'chem_id': result.chem_id} for strain, prob in result.predictions.items(): row = base_row.copy() row['strain'] = strain row['prediction_probability'] = prob rows.append(row) results_df = pd.DataFrame(rows) results_df.to_csv(args.output_filepath, sep='\t', index=False) if __name__ == "__main__": main()