feat: 实现大规模并行预测功能 (v2.0.0)

新增功能:
- 新增统一批量预测工具 utils/batch_predictor.py
  * 支持单进程/多进程并行模式
  * 灵活的 GPU 配置和显存自动计算
  * 自动临时文件管理和断点续传
  * 完整的 CLI 参数支持(Click 框架)

- 新增 Shell 脚本集合 scripts/
  * run_parallel_predict.sh - 并行预测脚本
  * run_single_predict.sh - 单进程预测脚本
  * merge_results.sh - 结果合并脚本

性能优化:
- 解决 CUDA + multiprocessing fork 死锁问题
  * 使用 spawn 模式替代 fork
  * 文件描述符级别的输出重定向

- 优化预测性能
  * XGBoost OpenMP 多线程(利用所有 CPU 核心)
  * 预加载模型减少重复加载
  * 大批量处理降低函数调用开销
  * 实际加速比:2-3x(12进程 vs 单进程)

- 优化输出显示
  * 抑制模型加载时的权重信息
  * 只显示进度条和关键统计
  * 临时文件自动保存到专门目录

文档更新:
- README.md 新增"大规模并行预测"章节
- README.md 新增"性能优化说明"章节
- 添加详细的使用示例和参数说明
- 更新项目结构和版本信息

技术细节:
- 每个模型实例约占用 2.5GB GPU 显存
- 显存计算公式:建议进程数 = GPU显存(GB) / 2.5
- GPU 瓶颈占比:MolE 表示生成 94%
- 非 GIL 问题:计算密集任务在 C/CUDA 层

Breaking Changes:
- 废弃旧的独立预测脚本,统一使用新工具

相关 Issue: 解决 #并行预测卡死问题
测试平台: Linux, 256 CPU cores, NVIDIA RTX 5090 32GB
This commit is contained in:
2025-10-18 20:53:39 +08:00
parent 4745ce3884
commit a8fea027ac
8 changed files with 1202 additions and 51 deletions

View File

