Files
SIME/utils/mole_predictor.py
hotwa a8fea027ac feat: 实现大规模并行预测功能 (v2.0.0)
新增功能:
- 新增统一批量预测工具 utils/batch_predictor.py
  * 支持单进程/多进程并行模式
  * 灵活的 GPU 配置和显存自动计算
  * 自动临时文件管理和断点续传
  * 完整的 CLI 参数支持(Click 框架)

- 新增 Shell 脚本集合 scripts/
  * run_parallel_predict.sh - 并行预测脚本
  * run_single_predict.sh - 单进程预测脚本
  * merge_results.sh - 结果合并脚本

性能优化:
- 解决 CUDA + multiprocessing fork 死锁问题
  * 使用 spawn 模式替代 fork
  * 文件描述符级别的输出重定向

- 优化预测性能
  * XGBoost OpenMP 多线程(利用所有 CPU 核心)
  * 预加载模型减少重复加载
  * 大批量处理降低函数调用开销
  * 实际加速比:2-3x(12进程 vs 单进程)

- 优化输出显示
  * 抑制模型加载时的权重信息
  * 只显示进度条和关键统计
  * 临时文件自动保存到专门目录

文档更新:
- README.md 新增"大规模并行预测"章节
- README.md 新增"性能优化说明"章节
- 添加详细的使用示例和参数说明
- 更新项目结构和版本信息

技术细节:
- 每个模型实例约占用 2.5GB GPU 显存
- 显存计算公式:建议进程数 = GPU显存(GB) / 2.5
- GPU 瓶颈占比:MolE 表示生成 94%
- 非 GIL 问题:计算密集任务在 C/CUDA 层

Breaking Changes:
- 废弃旧的独立预测脚本,统一使用新工具

相关 Issue: 解决 #并行预测卡死问题
测试平台: Linux, 256 CPU cores, NVIDIA RTX 5090 32GB
2025-10-18 20:53:39 +08:00

