Files
SIME/utils/mole_predictor.py
2025-10-17 15:54:00 +08:00

279 lines
8.7 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
MolE 抗菌活性预测工具
这个脚本提供了使用 MolE 模型预测小分子 SMILES 抗菌活性的功能。
支持命令行和 Python API 调用两种方式。
命令行示例:
python mole_predictor.py input.csv output.csv --smiles-column smiles --id-column chem_id
Python API 示例:
from utils.mole_predictor import predict_csv_file
predict_csv_file(
input_path="input.csv",
output_path="output.csv",
smiles_column="smiles",
id_column="chem_id"
)
"""
import sys
import os
from pathlib import Path
# 添加项目根目录到 Python 路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
import click
import pandas as pd
from typing import Optional, List
from datetime import datetime
from models.broad_spectrum_predictor import (
ParallelBroadSpectrumPredictor,
PredictionConfig,
MoleculeInput,
BroadSpectrumResult
)
def predict_csv_file(
input_path: str,
output_path: Optional[str] = None,
smiles_column: str = "smiles",
id_column: str = "chem_id",
batch_size: int = 100,
n_workers: Optional[int] = None,
device: str = "auto",
add_suffix: bool = True
) -> pd.DataFrame:
"""
预测 CSV 文件中的分子抗菌活性
Args:
input_path: 输入 CSV 文件路径
output_path: 输出 CSV 文件路径,如果为 None 则自动生成
smiles_column: SMILES 列名
id_column: 化合物 ID 列名
batch_size: 批处理大小
n_workers: 工作进程数
device: 计算设备 ("auto", "cpu", "cuda:0" 等)
add_suffix: 是否在输出文件名后添加预测后缀
Returns:
包含预测结果的 DataFrame
"""
print(f"开始处理文件: {input_path}")
# 读取输入文件
input_path_obj = Path(input_path)
if not input_path_obj.exists():
raise FileNotFoundError(f"输入文件不存在: {input_path}")
# 读取 CSV
try:
df_input = pd.read_csv(input_path)
except Exception as e:
raise RuntimeError(f"读取 CSV 文件失败: {e}")
print(f"读取了 {len(df_input)} 条数据")
# 检查列是否存在(大小写不敏感)
columns_lower = {col.lower(): col for col in df_input.columns}
smiles_col_actual = columns_lower.get(smiles_column.lower())
if smiles_col_actual is None:
raise ValueError(
f"未找到 SMILES 列 '{smiles_column}'。可用列: {list(df_input.columns)}"
)
# 处理 ID 列
id_col_actual = columns_lower.get(id_column.lower())
if id_col_actual is None:
print(f"未找到 ID 列 '{id_column}',将自动生成 ID")
df_input[id_column] = [f"mol{i+1}" for i in range(len(df_input))]
id_col_actual = id_column
# 创建预测器配置
config = PredictionConfig(
batch_size=batch_size,
n_workers=n_workers,
device=device
)
# 初始化预测器
print("初始化预测器...")
predictor = ParallelBroadSpectrumPredictor(config)
# 准备分子输入
molecules = [
MoleculeInput(smiles=row[smiles_col_actual], chem_id=str(row[id_col_actual]))
for _, row in df_input.iterrows()
]
# 执行预测
print("开始预测...")
results = predictor.predict_batch(molecules)
# 转换结果为 DataFrame
results_dicts = [r.to_dict() for r in results]
df_results = pd.DataFrame(results_dicts)
# 合并原始数据和预测结果
# 使用 chem_id 作为键进行合并
df_input['_merge_id'] = df_input[id_col_actual].astype(str)
df_results['_merge_id'] = df_results['chem_id'].astype(str)
df_output = df_input.merge(
df_results.drop(columns=['chem_id']),
on='_merge_id',
how='left'
)
df_output = df_output.drop(columns=['_merge_id'])
# 生成输出路径
if output_path is None:
if add_suffix:
output_path = str(input_path_obj.parent / f"{input_path_obj.stem}_predicted{input_path_obj.suffix}")
else:
output_path = str(input_path_obj.parent / f"{input_path_obj.stem}_output{input_path_obj.suffix}")
elif add_suffix:
output_path_obj = Path(output_path)
output_path = str(output_path_obj.parent / f"{output_path_obj.stem}_predicted{output_path_obj.suffix}")
# 保存结果
print(f"保存结果到: {output_path}")
df_output.to_csv(output_path, index=False)
print(f"完成! 预测了 {len(results)} 个分子")
print(f"其中 {sum(r.broad_spectrum for r in results)} 个分子被预测为广谱抗菌")
return df_output
def predict_multiple_files(
input_paths: List[str],
output_dir: Optional[str] = None,
smiles_column: str = "smiles",
id_column: str = "chem_id",
batch_size: int = 100,
n_workers: Optional[int] = None,
device: str = "auto",
add_suffix: bool = True
) -> List[pd.DataFrame]:
"""
批量预测多个 CSV 文件
Args:
input_paths: 输入 CSV 文件路径列表
output_dir: 输出目录,如果为 None 则在原文件目录生成
smiles_column: SMILES 列名
id_column: 化合物 ID 列名
batch_size: 批处理大小
n_workers: 工作进程数
device: 计算设备
add_suffix: 是否在输出文件名后添加预测后缀
Returns:
包含预测结果的 DataFrame 列表
"""
results = []
for input_path in input_paths:
input_path_obj = Path(input_path)
# 确定输出路径
if output_dir is not None:
output_dir_obj = Path(output_dir)
output_dir_obj.mkdir(parents=True, exist_ok=True)
if add_suffix:
output_path = str(output_dir_obj / f"{input_path_obj.stem}_predicted{input_path_obj.suffix}")
else:
output_path = str(output_dir_obj / input_path_obj.name)
else:
output_path = None
# 预测单个文件
try:
df_result = predict_csv_file(
input_path=input_path,
output_path=output_path,
smiles_column=smiles_column,
id_column=id_column,
batch_size=batch_size,
n_workers=n_workers,
device=device,
add_suffix=add_suffix
)
results.append(df_result)
except Exception as e:
print(f"处理文件 {input_path} 时出错: {e}")
continue
return results
# ============================================================================
# 命令行接口
# ============================================================================
@click.command()
@click.argument('input_path', type=click.Path(exists=True))
@click.argument('output_path', type=click.Path(), required=False)
@click.option('--smiles-column', '-s', default='smiles',
help='SMILES 列名 (默认: smiles)')
@click.option('--id-column', '-i', default='chem_id',
help='化合物 ID 列名 (默认: chem_id)')
@click.option('--batch-size', '-b', default=100, type=int,
help='批处理大小 (默认: 100)')
@click.option('--n-workers', '-w', default=None, type=int,
help='工作进程数 (默认: CPU 核心数)')
@click.option('--device', '-d', default='auto',
type=click.Choice(['auto', 'cpu', 'cuda:0', 'cuda:1'], case_sensitive=False),
help='计算设备 (默认: auto)')
@click.option('--add-suffix/--no-add-suffix', default=True,
help='是否在输出文件名后添加 "_predicted" 后缀 (默认: 添加)')
def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers, device, add_suffix):
"""
使用 MolE 模型预测小分子 SMILES 的抗菌活性
INPUT_PATH: 输入 CSV 文件路径
OUTPUT_PATH: 输出 CSV 文件路径 (可选,默认在原文件目录生成)
示例:
python mole_predictor.py input.csv output.csv
python mole_predictor.py input.csv -s SMILES -i ID
python mole_predictor.py input.csv --device cuda:0 --batch-size 200
"""
try:
predict_csv_file(
input_path=input_path,
output_path=output_path,
smiles_column=smiles_column,
id_column=id_column,
batch_size=batch_size,
n_workers=n_workers,
device=device,
add_suffix=add_suffix
)
except Exception as e:
click.echo(f"错误: {e}", err=True)
sys.exit(1)
if __name__ == '__main__':
cli()