""" 并行广谱抗菌预测器模块 提供高性能的分子广谱抗菌活性预测功能,支持批量处理和多进程并行计算。 基于MolE分子表示和XGBoost模型进行预测。 """ import os import re import pickle import torch import numpy as np import pandas as pd import multiprocessing as mp from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed from typing import List, Dict, Union, Optional, Tuple, Any from dataclasses import dataclass from pathlib import Path from scipy.stats.mstats import gmean from sklearn.preprocessing import OneHotEncoder from .mole_representation import process_representation @dataclass class PredictionConfig: """预测配置参数""" xgboost_model_path: str = None mole_model_path: str = None strain_categories_path: str = None gram_info_path: str = None app_threshold: float = 0.04374140128493309 min_nkill: int = 10 batch_size: int = 10000 # 优化:进一步增加到10000 n_workers: Optional[int] = 2 # 优化:减少到2个线程,避免CPU竞争 device: str = "auto" def __post_init__(self): """设置默认路径""" from pathlib import Path # 获取当前文件所在目录 current_file = Path(__file__).resolve() project_root = current_file.parent.parent # models -> 项目根 data_dir = project_root / "Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001" # 设置所有路径 if self.mole_model_path is None: self.mole_model_path = str(data_dir) if self.xgboost_model_path is None: self.xgboost_model_path = str(data_dir / "MolE-XGBoost-08.03.2025_10.17.pkl") if self.strain_categories_path is None: self.strain_categories_path = str(data_dir / "maier_screening_results.tsv.gz") if self.gram_info_path is None: self.gram_info_path = str(data_dir / "strain_info_SF2.xlsx") @dataclass class MoleculeInput: """分子输入数据结构""" smiles: str chem_id: Optional[str] = None @dataclass class StrainPrediction: """单个菌株的预测结果""" pred_id: str # 格式: "chem_id:strain_name" chem_id: str strain_name: str antimicrobial_predictive_probability: float # 预测概率 no_growth_probability: float # 不生长概率(1 - antimicrobial_predictive_probability) growth_inhibition: int # 二值化结果 (0/1) gram_stain: Optional[str] = None # 革兰染色类型 @dataclass class BroadSpectrumResult: """广谱抗菌预测结果(聚合结果)""" chem_id: str apscore_total: float apscore_gnegative: float apscore_gpositive: float ginhib_total: int ginhib_gnegative: int ginhib_gpositive: int broad_spectrum: int strain_predictions: Optional[pd.DataFrame] = None # 菌株级别预测(40行) def to_dict(self) -> Dict[str, Union[str, float, int]]: """转换为字典格式(仅聚合字段)""" return { 'chem_id': self.chem_id, 'apscore_total': self.apscore_total, 'apscore_gnegative': self.apscore_gnegative, 'apscore_gpositive': self.apscore_gpositive, 'ginhib_total': self.ginhib_total, 'ginhib_gnegative': self.ginhib_gnegative, 'ginhib_gpositive': self.ginhib_gpositive, 'broad_spectrum': self.broad_spectrum } def to_strain_predictions_list(self) -> List['StrainPrediction']: """ 将 DataFrame 转换为 StrainPrediction 列表(用于类型安全场景) Returns: StrainPrediction 对象列表 """ if self.strain_predictions is None or self.strain_predictions.empty: return [] strain_list = [] for _, row in self.strain_predictions.iterrows(): strain_pred = StrainPrediction( pred_id=row['pred_id'], chem_id=row['chem_id'], strain_name=row['strain_name'], antimicrobial_predictive_probability=row['antimicrobial_predictive_probability'], no_growth_probability=row['no_growth_probability'], growth_inhibition=row['growth_inhibition'], gram_stain=row.get('gram_stain', None) ) strain_list.append(strain_pred) return strain_list class BroadSpectrumPredictor: """ 广谱抗菌预测器 基于MolE分子表示和XGBoost模型预测分子的广谱抗菌活性。 支持单分子和批量预测,提供详细的抗菌潜力分析。 """ def __init__(self, config: Optional[PredictionConfig] = None) -> None: """ 初始化预测器 Args: config: 预测配置参数,如果为None则使用默认配置 """ self.config = config or PredictionConfig() self.n_workers = self.config.n_workers or mp.cpu_count() # 验证文件路径 self._validate_paths() # 预加载共享数据 self._load_shared_data() def _validate_paths(self) -> None: """验证必要文件路径是否存在""" required_files = { "mole_model": self.config.mole_model_path, "xgboost_model": self.config.xgboost_model_path, "strain_categories": self.config.strain_categories_path, "gram_info": self.config.gram_info_path, } for name, file_path in required_files.items(): if file_path is None: raise ValueError(f"{name} is None! Check __post_init__ configuration") if not Path(file_path).exists(): raise FileNotFoundError(f"Required {name} not found: {file_path}") def _load_shared_data(self) -> None: """加载共享数据(菌株信息、革兰染色信息等)""" try: # 加载菌株筛选数据 self.maier_screen: pd.DataFrame = pd.read_csv( self.config.strain_categories_path, sep='\t', index_col=0 ) # 准备菌株独热编码 self.strain_ohe: pd.DataFrame = self._prep_ohe(self.maier_screen.columns) # 加载革兰染色信息 self.maier_strains: pd.DataFrame = pd.read_excel( self.config.gram_info_path, skiprows=[0, 1, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54], index_col="NT data base" ) except Exception as e: raise RuntimeError(f"Failed to load shared data: {str(e)}") def _prep_ohe(self, categories: pd.Index) -> pd.DataFrame: """ 准备菌株的独热编码 Args: categories: 菌株类别索引 Returns: 独热编码后的DataFrame """ try: # 新版本 sklearn 使用 sparse_output ohe = OneHotEncoder(sparse_output=False) except TypeError: # 旧版本 sklearn 使用 sparse ohe = OneHotEncoder(sparse=False) ohe.fit(pd.DataFrame(categories)) cat_ohe = pd.DataFrame( ohe.transform(pd.DataFrame(categories)), columns=categories, index=categories ) return cat_ohe def _get_mole_representation(self, molecules: List[MoleculeInput]) -> pd.DataFrame: """ 获取分子的MolE表示 Args: molecules: 分子输入列表 Returns: MolE特征表示DataFrame """ # 准备输入数据 df_data = [] for i, mol in enumerate(molecules): chem_id = mol.chem_id or f"mol{i+1}" df_data.append({"smiles": mol.smiles, "chem_id": chem_id}) df = pd.DataFrame(df_data) # 确定设备 device = self.config.device if device == "auto": device = "cuda:0" if torch.cuda.is_available() else "cpu" # 获取MolE表示 return process_representation( dataset_path=df, smile_column_str="smiles", id_column_str="chem_id", pretrained_dir=self.config.mole_model_path, device=device ) def _add_strains(self, chemfeats_df: pd.DataFrame) -> pd.DataFrame: """ 添加菌株信息到化学特征(笛卡尔积) Args: chemfeats_df: 化学特征DataFrame Returns: 包含菌株信息的特征DataFrame """ # 准备化学特征 chemfe = chemfeats_df.reset_index().rename(columns={"index": "chem_id"}) chemfe["chem_id"] = chemfe["chem_id"].astype(str) # 准备独热编码 sohe = self.strain_ohe.reset_index().rename(columns={"index": "strain_name"}) # 笛卡尔积合并 xpred = chemfe.merge(sohe, how="cross") xpred["pred_id"] = xpred["chem_id"].str.cat(xpred["strain_name"], sep=":") xpred = xpred.set_index("pred_id") xpred = xpred.drop(columns=["chem_id", "strain_name"]) return xpred def _gram_stain(self, label_df: pd.DataFrame) -> pd.DataFrame: """ 添加革兰染色信息 Args: label_df: 包含菌株名称的DataFrame Returns: 添加革兰染色信息后的DataFrame """ df_label = label_df.copy() # 提取NT编号 df_label["nt_number"] = df_label["strain_name"].apply( lambda x: re.search(r".*?\((NT\d+)\)", x).group(1) if re.search(r".*?\((NT\d+)\)", x) else None ) # 创建革兰染色字典 gram_dict = self.maier_strains[["Gram stain"]].to_dict()["Gram stain"] # 添加染色信息 df_label["gram_stain"] = df_label["nt_number"].apply(gram_dict.get) return df_label def _prepare_strain_level_predictions(self, score_df: pd.DataFrame) -> pd.DataFrame: """ 准备菌株级别的预测数据 Args: score_df: 原始预测分数DataFrame,包含 pred_id, 0, 1, growth_inhibition 列 Returns: 格式化的菌株级别预测DataFrame """ # 创建副本 strain_df = score_df.copy() # 分离化合物ID和菌株名 strain_df["chem_id"] = strain_df["pred_id"].str.split(":", expand=True)[0] strain_df["strain_name"] = strain_df["pred_id"].str.split(":", expand=True)[1] # 添加革兰染色信息 strain_df = self._gram_stain(strain_df) # 重命名列为用户友好的名称 strain_df = strain_df.rename(columns={ "1": "antimicrobial_predictive_probability", "0": "no_growth_probability" }) # 选择并排序输出列 output_columns = [ "pred_id", "chem_id", "strain_name", "antimicrobial_predictive_probability", "no_growth_probability", "growth_inhibition", "gram_stain" ] return strain_df[output_columns] def _antimicrobial_potential(self, score_df: pd.DataFrame) -> pd.DataFrame: """ 计算抗菌潜力分数 Args: score_df: 预测分数DataFrame Returns: 聚合后的抗菌潜力DataFrame """ # 分离化合物ID和菌株名 score_df["chem_id"] = score_df["pred_id"].str.split(":", expand=True)[0] score_df["strain_name"] = score_df["pred_id"].str.split(":", expand=True)[1] # 添加革兰染色信息 pred_df = self._gram_stain(score_df) # 计算抗菌潜力分数(几何平均数的对数) apscore_total = pred_df.groupby("chem_id")["1"].apply(gmean).to_frame().rename( columns={"1": "apscore_total"} ) apscore_total["apscore_total"] = np.log(apscore_total["apscore_total"]) # 按革兰染色分组的抗菌分数 apscore_gram = pred_df.groupby(["chem_id", "gram_stain"])["1"].apply(gmean).unstack().rename( columns={"negative": "apscore_gnegative", "positive": "apscore_gpositive"} ) apscore_gram["apscore_gnegative"] = np.log(apscore_gram["apscore_gnegative"]) apscore_gram["apscore_gpositive"] = np.log(apscore_gram["apscore_gpositive"]) # 被抑制菌株数统计 inhibted_total = pred_df.groupby("chem_id")["growth_inhibition"].sum().to_frame().rename( columns={"growth_inhibition": "ginhib_total"} ) # 按革兰染色分组的被抑制菌株数 inhibted_gram = pred_df.groupby(["chem_id", "gram_stain"])["growth_inhibition"].sum().unstack().rename( columns={"negative": "ginhib_gnegative", "positive": "ginhib_gpositive"} ) # 合并所有结果 agg_pred = apscore_total.join(apscore_gram).join(inhibted_total).join(inhibted_gram) # 填充NaN值 agg_pred = agg_pred.fillna(0) return agg_pred def _predict_batch_worker(batch_data: Tuple[pd.DataFrame, int], model_path: str, app_threshold: float) -> Tuple[int, pd.DataFrame]: """ 批次预测工作函数(用于多进程) Args: batch_data: (特征数据, 批次ID) model_path: XGBoost模型路径 app_threshold: 抑制阈值 Returns: (批次ID, 预测结果DataFrame) """ import warnings # 忽略所有XGBoost版本相关的警告 warnings.filterwarnings("ignore", category=UserWarning, module="xgboost") X_input, batch_id = batch_data # 加载模型 with open(model_path, "rb") as file: model = pickle.load(file) # 修复特征名称兼容性问题 # 原因:模型使用旧版 XGBoost 保存时,特征列为元组格式(如 "('bacteria_name',)") # 新版 XGBoost 严格检查特征名称匹配,导致预测失败。 # 解决:清除 XGBoost 内部的特征名称验证,直接使用输入特征进行预测 # 注意:此操作不改变模型权重和预测逻辑,只禁用格式检查,预测结果保持一致 if hasattr(model, 'get_booster'): model.get_booster().feature_names = None # 进行预测 y_pred = model.predict_proba(X_input) pred_df = pd.DataFrame(y_pred, columns=["0", "1"], index=X_input.index) # 二值化预测结果 pred_df["growth_inhibition"] = pred_df["1"].apply( lambda x: 1 if x >= app_threshold else 0 ) return batch_id, pred_df class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor): """ 优化后的预测器 - 使用XGBoost内部并行 关键改进: 1. 单进程处理(避免GIL和进程间通信开销) 2. 模型只加载一次 3. XGBoost内部使用所有CPU核心(OpenMP) 4. 大批量处理 """ def __init__(self, config: Optional[PredictionConfig] = None) -> None: """ 初始化并预加载模型 Args: config: 预测配置参数 """ # 调用父类初始化 super().__init__(config) # ✅ 核心优化:预加载XGBoost模型到内存 print("⚡ Loading XGBoost model...") import time import warnings warnings.filterwarnings("ignore", category=UserWarning, module="xgboost") start = time.time() with open(self.config.xgboost_model_path, "rb") as file: self.xgboost_model = pickle.load(file) # 修复特征名称兼容性 if hasattr(self.xgboost_model, 'get_booster'): self.xgboost_model.get_booster().feature_names = None # ✅ 关键:设置XGBoost使用所有CPU核心 n_threads = mp.cpu_count() self.xgboost_model.get_booster().set_param({ 'nthread': n_threads }) print(f"✓ XGBoost configured to use {n_threads} CPU threads") print(f"✓ Model loaded in {time.time()-start:.2f}s") def predict_batch(self, molecules: List[MoleculeInput], include_strain_predictions: bool = False) -> List[BroadSpectrumResult]: """ 单进程批量预测 - 使用XGBoost内部并行 Args: molecules: 分子输入列表 include_strain_predictions: 是否在结果中包含菌株级别预测数据 Returns: 广谱抗菌预测结果列表 """ if not molecules: return [] import time # 1. MolE表示(GPU) print(f"\n{'='*60}") print(f"Processing {len(molecules)} molecules...") print(f"{'='*60}") start_total = time.time() start = time.time() print("\n[1/4] Generating MolE representations (GPU)...") mole_representation = self._get_mole_representation(molecules) time_mole = time.time() - start print(f"✓ Done in {time_mole:.1f}s") # 2. 准备特征(添加菌株信息) start = time.time() print("\n[2/4] Preparing strain-level features...") X_input = self._add_strains(mole_representation) time_prep = time.time() - start print(f"✓ Done in {time_prep:.1f}s") print(f" Total predictions needed: {len(X_input):,}") # 3. XGBoost预测(单次大批量,内部48核并行) start = time.time() print(f"\n[3/4] XGBoost prediction (using {mp.cpu_count()} CPU cores)...") print(f" Predicting {len(X_input):,} samples in one batch...") print(f" (Watch CPU usage - should be ~{mp.cpu_count()*100}%)") # ✅ 关键:单次预测所有数据 # XGBoost内部会自动使用OpenMP并行到所有核心 y_pred = self.xgboost_model.predict_proba(X_input) time_pred = time.time() - start print(f"✓ Done in {time_pred:.1f}s") print(f" Throughput: {len(X_input)/time_pred:.0f} predictions/second") # 4. 后处理 start = time.time() print("\n[4/4] Post-processing results...") pred_df = pd.DataFrame( y_pred, columns=["0", "1"], index=X_input.index ) pred_df["growth_inhibition"] = ( pred_df["1"] >= self.config.app_threshold ).astype(int) pred_df = pred_df.reset_index() # 准备菌株级别数据(如果需要) strain_level_data = None if include_strain_predictions: strain_level_data = self._prepare_strain_level_predictions(pred_df) # 计算抗菌潜力 agg_df = self._antimicrobial_potential(pred_df) # 判断广谱抗菌 agg_df["broad_spectrum"] = agg_df["ginhib_total"].apply( lambda x: 1 if x >= self.config.min_nkill else 0 ) # 转换为结果对象 results_list = [] for _, row in agg_df.iterrows(): mol_strain_preds = None if strain_level_data is not None: mol_strain_preds = strain_level_data[ strain_level_data['chem_id'] == row.name ].reset_index(drop=True) result = BroadSpectrumResult( chem_id=row.name, apscore_total=row["apscore_total"], apscore_gnegative=row["apscore_gnegative"], apscore_gpositive=row["apscore_gpositive"], ginhib_total=int(row["ginhib_total"]), ginhib_gnegative=int(row["ginhib_gnegative"]), ginhib_gpositive=int(row["ginhib_gpositive"]), broad_spectrum=int(row["broad_spectrum"]), strain_predictions=mol_strain_preds ) results_list.append(result) time_post = time.time() - start print(f"✓ Done in {time_post:.1f}s") # 总结 total_time = time.time() - start_total print(f"\n{'='*60}") print(f"SUMMARY") print(f"{'='*60}") print(f" MolE representation: {time_mole:6.1f}s ({time_mole/total_time*100:5.1f}%)") print(f" Feature preparation: {time_prep:6.1f}s ({time_prep/total_time*100:5.1f}%)") print(f" XGBoost prediction: {time_pred:6.1f}s ({time_pred/total_time*100:5.1f}%)") print(f" Post-processing: {time_post:6.1f}s ({time_post/total_time*100:5.1f}%)") print(f" {'─'*58}") print(f" Total time: {total_time:6.1f}s") print(f" Molecules processed: {len(molecules)}") print(f" Time per molecule: {total_time/len(molecules):.3f}s") print(f"{'='*60}\n") return results_list def predict_single(self, molecule: MoleculeInput) -> BroadSpectrumResult: """ 预测单个分子的广谱抗菌活性 Args: molecule: 分子输入数据 Returns: 广谱抗菌预测结果 """ results = self.predict_batch([molecule]) return results[0] def predict_from_smiles(self, smiles_list: List[str], chem_ids: Optional[List[str]] = None) -> List[BroadSpectrumResult]: """ 从SMILES字符串列表预测广谱抗菌活性 Args: smiles_list: SMILES字符串列表 chem_ids: 化合物ID列表,如果为None则自动生成 Returns: 广谱抗菌预测结果列表 """ if chem_ids is None: chem_ids = [f"mol{i+1}" for i in range(len(smiles_list))] if len(smiles_list) != len(chem_ids): raise ValueError("smiles_list and chem_ids must have the same length") molecules = [ MoleculeInput(smiles=smiles, chem_id=chem_id) for smiles, chem_id in zip(smiles_list, chem_ids) ] return self.predict_batch(molecules) def predict_from_file(self, file_path: str, smiles_column: str = "smiles", id_column: str = "chem_id") -> List[BroadSpectrumResult]: """ 从文件预测广谱抗菌活性 Args: file_path: 输入文件路径(支持CSV/TSV) smiles_column: SMILES列名 id_column: 化合物ID列名 Returns: 广谱抗菌预测结果列表 """ # 读取文件 if file_path.endswith('.tsv'): df = pd.read_csv(file_path, sep='\t') else: df = pd.read_csv(file_path) # 验证列存在(大小写不敏感) columns_lower = {col.lower(): col for col in df.columns} smiles_col_actual = columns_lower.get(smiles_column.lower()) if smiles_col_actual is None: raise ValueError(f"Column '{smiles_column}' not found in file. Available columns: {list(df.columns)}") # 处理ID列 id_col_actual = columns_lower.get(id_column.lower()) if id_col_actual is None: df[id_column] = [f"mol{i+1}" for i in range(len(df))] id_col_actual = id_column # 创建分子输入 molecules = [ MoleculeInput(smiles=row[smiles_col_actual], chem_id=str(row[id_col_actual])) for _, row in df.iterrows() ] return self.predict_batch(molecules) def create_predictor(config: Optional[PredictionConfig] = None) -> ParallelBroadSpectrumPredictor: """ 创建并行广谱抗菌预测器实例 Args: config: 预测配置参数 Returns: 预测器实例 """ return ParallelBroadSpectrumPredictor(config) # 便捷函数 def predict_smiles(smiles_list: List[str], chem_ids: Optional[List[str]] = None, config: Optional[PredictionConfig] = None) -> List[BroadSpectrumResult]: """ 便捷函数:直接从SMILES列表预测广谱抗菌活性 Args: smiles_list: SMILES字符串列表 chem_ids: 化合物ID列表 config: 预测配置 Returns: 预测结果列表 """ predictor = create_predictor(config) return predictor.predict_from_smiles(smiles_list, chem_ids) def predict_file(file_path: str, smiles_column: str = "smiles", id_column: str = "chem_id", config: Optional[PredictionConfig] = None) -> List[BroadSpectrumResult]: """ 便捷函数:从文件预测广谱抗菌活性 Args: file_path: 输入文件路径 smiles_column: SMILES列名 id_column: ID列名 config: 预测配置 Returns: 预测结果列表 """ predictor = create_predictor(config) return predictor.predict_from_file(file_path, smiles_column, id_column)