Files
SIME/utils/batch_predictor.py
hotwa a8fea027ac feat: 实现大规模并行预测功能 (v2.0.0)
新增功能:
- 新增统一批量预测工具 utils/batch_predictor.py
  * 支持单进程/多进程并行模式
  * 灵活的 GPU 配置和显存自动计算
  * 自动临时文件管理和断点续传
  * 完整的 CLI 参数支持(Click 框架)

- 新增 Shell 脚本集合 scripts/
  * run_parallel_predict.sh - 并行预测脚本
  * run_single_predict.sh - 单进程预测脚本
  * merge_results.sh - 结果合并脚本

性能优化:
- 解决 CUDA + multiprocessing fork 死锁问题
  * 使用 spawn 模式替代 fork
  * 文件描述符级别的输出重定向

- 优化预测性能
  * XGBoost OpenMP 多线程(利用所有 CPU 核心)
  * 预加载模型减少重复加载
  * 大批量处理降低函数调用开销
  * 实际加速比:2-3x(12进程 vs 单进程)

- 优化输出显示
  * 抑制模型加载时的权重信息
  * 只显示进度条和关键统计
  * 临时文件自动保存到专门目录

文档更新:
- README.md 新增"大规模并行预测"章节
- README.md 新增"性能优化说明"章节
- 添加详细的使用示例和参数说明
- 更新项目结构和版本信息

技术细节:
- 每个模型实例约占用 2.5GB GPU 显存
- 显存计算公式:建议进程数 = GPU显存(GB) / 2.5
- GPU 瓶颈占比:MolE 表示生成 94%
- 非 GIL 问题:计算密集任务在 C/CUDA 层

Breaking Changes:
- 废弃旧的独立预测脚本,统一使用新工具

相关 Issue: 解决 #并行预测卡死问题
测试平台: Linux, 256 CPU cores, NVIDIA RTX 5090 32GB
2025-10-18 20:53:39 +08:00

