Files
mole_broad_spectrum_parallel/broad_spectrum_api.py
mm644706215 a56e60e9a3 first add
2025-10-16 17:21:48 +08:00

524 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
并行广谱抗菌预测API模块
提供高性能的分子广谱抗菌活性预测功能,支持批量处理和多进程并行计算。
基于MolE分子表示和XGBoost模型进行预测。
"""
import os
import re
import pickle
import torch
import numpy as np
import pandas as pd
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import List, Dict, Union, Optional, Tuple, Any
from dataclasses import dataclass
from pathlib import Path
from scipy.stats.mstats import gmean
from sklearn.preprocessing import OneHotEncoder
from xgboost import XGBClassifier
try:
from mole_representation import process_representation
except ImportError:
print("Warning: mole_representation module not found. Please ensure it's in your Python path.")
@dataclass
class PredictionConfig:
"""预测配置参数"""
xgboost_model_path: str = "data/03.model_evaluation/MolE-XGBoost-08.03.2024_14.20.pkl"
mole_model_path: str = "pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001"
strain_categories_path: str = "data/01.prepare_training_data/maier_screening_results.tsv.gz"
gram_info_path: str = "raw_data/maier_microbiome/strain_info_SF2.xlsx"
app_threshold: float = 0.04374140128493309
min_nkill: int = 10
batch_size: int = 100
n_workers: Optional[int] = None
device: str = "auto"
@dataclass
class MoleculeInput:
"""分子输入数据结构"""
smiles: str
chem_id: Optional[str] = None
@dataclass
class BroadSpectrumResult:
"""广谱抗菌预测结果"""
chem_id: str
apscore_total: float
apscore_gnegative: float
apscore_gpositive: float
ginhib_total: int
ginhib_gnegative: int
ginhib_gpositive: int
broad_spectrum: int
def to_dict(self) -> Dict[str, Union[str, float, int]]:
"""转换为字典格式"""
return {
'chem_id': self.chem_id,
'apscore_total': self.apscore_total,
'apscore_gnegative': self.apscore_gnegative,
'apscore_gpositive': self.apscore_gpositive,
'ginhib_total': self.ginhib_total,
'ginhib_gnegative': self.ginhib_gnegative,
'ginhib_gpositive': self.ginhib_gpositive,
'broad_spectrum': self.broad_spectrum
}
class BroadSpectrumPredictor:
"""
广谱抗菌预测器
基于MolE分子表示和XGBoost模型预测分子的广谱抗菌活性。
支持单分子和批量预测,提供详细的抗菌潜力分析。
"""
def __init__(self, config: Optional[PredictionConfig] = None) -> None:
"""
初始化预测器
Args:
config: 预测配置参数如果为None则使用默认配置
"""
self.config = config or PredictionConfig()
self.n_workers = self.config.n_workers or mp.cpu_count()
# 验证文件路径
self._validate_paths()
# 预加载共享数据
self._load_shared_data()
def _validate_paths(self) -> None:
"""验证必要文件路径是否存在"""
required_files = [
self.config.xgboost_model_path,
self.config.strain_categories_path,
self.config.gram_info_path
]
for file_path in required_files:
if not Path(file_path).exists():
raise FileNotFoundError(f"Required file not found: {file_path}")
def _load_shared_data(self) -> None:
"""加载共享数据(菌株信息、革兰染色信息等)"""
try:
# 加载菌株筛选数据
self.maier_screen: pd.DataFrame = pd.read_csv(
self.config.strain_categories_path, sep='\t', index_col=0
)
# 准备菌株独热编码
self.strain_ohe: pd.DataFrame = self._prep_ohe(self.maier_screen.columns)
# 加载革兰染色信息
self.maier_strains: pd.DataFrame = pd.read_excel(
self.config.gram_info_path,
skiprows=[0, 1, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54],
index_col="NT data base"
)
except Exception as e:
raise RuntimeError(f"Failed to load shared data: {str(e)}")
def _prep_ohe(self, categories: pd.Index) -> pd.DataFrame:
"""
准备菌株的独热编码
Args:
categories: 菌株类别索引
Returns:
独热编码后的DataFrame
"""
ohe = OneHotEncoder(sparse=False)
ohe.fit(pd.DataFrame(categories))
cat_ohe = pd.DataFrame(
ohe.transform(pd.DataFrame(categories)),
columns=categories,
index=categories
)
return cat_ohe
def _get_mole_representation(self, molecules: List[MoleculeInput]) -> pd.DataFrame:
"""
获取分子的MolE表示
Args:
molecules: 分子输入列表
Returns:
MolE特征表示DataFrame
"""
# 准备输入数据
df_data = []
for i, mol in enumerate(molecules):
chem_id = mol.chem_id or f"mol{i+1}"
df_data.append({"smiles": mol.smiles, "chem_id": chem_id})
df = pd.DataFrame(df_data)
# 确定设备
device = self.config.device
if device == "auto":
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# 获取MolE表示
return process_representation(
dataset_path=df,
smile_column_str="smiles",
id_column_str="chem_id",
pretrained_dir=self.config.mole_model_path,
device=device
)
def _add_strains(self, chemfeats_df: pd.DataFrame) -> pd.DataFrame:
"""
添加菌株信息到化学特征(笛卡尔积)
Args:
chemfeats_df: 化学特征DataFrame
Returns:
包含菌株信息的特征DataFrame
"""
# 准备化学特征
chemfe = chemfeats_df.reset_index().rename(columns={"index": "chem_id"})
chemfe["chem_id"] = chemfe["chem_id"].astype(str)
# 准备独热编码
sohe = self.strain_ohe.reset_index().rename(columns={"index": "strain_name"})
# 笛卡尔积合并
xpred = chemfe.merge(sohe, how="cross")
xpred["pred_id"] = xpred["chem_id"].str.cat(xpred["strain_name"], sep=":")
xpred = xpred.set_index("pred_id")
xpred = xpred.drop(columns=["chem_id", "strain_name"])
return xpred
def _gram_stain(self, label_df: pd.DataFrame) -> pd.DataFrame:
"""
添加革兰染色信息
Args:
label_df: 包含菌株名称的DataFrame
Returns:
添加革兰染色信息后的DataFrame
"""
df_label = label_df.copy()
# 提取NT编号
df_label["nt_number"] = df_label["strain_name"].apply(
lambda x: re.search(r".*?\((NT\d+)\)", x).group(1) if re.search(r".*?\((NT\d+)\)", x) else None
)
# 创建革兰染色字典
gram_dict = self.maier_strains[["Gram stain"]].to_dict()["Gram stain"]
# 添加染色信息
df_label["gram_stain"] = df_label["nt_number"].apply(gram_dict.get)
return df_label
def _antimicrobial_potential(self, score_df: pd.DataFrame) -> pd.DataFrame:
"""
计算抗菌潜力分数
Args:
score_df: 预测分数DataFrame
Returns:
聚合后的抗菌潜力DataFrame
"""
# 分离化合物ID和菌株名
score_df["chem_id"] = score_df["pred_id"].str.split(":", expand=True)[0]
score_df["strain_name"] = score_df["pred_id"].str.split(":", expand=True)[1]
# 添加革兰染色信息
pred_df = self._gram_stain(score_df)
# 计算抗菌潜力分数(几何平均数的对数)
apscore_total = pred_df.groupby("chem_id")["1"].apply(gmean).to_frame().rename(
columns={"1": "apscore_total"}
)
apscore_total["apscore_total"] = np.log(apscore_total["apscore_total"])
# 按革兰染色分组的抗菌分数
apscore_gram = pred_df.groupby(["chem_id", "gram_stain"])["1"].apply(gmean).unstack().rename(
columns={"negative": "apscore_gnegative", "positive": "apscore_gpositive"}
)
apscore_gram["apscore_gnegative"] = np.log(apscore_gram["apscore_gnegative"])
apscore_gram["apscore_gpositive"] = np.log(apscore_gram["apscore_gpositive"])
# 被抑制菌株数统计
inhibted_total = pred_df.groupby("chem_id")["growth_inhibition"].sum().to_frame().rename(
columns={"growth_inhibition": "ginhib_total"}
)
# 按革兰染色分组的被抑制菌株数
inhibted_gram = pred_df.groupby(["chem_id", "gram_stain"])["growth_inhibition"].sum().unstack().rename(
columns={"negative": "ginhib_gnegative", "positive": "ginhib_gpositive"}
)
# 合并所有结果
agg_pred = apscore_total.join(apscore_gram).join(inhibted_total).join(inhibted_gram)
# 填充NaN值
agg_pred = agg_pred.fillna(0)
return agg_pred
def _predict_batch_worker(batch_data: Tuple[pd.DataFrame, int],
model_path: str,
app_threshold: float) -> Tuple[int, pd.DataFrame]:
"""
批次预测工作函数(用于多进程)
Args:
batch_data: (特征数据, 批次ID)
model_path: XGBoost模型路径
app_threshold: 抑制阈值
Returns:
(批次ID, 预测结果DataFrame)
"""
X_input, batch_id = batch_data
# 加载模型
with open(model_path, "rb") as file:
model = pickle.load(file)
# 进行预测
y_pred = model.predict_proba(X_input)
pred_df = pd.DataFrame(y_pred, columns=["0", "1"], index=X_input.index)
# 二值化预测结果
pred_df["growth_inhibition"] = pred_df["1"].apply(
lambda x: 1 if x >= app_threshold else 0
)
return batch_id, pred_df
class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
"""
并行广谱抗菌预测器
继承自BroadSpectrumPredictor添加了多进程并行处理能力
适用于大规模分子批量预测。
"""
def predict_single(self, molecule: MoleculeInput) -> BroadSpectrumResult:
"""
预测单个分子的广谱抗菌活性
Args:
molecule: 分子输入数据
Returns:
广谱抗菌预测结果
"""
results = self.predict_batch([molecule])
return results[0]
def predict_batch(self, molecules: List[MoleculeInput]) -> List[BroadSpectrumResult]:
"""
批量预测分子的广谱抗菌活性
Args:
molecules: 分子输入列表
Returns:
广谱抗菌预测结果列表
"""
if not molecules:
return []
# 获取MolE表示
print(f"Processing {len(molecules)} molecules...")
mole_representation = self._get_mole_representation(molecules)
# 添加菌株信息
print("Preparing strain-level features...")
X_input = self._add_strains(mole_representation)
# 分批处理
print(f"Starting parallel prediction with {self.n_workers} workers...")
batches = []
for i in range(0, len(X_input), self.config.batch_size):
batch = X_input.iloc[i:i+self.config.batch_size]
batches.append((batch, i // self.config.batch_size))
# 并行预测
results = {}
with ProcessPoolExecutor(max_workers=self.n_workers) as executor:
futures = {
executor.submit(_predict_batch_worker, (batch_data, batch_id),
self.config.xgboost_model_path,
self.config.app_threshold): batch_id
for batch_data, batch_id in batches
}
for future in as_completed(futures):
batch_id, pred_df = future.result()
results[batch_id] = pred_df
print(f"Batch {batch_id} completed")
# 合并结果
print("Merging prediction results...")
all_pred_df = pd.concat([results[i] for i in sorted(results.keys())])
# 计算抗菌潜力
print("Calculating antimicrobial potential scores...")
all_pred_df = all_pred_df.reset_index()
agg_df = self._antimicrobial_potential(all_pred_df)
# 判断广谱抗菌
agg_df["broad_spectrum"] = agg_df["ginhib_total"].apply(
lambda x: 1 if x >= self.config.min_nkill else 0
)
# 转换为结果对象
results_list = []
for _, row in agg_df.iterrows():
result = BroadSpectrumResult(
chem_id=row.name,
apscore_total=row["apscore_total"],
apscore_gnegative=row["apscore_gnegative"],
apscore_gpositive=row["apscore_gpositive"],
ginhib_total=int(row["ginhib_total"]),
ginhib_gnegative=int(row["ginhib_gnegative"]),
ginhib_gpositive=int(row["ginhib_gpositive"]),
broad_spectrum=int(row["broad_spectrum"])
)
results_list.append(result)
return results_list
def predict_from_smiles(self,
smiles_list: List[str],
chem_ids: Optional[List[str]] = None) -> List[BroadSpectrumResult]:
"""
从SMILES字符串列表预测广谱抗菌活性
Args:
smiles_list: SMILES字符串列表
chem_ids: 化合物ID列表如果为None则自动生成
Returns:
广谱抗菌预测结果列表
"""
if chem_ids is None:
chem_ids = [f"mol{i+1}" for i in range(len(smiles_list))]
if len(smiles_list) != len(chem_ids):
raise ValueError("smiles_list and chem_ids must have the same length")
molecules = [
MoleculeInput(smiles=smiles, chem_id=chem_id)
for smiles, chem_id in zip(smiles_list, chem_ids)
]
return self.predict_batch(molecules)
def predict_from_file(self,
file_path: str,
smiles_column: str = "smiles",
id_column: str = "chem_id") -> List[BroadSpectrumResult]:
"""
从文件预测广谱抗菌活性
Args:
file_path: 输入文件路径支持CSV/TSV
smiles_column: SMILES列名
id_column: 化合物ID列名
Returns:
广谱抗菌预测结果列表
"""
# 读取文件
if file_path.endswith('.tsv'):
df = pd.read_csv(file_path, sep='\t')
else:
df = pd.read_csv(file_path)
# 验证列存在
if smiles_column not in df.columns:
raise ValueError(f"Column '{smiles_column}' not found in file")
# 处理ID列
if id_column not in df.columns:
df[id_column] = [f"mol{i+1}" for i in range(len(df))]
# 创建分子输入
molecules = [
MoleculeInput(smiles=row[smiles_column], chem_id=row[id_column])
for _, row in df.iterrows()
]
return self.predict_batch(molecules)
def create_predictor(config: Optional[PredictionConfig] = None) -> ParallelBroadSpectrumPredictor:
"""
创建并行广谱抗菌预测器实例
Args:
config: 预测配置参数
Returns:
预测器实例
"""
return ParallelBroadSpectrumPredictor(config)
# 便捷函数
def predict_smiles(smiles_list: List[str],
chem_ids: Optional[List[str]] = None,
config: Optional[PredictionConfig] = None) -> List[BroadSpectrumResult]:
"""
便捷函数直接从SMILES列表预测广谱抗菌活性
Args:
smiles_list: SMILES字符串列表
chem_ids: 化合物ID列表
config: 预测配置
Returns:
预测结果列表
"""
predictor = create_predictor(config)
return predictor.predict_from_smiles(smiles_list, chem_ids)
def predict_file(file_path: str,
smiles_column: str = "smiles",
id_column: str = "chem_id",
config: Optional[PredictionConfig] = None) -> List[BroadSpectrumResult]:
"""
便捷函数:从文件预测广谱抗菌活性
Args:
file_path: 输入文件路径
smiles_column: SMILES列名
id_column: ID列名
config: 预测配置
Returns:
预测结果列表
"""
predictor = create_predictor(config)
return predictor.predict_from_file(file_path, smiles_column, id_column)