327 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
MolE 抗菌活性预测工具
这个脚本提供了使用 MolE 模型预测小分子 SMILES 抗菌活性的功能。
支持命令行和 Python API 调用两种方式。
命令行示例:
python mole_predictor.py input.csv output.csv --smiles-column smiles --id-column chem_id
Python API 示例:
from utils.mole_predictor import predict_csv_file
predict_csv_file(
input_path="input.csv",
output_path="output.csv",
smiles_column="smiles",
id_column="chem_id"
)
"""
import sys
import os
from pathlib import Path
# 添加项目根目录到 Python 路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
import click
import pandas as pd
from typing import Optional, List
from datetime import datetime
from tqdm import tqdm
from models.broad_spectrum_predictor import (
ParallelBroadSpectrumPredictor,
PredictionConfig,
MoleculeInput,
BroadSpectrumResult
)
def predict_csv_file(
input_path: str,
output_path: Optional[str] = None,
smiles_column: str = "smiles",
id_column: str = "chem_id",
batch_size: int = 100,
n_workers: Optional[int] = None,
device: str = "auto",
add_suffix: bool = True,
include_strain_predictions: bool = False
) -> pd.DataFrame:
"""
预测 CSV 文件中的分子抗菌活性
Args:
input_path: 输入 CSV 文件路径
output_path: 输出 CSV 文件路径,如果为 None 则自动生成
smiles_column: SMILES 列名
id_column: 化合物 ID 列名
batch_size: 批处理大小
n_workers: 工作进程数
device: 计算设备 ("auto", "cpu", "cuda:0" 等)
add_suffix: 是否在输出文件名后添加预测后缀
include_strain_predictions: 是否在输出中包含40种菌株的预测详情
Returns:
包含预测结果的 DataFrame
"""
print(f"开始处理文件: {input_path}")
# 读取输入文件
input_path_obj = Path(input_path)
if not input_path_obj.exists():
raise FileNotFoundError(f"输入文件不存在: {input_path}")
# 读取 CSV
try:
df_input = pd.read_csv(input_path)
except Exception as e:
raise RuntimeError(f"读取 CSV 文件失败: {e}")
print(f"读取了 {len(df_input)} 条数据")
# 检查列是否存在(大小写不敏感)
columns_lower = {col.lower(): col for col in df_input.columns}
smiles_col_actual = columns_lower.get(smiles_column.lower())
if smiles_col_actual is None:
raise ValueError(
f"未找到 SMILES 列 '{smiles_column}'。可用列: {list(df_input.columns)}"
)
# 处理 ID 列
id_col_actual = columns_lower.get(id_column.lower())
if id_col_actual is None:
print(f"未找到 ID 列 '{id_column}',将自动生成 ID")
df_input[id_column] = [f"mol{i+1}" for i in range(len(df_input))]
id_col_actual = id_column
# 创建预测器配置
config = PredictionConfig(
batch_size=batch_size,
n_workers=n_workers,
device=device
)
# 初始化预测器
print("初始化预测器...")
predictor = ParallelBroadSpectrumPredictor(config)
# 准备分子输入
molecules = [
MoleculeInput(smiles=row[smiles_col_actual], chem_id=str(row[id_col_actual]))
for _, row in df_input.iterrows()
]
# 执行预测
print("开始预测...")
results = predictor.predict_batch(molecules, include_strain_predictions=include_strain_predictions)
# 转换结果为 DataFrame
results_dicts = [r.to_dict() for r in results]
df_results = pd.DataFrame(results_dicts)
# 合并原始数据和预测结果
# 使用 chem_id 作为键进行合并
df_input['_merge_id'] = df_input[id_col_actual].astype(str)
df_results['_merge_id'] = df_results['chem_id'].astype(str)
df_output = df_input.merge(
df_results.drop(columns=['chem_id']),
on='_merge_id',
how='left'
)
df_output = df_output.drop(columns=['_merge_id'])
# 如果包含菌株级别预测,将其添加到输出中
if include_strain_predictions:
print("合并菌株级别预测数据...")
# 收集所有菌株级别预测
all_strain_predictions = []
for result in results:
if result.strain_predictions is not None and not result.strain_predictions.empty:
all_strain_predictions.append(result.strain_predictions)
if all_strain_predictions:
# 合并所有菌株预测
df_strain_predictions = pd.concat(all_strain_predictions, ignore_index=True)
# 将聚合结果和菌株预测合并
# 为了在同一个 CSV 中展示,我们使用重复行的方式
# 每个分子的聚合结果会重复40次每个菌株一次
df_output_expanded = df_output.merge(
df_strain_predictions,
left_on=df_output.columns[df_output.columns.get_loc(id_col_actual)],
right_on='chem_id',
how='left',
suffixes=('', '_strain')
)
# 移除重复的 chem_id 列
if 'chem_id_strain' in df_output_expanded.columns:
df_output_expanded = df_output_expanded.drop(columns=['chem_id_strain'])
df_output = df_output_expanded
# 生成输出路径
if output_path is None:
if add_suffix:
output_path = str(input_path_obj.parent / f"{input_path_obj.stem}_predicted{input_path_obj.suffix}")
else:
output_path = str(input_path_obj.parent / f"{input_path_obj.stem}_output{input_path_obj.suffix}")
elif add_suffix:
output_path_obj = Path(output_path)
output_path = str(output_path_obj.parent / f"{output_path_obj.stem}_predicted{output_path_obj.suffix}")
# 保存结果
print(f"保存结果到: {output_path}")
df_output.to_csv(output_path, index=False)
print(f"完成! 预测了 {len(results)} 个分子")
print(f"其中 {sum(r.broad_spectrum for r in results)} 个分子被预测为广谱抗菌")
return df_output
def predict_multiple_files(
input_paths: List[str],
output_dir: Optional[str] = None,
smiles_column: str = "smiles",
id_column: str = "chem_id",
batch_size: int = 100,
n_workers: Optional[int] = None,
device: str = "auto",
add_suffix: bool = True,
include_strain_predictions: bool = False
) -> List[pd.DataFrame]:
"""
批量预测多个 CSV 文件
Args:
input_paths: 输入 CSV 文件路径列表
output_dir: 输出目录,如果为 None 则在原文件目录生成
smiles_column: SMILES 列名
id_column: 化合物 ID 列名
batch_size: 批处理大小
n_workers: 工作进程数
device: 计算设备
add_suffix: 是否在输出文件名后添加预测后缀
include_strain_predictions: 是否在输出中包含40种菌株的预测详情
Returns:
包含预测结果的 DataFrame 列表
"""
results = []
for input_path in input_paths:
input_path_obj = Path(input_path)
# 确定输出路径
if output_dir is not None:
output_dir_obj = Path(output_dir)
output_dir_obj.mkdir(parents=True, exist_ok=True)
if add_suffix:
output_path = str(output_dir_obj / f"{input_path_obj.stem}_predicted{input_path_obj.suffix}")
else:
output_path = str(output_dir_obj / input_path_obj.name)
else:
output_path = None
# 预测单个文件
try:
df_result = predict_csv_file(
input_path=input_path,
output_path=output_path,
smiles_column=smiles_column,
id_column=id_column,
batch_size=batch_size,
n_workers=n_workers,
device=device,
add_suffix=add_suffix,
include_strain_predictions=include_strain_predictions
)
results.append(df_result)
except Exception as e:
print(f"处理文件 {input_path} 时出错: {e}")
continue
return results
# ============================================================================
# 命令行接口
# ============================================================================
@click.command()
@click.argument('input_path', type=click.Path(exists=True))
@click.argument('output_path', type=click.Path(), required=False)
@click.option('--smiles-column', '-s', default='smiles',
help='SMILES 列名 (默认: smiles)')
@click.option('--id-column', '-i', default='chem_id',
help='化合物 ID 列名 (默认: chem_id)')
@click.option('--batch-size', '-b', default=100, type=int,
help='批处理大小 (默认: 100)')
@click.option('--n-workers', '-w', default=None, type=int,
help='工作进程数 (默认: CPU 核心数)')
@click.option('--device', '-d', default='auto',
type=click.Choice(['auto', 'cpu', 'cuda:0', 'cuda:1'], case_sensitive=False),
help='计算设备 (默认: auto)')
@click.option('--add-suffix/--no-add-suffix', default=True,
help='是否在输出文件名后添加 "_predicted" 后缀 (默认: 添加)')
@click.option('--include-strain-predictions', is_flag=True, default=False,
help='在输出中包含40种菌株的详细预测数据每个分子将产生40行数据对应每个菌株的预测概率和抑制情况')
def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers, device, add_suffix, include_strain_predictions):
"""
使用 MolE 模型预测小分子 SMILES 的抗菌活性
INPUT_PATH: 输入 CSV 文件路径
OUTPUT_PATH: 输出 CSV 文件路径 (可选,默认在原文件目录生成)
默认输出包含聚合的抗菌活性指标(广谱抗菌评分、抑制菌株数等)。
使用 --include-strain-predictions 可以额外包含每个菌株的详细预测数据。
示例:
# 基本用法(仅输出聚合结果)
python mole_predictor.py input.csv output.csv
# 包含40种菌株的详细预测数据
python mole_predictor.py input.csv output.csv --include-strain-predictions
# 指定列名和设备
python mole_predictor.py input.csv -s SMILES -i ID --device cuda:0
# 自定义批处理大小
python mole_predictor.py input.csv --device cuda:0 --batch-size 200
"""
try:
predict_csv_file(
input_path=input_path,
output_path=output_path,
smiles_column=smiles_column,
id_column=id_column,
batch_size=batch_size,
n_workers=n_workers,
device=device,
add_suffix=add_suffix,
include_strain_predictions=include_strain_predictions
)
except Exception as e:
click.echo(f"错误: {e}", err=True)
sys.exit(1)
if __name__ == '__main__':
cli()