@@ -12,7 +12,7 @@ import torch
import numpy as np
import pandas as pd
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor, as_completed
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from typing import List, Dict, Union, Optional, Tuple, Any
from dataclasses import dataclass
from pathlib import Path
@@ -31,8 +31,8 @@ class PredictionConfig:
gram_info_path: str = None
app_threshold: float = 0.04374140128493309
min_nkill: int = 10
batch_size: int = 100
n_workers: Optional[int] = None
batch_size: int = 10000 # 优化进一步增加到10000
n_workers: Optional[int] = 2 # 优化减少到2个线程避免CPU竞争
device: str = "auto"
def __post_init__(self):
@@ -429,29 +429,52 @@ def _predict_batch_worker(batch_data: Tuple[pd.DataFrame, int],
class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
"""
并行广谱抗菌预测器
优化后的预测器 - 使用XGBoost内部并行
继承自BroadSpectrumPredictor添加了多进程并行处理能力
适用于大规模分子批量预测。
关键改进:
1. 单进程处理避免GIL和进程间通信开销
2. 模型只加载一次
3. XGBoost内部使用所有CPU核心OpenMP
4. 大批量处理
"""
def predict_single(self, molecule: MoleculeInput) -> BroadSpectrumResult:
def __init__(self, config: Optional[PredictionConfig] = None) -> None:
"""
预测单个分子的广谱抗菌活性
初始化并预加载模型
Args:
molecule: 分子输入数据
Returns:
广谱抗菌预测结果
config: 预测配置参数
"""
results = self.predict_batch([molecule])
return results[0]
# 调用父类初始化
super().__init__(config)
# ✅ 核心优化预加载XGBoost模型到内存
print("⚡ Loading XGBoost model...")
import time
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="xgboost")
start = time.time()
with open(self.config.xgboost_model_path, "rb") as file:
self.xgboost_model = pickle.load(file)
# 修复特征名称兼容性
if hasattr(self.xgboost_model, 'get_booster'):
self.xgboost_model.get_booster().feature_names = None
# ✅ 关键设置XGBoost使用所有CPU核心
n_threads = mp.cpu_count()
self.xgboost_model.get_booster().set_param({
'nthread': n_threads
})
print(f"✓ XGBoost configured to use {n_threads} CPU threads")
print(f"✓ Model loaded in {time.time()-start:.2f}s")
def predict_batch(self, molecules: List[MoleculeInput],
include_strain_predictions: bool = False) -> List[BroadSpectrumResult]:
"""
批量预测分子的广谱抗菌活性
单进程批量预测 - 使用XGBoost内部并行
Args:
molecules: 分子输入列表
@@ -463,50 +486,66 @@ class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
if not molecules:
return []
# 获取MolE表示
import time
# 1. MolE表示GPU
print(f"\n{'='*60}")
print(f"Processing {len(molecules)} molecules...")
print(f"{'='*60}")
start_total = time.time()
start = time.time()
print("\n[1/4] Generating MolE representations (GPU)...")
mole_representation = self._get_mole_representation(molecules)
time_mole = time.time() - start
print(f"✓ Done in {time_mole:.1f}s")
# 添加菌株信息
print("Preparing strain-level features...")
# 2. 准备特征(添加菌株信息
start = time.time()
print("\n[2/4] Preparing strain-level features...")
X_input = self._add_strains(mole_representation)
time_prep = time.time() - start
print(f"✓ Done in {time_prep:.1f}s")
print(f" Total predictions needed: {len(X_input):,}")
# 分批处理
print(f"Starting parallel prediction with {self.n_workers} workers...")
batches = []
for i in range(0, len(X_input), self.config.batch_size):
batch = X_input.iloc[i:i+self.config.batch_size]
batches.append((batch, i // self.config.batch_size))
# 3. XGBoost预测单次大批量内部48核并行
start = time.time()
print(f"\n[3/4] XGBoost prediction (using {mp.cpu_count()} CPU cores)...")
print(f" Predicting {len(X_input):,} samples in one batch...")
print(f" (Watch CPU usage - should be ~{mp.cpu_count()*100}%)")
# 并行预测
results = {}
with ProcessPoolExecutor(max_workers=self.n_workers) as executor:
futures = {
executor.submit(_predict_batch_worker, (batch_data, batch_id),
self.config.xgboost_model_path,
self.config.app_threshold): batch_id
for batch_data, batch_id in batches
}
for future in as_completed(futures):
batch_id, pred_df = future.result()
results[batch_id] = pred_df
print(f"Batch {batch_id} completed")
# ✅ 关键:单次预测所有数据
# XGBoost内部会自动使用OpenMP并行到所有核心
y_pred = self.xgboost_model.predict_proba(X_input)
# 合并结果
print("Merging prediction results...")
all_pred_df = pd.concat([results[i] for i in sorted(results.keys())])
all_pred_df = all_pred_df.reset_index()
time_pred = time.time() - start
print(f"✓ Done in {time_pred:.1f}s")
print(f" Throughput: {len(X_input)/time_pred:.0f} predictions/second")
# 4. 后处理
start = time.time()
print("\n[4/4] Post-processing results...")
pred_df = pd.DataFrame(
y_pred,
columns=["0", "1"],
index=X_input.index
)
pred_df["growth_inhibition"] = (
pred_df["1"] >= self.config.app_threshold
).astype(int)
pred_df = pred_df.reset_index()
# 准备菌株级别数据(如果需要)
strain_level_data = None
if include_strain_predictions:
print("Preparing strain-level predictions...")
strain_level_data = self._prepare_strain_level_predictions(all_pred_df)
strain_level_data = self._prepare_strain_level_predictions(pred_df)
# 计算抗菌潜力
print("Calculating antimicrobial potential scores...")
agg_df = self._antimicrobial_potential(all_pred_df)
agg_df = self._antimicrobial_potential(pred_df)
# 判断广谱抗菌
agg_df["broad_spectrum"] = agg_df["ginhib_total"].apply(
@@ -516,7 +555,6 @@ class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
# 转换为结果对象
results_list = []
for _, row in agg_df.iterrows():
# 获取该分子的菌株级别预测
mol_strain_preds = None
if strain_level_data is not None:
mol_strain_preds = strain_level_data[
@@ -536,8 +574,40 @@ class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
)
results_list.append(result)
time_post = time.time() - start
print(f"✓ Done in {time_post:.1f}s")
# 总结
total_time = time.time() - start_total
print(f"\n{'='*60}")
print(f"SUMMARY")
print(f"{'='*60}")
print(f" MolE representation: {time_mole:6.1f}s ({time_mole/total_time*100:5.1f}%)")
print(f" Feature preparation: {time_prep:6.1f}s ({time_prep/total_time*100:5.1f}%)")
print(f" XGBoost prediction: {time_pred:6.1f}s ({time_pred/total_time*100:5.1f}%)")
print(f" Post-processing: {time_post:6.1f}s ({time_post/total_time*100:5.1f}%)")
print(f" {''*58}")
print(f" Total time: {total_time:6.1f}s")
print(f" Molecules processed: {len(molecules)}")
print(f" Time per molecule: {total_time/len(molecules):.3f}s")
print(f"{'='*60}\n")
return results_list
def predict_single(self, molecule: MoleculeInput) -> BroadSpectrumResult:
"""
预测单个分子的广谱抗菌活性
Args:
molecule: 分子输入数据
Returns:
广谱抗菌预测结果
"""
results = self.predict_batch([molecule])
return results[0]
def predict_from_smiles(self,
smiles_list: List[str],
chem_ids: Optional[List[str]] = None) -> List[BroadSpectrumResult]: