#!/usr/bin/env python # -*- coding: utf-8 -*- """ MolE 抗菌活性预测工具 这个脚本提供了使用 MolE 模型预测小分子 SMILES 抗菌活性的功能。 支持命令行和 Python API 调用两种方式。 命令行示例: python mole_predictor.py input.csv output.csv --smiles-column smiles --id-column chem_id Python API 示例: from utils.mole_predictor import predict_csv_file predict_csv_file( input_path="input.csv", output_path="output.csv", smiles_column="smiles", id_column="chem_id" ) """ import sys import os from pathlib import Path # 添加项目根目录到 Python 路径 project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) import click import pandas as pd from typing import Optional, List from datetime import datetime from tqdm import tqdm from models.broad_spectrum_predictor import ( ParallelBroadSpectrumPredictor, PredictionConfig, MoleculeInput, BroadSpectrumResult ) def predict_csv_file( input_path: str, output_path: Optional[str] = None, smiles_column: str = "smiles", id_column: str = "chem_id", batch_size: int = 100, n_workers: Optional[int] = None, device: str = "auto", add_suffix: bool = True, include_strain_predictions: bool = False ) -> pd.DataFrame: """ 预测 CSV 文件中的分子抗菌活性 Args: input_path: 输入 CSV 文件路径 output_path: 输出 CSV 文件路径,如果为 None 则自动生成 smiles_column: SMILES 列名 id_column: 化合物 ID 列名 batch_size: 批处理大小 n_workers: 工作进程数 device: 计算设备 ("auto", "cpu", "cuda:0" 等) add_suffix: 是否在输出文件名后添加预测后缀 include_strain_predictions: 是否在输出中包含40种菌株的预测详情 Returns: 包含预测结果的 DataFrame """ print(f"开始处理文件: {input_path}") # 读取输入文件 input_path_obj = Path(input_path) if not input_path_obj.exists(): raise FileNotFoundError(f"输入文件不存在: {input_path}") # 读取 CSV try: df_input = pd.read_csv(input_path) except Exception as e: raise RuntimeError(f"读取 CSV 文件失败: {e}") print(f"读取了 {len(df_input)} 条数据") # 检查列是否存在(大小写不敏感) columns_lower = {col.lower(): col for col in df_input.columns} smiles_col_actual = columns_lower.get(smiles_column.lower()) if smiles_col_actual is None: raise ValueError( f"未找到 SMILES 列 '{smiles_column}'。可用列: {list(df_input.columns)}" ) # 处理 ID 列 id_col_actual = columns_lower.get(id_column.lower()) if id_col_actual is None: print(f"未找到 ID 列 '{id_column}',将自动生成 ID") df_input[id_column] = [f"mol{i+1}" for i in range(len(df_input))] id_col_actual = id_column # 创建预测器配置 config = PredictionConfig( batch_size=batch_size, n_workers=n_workers, device=device ) # 初始化预测器 print("初始化预测器...") predictor = ParallelBroadSpectrumPredictor(config) # 准备分子输入 molecules = [ MoleculeInput(smiles=row[smiles_col_actual], chem_id=str(row[id_col_actual])) for _, row in df_input.iterrows() ] # 执行预测 print("开始预测...") results = predictor.predict_batch(molecules, include_strain_predictions=include_strain_predictions) # 转换结果为 DataFrame results_dicts = [r.to_dict() for r in results] df_results = pd.DataFrame(results_dicts) # 合并原始数据和预测结果 # 使用 chem_id 作为键进行合并 df_input['_merge_id'] = df_input[id_col_actual].astype(str) df_results['_merge_id'] = df_results['chem_id'].astype(str) df_output = df_input.merge( df_results.drop(columns=['chem_id']), on='_merge_id', how='left' ) df_output = df_output.drop(columns=['_merge_id']) # 如果包含菌株级别预测,将其添加到输出中 if include_strain_predictions: print("合并菌株级别预测数据...") # 收集所有菌株级别预测 all_strain_predictions = [] for result in results: if result.strain_predictions is not None and not result.strain_predictions.empty: all_strain_predictions.append(result.strain_predictions) if all_strain_predictions: # 合并所有菌株预测 df_strain_predictions = pd.concat(all_strain_predictions, ignore_index=True) # 将聚合结果和菌株预测合并 # 为了在同一个 CSV 中展示,我们使用重复行的方式 # 每个分子的聚合结果会重复40次(每个菌株一次) df_output_expanded = df_output.merge( df_strain_predictions, left_on=df_output.columns[df_output.columns.get_loc(id_col_actual)], right_on='chem_id', how='left', suffixes=('', '_strain') ) # 移除重复的 chem_id 列 if 'chem_id_strain' in df_output_expanded.columns: df_output_expanded = df_output_expanded.drop(columns=['chem_id_strain']) df_output = df_output_expanded # 生成输出路径 if output_path is None: if add_suffix: output_path = str(input_path_obj.parent / f"{input_path_obj.stem}_predicted{input_path_obj.suffix}") else: output_path = str(input_path_obj.parent / f"{input_path_obj.stem}_output{input_path_obj.suffix}") elif add_suffix: output_path_obj = Path(output_path) output_path = str(output_path_obj.parent / f"{output_path_obj.stem}_predicted{output_path_obj.suffix}") # 保存结果 print(f"保存结果到: {output_path}") df_output.to_csv(output_path, index=False) print(f"完成! 预测了 {len(results)} 个分子") print(f"其中 {sum(r.broad_spectrum for r in results)} 个分子被预测为广谱抗菌") return df_output def predict_multiple_files( input_paths: List[str], output_dir: Optional[str] = None, smiles_column: str = "smiles", id_column: str = "chem_id", batch_size: int = 100, n_workers: Optional[int] = None, device: str = "auto", add_suffix: bool = True, include_strain_predictions: bool = False ) -> List[pd.DataFrame]: """ 批量预测多个 CSV 文件 Args: input_paths: 输入 CSV 文件路径列表 output_dir: 输出目录,如果为 None 则在原文件目录生成 smiles_column: SMILES 列名 id_column: 化合物 ID 列名 batch_size: 批处理大小 n_workers: 工作进程数 device: 计算设备 add_suffix: 是否在输出文件名后添加预测后缀 include_strain_predictions: 是否在输出中包含40种菌株的预测详情 Returns: 包含预测结果的 DataFrame 列表 """ results = [] for input_path in input_paths: input_path_obj = Path(input_path) # 确定输出路径 if output_dir is not None: output_dir_obj = Path(output_dir) output_dir_obj.mkdir(parents=True, exist_ok=True) if add_suffix: output_path = str(output_dir_obj / f"{input_path_obj.stem}_predicted{input_path_obj.suffix}") else: output_path = str(output_dir_obj / input_path_obj.name) else: output_path = None # 预测单个文件 try: df_result = predict_csv_file( input_path=input_path, output_path=output_path, smiles_column=smiles_column, id_column=id_column, batch_size=batch_size, n_workers=n_workers, device=device, add_suffix=add_suffix, include_strain_predictions=include_strain_predictions ) results.append(df_result) except Exception as e: print(f"处理文件 {input_path} 时出错: {e}") continue return results # ============================================================================ # 命令行接口 # ============================================================================ @click.command() @click.argument('input_path', type=click.Path(exists=True)) @click.argument('output_path', type=click.Path(), required=False) @click.option('--smiles-column', '-s', default='smiles', help='SMILES 列名 (默认: smiles)') @click.option('--id-column', '-i', default='chem_id', help='化合物 ID 列名 (默认: chem_id)') @click.option('--batch-size', '-b', default=100, type=int, help='批处理大小 (默认: 100)') @click.option('--n-workers', '-w', default=None, type=int, help='工作进程数 (默认: CPU 核心数)') @click.option('--device', '-d', default='auto', type=click.Choice(['auto', 'cpu', 'cuda:0', 'cuda:1'], case_sensitive=False), help='计算设备 (默认: auto)') @click.option('--add-suffix/--no-add-suffix', default=True, help='是否在输出文件名后添加 "_predicted" 后缀 (默认: 添加)') @click.option('--include-strain-predictions', is_flag=True, default=False, help='在输出中包含40种菌株的详细预测数据(每个分子将产生40行数据,对应每个菌株的预测概率和抑制情况)') def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers, device, add_suffix, include_strain_predictions): """ 使用 MolE 模型预测小分子 SMILES 的抗菌活性 INPUT_PATH: 输入 CSV 文件路径 OUTPUT_PATH: 输出 CSV 文件路径 (可选,默认在原文件目录生成) 默认输出包含聚合的抗菌活性指标(广谱抗菌评分、抑制菌株数等)。 使用 --include-strain-predictions 可以额外包含每个菌株的详细预测数据。 示例: # 基本用法(仅输出聚合结果) python mole_predictor.py input.csv output.csv # 包含40种菌株的详细预测数据 python mole_predictor.py input.csv output.csv --include-strain-predictions # 指定列名和设备 python mole_predictor.py input.csv -s SMILES -i ID --device cuda:0 # 自定义批处理大小 python mole_predictor.py input.csv --device cuda:0 --batch-size 200 """ try: predict_csv_file( input_path=input_path, output_path=output_path, smiles_column=smiles_column, id_column=id_column, batch_size=batch_size, n_workers=n_workers, device=device, add_suffix=add_suffix, include_strain_predictions=include_strain_predictions ) except Exception as e: click.echo(f"错误: {e}", err=True) sys.exit(1) if __name__ == '__main__': cli()