Files
SIME/models/broad_spectrum_predictor.py
hotwa a8fea027ac feat: 实现大规模并行预测功能 (v2.0.0)
新增功能:
- 新增统一批量预测工具 utils/batch_predictor.py
  * 支持单进程/多进程并行模式
  * 灵活的 GPU 配置和显存自动计算
  * 自动临时文件管理和断点续传
  * 完整的 CLI 参数支持(Click 框架)

- 新增 Shell 脚本集合 scripts/
  * run_parallel_predict.sh - 并行预测脚本
  * run_single_predict.sh - 单进程预测脚本
  * merge_results.sh - 结果合并脚本

性能优化:
- 解决 CUDA + multiprocessing fork 死锁问题
  * 使用 spawn 模式替代 fork
  * 文件描述符级别的输出重定向

- 优化预测性能
  * XGBoost OpenMP 多线程(利用所有 CPU 核心)
  * 预加载模型减少重复加载
  * 大批量处理降低函数调用开销
  * 实际加速比:2-3x(12进程 vs 单进程)

- 优化输出显示
  * 抑制模型加载时的权重信息
  * 只显示进度条和关键统计
  * 临时文件自动保存到专门目录

文档更新:
- README.md 新增"大规模并行预测"章节
- README.md 新增"性能优化说明"章节
- 添加详细的使用示例和参数说明
- 更新项目结构和版本信息

技术细节:
- 每个模型实例约占用 2.5GB GPU 显存
- 显存计算公式:建议进程数 = GPU显存(GB) / 2.5
- GPU 瓶颈占比:MolE 表示生成 94%
- 非 GIL 问题:计算密集任务在 C/CUDA 层

Breaking Changes:
- 废弃旧的独立预测脚本,统一使用新工具

相关 Issue: 解决 #并行预测卡死问题
测试平台: Linux, 256 CPU cores, NVIDIA RTX 5090 32GB
2025-10-18 20:53:39 +08:00

