新增功能: - 新增统一批量预测工具 utils/batch_predictor.py * 支持单进程/多进程并行模式 * 灵活的 GPU 配置和显存自动计算 * 自动临时文件管理和断点续传 * 完整的 CLI 参数支持(Click 框架) - 新增 Shell 脚本集合 scripts/ * run_parallel_predict.sh - 并行预测脚本 * run_single_predict.sh - 单进程预测脚本 * merge_results.sh - 结果合并脚本 性能优化: - 解决 CUDA + multiprocessing fork 死锁问题 * 使用 spawn 模式替代 fork * 文件描述符级别的输出重定向 - 优化预测性能 * XGBoost OpenMP 多线程(利用所有 CPU 核心) * 预加载模型减少重复加载 * 大批量处理降低函数调用开销 * 实际加速比:2-3x(12进程 vs 单进程) - 优化输出显示 * 抑制模型加载时的权重信息 * 只显示进度条和关键统计 * 临时文件自动保存到专门目录 文档更新: - README.md 新增"大规模并行预测"章节 - README.md 新增"性能优化说明"章节 - 添加详细的使用示例和参数说明 - 更新项目结构和版本信息 技术细节: - 每个模型实例约占用 2.5GB GPU 显存 - 显存计算公式:建议进程数 = GPU显存(GB) / 2.5 - GPU 瓶颈占比:MolE 表示生成 94% - 非 GIL 问题:计算密集任务在 C/CUDA 层 Breaking Changes: - 废弃旧的独立预测脚本,统一使用新工具 相关 Issue: 解决 #并行预测卡死问题 测试平台: Linux, 256 CPU cores, NVIDIA RTX 5090 32GB
731 lines
25 KiB
Python
731 lines
25 KiB
Python
"""
|
||
并行广谱抗菌预测器模块
|
||
|
||
提供高性能的分子广谱抗菌活性预测功能,支持批量处理和多进程并行计算。
|
||
基于MolE分子表示和XGBoost模型进行预测。
|
||
"""
|
||
|
||
import os
|
||
import re
|
||
import pickle
|
||
import torch
|
||
import numpy as np
|
||
import pandas as pd
|
||
import multiprocessing as mp
|
||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
||
from typing import List, Dict, Union, Optional, Tuple, Any
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
from scipy.stats.mstats import gmean
|
||
from sklearn.preprocessing import OneHotEncoder
|
||
|
||
from .mole_representation import process_representation
|
||
|
||
|
||
@dataclass
|
||
class PredictionConfig:
|
||
"""预测配置参数"""
|
||
xgboost_model_path: str = None
|
||
mole_model_path: str = None
|
||
strain_categories_path: str = None
|
||
gram_info_path: str = None
|
||
app_threshold: float = 0.04374140128493309
|
||
min_nkill: int = 10
|
||
batch_size: int = 10000 # 优化:进一步增加到10000
|
||
n_workers: Optional[int] = 2 # 优化:减少到2个线程,避免CPU竞争
|
||
device: str = "auto"
|
||
|
||
def __post_init__(self):
|
||
"""设置默认路径"""
|
||
from pathlib import Path
|
||
|
||
# 获取当前文件所在目录
|
||
current_file = Path(__file__).resolve()
|
||
project_root = current_file.parent.parent # models -> 项目根
|
||
|
||
data_dir = project_root / "Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001"
|
||
|
||
# 设置所有路径
|
||
if self.mole_model_path is None:
|
||
self.mole_model_path = str(data_dir)
|
||
|
||
if self.xgboost_model_path is None:
|
||
self.xgboost_model_path = str(data_dir / "MolE-XGBoost-08.03.2025_10.17.pkl")
|
||
|
||
if self.strain_categories_path is None:
|
||
self.strain_categories_path = str(data_dir / "maier_screening_results.tsv.gz")
|
||
|
||
if self.gram_info_path is None:
|
||
self.gram_info_path = str(data_dir / "strain_info_SF2.xlsx")
|
||
|
||
@dataclass
|
||
class MoleculeInput:
|
||
"""分子输入数据结构"""
|
||
smiles: str
|
||
chem_id: Optional[str] = None
|
||
|
||
|
||
@dataclass
|
||
class StrainPrediction:
|
||
"""单个菌株的预测结果"""
|
||
pred_id: str # 格式: "chem_id:strain_name"
|
||
chem_id: str
|
||
strain_name: str
|
||
antimicrobial_predictive_probability: float # 预测概率
|
||
no_growth_probability: float # 不生长概率(1 - antimicrobial_predictive_probability)
|
||
growth_inhibition: int # 二值化结果 (0/1)
|
||
gram_stain: Optional[str] = None # 革兰染色类型
|
||
|
||
|
||
@dataclass
|
||
class BroadSpectrumResult:
|
||
"""广谱抗菌预测结果(聚合结果)"""
|
||
chem_id: str
|
||
apscore_total: float
|
||
apscore_gnegative: float
|
||
apscore_gpositive: float
|
||
ginhib_total: int
|
||
ginhib_gnegative: int
|
||
ginhib_gpositive: int
|
||
broad_spectrum: int
|
||
strain_predictions: Optional[pd.DataFrame] = None # 菌株级别预测(40行)
|
||
|
||
def to_dict(self) -> Dict[str, Union[str, float, int]]:
|
||
"""转换为字典格式(仅聚合字段)"""
|
||
return {
|
||
'chem_id': self.chem_id,
|
||
'apscore_total': self.apscore_total,
|
||
'apscore_gnegative': self.apscore_gnegative,
|
||
'apscore_gpositive': self.apscore_gpositive,
|
||
'ginhib_total': self.ginhib_total,
|
||
'ginhib_gnegative': self.ginhib_gnegative,
|
||
'ginhib_gpositive': self.ginhib_gpositive,
|
||
'broad_spectrum': self.broad_spectrum
|
||
}
|
||
|
||
def to_strain_predictions_list(self) -> List['StrainPrediction']:
|
||
"""
|
||
将 DataFrame 转换为 StrainPrediction 列表(用于类型安全场景)
|
||
|
||
Returns:
|
||
StrainPrediction 对象列表
|
||
"""
|
||
if self.strain_predictions is None or self.strain_predictions.empty:
|
||
return []
|
||
|
||
strain_list = []
|
||
for _, row in self.strain_predictions.iterrows():
|
||
strain_pred = StrainPrediction(
|
||
pred_id=row['pred_id'],
|
||
chem_id=row['chem_id'],
|
||
strain_name=row['strain_name'],
|
||
antimicrobial_predictive_probability=row['antimicrobial_predictive_probability'],
|
||
no_growth_probability=row['no_growth_probability'],
|
||
growth_inhibition=row['growth_inhibition'],
|
||
gram_stain=row.get('gram_stain', None)
|
||
)
|
||
strain_list.append(strain_pred)
|
||
|
||
return strain_list
|
||
|
||
|
||
class BroadSpectrumPredictor:
|
||
"""
|
||
广谱抗菌预测器
|
||
|
||
基于MolE分子表示和XGBoost模型预测分子的广谱抗菌活性。
|
||
支持单分子和批量预测,提供详细的抗菌潜力分析。
|
||
"""
|
||
|
||
def __init__(self, config: Optional[PredictionConfig] = None) -> None:
|
||
"""
|
||
初始化预测器
|
||
|
||
Args:
|
||
config: 预测配置参数,如果为None则使用默认配置
|
||
"""
|
||
self.config = config or PredictionConfig()
|
||
self.n_workers = self.config.n_workers or mp.cpu_count()
|
||
|
||
# 验证文件路径
|
||
self._validate_paths()
|
||
|
||
# 预加载共享数据
|
||
self._load_shared_data()
|
||
|
||
def _validate_paths(self) -> None:
|
||
"""验证必要文件路径是否存在"""
|
||
required_files = {
|
||
"mole_model": self.config.mole_model_path,
|
||
"xgboost_model": self.config.xgboost_model_path,
|
||
"strain_categories": self.config.strain_categories_path,
|
||
"gram_info": self.config.gram_info_path,
|
||
}
|
||
|
||
for name, file_path in required_files.items():
|
||
if file_path is None:
|
||
raise ValueError(f"{name} is None! Check __post_init__ configuration")
|
||
if not Path(file_path).exists():
|
||
raise FileNotFoundError(f"Required {name} not found: {file_path}")
|
||
|
||
def _load_shared_data(self) -> None:
|
||
"""加载共享数据(菌株信息、革兰染色信息等)"""
|
||
try:
|
||
# 加载菌株筛选数据
|
||
self.maier_screen: pd.DataFrame = pd.read_csv(
|
||
self.config.strain_categories_path, sep='\t', index_col=0
|
||
)
|
||
|
||
# 准备菌株独热编码
|
||
self.strain_ohe: pd.DataFrame = self._prep_ohe(self.maier_screen.columns)
|
||
|
||
# 加载革兰染色信息
|
||
self.maier_strains: pd.DataFrame = pd.read_excel(
|
||
self.config.gram_info_path,
|
||
skiprows=[0, 1, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54],
|
||
index_col="NT data base"
|
||
)
|
||
|
||
except Exception as e:
|
||
raise RuntimeError(f"Failed to load shared data: {str(e)}")
|
||
|
||
def _prep_ohe(self, categories: pd.Index) -> pd.DataFrame:
|
||
"""
|
||
准备菌株的独热编码
|
||
|
||
Args:
|
||
categories: 菌株类别索引
|
||
|
||
Returns:
|
||
独热编码后的DataFrame
|
||
"""
|
||
try:
|
||
# 新版本 sklearn 使用 sparse_output
|
||
ohe = OneHotEncoder(sparse_output=False)
|
||
except TypeError:
|
||
# 旧版本 sklearn 使用 sparse
|
||
ohe = OneHotEncoder(sparse=False)
|
||
|
||
ohe.fit(pd.DataFrame(categories))
|
||
cat_ohe = pd.DataFrame(
|
||
ohe.transform(pd.DataFrame(categories)),
|
||
columns=categories,
|
||
index=categories
|
||
)
|
||
return cat_ohe
|
||
|
||
def _get_mole_representation(self, molecules: List[MoleculeInput]) -> pd.DataFrame:
|
||
"""
|
||
获取分子的MolE表示
|
||
|
||
Args:
|
||
molecules: 分子输入列表
|
||
|
||
Returns:
|
||
MolE特征表示DataFrame
|
||
"""
|
||
# 准备输入数据
|
||
df_data = []
|
||
for i, mol in enumerate(molecules):
|
||
chem_id = mol.chem_id or f"mol{i+1}"
|
||
df_data.append({"smiles": mol.smiles, "chem_id": chem_id})
|
||
|
||
df = pd.DataFrame(df_data)
|
||
|
||
# 确定设备
|
||
device = self.config.device
|
||
if device == "auto":
|
||
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||
|
||
# 获取MolE表示
|
||
return process_representation(
|
||
dataset_path=df,
|
||
smile_column_str="smiles",
|
||
id_column_str="chem_id",
|
||
pretrained_dir=self.config.mole_model_path,
|
||
device=device
|
||
)
|
||
|
||
def _add_strains(self, chemfeats_df: pd.DataFrame) -> pd.DataFrame:
|
||
"""
|
||
添加菌株信息到化学特征(笛卡尔积)
|
||
|
||
Args:
|
||
chemfeats_df: 化学特征DataFrame
|
||
|
||
Returns:
|
||
包含菌株信息的特征DataFrame
|
||
"""
|
||
# 准备化学特征
|
||
chemfe = chemfeats_df.reset_index().rename(columns={"index": "chem_id"})
|
||
chemfe["chem_id"] = chemfe["chem_id"].astype(str)
|
||
|
||
# 准备独热编码
|
||
sohe = self.strain_ohe.reset_index().rename(columns={"index": "strain_name"})
|
||
|
||
# 笛卡尔积合并
|
||
xpred = chemfe.merge(sohe, how="cross")
|
||
xpred["pred_id"] = xpred["chem_id"].str.cat(xpred["strain_name"], sep=":")
|
||
|
||
xpred = xpred.set_index("pred_id")
|
||
xpred = xpred.drop(columns=["chem_id", "strain_name"])
|
||
|
||
return xpred
|
||
|
||
def _gram_stain(self, label_df: pd.DataFrame) -> pd.DataFrame:
|
||
"""
|
||
添加革兰染色信息
|
||
|
||
Args:
|
||
label_df: 包含菌株名称的DataFrame
|
||
|
||
Returns:
|
||
添加革兰染色信息后的DataFrame
|
||
"""
|
||
df_label = label_df.copy()
|
||
|
||
# 提取NT编号
|
||
df_label["nt_number"] = df_label["strain_name"].apply(
|
||
lambda x: re.search(r".*?\((NT\d+)\)", x).group(1) if re.search(r".*?\((NT\d+)\)", x) else None
|
||
)
|
||
|
||
# 创建革兰染色字典
|
||
gram_dict = self.maier_strains[["Gram stain"]].to_dict()["Gram stain"]
|
||
|
||
# 添加染色信息
|
||
df_label["gram_stain"] = df_label["nt_number"].apply(gram_dict.get)
|
||
|
||
return df_label
|
||
|
||
def _prepare_strain_level_predictions(self, score_df: pd.DataFrame) -> pd.DataFrame:
|
||
"""
|
||
准备菌株级别的预测数据
|
||
|
||
Args:
|
||
score_df: 原始预测分数DataFrame,包含 pred_id, 0, 1, growth_inhibition 列
|
||
|
||
Returns:
|
||
格式化的菌株级别预测DataFrame
|
||
"""
|
||
# 创建副本
|
||
strain_df = score_df.copy()
|
||
|
||
# 分离化合物ID和菌株名
|
||
strain_df["chem_id"] = strain_df["pred_id"].str.split(":", expand=True)[0]
|
||
strain_df["strain_name"] = strain_df["pred_id"].str.split(":", expand=True)[1]
|
||
|
||
# 添加革兰染色信息
|
||
strain_df = self._gram_stain(strain_df)
|
||
|
||
# 重命名列为用户友好的名称
|
||
strain_df = strain_df.rename(columns={
|
||
"1": "antimicrobial_predictive_probability",
|
||
"0": "no_growth_probability"
|
||
})
|
||
|
||
# 选择并排序输出列
|
||
output_columns = [
|
||
"pred_id",
|
||
"chem_id",
|
||
"strain_name",
|
||
"antimicrobial_predictive_probability",
|
||
"no_growth_probability",
|
||
"growth_inhibition",
|
||
"gram_stain"
|
||
]
|
||
|
||
return strain_df[output_columns]
|
||
|
||
def _antimicrobial_potential(self, score_df: pd.DataFrame) -> pd.DataFrame:
|
||
"""
|
||
计算抗菌潜力分数
|
||
|
||
Args:
|
||
score_df: 预测分数DataFrame
|
||
|
||
Returns:
|
||
聚合后的抗菌潜力DataFrame
|
||
"""
|
||
# 分离化合物ID和菌株名
|
||
score_df["chem_id"] = score_df["pred_id"].str.split(":", expand=True)[0]
|
||
score_df["strain_name"] = score_df["pred_id"].str.split(":", expand=True)[1]
|
||
|
||
# 添加革兰染色信息
|
||
pred_df = self._gram_stain(score_df)
|
||
|
||
# 计算抗菌潜力分数(几何平均数的对数)
|
||
apscore_total = pred_df.groupby("chem_id")["1"].apply(gmean).to_frame().rename(
|
||
columns={"1": "apscore_total"}
|
||
)
|
||
apscore_total["apscore_total"] = np.log(apscore_total["apscore_total"])
|
||
|
||
# 按革兰染色分组的抗菌分数
|
||
apscore_gram = pred_df.groupby(["chem_id", "gram_stain"])["1"].apply(gmean).unstack().rename(
|
||
columns={"negative": "apscore_gnegative", "positive": "apscore_gpositive"}
|
||
)
|
||
apscore_gram["apscore_gnegative"] = np.log(apscore_gram["apscore_gnegative"])
|
||
apscore_gram["apscore_gpositive"] = np.log(apscore_gram["apscore_gpositive"])
|
||
|
||
# 被抑制菌株数统计
|
||
inhibted_total = pred_df.groupby("chem_id")["growth_inhibition"].sum().to_frame().rename(
|
||
columns={"growth_inhibition": "ginhib_total"}
|
||
)
|
||
|
||
# 按革兰染色分组的被抑制菌株数
|
||
inhibted_gram = pred_df.groupby(["chem_id", "gram_stain"])["growth_inhibition"].sum().unstack().rename(
|
||
columns={"negative": "ginhib_gnegative", "positive": "ginhib_gpositive"}
|
||
)
|
||
|
||
# 合并所有结果
|
||
agg_pred = apscore_total.join(apscore_gram).join(inhibted_total).join(inhibted_gram)
|
||
|
||
# 填充NaN值
|
||
agg_pred = agg_pred.fillna(0)
|
||
|
||
return agg_pred
|
||
|
||
|
||
def _predict_batch_worker(batch_data: Tuple[pd.DataFrame, int],
|
||
model_path: str,
|
||
app_threshold: float) -> Tuple[int, pd.DataFrame]:
|
||
"""
|
||
批次预测工作函数(用于多进程)
|
||
|
||
Args:
|
||
batch_data: (特征数据, 批次ID)
|
||
model_path: XGBoost模型路径
|
||
app_threshold: 抑制阈值
|
||
|
||
Returns:
|
||
(批次ID, 预测结果DataFrame)
|
||
"""
|
||
import warnings
|
||
# 忽略所有XGBoost版本相关的警告
|
||
warnings.filterwarnings("ignore", category=UserWarning, module="xgboost")
|
||
X_input, batch_id = batch_data
|
||
|
||
# 加载模型
|
||
with open(model_path, "rb") as file:
|
||
model = pickle.load(file)
|
||
# 修复特征名称兼容性问题
|
||
# 原因:模型使用旧版 XGBoost 保存时,特征列为元组格式(如 "('bacteria_name',)")
|
||
# 新版 XGBoost 严格检查特征名称匹配,导致预测失败。
|
||
# 解决:清除 XGBoost 内部的特征名称验证,直接使用输入特征进行预测
|
||
# 注意:此操作不改变模型权重和预测逻辑,只禁用格式检查,预测结果保持一致
|
||
if hasattr(model, 'get_booster'):
|
||
model.get_booster().feature_names = None
|
||
|
||
# 进行预测
|
||
y_pred = model.predict_proba(X_input)
|
||
pred_df = pd.DataFrame(y_pred, columns=["0", "1"], index=X_input.index)
|
||
|
||
# 二值化预测结果
|
||
pred_df["growth_inhibition"] = pred_df["1"].apply(
|
||
lambda x: 1 if x >= app_threshold else 0
|
||
)
|
||
|
||
return batch_id, pred_df
|
||
|
||
|
||
class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
|
||
"""
|
||
优化后的预测器 - 使用XGBoost内部并行
|
||
|
||
关键改进:
|
||
1. 单进程处理(避免GIL和进程间通信开销)
|
||
2. 模型只加载一次
|
||
3. XGBoost内部使用所有CPU核心(OpenMP)
|
||
4. 大批量处理
|
||
"""
|
||
|
||
def __init__(self, config: Optional[PredictionConfig] = None) -> None:
|
||
"""
|
||
初始化并预加载模型
|
||
|
||
Args:
|
||
config: 预测配置参数
|
||
"""
|
||
# 调用父类初始化
|
||
super().__init__(config)
|
||
|
||
# ✅ 核心优化:预加载XGBoost模型到内存
|
||
print("⚡ Loading XGBoost model...")
|
||
import time
|
||
import warnings
|
||
warnings.filterwarnings("ignore", category=UserWarning, module="xgboost")
|
||
|
||
start = time.time()
|
||
with open(self.config.xgboost_model_path, "rb") as file:
|
||
self.xgboost_model = pickle.load(file)
|
||
|
||
# 修复特征名称兼容性
|
||
if hasattr(self.xgboost_model, 'get_booster'):
|
||
self.xgboost_model.get_booster().feature_names = None
|
||
|
||
# ✅ 关键:设置XGBoost使用所有CPU核心
|
||
n_threads = mp.cpu_count()
|
||
self.xgboost_model.get_booster().set_param({
|
||
'nthread': n_threads
|
||
})
|
||
print(f"✓ XGBoost configured to use {n_threads} CPU threads")
|
||
|
||
print(f"✓ Model loaded in {time.time()-start:.2f}s")
|
||
|
||
def predict_batch(self, molecules: List[MoleculeInput],
|
||
include_strain_predictions: bool = False) -> List[BroadSpectrumResult]:
|
||
"""
|
||
单进程批量预测 - 使用XGBoost内部并行
|
||
|
||
Args:
|
||
molecules: 分子输入列表
|
||
include_strain_predictions: 是否在结果中包含菌株级别预测数据
|
||
|
||
Returns:
|
||
广谱抗菌预测结果列表
|
||
"""
|
||
if not molecules:
|
||
return []
|
||
|
||
import time
|
||
|
||
# 1. MolE表示(GPU)
|
||
print(f"\n{'='*60}")
|
||
print(f"Processing {len(molecules)} molecules...")
|
||
print(f"{'='*60}")
|
||
|
||
start_total = time.time()
|
||
|
||
start = time.time()
|
||
print("\n[1/4] Generating MolE representations (GPU)...")
|
||
mole_representation = self._get_mole_representation(molecules)
|
||
time_mole = time.time() - start
|
||
print(f"✓ Done in {time_mole:.1f}s")
|
||
|
||
# 2. 准备特征(添加菌株信息)
|
||
start = time.time()
|
||
print("\n[2/4] Preparing strain-level features...")
|
||
X_input = self._add_strains(mole_representation)
|
||
time_prep = time.time() - start
|
||
print(f"✓ Done in {time_prep:.1f}s")
|
||
print(f" Total predictions needed: {len(X_input):,}")
|
||
|
||
# 3. XGBoost预测(单次大批量,内部48核并行)
|
||
start = time.time()
|
||
print(f"\n[3/4] XGBoost prediction (using {mp.cpu_count()} CPU cores)...")
|
||
print(f" Predicting {len(X_input):,} samples in one batch...")
|
||
print(f" (Watch CPU usage - should be ~{mp.cpu_count()*100}%)")
|
||
|
||
# ✅ 关键:单次预测所有数据
|
||
# XGBoost内部会自动使用OpenMP并行到所有核心
|
||
y_pred = self.xgboost_model.predict_proba(X_input)
|
||
|
||
time_pred = time.time() - start
|
||
print(f"✓ Done in {time_pred:.1f}s")
|
||
print(f" Throughput: {len(X_input)/time_pred:.0f} predictions/second")
|
||
|
||
# 4. 后处理
|
||
start = time.time()
|
||
print("\n[4/4] Post-processing results...")
|
||
|
||
pred_df = pd.DataFrame(
|
||
y_pred,
|
||
columns=["0", "1"],
|
||
index=X_input.index
|
||
)
|
||
|
||
pred_df["growth_inhibition"] = (
|
||
pred_df["1"] >= self.config.app_threshold
|
||
).astype(int)
|
||
|
||
pred_df = pred_df.reset_index()
|
||
|
||
# 准备菌株级别数据(如果需要)
|
||
strain_level_data = None
|
||
if include_strain_predictions:
|
||
strain_level_data = self._prepare_strain_level_predictions(pred_df)
|
||
|
||
# 计算抗菌潜力
|
||
agg_df = self._antimicrobial_potential(pred_df)
|
||
|
||
# 判断广谱抗菌
|
||
agg_df["broad_spectrum"] = agg_df["ginhib_total"].apply(
|
||
lambda x: 1 if x >= self.config.min_nkill else 0
|
||
)
|
||
|
||
# 转换为结果对象
|
||
results_list = []
|
||
for _, row in agg_df.iterrows():
|
||
mol_strain_preds = None
|
||
if strain_level_data is not None:
|
||
mol_strain_preds = strain_level_data[
|
||
strain_level_data['chem_id'] == row.name
|
||
].reset_index(drop=True)
|
||
|
||
result = BroadSpectrumResult(
|
||
chem_id=row.name,
|
||
apscore_total=row["apscore_total"],
|
||
apscore_gnegative=row["apscore_gnegative"],
|
||
apscore_gpositive=row["apscore_gpositive"],
|
||
ginhib_total=int(row["ginhib_total"]),
|
||
ginhib_gnegative=int(row["ginhib_gnegative"]),
|
||
ginhib_gpositive=int(row["ginhib_gpositive"]),
|
||
broad_spectrum=int(row["broad_spectrum"]),
|
||
strain_predictions=mol_strain_preds
|
||
)
|
||
results_list.append(result)
|
||
|
||
time_post = time.time() - start
|
||
print(f"✓ Done in {time_post:.1f}s")
|
||
|
||
# 总结
|
||
total_time = time.time() - start_total
|
||
print(f"\n{'='*60}")
|
||
print(f"SUMMARY")
|
||
print(f"{'='*60}")
|
||
print(f" MolE representation: {time_mole:6.1f}s ({time_mole/total_time*100:5.1f}%)")
|
||
print(f" Feature preparation: {time_prep:6.1f}s ({time_prep/total_time*100:5.1f}%)")
|
||
print(f" XGBoost prediction: {time_pred:6.1f}s ({time_pred/total_time*100:5.1f}%)")
|
||
print(f" Post-processing: {time_post:6.1f}s ({time_post/total_time*100:5.1f}%)")
|
||
print(f" {'─'*58}")
|
||
print(f" Total time: {total_time:6.1f}s")
|
||
print(f" Molecules processed: {len(molecules)}")
|
||
print(f" Time per molecule: {total_time/len(molecules):.3f}s")
|
||
print(f"{'='*60}\n")
|
||
|
||
return results_list
|
||
|
||
def predict_single(self, molecule: MoleculeInput) -> BroadSpectrumResult:
|
||
"""
|
||
预测单个分子的广谱抗菌活性
|
||
|
||
Args:
|
||
molecule: 分子输入数据
|
||
|
||
Returns:
|
||
广谱抗菌预测结果
|
||
"""
|
||
results = self.predict_batch([molecule])
|
||
return results[0]
|
||
|
||
|
||
def predict_from_smiles(self,
|
||
smiles_list: List[str],
|
||
chem_ids: Optional[List[str]] = None) -> List[BroadSpectrumResult]:
|
||
"""
|
||
从SMILES字符串列表预测广谱抗菌活性
|
||
|
||
Args:
|
||
smiles_list: SMILES字符串列表
|
||
chem_ids: 化合物ID列表,如果为None则自动生成
|
||
|
||
Returns:
|
||
广谱抗菌预测结果列表
|
||
"""
|
||
if chem_ids is None:
|
||
chem_ids = [f"mol{i+1}" for i in range(len(smiles_list))]
|
||
|
||
if len(smiles_list) != len(chem_ids):
|
||
raise ValueError("smiles_list and chem_ids must have the same length")
|
||
|
||
molecules = [
|
||
MoleculeInput(smiles=smiles, chem_id=chem_id)
|
||
for smiles, chem_id in zip(smiles_list, chem_ids)
|
||
]
|
||
|
||
return self.predict_batch(molecules)
|
||
|
||
def predict_from_file(self,
|
||
file_path: str,
|
||
smiles_column: str = "smiles",
|
||
id_column: str = "chem_id") -> List[BroadSpectrumResult]:
|
||
"""
|
||
从文件预测广谱抗菌活性
|
||
|
||
Args:
|
||
file_path: 输入文件路径(支持CSV/TSV)
|
||
smiles_column: SMILES列名
|
||
id_column: 化合物ID列名
|
||
|
||
Returns:
|
||
广谱抗菌预测结果列表
|
||
"""
|
||
# 读取文件
|
||
if file_path.endswith('.tsv'):
|
||
df = pd.read_csv(file_path, sep='\t')
|
||
else:
|
||
df = pd.read_csv(file_path)
|
||
|
||
# 验证列存在(大小写不敏感)
|
||
columns_lower = {col.lower(): col for col in df.columns}
|
||
|
||
smiles_col_actual = columns_lower.get(smiles_column.lower())
|
||
if smiles_col_actual is None:
|
||
raise ValueError(f"Column '{smiles_column}' not found in file. Available columns: {list(df.columns)}")
|
||
|
||
# 处理ID列
|
||
id_col_actual = columns_lower.get(id_column.lower())
|
||
if id_col_actual is None:
|
||
df[id_column] = [f"mol{i+1}" for i in range(len(df))]
|
||
id_col_actual = id_column
|
||
|
||
# 创建分子输入
|
||
molecules = [
|
||
MoleculeInput(smiles=row[smiles_col_actual], chem_id=str(row[id_col_actual]))
|
||
for _, row in df.iterrows()
|
||
]
|
||
|
||
return self.predict_batch(molecules)
|
||
|
||
|
||
def create_predictor(config: Optional[PredictionConfig] = None) -> ParallelBroadSpectrumPredictor:
|
||
"""
|
||
创建并行广谱抗菌预测器实例
|
||
|
||
Args:
|
||
config: 预测配置参数
|
||
|
||
Returns:
|
||
预测器实例
|
||
"""
|
||
return ParallelBroadSpectrumPredictor(config)
|
||
|
||
|
||
# 便捷函数
|
||
def predict_smiles(smiles_list: List[str],
|
||
chem_ids: Optional[List[str]] = None,
|
||
config: Optional[PredictionConfig] = None) -> List[BroadSpectrumResult]:
|
||
"""
|
||
便捷函数:直接从SMILES列表预测广谱抗菌活性
|
||
|
||
Args:
|
||
smiles_list: SMILES字符串列表
|
||
chem_ids: 化合物ID列表
|
||
config: 预测配置
|
||
|
||
Returns:
|
||
预测结果列表
|
||
"""
|
||
predictor = create_predictor(config)
|
||
return predictor.predict_from_smiles(smiles_list, chem_ids)
|
||
|
||
|
||
def predict_file(file_path: str,
|
||
smiles_column: str = "smiles",
|
||
id_column: str = "chem_id",
|
||
config: Optional[PredictionConfig] = None) -> List[BroadSpectrumResult]:
|
||
"""
|
||
便捷函数:从文件预测广谱抗菌活性
|
||
|
||
Args:
|
||
file_path: 输入文件路径
|
||
smiles_column: SMILES列名
|
||
id_column: ID列名
|
||
config: 预测配置
|
||
|
||
Returns:
|
||
预测结果列表
|
||
"""
|
||
predictor = create_predictor(config)
|
||
return predictor.predict_from_file(file_path, smiles_column, id_column)
|
||
|