508 lines
15 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
批量抗菌活性预测工具
这个工具用于大规模预测分子的抗菌活性,支持:
- 单进程或多进程并行预测
- 自动处理大型 CSV 文件
- 灵活的 GPU 配置
- 临时文件管理和断点续传
技术细节:
- 使用 ParallelBroadSpectrumPredictor单进程 + XGBoost OpenMP 多线程)
- 避免 CUDA fork 死锁问题
- 每个模型实例约占用 2.5GB GPU 显存
- XGBoost 自动利用所有 CPU 核心
作者: AI Assistant
日期: 2025-10-17
"""
import sys
import os
import time
import click
import pandas as pd
from pathlib import Path
from tqdm import tqdm
from contextlib import redirect_stdout, redirect_stderr
import multiprocessing as mp
# 添加项目根目录到 Python 路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from models.broad_spectrum_predictor import (
ParallelBroadSpectrumPredictor,
MoleculeInput,
PredictionConfig
)
def predict_single_process(
input_path: str,
output_path: str,
smiles_column: str,
id_column: str,
device: str,
batch_size: int,
start_from: int,
max_molecules: int,
temp_dir: Path,
verbose: bool = True
) -> pd.DataFrame:
"""
单进程预测分子抗菌活性
Args:
input_path: 输入 CSV 文件路径
output_path: 输出 CSV 文件路径
smiles_column: SMILES 列名
id_column: 化合物 ID 列名
device: GPU 设备(如 'cuda:0''cpu'
batch_size: 批处理大小
start_from: 从第几行开始
max_molecules: 最多处理多少个分子
temp_dir: 临时文件目录
verbose: 是否显示详细信息
Returns:
预测结果 DataFrame
"""
# 读取数据
df_input = pd.read_csv(input_path)
# 检查列是否存在(大小写不敏感)
columns_lower = {col.lower(): col for col in df_input.columns}
smiles_col_actual = columns_lower.get(smiles_column.lower())
if smiles_col_actual is None:
raise ValueError(
f"SMILES 列 '{smiles_column}' 不存在。可用列: {list(df_input.columns)}"
)
# 处理 ID 列
id_col_actual = columns_lower.get(id_column.lower())
if id_col_actual is None:
if verbose:
print(f"未找到 ID 列 '{id_column}',将自动生成 ID")
df_input[id_column] = [f"mol{i+1}" for i in range(len(df_input))]
id_col_actual = id_column
# 应用限制
if start_from > 0:
df_input = df_input.iloc[start_from:]
if max_molecules:
df_input = df_input.iloc[:max_molecules]
if verbose:
print(f" 处理 {len(df_input):,} 个分子")
# 初始化预测器
config = PredictionConfig(
batch_size=10000,
device=device
)
# 抑制模型加载时的输出
if not verbose:
with open(os.devnull, 'w') as devnull:
with redirect_stdout(devnull):
predictor = ParallelBroadSpectrumPredictor(config)
else:
predictor = ParallelBroadSpectrumPredictor(config)
# 分批处理
all_results = []
n_batches = (len(df_input) + batch_size - 1) // batch_size
iterator = range(0, len(df_input), batch_size)
if verbose:
iterator = tqdm(iterator, desc="处理进度", unit="")
for i in iterator:
batch_df = df_input.iloc[i:i+batch_size]
# 准备分子输入
molecules = [
MoleculeInput(
smiles=row[smiles_col_actual],
chem_id=str(row[id_col_actual])
)
for _, row in batch_df.iterrows()
]
# 执行预测
try:
# 抑制详细输出
with open(os.devnull, 'w') as devnull:
with redirect_stdout(devnull):
results = predictor.predict_batch(
molecules,
include_strain_predictions=False
)
# 转换结果
for result in results:
result_dict = result.to_dict()
mol_idx = int(result.chem_id.replace('mol', '')) - 1
if mol_idx < len(batch_df):
result_dict['smiles'] = batch_df.iloc[mol_idx][smiles_col_actual]
all_results.append(result_dict)
except Exception as e:
if verbose:
print(f"\n❌ 批次 {i//batch_size + 1} 失败: {e}")
continue
# 定期保存临时结果(每 10 批)
if (i // batch_size + 1) % 10 == 0:
temp_df = pd.DataFrame(all_results)
temp_file = temp_dir / f"batch_{i//batch_size+1}.csv"
temp_df.to_csv(temp_file, index=False)
if verbose:
print(f"\n💾 临时保存: {temp_file}")
# 转换为 DataFrame
df_results = pd.DataFrame(all_results)
# 重新排列列顺序
if 'smiles' in df_results.columns:
cols = ['smiles', 'chem_id'] + [col for col in df_results.columns
if col not in ['smiles', 'chem_id']]
df_results = df_results[cols]
# 保存结果
df_results.to_csv(output_path, index=False)
return df_results
def predict_chunk_worker(args):
"""
多进程工作函数:处理单个数据块
Args:
args: (chunk_data, chunk_id, output_file, params_dict)
Returns:
(chunk_id, output_file, success)
"""
chunk_data, chunk_id, output_file, params = args
try:
# 保存 chunk 数据到临时文件
temp_input = params['temp_dir'] / f"chunk_{chunk_id}_input.csv"
chunk_data.to_csv(temp_input, index=False)
# 调用单进程预测
predict_single_process(
input_path=str(temp_input),
output_path=str(output_file),
smiles_column=params['smiles_column'],
id_column=params['id_column'],
device=params['device'],
batch_size=params['batch_size'],
start_from=0,
max_molecules=None,
temp_dir=params['temp_dir'],
verbose=(chunk_id == 0) # 只有第一个进程显示详细信息
)
# 清理临时输入文件
temp_input.unlink()
return chunk_id, output_file, True
except Exception as e:
print(f"❌ Chunk {chunk_id} 失败: {e}")
import traceback
traceback.print_exc()
return chunk_id, output_file, False
@click.command()
@click.option(
'--input', '-i',
required=True,
type=click.Path(exists=True),
help='输入 CSV 文件路径'
)
@click.option(
'--output', '-o',
required=True,
type=click.Path(),
help='输出 CSV 文件路径'
)
@click.option(
'--smiles-column', '-s',
default='smiles',
help='SMILES 列名(默认: smiles'
)
@click.option(
'--id-column', '-d',
default='chem_id',
help='化合物 ID 列名(默认: chem_id如不存在则自动生成'
)
@click.option(
'--device', '-g',
default='cuda:0',
help='GPU 设备(默认: cuda:0。可选: cuda:0, cuda:1, cpu'
)
@click.option(
'--n-processes', '-n',
default=1,
type=int,
help='并行进程数(默认: 1。建议值 = GPU显存(GB) / 2.5。例如 32GB 显存可用 ~12 个进程'
)
@click.option(
'--batch-size', '-b',
default=1000,
type=int,
help='每批处理的分子数量(默认: 1000'
)
@click.option(
'--start-from',
default=0,
type=int,
help='从第几行开始处理(默认: 0用于断点续传'
)
@click.option(
'--max-molecules', '-m',
default=None,
type=int,
help='最多处理多少个分子(默认: None处理全部'
)
@click.option(
'--temp-dir',
default=None,
type=click.Path(),
help='临时文件目录(默认: {输入文件名}_temp'
)
@click.option(
'--keep-temp/--no-keep-temp',
default=True,
help='是否保留临时文件(默认: 保留)'
)
@click.option(
'--verbose/--quiet',
default=True,
help='是否显示详细信息(默认: 显示)'
)
def main(
input: str,
output: str,
smiles_column: str,
id_column: str,
device: str,
n_processes: int,
batch_size: int,
start_from: int,
max_molecules: int,
temp_dir: str,
keep_temp: bool,
verbose: bool
):
"""
批量预测分子抗菌活性
\b
示例用法:
1. 单进程预测(最稳定):
pixi run python utils/batch_predictor.py -i data.csv -o output.csv
2. 多进程并行4个进程适合32GB显存
pixi run python utils/batch_predictor.py -i data.csv -o output.csv -n 4
3. 指定 GPU 和列名:
pixi run python utils/batch_predictor.py -i data.csv -o output.csv \\
-g cuda:1 -s SMILES -d ID -n 8
4. 断点续传(从第 100000 行开始):
pixi run python utils/batch_predictor.py -i data.csv -o output.csv \\
--start-from 100000
\b
显存计算公式:
- 每个模型实例约占用 2.5GB GPU 显存
- 建议并行进程数 = GPU显存(GB) / 2.5
- 例如:
* 12GB 显存 → 建议 4 个进程
* 24GB 显存 → 建议 9 个进程
* 32GB 显存 → 建议 12 个进程
* 48GB 显存 → 建议 19 个进程
\b
注意事项:
- 单 GPU 上的多进程会串行使用 GPU不是真正的并行
- 预期加速比约 2-3x而非线性加速
- 如果 GPU 内存不足,减少 n-processes
- 临时文件保存在 {输入文件名}_temp/ 目录
"""
print("=" * 80)
print("🚀 批量抗菌活性预测")
print("=" * 80)
# 设置路径
input_path = Path(input)
output_path = Path(output)
if temp_dir is None:
temp_dir = input_path.parent / f"{input_path.stem}_temp"
else:
temp_dir = Path(temp_dir)
temp_dir.mkdir(parents=True, exist_ok=True)
# 显示配置
if verbose:
print(f"\n配置:")
print(f" 输入文件: {input_path}")
print(f" 输出文件: {output_path}")
print(f" SMILES 列: {smiles_column}")
print(f" ID 列: {id_column}")
print(f" GPU 设备: {device}")
print(f" 并行进程数: {n_processes}")
print(f" 批处理大小: {batch_size}")
print(f" 临时目录: {temp_dir}")
if start_from > 0:
print(f" 开始行: {start_from}")
if max_molecules:
print(f" 最多处理: {max_molecules:,} 个分子")
print("=" * 80)
start_time = time.time()
# 单进程模式
if n_processes == 1:
if verbose:
print("\n📦 使用单进程模式")
df_results = predict_single_process(
input_path=str(input_path),
output_path=str(output_path),
smiles_column=smiles_column,
id_column=id_column,
device=device,
batch_size=batch_size,
start_from=start_from,
max_molecules=max_molecules,
temp_dir=temp_dir,
verbose=verbose
)
# 多进程模式
else:
if verbose:
print(f"\n📦 使用多进程模式({n_processes} 个进程)")
# 读取数据
df_input = pd.read_csv(input_path)
# 应用限制
if start_from > 0:
df_input = df_input.iloc[start_from:]
if max_molecules:
df_input = df_input.iloc[:max_molecules]
# 分割数据
chunk_size = len(df_input) // n_processes
chunks = []
for i in range(n_processes):
start_idx = i * chunk_size
if i == n_processes - 1:
end_idx = len(df_input)
else:
end_idx = (i + 1) * chunk_size
chunk = df_input.iloc[start_idx:end_idx].copy()
output_file = temp_dir / f"part_{i}.csv"
params = {
'smiles_column': smiles_column,
'id_column': id_column,
'device': device,
'batch_size': batch_size,
'temp_dir': temp_dir
}
chunks.append((chunk, i, output_file, params))
if verbose:
print(f" 数据分成 {n_processes}")
print(f" 每块约 {chunk_size:,} 个分子")
print(f"\n🔄 开始并行处理...")
# 使用 spawn 模式避免 CUDA fork 问题
mp.set_start_method('spawn', force=True)
# 并行处理
with mp.Pool(processes=n_processes) as pool:
results = list(tqdm(
pool.imap(predict_chunk_worker, chunks),
total=len(chunks),
desc="处理进度",
disable=not verbose
))
# 合并结果
if verbose:
print("\n📦 合并结果...")
all_dfs = []
for chunk_id, output_file, success in sorted(results, key=lambda x: x[0]):
if success and output_file.exists():
df = pd.read_csv(output_file)
all_dfs.append(df)
if verbose:
print(f" ✓ part_{chunk_id}.csv: {len(df):,}")
else:
print(f" ❌ part_{chunk_id}.csv: 失败")
df_results = pd.concat(all_dfs, ignore_index=True)
df_results.to_csv(output_path, index=False)
# 统计信息
elapsed_time = time.time() - start_time
if verbose:
print("\n" + "=" * 80)
print("✅ 预测完成")
print("=" * 80)
print(f"\n📈 统计信息:")
print(f" 处理分子数: {len(df_results):,}")
print(f" 总耗时: {elapsed_time:.2f} 秒 ({elapsed_time/60:.2f} 分钟)")
print(f" 平均速度: {len(df_results)/elapsed_time:.2f} 分子/秒")
# 广谱抗菌统计
n_broad = df_results['broad_spectrum'].sum()
print(f"\n🎯 预测结果:")
print(f" 广谱抗菌: {n_broad:,} 个 ({n_broad/len(df_results)*100:.2f}%)")
print(f" 非广谱: {len(df_results)-n_broad:,}")
# 抑制菌株数分布
print(f"\n📊 抑制菌株数分布:")
for threshold in [0, 5, 10, 15, 20, 30]:
n = (df_results['ginhib_total'] >= threshold).sum()
print(f"{threshold:2d} 个菌株: {n:,} ({n/len(df_results)*100:.2f}%)")
print("\n" + "=" * 80)
if not keep_temp:
print(f"\n🗑️ 清理临时文件...")
import shutil
shutil.rmtree(temp_dir)
print(f"✓ 已删除: {temp_dir}")
else:
print(f"\n📁 临时文件保留在: {temp_dir}")
if __name__ == '__main__':
main()