731 lines
25 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
并行广谱抗菌预测器模块
提供高性能的分子广谱抗菌活性预测功能,支持批量处理和多进程并行计算。
基于MolE分子表示和XGBoost模型进行预测。
"""
import os
import re
import pickle
import torch
import numpy as np
import pandas as pd
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from typing import List, Dict, Union, Optional, Tuple, Any
from dataclasses import dataclass
from pathlib import Path
from scipy.stats.mstats import gmean
from sklearn.preprocessing import OneHotEncoder
from .mole_representation import process_representation
@dataclass
class PredictionConfig:
"""预测配置参数"""
xgboost_model_path: str = None
mole_model_path: str = None
strain_categories_path: str = None
gram_info_path: str = None
app_threshold: float = 0.04374140128493309
min_nkill: int = 10
batch_size: int = 10000 # 优化进一步增加到10000
n_workers: Optional[int] = 2 # 优化减少到2个线程避免CPU竞争
device: str = "auto"
def __post_init__(self):
"""设置默认路径"""
from pathlib import Path
# 获取当前文件所在目录
current_file = Path(__file__).resolve()
project_root = current_file.parent.parent # models -> 项目根
data_dir = project_root / "Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001"
# 设置所有路径
if self.mole_model_path is None:
self.mole_model_path = str(data_dir)
if self.xgboost_model_path is None:
self.xgboost_model_path = str(data_dir / "MolE-XGBoost-08.03.2025_10.17.pkl")
if self.strain_categories_path is None:
self.strain_categories_path = str(data_dir / "maier_screening_results.tsv.gz")
if self.gram_info_path is None:
self.gram_info_path = str(data_dir / "strain_info_SF2.xlsx")
@dataclass
class MoleculeInput:
"""分子输入数据结构"""
smiles: str
chem_id: Optional[str] = None
@dataclass
class StrainPrediction:
"""单个菌株的预测结果"""
pred_id: str # 格式: "chem_id:strain_name"
chem_id: str
strain_name: str
antimicrobial_predictive_probability: float # 预测概率
no_growth_probability: float # 不生长概率1 - antimicrobial_predictive_probability
growth_inhibition: int # 二值化结果 (0/1)
gram_stain: Optional[str] = None # 革兰染色类型
@dataclass
class BroadSpectrumResult:
"""广谱抗菌预测结果(聚合结果)"""
chem_id: str
apscore_total: float
apscore_gnegative: float
apscore_gpositive: float
ginhib_total: int
ginhib_gnegative: int
ginhib_gpositive: int
broad_spectrum: int
strain_predictions: Optional[pd.DataFrame] = None # 菌株级别预测40行
def to_dict(self) -> Dict[str, Union[str, float, int]]:
"""转换为字典格式(仅聚合字段)"""
return {
'chem_id': self.chem_id,
'apscore_total': self.apscore_total,
'apscore_gnegative': self.apscore_gnegative,
'apscore_gpositive': self.apscore_gpositive,
'ginhib_total': self.ginhib_total,
'ginhib_gnegative': self.ginhib_gnegative,
'ginhib_gpositive': self.ginhib_gpositive,
'broad_spectrum': self.broad_spectrum
}
def to_strain_predictions_list(self) -> List['StrainPrediction']:
"""
将 DataFrame 转换为 StrainPrediction 列表(用于类型安全场景)
Returns:
StrainPrediction 对象列表
"""
if self.strain_predictions is None or self.strain_predictions.empty:
return []
strain_list = []
for _, row in self.strain_predictions.iterrows():
strain_pred = StrainPrediction(
pred_id=row['pred_id'],
chem_id=row['chem_id'],
strain_name=row['strain_name'],
antimicrobial_predictive_probability=row['antimicrobial_predictive_probability'],
no_growth_probability=row['no_growth_probability'],
growth_inhibition=row['growth_inhibition'],
gram_stain=row.get('gram_stain', None)
)
strain_list.append(strain_pred)
return strain_list
class BroadSpectrumPredictor:
"""
广谱抗菌预测器
基于MolE分子表示和XGBoost模型预测分子的广谱抗菌活性。
支持单分子和批量预测,提供详细的抗菌潜力分析。
"""
def __init__(self, config: Optional[PredictionConfig] = None) -> None:
"""
初始化预测器
Args:
config: 预测配置参数如果为None则使用默认配置
"""
self.config = config or PredictionConfig()
self.n_workers = self.config.n_workers or mp.cpu_count()
# 验证文件路径
self._validate_paths()
# 预加载共享数据
self._load_shared_data()
def _validate_paths(self) -> None:
"""验证必要文件路径是否存在"""
required_files = {
"mole_model": self.config.mole_model_path,
"xgboost_model": self.config.xgboost_model_path,
"strain_categories": self.config.strain_categories_path,
"gram_info": self.config.gram_info_path,
}
for name, file_path in required_files.items():
if file_path is None:
raise ValueError(f"{name} is None! Check __post_init__ configuration")
if not Path(file_path).exists():
raise FileNotFoundError(f"Required {name} not found: {file_path}")
def _load_shared_data(self) -> None:
"""加载共享数据(菌株信息、革兰染色信息等)"""
try:
# 加载菌株筛选数据
self.maier_screen: pd.DataFrame = pd.read_csv(
self.config.strain_categories_path, sep='\t', index_col=0
)
# 准备菌株独热编码
self.strain_ohe: pd.DataFrame = self._prep_ohe(self.maier_screen.columns)
# 加载革兰染色信息
self.maier_strains: pd.DataFrame = pd.read_excel(
self.config.gram_info_path,
skiprows=[0, 1, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54],
index_col="NT data base"
)
except Exception as e:
raise RuntimeError(f"Failed to load shared data: {str(e)}")
def _prep_ohe(self, categories: pd.Index) -> pd.DataFrame:
"""
准备菌株的独热编码
Args:
categories: 菌株类别索引
Returns:
独热编码后的DataFrame
"""
try:
# 新版本 sklearn 使用 sparse_output
ohe = OneHotEncoder(sparse_output=False)
except TypeError:
# 旧版本 sklearn 使用 sparse
ohe = OneHotEncoder(sparse=False)
ohe.fit(pd.DataFrame(categories))
cat_ohe = pd.DataFrame(
ohe.transform(pd.DataFrame(categories)),
columns=categories,
index=categories
)
return cat_ohe
def _get_mole_representation(self, molecules: List[MoleculeInput]) -> pd.DataFrame:
"""
获取分子的MolE表示
Args:
molecules: 分子输入列表
Returns:
MolE特征表示DataFrame
"""
# 准备输入数据
df_data = []
for i, mol in enumerate(molecules):
chem_id = mol.chem_id or f"mol{i+1}"
df_data.append({"smiles": mol.smiles, "chem_id": chem_id})
df = pd.DataFrame(df_data)
# 确定设备
device = self.config.device
if device == "auto":
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# 获取MolE表示
return process_representation(
dataset_path=df,
smile_column_str="smiles",
id_column_str="chem_id",
pretrained_dir=self.config.mole_model_path,
device=device
)
def _add_strains(self, chemfeats_df: pd.DataFrame) -> pd.DataFrame:
"""
添加菌株信息到化学特征(笛卡尔积)
Args:
chemfeats_df: 化学特征DataFrame
Returns:
包含菌株信息的特征DataFrame
"""
# 准备化学特征
chemfe = chemfeats_df.reset_index().rename(columns={"index": "chem_id"})
chemfe["chem_id"] = chemfe["chem_id"].astype(str)
# 准备独热编码
sohe = self.strain_ohe.reset_index().rename(columns={"index": "strain_name"})
# 笛卡尔积合并
xpred = chemfe.merge(sohe, how="cross")
xpred["pred_id"] = xpred["chem_id"].str.cat(xpred["strain_name"], sep=":")
xpred = xpred.set_index("pred_id")
xpred = xpred.drop(columns=["chem_id", "strain_name"])
return xpred
def _gram_stain(self, label_df: pd.DataFrame) -> pd.DataFrame:
"""
添加革兰染色信息
Args:
label_df: 包含菌株名称的DataFrame
Returns:
添加革兰染色信息后的DataFrame
"""
df_label = label_df.copy()
# 提取NT编号
df_label["nt_number"] = df_label["strain_name"].apply(
lambda x: re.search(r".*?\((NT\d+)\)", x).group(1) if re.search(r".*?\((NT\d+)\)", x) else None
)
# 创建革兰染色字典
gram_dict = self.maier_strains[["Gram stain"]].to_dict()["Gram stain"]
# 添加染色信息
df_label["gram_stain"] = df_label["nt_number"].apply(gram_dict.get)
return df_label
def _prepare_strain_level_predictions(self, score_df: pd.DataFrame) -> pd.DataFrame:
"""
准备菌株级别的预测数据
Args:
score_df: 原始预测分数DataFrame包含 pred_id, 0, 1, growth_inhibition 列
Returns:
格式化的菌株级别预测DataFrame
"""
# 创建副本
strain_df = score_df.copy()
# 分离化合物ID和菌株名
strain_df["chem_id"] = strain_df["pred_id"].str.split(":", expand=True)[0]
strain_df["strain_name"] = strain_df["pred_id"].str.split(":", expand=True)[1]
# 添加革兰染色信息
strain_df = self._gram_stain(strain_df)
# 重命名列为用户友好的名称
strain_df = strain_df.rename(columns={
"1": "antimicrobial_predictive_probability",
"0": "no_growth_probability"
})
# 选择并排序输出列
output_columns = [
"pred_id",
"chem_id",
"strain_name",
"antimicrobial_predictive_probability",
"no_growth_probability",
"growth_inhibition",
"gram_stain"
]
return strain_df[output_columns]
def _antimicrobial_potential(self, score_df: pd.DataFrame) -> pd.DataFrame:
"""
计算抗菌潜力分数
Args:
score_df: 预测分数DataFrame
Returns:
聚合后的抗菌潜力DataFrame
"""
# 分离化合物ID和菌株名
score_df["chem_id"] = score_df["pred_id"].str.split(":", expand=True)[0]
score_df["strain_name"] = score_df["pred_id"].str.split(":", expand=True)[1]
# 添加革兰染色信息
pred_df = self._gram_stain(score_df)
# 计算抗菌潜力分数(几何平均数的对数)
apscore_total = pred_df.groupby("chem_id")["1"].apply(gmean).to_frame().rename(
columns={"1": "apscore_total"}
)
apscore_total["apscore_total"] = np.log(apscore_total["apscore_total"])
# 按革兰染色分组的抗菌分数
apscore_gram = pred_df.groupby(["chem_id", "gram_stain"])["1"].apply(gmean).unstack().rename(
columns={"negative": "apscore_gnegative", "positive": "apscore_gpositive"}
)
apscore_gram["apscore_gnegative"] = np.log(apscore_gram["apscore_gnegative"])
apscore_gram["apscore_gpositive"] = np.log(apscore_gram["apscore_gpositive"])
# 被抑制菌株数统计
inhibted_total = pred_df.groupby("chem_id")["growth_inhibition"].sum().to_frame().rename(
columns={"growth_inhibition": "ginhib_total"}
)
# 按革兰染色分组的被抑制菌株数
inhibted_gram = pred_df.groupby(["chem_id", "gram_stain"])["growth_inhibition"].sum().unstack().rename(
columns={"negative": "ginhib_gnegative", "positive": "ginhib_gpositive"}
)
# 合并所有结果
agg_pred = apscore_total.join(apscore_gram).join(inhibted_total).join(inhibted_gram)
# 填充NaN值
agg_pred = agg_pred.fillna(0)
return agg_pred
def _predict_batch_worker(batch_data: Tuple[pd.DataFrame, int],
model_path: str,
app_threshold: float) -> Tuple[int, pd.DataFrame]:
"""
批次预测工作函数(用于多进程)
Args:
batch_data: (特征数据, 批次ID)
model_path: XGBoost模型路径
app_threshold: 抑制阈值
Returns:
(批次ID, 预测结果DataFrame)
"""
import warnings
# 忽略所有XGBoost版本相关的警告
warnings.filterwarnings("ignore", category=UserWarning, module="xgboost")
X_input, batch_id = batch_data
# 加载模型
with open(model_path, "rb") as file:
model = pickle.load(file)
# 修复特征名称兼容性问题
# 原因:模型使用旧版 XGBoost 保存时,特征列为元组格式(如 "('bacteria_name',)"
# 新版 XGBoost 严格检查特征名称匹配,导致预测失败。
# 解决:清除 XGBoost 内部的特征名称验证,直接使用输入特征进行预测
# 注意:此操作不改变模型权重和预测逻辑,只禁用格式检查,预测结果保持一致
if hasattr(model, 'get_booster'):
model.get_booster().feature_names = None
# 进行预测
y_pred = model.predict_proba(X_input)
pred_df = pd.DataFrame(y_pred, columns=["0", "1"], index=X_input.index)
# 二值化预测结果
pred_df["growth_inhibition"] = pred_df["1"].apply(
lambda x: 1 if x >= app_threshold else 0
)
return batch_id, pred_df
class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
"""
优化后的预测器 - 使用XGBoost内部并行
关键改进:
1. 单进程处理避免GIL和进程间通信开销
2. 模型只加载一次
3. XGBoost内部使用所有CPU核心OpenMP
4. 大批量处理
"""
def __init__(self, config: Optional[PredictionConfig] = None) -> None:
"""
初始化并预加载模型
Args:
config: 预测配置参数
"""
# 调用父类初始化
super().__init__(config)
# ✅ 核心优化预加载XGBoost模型到内存
print("⚡ Loading XGBoost model...")
import time
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="xgboost")
start = time.time()
with open(self.config.xgboost_model_path, "rb") as file:
self.xgboost_model = pickle.load(file)
# 修复特征名称兼容性
if hasattr(self.xgboost_model, 'get_booster'):
self.xgboost_model.get_booster().feature_names = None
# ✅ 关键设置XGBoost使用所有CPU核心
n_threads = mp.cpu_count()
self.xgboost_model.get_booster().set_param({
'nthread': n_threads
})
print(f"✓ XGBoost configured to use {n_threads} CPU threads")
print(f"✓ Model loaded in {time.time()-start:.2f}s")
def predict_batch(self, molecules: List[MoleculeInput],
include_strain_predictions: bool = False) -> List[BroadSpectrumResult]:
"""
单进程批量预测 - 使用XGBoost内部并行
Args:
molecules: 分子输入列表
include_strain_predictions: 是否在结果中包含菌株级别预测数据
Returns:
广谱抗菌预测结果列表
"""
if not molecules:
return []
import time
# 1. MolE表示GPU
print(f"\n{'='*60}")
print(f"Processing {len(molecules)} molecules...")
print(f"{'='*60}")
start_total = time.time()
start = time.time()
print("\n[1/4] Generating MolE representations (GPU)...")
mole_representation = self._get_mole_representation(molecules)
time_mole = time.time() - start
print(f"✓ Done in {time_mole:.1f}s")
# 2. 准备特征(添加菌株信息)
start = time.time()
print("\n[2/4] Preparing strain-level features...")
X_input = self._add_strains(mole_representation)
time_prep = time.time() - start
print(f"✓ Done in {time_prep:.1f}s")
print(f" Total predictions needed: {len(X_input):,}")
# 3. XGBoost预测单次大批量内部48核并行
start = time.time()
print(f"\n[3/4] XGBoost prediction (using {mp.cpu_count()} CPU cores)...")
print(f" Predicting {len(X_input):,} samples in one batch...")
print(f" (Watch CPU usage - should be ~{mp.cpu_count()*100}%)")
# ✅ 关键:单次预测所有数据
# XGBoost内部会自动使用OpenMP并行到所有核心
y_pred = self.xgboost_model.predict_proba(X_input)
time_pred = time.time() - start
print(f"✓ Done in {time_pred:.1f}s")
print(f" Throughput: {len(X_input)/time_pred:.0f} predictions/second")
# 4. 后处理
start = time.time()
print("\n[4/4] Post-processing results...")
pred_df = pd.DataFrame(
y_pred,
columns=["0", "1"],
index=X_input.index
)
pred_df["growth_inhibition"] = (
pred_df["1"] >= self.config.app_threshold
).astype(int)
pred_df = pred_df.reset_index()
# 准备菌株级别数据(如果需要)
strain_level_data = None
if include_strain_predictions:
strain_level_data = self._prepare_strain_level_predictions(pred_df)
# 计算抗菌潜力
agg_df = self._antimicrobial_potential(pred_df)
# 判断广谱抗菌
agg_df["broad_spectrum"] = agg_df["ginhib_total"].apply(
lambda x: 1 if x >= self.config.min_nkill else 0
)
# 转换为结果对象
results_list = []
for _, row in agg_df.iterrows():
mol_strain_preds = None
if strain_level_data is not None:
mol_strain_preds = strain_level_data[
strain_level_data['chem_id'] == row.name
].reset_index(drop=True)
result = BroadSpectrumResult(
chem_id=row.name,
apscore_total=row["apscore_total"],
apscore_gnegative=row["apscore_gnegative"],
apscore_gpositive=row["apscore_gpositive"],
ginhib_total=int(row["ginhib_total"]),
ginhib_gnegative=int(row["ginhib_gnegative"]),
ginhib_gpositive=int(row["ginhib_gpositive"]),
broad_spectrum=int(row["broad_spectrum"]),
strain_predictions=mol_strain_preds
)
results_list.append(result)
time_post = time.time() - start
print(f"✓ Done in {time_post:.1f}s")
# 总结
total_time = time.time() - start_total
print(f"\n{'='*60}")
print(f"SUMMARY")
print(f"{'='*60}")
print(f" MolE representation: {time_mole:6.1f}s ({time_mole/total_time*100:5.1f}%)")
print(f" Feature preparation: {time_prep:6.1f}s ({time_prep/total_time*100:5.1f}%)")
print(f" XGBoost prediction: {time_pred:6.1f}s ({time_pred/total_time*100:5.1f}%)")
print(f" Post-processing: {time_post:6.1f}s ({time_post/total_time*100:5.1f}%)")
print(f" {''*58}")
print(f" Total time: {total_time:6.1f}s")
print(f" Molecules processed: {len(molecules)}")
print(f" Time per molecule: {total_time/len(molecules):.3f}s")
print(f"{'='*60}\n")
return results_list
def predict_single(self, molecule: MoleculeInput) -> BroadSpectrumResult:
"""
预测单个分子的广谱抗菌活性
Args:
molecule: 分子输入数据
Returns:
广谱抗菌预测结果
"""
results = self.predict_batch([molecule])
return results[0]
def predict_from_smiles(self,
smiles_list: List[str],
chem_ids: Optional[List[str]] = None) -> List[BroadSpectrumResult]:
"""
从SMILES字符串列表预测广谱抗菌活性
Args:
smiles_list: SMILES字符串列表
chem_ids: 化合物ID列表如果为None则自动生成
Returns:
广谱抗菌预测结果列表
"""
if chem_ids is None:
chem_ids = [f"mol{i+1}" for i in range(len(smiles_list))]
if len(smiles_list) != len(chem_ids):
raise ValueError("smiles_list and chem_ids must have the same length")
molecules = [
MoleculeInput(smiles=smiles, chem_id=chem_id)
for smiles, chem_id in zip(smiles_list, chem_ids)
]
return self.predict_batch(molecules)
def predict_from_file(self,
file_path: str,
smiles_column: str = "smiles",
id_column: str = "chem_id") -> List[BroadSpectrumResult]:
"""
从文件预测广谱抗菌活性
Args:
file_path: 输入文件路径支持CSV/TSV
smiles_column: SMILES列名
id_column: 化合物ID列名
Returns:
广谱抗菌预测结果列表
"""
# 读取文件
if file_path.endswith('.tsv'):
df = pd.read_csv(file_path, sep='\t')
else:
df = pd.read_csv(file_path)
# 验证列存在(大小写不敏感)
columns_lower = {col.lower(): col for col in df.columns}
smiles_col_actual = columns_lower.get(smiles_column.lower())
if smiles_col_actual is None:
raise ValueError(f"Column '{smiles_column}' not found in file. Available columns: {list(df.columns)}")
# 处理ID列
id_col_actual = columns_lower.get(id_column.lower())
if id_col_actual is None:
df[id_column] = [f"mol{i+1}" for i in range(len(df))]
id_col_actual = id_column
# 创建分子输入
molecules = [
MoleculeInput(smiles=row[smiles_col_actual], chem_id=str(row[id_col_actual]))
for _, row in df.iterrows()
]
return self.predict_batch(molecules)
def create_predictor(config: Optional[PredictionConfig] = None) -> ParallelBroadSpectrumPredictor:
"""
创建并行广谱抗菌预测器实例
Args:
config: 预测配置参数
Returns:
预测器实例
"""
return ParallelBroadSpectrumPredictor(config)
# 便捷函数
def predict_smiles(smiles_list: List[str],
chem_ids: Optional[List[str]] = None,
config: Optional[PredictionConfig] = None) -> List[BroadSpectrumResult]:
"""
便捷函数直接从SMILES列表预测广谱抗菌活性
Args:
smiles_list: SMILES字符串列表
chem_ids: 化合物ID列表
config: 预测配置
Returns:
预测结果列表
"""
predictor = create_predictor(config)
return predictor.predict_from_smiles(smiles_list, chem_ids)
def predict_file(file_path: str,
smiles_column: str = "smiles",
id_column: str = "chem_id",
config: Optional[PredictionConfig] = None) -> List[BroadSpectrumResult]:
"""
便捷函数:从文件预测广谱抗菌活性
Args:
file_path: 输入文件路径
smiles_column: SMILES列名
id_column: ID列名
config: 预测配置
Returns:
预测结果列表
"""
predictor = create_predictor(config)
return predictor.predict_from_file(file_path, smiles_column, id_column)