first add

This commit is contained in:
mm644706215
2025-10-16 17:21:48 +08:00
commit a56e60e9a3
192 changed files with 32720 additions and 0 deletions

199
cli.py Normal file
View File

@@ -0,0 +1,199 @@
"""
命令行接口基于并行广谱抗菌预测API
"""
import argparse
import pandas as pd
from typing import List
from .broad_spectrum_api import (
ParallelBroadSpectrumPredictor,
PredictionConfig,
MoleculeInput
)
def parse_arguments():
"""
解析命令行参数
"""
parser = argparse.ArgumentParser(
prog="Prediction of antimicrobial activity.",
description="This program receives a collection of molecules as input. "
"If it receives SMILES, it first featurizes the molecules using MolE, "
"then makes predictions of antimicrobial activity. "
"By default, the program returns the antimicrobial predictive probabilities "
"for each compound-strain pair. "
"If the --aggregate_scores flag is set, then the program aggregates the predictions "
"into an antimicrobial potential score and reports the number of strains inhibited by each compound.",
usage="python cli.py input_filepath output_filepath [options]",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
# 输入文件
parser.add_argument("input_filepath", help="Complete path to input file. Can be a file with SMILES "
"(make sure to set the --smiles_input flag) or a file with MolE representation.")
# 输出文件
parser.add_argument("output_filepath", help="Complete path for output file")
# 输入类型参数组
inputargs = parser.add_argument_group("Input arguments", "Arguments related to the input file")
# 如果是SMILES输入
inputargs.add_argument("-s", "--smiles_input",
help="Flag variable. Indicates if the input_filepath contains SMILES "
"that have to be first represented using a MolE pre-trained model.",
action="store_true")
# SMILES列名
inputargs.add_argument("-c", "--smiles_colname",
help="Column name in input_filepath that contains the SMILES. Only used if --smiles_input is set.",
default="smiles")
# 化合物ID列名
inputargs.add_argument("-i", "--chemid_colname",
help="Column name in smiles_filepath that contains the ID string of each chemical. "
"Only used if --smiles_input is set",
default="chem_id")
# 模型参数组
modelargs = parser.add_argument_group("Model arguments", "Arguments related to the models used for prediction")
# XGBoost模型路径
modelargs.add_argument("-x", "--xgboost_model",
help="Path to the pickled XGBoost model that makes predictions (.pkl).",
default="data/03.model_evaluation/MolE-XGBoost-08.03.2024_14.20.pkl")
# MolE模型路径
modelargs.add_argument("-m", "--mole_model",
help="Path to the directory containing the config.yaml and model.pth files "
"of the pre-trained MolE chemical representation. Only used if smiles_input is set.",
default="pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001")
# 预测参数组
predargs = parser.add_argument_group("Prediction arguments", "Arguments related to the prediction process.")
# 聚合预测结果
predargs.add_argument("-a", "--aggregate_scores",
help="Flag variable. If not set, then the prediction for each compound-strain pair is reported. "
"If set, then prediction scores of each compound is aggregated into the antimicrobial "
"potential score and the total number of strains predicted to be inhibited is reported. "
"Additionally, the broad spectrum antibiotic prediction is reported.",
action="store_true")
# 抗菌评分阈值
predargs.add_argument("-t", "--app_threshold",
help="Threshold score applied to the antimicrobial predictive probabilities "
"in order to binarize compound-microbe predictions of growth inhibition. "
"Default from original publication.",
default=0.04374140128493309, type=float)
# 广谱抗菌阈值
predargs.add_argument("-k", "--min_nkill",
help="Minimum number of microbes predicted to be inhibited "
"in order to consider the compound a broad spectrum antibiotic.",
default=10, type=int)
# 批次大小
predargs.add_argument("--batch_size",
help="Batch size for processing molecules.",
default=100, type=int)
# 工作进程数
predargs.add_argument("--n_workers",
help="Number of worker processes for parallel processing. "
"If not set, the number of CPU cores will be used.",
default=None, type=int)
# 元数据参数组
metadataargs = parser.add_argument_group("Metadata arguments", "Arguments related to the metadata used for prediction.")
# Maier菌株信息
metadataargs.add_argument("-b", "--strain_categories",
help="Path to the Maier et.al. screening results.",
default="data/01.prepare_training_data/maier_screening_results.tsv.gz")
# 细菌信息
metadataargs.add_argument("-g", "--gram_information",
help="Path to strain metadata.",
default="raw_data/maier_microbiome/strain_info_SF2.xlsx")
# 设备
parser.add_argument("-d", "--device",
help="Device where the pre-trained model is loaded. "
"Can be one of ['cpu', 'cuda', 'auto']. If 'auto' (default) "
"then cuda:0 device is selected if a GPU is detected.",
default="auto")
args = parser.parse_args()
# 给出返回信息的提示
if args.aggregate_scores:
print("Aggregating predictions of antimicrobial activity.")
else:
print("Returning predictions of antimicrobial activity for each compound-strain pair.")
return args
def main():
"""
主函数
"""
# 解析命令行参数
args = parse_arguments()
# 创建配置对象
config = PredictionConfig(
xgboost_model_path=args.xgboost_model,
mole_model_path=args.mole_model,
strain_categories_path=args.strain_categories,
gram_info_path=args.gram_information,
app_threshold=args.app_threshold,
min_nkill=args.min_nkill,
batch_size=args.batch_size,
n_workers=args.n_workers,
device=args.device
)
# 创建预测器
predictor = ParallelBroadSpectrumPredictor(config)
# 根据输入类型处理
if args.smiles_input:
# 从文件中读取SMILES
results = predictor.predict_from_file(
args.input_filepath,
smiles_column=args.smiles_colname,
id_column=args.chemid_colname
)
else:
# 从已有表示中读取 (MolE 特征向量)
# TODO: 实现 predict_from_mole_representation 方法或调用相应API
raise NotImplementedError("Processing from pre-computed representations is not yet implemented in the new API")
# 根据是否聚合结果进行输出
# 根据是否聚合结果进行输出
if args.aggregate_scores:
print("Aggregating Antimicrobial potential")
# 聚合模式:每行一个化合物,包含汇总统计
results_df = pd.DataFrame([r.to_dict() for r in results])
results_df.set_index('chem_id', inplace=True)
results_df.to_csv(args.output_filepath, sep='\t')
else:
# 非聚合模式:每行一个化合物-菌株对,输出预测概率
print("Generating non-aggregated predictions (compound-strain pairs)")
rows = []
for result in results:
base_row = {'chem_id': result.chem_id}
for strain, prob in result.predictions.items():
row = base_row.copy()
row['strain'] = strain
row['prediction_probability'] = prob
rows.append(row)
results_df = pd.DataFrame(rows)
results_df.to_csv(args.output_filepath, sep='\t', index=False)
if __name__ == "__main__":
main()