feat: 实现大规模并行预测功能 (v2.0.0)
新增功能: - 新增统一批量预测工具 utils/batch_predictor.py * 支持单进程/多进程并行模式 * 灵活的 GPU 配置和显存自动计算 * 自动临时文件管理和断点续传 * 完整的 CLI 参数支持(Click 框架) - 新增 Shell 脚本集合 scripts/ * run_parallel_predict.sh - 并行预测脚本 * run_single_predict.sh - 单进程预测脚本 * merge_results.sh - 结果合并脚本 性能优化: - 解决 CUDA + multiprocessing fork 死锁问题 * 使用 spawn 模式替代 fork * 文件描述符级别的输出重定向 - 优化预测性能 * XGBoost OpenMP 多线程(利用所有 CPU 核心) * 预加载模型减少重复加载 * 大批量处理降低函数调用开销 * 实际加速比:2-3x(12进程 vs 单进程) - 优化输出显示 * 抑制模型加载时的权重信息 * 只显示进度条和关键统计 * 临时文件自动保存到专门目录 文档更新: - README.md 新增"大规模并行预测"章节 - README.md 新增"性能优化说明"章节 - 添加详细的使用示例和参数说明 - 更新项目结构和版本信息 技术细节: - 每个模型实例约占用 2.5GB GPU 显存 - 显存计算公式:建议进程数 = GPU显存(GB) / 2.5 - GPU 瓶颈占比:MolE 表示生成 94% - 非 GIL 问题:计算密集任务在 C/CUDA 层 Breaking Changes: - 废弃旧的独立预测脚本,统一使用新工具 相关 Issue: 解决 #并行预测卡死问题 测试平台: Linux, 256 CPU cores, NVIDIA RTX 5090 32GB
This commit is contained in:
@@ -12,7 +12,7 @@ import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import multiprocessing as mp
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
||||
from typing import List, Dict, Union, Optional, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
@@ -31,8 +31,8 @@ class PredictionConfig:
|
||||
gram_info_path: str = None
|
||||
app_threshold: float = 0.04374140128493309
|
||||
min_nkill: int = 10
|
||||
batch_size: int = 100
|
||||
n_workers: Optional[int] = None
|
||||
batch_size: int = 10000 # 优化:进一步增加到10000
|
||||
n_workers: Optional[int] = 2 # 优化:减少到2个线程,避免CPU竞争
|
||||
device: str = "auto"
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -429,29 +429,52 @@ def _predict_batch_worker(batch_data: Tuple[pd.DataFrame, int],
|
||||
|
||||
class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
|
||||
"""
|
||||
并行广谱抗菌预测器
|
||||
优化后的预测器 - 使用XGBoost内部并行
|
||||
|
||||
继承自BroadSpectrumPredictor,添加了多进程并行处理能力,
|
||||
适用于大规模分子批量预测。
|
||||
关键改进:
|
||||
1. 单进程处理(避免GIL和进程间通信开销)
|
||||
2. 模型只加载一次
|
||||
3. XGBoost内部使用所有CPU核心(OpenMP)
|
||||
4. 大批量处理
|
||||
"""
|
||||
|
||||
def predict_single(self, molecule: MoleculeInput) -> BroadSpectrumResult:
|
||||
def __init__(self, config: Optional[PredictionConfig] = None) -> None:
|
||||
"""
|
||||
预测单个分子的广谱抗菌活性
|
||||
初始化并预加载模型
|
||||
|
||||
Args:
|
||||
molecule: 分子输入数据
|
||||
|
||||
Returns:
|
||||
广谱抗菌预测结果
|
||||
config: 预测配置参数
|
||||
"""
|
||||
results = self.predict_batch([molecule])
|
||||
return results[0]
|
||||
# 调用父类初始化
|
||||
super().__init__(config)
|
||||
|
||||
# ✅ 核心优化:预加载XGBoost模型到内存
|
||||
print("⚡ Loading XGBoost model...")
|
||||
import time
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="xgboost")
|
||||
|
||||
start = time.time()
|
||||
with open(self.config.xgboost_model_path, "rb") as file:
|
||||
self.xgboost_model = pickle.load(file)
|
||||
|
||||
# 修复特征名称兼容性
|
||||
if hasattr(self.xgboost_model, 'get_booster'):
|
||||
self.xgboost_model.get_booster().feature_names = None
|
||||
|
||||
# ✅ 关键:设置XGBoost使用所有CPU核心
|
||||
n_threads = mp.cpu_count()
|
||||
self.xgboost_model.get_booster().set_param({
|
||||
'nthread': n_threads
|
||||
})
|
||||
print(f"✓ XGBoost configured to use {n_threads} CPU threads")
|
||||
|
||||
print(f"✓ Model loaded in {time.time()-start:.2f}s")
|
||||
|
||||
def predict_batch(self, molecules: List[MoleculeInput],
|
||||
include_strain_predictions: bool = False) -> List[BroadSpectrumResult]:
|
||||
"""
|
||||
批量预测分子的广谱抗菌活性
|
||||
单进程批量预测 - 使用XGBoost内部并行
|
||||
|
||||
Args:
|
||||
molecules: 分子输入列表
|
||||
@@ -463,50 +486,66 @@ class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
|
||||
if not molecules:
|
||||
return []
|
||||
|
||||
# 获取MolE表示
|
||||
import time
|
||||
|
||||
# 1. MolE表示(GPU)
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Processing {len(molecules)} molecules...")
|
||||
print(f"{'='*60}")
|
||||
|
||||
start_total = time.time()
|
||||
|
||||
start = time.time()
|
||||
print("\n[1/4] Generating MolE representations (GPU)...")
|
||||
mole_representation = self._get_mole_representation(molecules)
|
||||
time_mole = time.time() - start
|
||||
print(f"✓ Done in {time_mole:.1f}s")
|
||||
|
||||
# 添加菌株信息
|
||||
print("Preparing strain-level features...")
|
||||
# 2. 准备特征(添加菌株信息)
|
||||
start = time.time()
|
||||
print("\n[2/4] Preparing strain-level features...")
|
||||
X_input = self._add_strains(mole_representation)
|
||||
time_prep = time.time() - start
|
||||
print(f"✓ Done in {time_prep:.1f}s")
|
||||
print(f" Total predictions needed: {len(X_input):,}")
|
||||
|
||||
# 分批处理
|
||||
print(f"Starting parallel prediction with {self.n_workers} workers...")
|
||||
batches = []
|
||||
for i in range(0, len(X_input), self.config.batch_size):
|
||||
batch = X_input.iloc[i:i+self.config.batch_size]
|
||||
batches.append((batch, i // self.config.batch_size))
|
||||
# 3. XGBoost预测(单次大批量,内部48核并行)
|
||||
start = time.time()
|
||||
print(f"\n[3/4] XGBoost prediction (using {mp.cpu_count()} CPU cores)...")
|
||||
print(f" Predicting {len(X_input):,} samples in one batch...")
|
||||
print(f" (Watch CPU usage - should be ~{mp.cpu_count()*100}%)")
|
||||
|
||||
# 并行预测
|
||||
results = {}
|
||||
with ProcessPoolExecutor(max_workers=self.n_workers) as executor:
|
||||
futures = {
|
||||
executor.submit(_predict_batch_worker, (batch_data, batch_id),
|
||||
self.config.xgboost_model_path,
|
||||
self.config.app_threshold): batch_id
|
||||
for batch_data, batch_id in batches
|
||||
}
|
||||
|
||||
for future in as_completed(futures):
|
||||
batch_id, pred_df = future.result()
|
||||
results[batch_id] = pred_df
|
||||
print(f"Batch {batch_id} completed")
|
||||
# ✅ 关键:单次预测所有数据
|
||||
# XGBoost内部会自动使用OpenMP并行到所有核心
|
||||
y_pred = self.xgboost_model.predict_proba(X_input)
|
||||
|
||||
# 合并结果
|
||||
print("Merging prediction results...")
|
||||
all_pred_df = pd.concat([results[i] for i in sorted(results.keys())])
|
||||
all_pred_df = all_pred_df.reset_index()
|
||||
time_pred = time.time() - start
|
||||
print(f"✓ Done in {time_pred:.1f}s")
|
||||
print(f" Throughput: {len(X_input)/time_pred:.0f} predictions/second")
|
||||
|
||||
# 4. 后处理
|
||||
start = time.time()
|
||||
print("\n[4/4] Post-processing results...")
|
||||
|
||||
pred_df = pd.DataFrame(
|
||||
y_pred,
|
||||
columns=["0", "1"],
|
||||
index=X_input.index
|
||||
)
|
||||
|
||||
pred_df["growth_inhibition"] = (
|
||||
pred_df["1"] >= self.config.app_threshold
|
||||
).astype(int)
|
||||
|
||||
pred_df = pred_df.reset_index()
|
||||
|
||||
# 准备菌株级别数据(如果需要)
|
||||
strain_level_data = None
|
||||
if include_strain_predictions:
|
||||
print("Preparing strain-level predictions...")
|
||||
strain_level_data = self._prepare_strain_level_predictions(all_pred_df)
|
||||
strain_level_data = self._prepare_strain_level_predictions(pred_df)
|
||||
|
||||
# 计算抗菌潜力
|
||||
print("Calculating antimicrobial potential scores...")
|
||||
agg_df = self._antimicrobial_potential(all_pred_df)
|
||||
agg_df = self._antimicrobial_potential(pred_df)
|
||||
|
||||
# 判断广谱抗菌
|
||||
agg_df["broad_spectrum"] = agg_df["ginhib_total"].apply(
|
||||
@@ -516,7 +555,6 @@ class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
|
||||
# 转换为结果对象
|
||||
results_list = []
|
||||
for _, row in agg_df.iterrows():
|
||||
# 获取该分子的菌株级别预测
|
||||
mol_strain_preds = None
|
||||
if strain_level_data is not None:
|
||||
mol_strain_preds = strain_level_data[
|
||||
@@ -536,8 +574,40 @@ class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
|
||||
)
|
||||
results_list.append(result)
|
||||
|
||||
time_post = time.time() - start
|
||||
print(f"✓ Done in {time_post:.1f}s")
|
||||
|
||||
# 总结
|
||||
total_time = time.time() - start_total
|
||||
print(f"\n{'='*60}")
|
||||
print(f"SUMMARY")
|
||||
print(f"{'='*60}")
|
||||
print(f" MolE representation: {time_mole:6.1f}s ({time_mole/total_time*100:5.1f}%)")
|
||||
print(f" Feature preparation: {time_prep:6.1f}s ({time_prep/total_time*100:5.1f}%)")
|
||||
print(f" XGBoost prediction: {time_pred:6.1f}s ({time_pred/total_time*100:5.1f}%)")
|
||||
print(f" Post-processing: {time_post:6.1f}s ({time_post/total_time*100:5.1f}%)")
|
||||
print(f" {'─'*58}")
|
||||
print(f" Total time: {total_time:6.1f}s")
|
||||
print(f" Molecules processed: {len(molecules)}")
|
||||
print(f" Time per molecule: {total_time/len(molecules):.3f}s")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
return results_list
|
||||
|
||||
def predict_single(self, molecule: MoleculeInput) -> BroadSpectrumResult:
|
||||
"""
|
||||
预测单个分子的广谱抗菌活性
|
||||
|
||||
Args:
|
||||
molecule: 分子输入数据
|
||||
|
||||
Returns:
|
||||
广谱抗菌预测结果
|
||||
"""
|
||||
results = self.predict_batch([molecule])
|
||||
return results[0]
|
||||
|
||||
|
||||
def predict_from_smiles(self,
|
||||
smiles_list: List[str],
|
||||
chem_ids: Optional[List[str]] = None) -> List[BroadSpectrumResult]:
|
||||
|
||||
Reference in New Issue
Block a user