diff --git a/Data/mole/README.md b/Data/mole/README.md new file mode 100644 index 0000000..36d8097 --- /dev/null +++ b/Data/mole/README.md @@ -0,0 +1,163 @@ +## convert old xgboots pickle format + +```bash +cd Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001 +ipython +``` + +```python +import xgboost as xgb +import pickle +from pathlib import Path +ckpt = Path('MolE-XGBoost-08.03.2024_14.20.pkl') +out_ckpt = Path('./') + +# 加载旧模型 +with open(ckpt, 'rb') as f: + model = pickle.load(f) + +# 用新格式保存(推荐) +model.get_booster().save_model(out_ckpt.joinpath('MolE-XGBoost-08.03.2025_10.17.json')) + +# 或者继续用pickle但清晰格式 +booster = model.get_booster() +booster.feature_names = None +with open(out_ckpt.joinpath('MolE-XGBoost-08.03.2025_10.17.pkl'), 'wb') as f: + pickle.dump(model, f) +``` + +## 完整预测流程 + +```mermaid +SMILES 分子(输入CSV文件) + ↓ +[MolE 模型] + ├── config.yaml(模型配置) + └── model.pth(模型权重) + ↓ +分子特征表示(1000维向量) + ↓ +构建"分子-菌株对"(笛卡尔积) + └── maier_screening_results.tsv.gz(菌株列表) + ↓ +[XGBoost 模型] + └── MolE-XGBoost-08.03.2025_10.17.json(或.pkl) + ↓ +对每一对预测:是否抑制生长 + ↓ +获得原始预测结果(对每个菌株的预测) + ↓ +[聚合分析] + ├── maier_screening_results.tsv.gz(菌株列表) + └── strain_info_SF2.xlsx(革兰染色信息) + ↓ +最终预测结果 + ↓ +输出CSV文件 +``` + +## 所需文件清单 + +| 步骤 | 文件名 | 用途 | 备注 | +|------|--------|------|------| +| **MolE 模型** | `config.yaml` | 定义MolE网络结构 | YAML配置文件 | +| | `model.pth` | MolE模型权重 | PyTorch格式 | +| **构建菌株对** | `maier_screening_results.tsv.gz` | 提供40个菌株列表 | 压缩的TSV文件 | +| **XGBoost 预测** | `MolE-XGBoost-08.03.2025_10.17.json` | 预测分子-菌株对 | JSON格式(新)或PKL格式(旧) | +| **聚合分析** | `maier_screening_results.tsv.gz` | 菌株名称和统计 | 复用(与构建菌株对同一文件) | +| | `strain_info_SF2.xlsx` | 革兰染色分类信息 | Excel格式 | + +## 文件存放位置 + +所有文件应位于: +``` +Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001/ +├── config.yaml +├── model.pth +├── MolE-XGBoost-08.03.2025_10.17.json +├── maier_screening_results.tsv.gz +└── strain_info_SF2.xlsx +``` + +## 代码中的对应关系 + +```python +# PredictionConfig 中的配置 +@dataclass +class PredictionConfig: + xgboost_model_path = "MolE-XGBoost-08.03.2025_10.17.json" + mole_model_path = "model_ginconcat_btwin_100k_d8000_l0.0001" # 目录(包含config.yaml + model.pth) + strain_categories_path = "maier_screening_results.tsv.gz" + gram_info_path = "strain_info_SF2.xlsx" +``` + +## 数据流向总结 + +1. **输入**:CSV文件中的SMILES分子 +2. **MolE处理**:分子 → 1000维特征向量 +3. **菌株配对**:1个分子 × 40个菌株 = 40对 +4. **XGBoost预测**:每对 → 抑制概率 +5. **聚合分析**:统计和分类(按革兰染色) +6. **输出**:CSV文件中的预测结果(包含8个指标) + +## 参考文件 + +1. `maier_screening_results.tsv.gz` - 菌株列表和筛选数据 + +```python +self.maier_screen = pd.read_csv( + self.config.strain_categories_path, sep='\t', index_col=0 +) +self.strain_ohe = self._prep_ohe(self.maier_screen.columns) # 独热编码 +``` + +包含所有已知菌株的名称(40个菌株) +用于与每个分子做笛卡尔积(分子×菌株),生成所有"分子-菌株对" +XGBoost为每一对预测:是否能抑制该菌株的生长 + +2. `strain_info_SF2.xlsx` - 革兰染色信息 + +```python +self.maier_strains = pd.read_excel(self.config.gram_info_path, ...) +gram_dict = self.maier_strains[["Gram stain"]].to_dict()["Gram stain"] +``` + +记录每个菌株的革兰染色属性:阳性(positive) 或 阴性(negative) +用于将预测结果按革兰染色分类统计 + +预测结果示例: +某分子 mol1 的预测结果会包括: + +```python +BroadSpectrumResult( + chem_id='mol1', + apscore_total=2.5, # 对所有菌株的抗菌分数 + apscore_gnegative=2.1, # 仅对革兰阴性菌的分数 + apscore_gpositive=2.8, # 仅对革兰阳性菌的分数 + ginhib_total=25, # 抑制的菌株总数 + ginhib_gnegative=12, # 抑制的革兰阴性菌数 + ginhib_gpositive=13, # 抑制的革兰阳性菌数 + broad_spectrum=1 # 是否广谱(≥10个菌株) +) +``` + +结果解读: + +## BroadSpectrumResult 字段说明表 + +| 字段名 | 数据类型 | 计算方法 | 含义说明 | +|--------|----------|----------|---------| +| `chem_id` | 字符串 | 输入的化合物标识符 | 化合物的唯一标识,如 "mol1"、"compound_001" 等 | +| `apscore_total` | 浮点数 | `log(gmean(所有40个菌株的预测概率))` | 总体抗菌潜力分数:所有菌株预测概率的几何平均数的对数。值越高表示抗菌活性越强;负值表示整体抑制概率较低 | +| `apscore_gnegative` | 浮点数 | `log(gmean(革兰阴性菌株的预测概率))` | 革兰阴性菌抗菌潜力分数:仅针对革兰阴性菌株计算的抗菌分数。用于判断对阴性菌的特异性 | +| `apscore_gpositive` | 浮点数 | `log(gmean(革兰阳性菌株的预测概率))` | 革兰阳性菌抗菌潜力分数:仅针对革兰阳性菌株计算的抗菌分数。用于判断对阳性菌的特异性 | +| `ginhib_total` | 整数 | `sum(所有菌株的二值化预测)` | 总抑制菌株数:预测被抑制的菌株总数(概率 ≥ 0.04374 的菌株数量)。范围 0-40 | +| `ginhib_gnegative` | 整数 | `sum(革兰阴性菌株的二值化预测)` | 革兰阴性菌抑制数:预测被抑制的革兰阴性菌株数量。范围 0-20 | +| `ginhib_gpositive` | 整数 | `sum(革兰阳性菌株的二值化预测)` | 革兰阳性菌抑制数:预测被抑制的革兰阳性菌株数量。范围 0-20 | +| `broad_spectrum` | 整数 (0/1) | `1 if ginhib_total >= 10 else 0` | 广谱抗菌标志:如果抑制菌株数 ≥ 10,判定为广谱抗菌药物(1),否则为窄谱(0) | + +说明 + +- **apscore_* 类字段**:基于预测概率的连续评分,反映抗菌活性强度 +- **ginhib_* 类字段**:基于二值化预测的离散计数,反映抑制范围 +- **broad_spectrum**:基于 ginhib_total 的布尔判定,快速标识广谱特性 \ No newline at end of file diff --git a/Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001/MolE-XGBoost-08.03.2024_14.20.pkl b/Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001/MolE-XGBoost-08.03.2024_14.20.pkl new file mode 100644 index 0000000..bfaef01 Binary files /dev/null and b/Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001/MolE-XGBoost-08.03.2024_14.20.pkl differ diff --git a/Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001/MolE-XGBoost-08.03.2025_10.17.pkl b/Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001/MolE-XGBoost-08.03.2025_10.17.pkl new file mode 100644 index 0000000..c90ee63 Binary files /dev/null and b/Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001/MolE-XGBoost-08.03.2025_10.17.pkl differ diff --git a/Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001/config.yaml b/Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001/config.yaml new file mode 100644 index 0000000..3c8c032 --- /dev/null +++ b/Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001/config.yaml @@ -0,0 +1,28 @@ +batch_size: 1000 # batch size +warm_up: 10 # warm-up epochs +epochs: 1000 # total number of epochs + +load_model: None # resume training +eval_every_n_epochs: 1 # validation frequency +save_every_n_epochs: 5 # automatic model saving frequecy + +fp16_precision: False # float precision 16 (i.e. True/False) +init_lr: 0.0005 # initial learning rate for Adam +weight_decay: 1e-5 # weight decay for Adam +gpu: cuda:0 # training GPU + +model_type: gin_concat # GNN backbone (i.e., gin/gcn) +model: + num_layer: 5 # number of graph conv layers + emb_dim: 200 # embedding dimension in graph conv layers + feat_dim: 8000 # output feature dimention + drop_ratio: 0.0 # dropout ratio + pool: add # readout pooling (i.e., mean/max/add) + +dataset: + num_workers: 50 # dataloader number of workers + valid_size: 0.1 # ratio of validation data + data_path: data/pubchem_data/pubchem_100k_random.txt # path of pre-training data + +loss: + l: 0.0001 # Lambda parameter \ No newline at end of file diff --git a/Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001/maier_screening_results.tsv.gz b/Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001/maier_screening_results.tsv.gz new file mode 100644 index 0000000..c736060 Binary files /dev/null and b/Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001/maier_screening_results.tsv.gz differ diff --git a/Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001/model.pth b/Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001/model.pth new file mode 100644 index 0000000..d0fee23 Binary files /dev/null and b/Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001/model.pth differ diff --git a/Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001/strain_info_SF2.xlsx b/Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001/strain_info_SF2.xlsx new file mode 100644 index 0000000..e705cc3 Binary files /dev/null and b/Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001/strain_info_SF2.xlsx differ diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..eb8cefe --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,26 @@ +""" +SIME Models Package + +This package contains models for antimicrobial activity prediction. +""" + +from .broad_spectrum_predictor import ( + ParallelBroadSpectrumPredictor, + PredictionConfig, + MoleculeInput, + BroadSpectrumResult, + create_predictor, + predict_smiles, + predict_file +) + +__all__ = [ + 'ParallelBroadSpectrumPredictor', + 'PredictionConfig', + 'MoleculeInput', + 'BroadSpectrumResult', + 'create_predictor', + 'predict_smiles', + 'predict_file' +] + diff --git a/models/broad_spectrum_predictor.py b/models/broad_spectrum_predictor.py new file mode 100644 index 0000000..fa86043 --- /dev/null +++ b/models/broad_spectrum_predictor.py @@ -0,0 +1,567 @@ +""" +并行广谱抗菌预测器模块 + +提供高性能的分子广谱抗菌活性预测功能,支持批量处理和多进程并行计算。 +基于MolE分子表示和XGBoost模型进行预测。 +""" + +import os +import re +import pickle +import torch +import numpy as np +import pandas as pd +import multiprocessing as mp +from concurrent.futures import ProcessPoolExecutor, as_completed +from typing import List, Dict, Union, Optional, Tuple, Any +from dataclasses import dataclass +from pathlib import Path +from scipy.stats.mstats import gmean +from sklearn.preprocessing import OneHotEncoder + +from .mole_representation import process_representation + + +@dataclass +class PredictionConfig: + """预测配置参数""" + xgboost_model_path: str = None + mole_model_path: str = None + strain_categories_path: str = None + gram_info_path: str = None + app_threshold: float = 0.04374140128493309 + min_nkill: int = 10 + batch_size: int = 100 + n_workers: Optional[int] = None + device: str = "auto" + + def __post_init__(self): + """设置默认路径""" + from pathlib import Path + + # 获取当前文件所在目录 + current_file = Path(__file__).resolve() + project_root = current_file.parent.parent # models -> 项目根 + + data_dir = project_root / "Data/mole/pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001" + + # 设置所有路径 + if self.mole_model_path is None: + self.mole_model_path = str(data_dir) + + if self.xgboost_model_path is None: + self.xgboost_model_path = str(data_dir / "MolE-XGBoost-08.03.2025_10.17.pkl") + + if self.strain_categories_path is None: + self.strain_categories_path = str(data_dir / "maier_screening_results.tsv.gz") + + if self.gram_info_path is None: + self.gram_info_path = str(data_dir / "strain_info_SF2.xlsx") + +@dataclass +class MoleculeInput: + """分子输入数据结构""" + smiles: str + chem_id: Optional[str] = None + + +@dataclass +class BroadSpectrumResult: + """广谱抗菌预测结果""" + chem_id: str + apscore_total: float + apscore_gnegative: float + apscore_gpositive: float + ginhib_total: int + ginhib_gnegative: int + ginhib_gpositive: int + broad_spectrum: int + + def to_dict(self) -> Dict[str, Union[str, float, int]]: + """转换为字典格式""" + return { + 'chem_id': self.chem_id, + 'apscore_total': self.apscore_total, + 'apscore_gnegative': self.apscore_gnegative, + 'apscore_gpositive': self.apscore_gpositive, + 'ginhib_total': self.ginhib_total, + 'ginhib_gnegative': self.ginhib_gnegative, + 'ginhib_gpositive': self.ginhib_gpositive, + 'broad_spectrum': self.broad_spectrum + } + + +class BroadSpectrumPredictor: + """ + 广谱抗菌预测器 + + 基于MolE分子表示和XGBoost模型预测分子的广谱抗菌活性。 + 支持单分子和批量预测,提供详细的抗菌潜力分析。 + """ + + def __init__(self, config: Optional[PredictionConfig] = None) -> None: + """ + 初始化预测器 + + Args: + config: 预测配置参数,如果为None则使用默认配置 + """ + self.config = config or PredictionConfig() + self.n_workers = self.config.n_workers or mp.cpu_count() + + # 验证文件路径 + self._validate_paths() + + # 预加载共享数据 + self._load_shared_data() + + def _validate_paths(self) -> None: + """验证必要文件路径是否存在""" + required_files = { + "mole_model": self.config.mole_model_path, + "xgboost_model": self.config.xgboost_model_path, + "strain_categories": self.config.strain_categories_path, + "gram_info": self.config.gram_info_path, + } + + for name, file_path in required_files.items(): + if file_path is None: + raise ValueError(f"{name} is None! Check __post_init__ configuration") + if not Path(file_path).exists(): + raise FileNotFoundError(f"Required {name} not found: {file_path}") + + def _load_shared_data(self) -> None: + """加载共享数据(菌株信息、革兰染色信息等)""" + try: + # 加载菌株筛选数据 + self.maier_screen: pd.DataFrame = pd.read_csv( + self.config.strain_categories_path, sep='\t', index_col=0 + ) + + # 准备菌株独热编码 + self.strain_ohe: pd.DataFrame = self._prep_ohe(self.maier_screen.columns) + + # 加载革兰染色信息 + self.maier_strains: pd.DataFrame = pd.read_excel( + self.config.gram_info_path, + skiprows=[0, 1, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54], + index_col="NT data base" + ) + + except Exception as e: + raise RuntimeError(f"Failed to load shared data: {str(e)}") + + def _prep_ohe(self, categories: pd.Index) -> pd.DataFrame: + """ + 准备菌株的独热编码 + + Args: + categories: 菌株类别索引 + + Returns: + 独热编码后的DataFrame + """ + try: + # 新版本 sklearn 使用 sparse_output + ohe = OneHotEncoder(sparse_output=False) + except TypeError: + # 旧版本 sklearn 使用 sparse + ohe = OneHotEncoder(sparse=False) + + ohe.fit(pd.DataFrame(categories)) + cat_ohe = pd.DataFrame( + ohe.transform(pd.DataFrame(categories)), + columns=categories, + index=categories + ) + return cat_ohe + + def _get_mole_representation(self, molecules: List[MoleculeInput]) -> pd.DataFrame: + """ + 获取分子的MolE表示 + + Args: + molecules: 分子输入列表 + + Returns: + MolE特征表示DataFrame + """ + # 准备输入数据 + df_data = [] + for i, mol in enumerate(molecules): + chem_id = mol.chem_id or f"mol{i+1}" + df_data.append({"smiles": mol.smiles, "chem_id": chem_id}) + + df = pd.DataFrame(df_data) + + # 确定设备 + device = self.config.device + if device == "auto": + device = "cuda:0" if torch.cuda.is_available() else "cpu" + + # 获取MolE表示 + return process_representation( + dataset_path=df, + smile_column_str="smiles", + id_column_str="chem_id", + pretrained_dir=self.config.mole_model_path, + device=device + ) + + def _add_strains(self, chemfeats_df: pd.DataFrame) -> pd.DataFrame: + """ + 添加菌株信息到化学特征(笛卡尔积) + + Args: + chemfeats_df: 化学特征DataFrame + + Returns: + 包含菌株信息的特征DataFrame + """ + # 准备化学特征 + chemfe = chemfeats_df.reset_index().rename(columns={"index": "chem_id"}) + chemfe["chem_id"] = chemfe["chem_id"].astype(str) + + # 准备独热编码 + sohe = self.strain_ohe.reset_index().rename(columns={"index": "strain_name"}) + + # 笛卡尔积合并 + xpred = chemfe.merge(sohe, how="cross") + xpred["pred_id"] = xpred["chem_id"].str.cat(xpred["strain_name"], sep=":") + + xpred = xpred.set_index("pred_id") + xpred = xpred.drop(columns=["chem_id", "strain_name"]) + + return xpred + + def _gram_stain(self, label_df: pd.DataFrame) -> pd.DataFrame: + """ + 添加革兰染色信息 + + Args: + label_df: 包含菌株名称的DataFrame + + Returns: + 添加革兰染色信息后的DataFrame + """ + df_label = label_df.copy() + + # 提取NT编号 + df_label["nt_number"] = df_label["strain_name"].apply( + lambda x: re.search(r".*?\((NT\d+)\)", x).group(1) if re.search(r".*?\((NT\d+)\)", x) else None + ) + + # 创建革兰染色字典 + gram_dict = self.maier_strains[["Gram stain"]].to_dict()["Gram stain"] + + # 添加染色信息 + df_label["gram_stain"] = df_label["nt_number"].apply(gram_dict.get) + + return df_label + + def _antimicrobial_potential(self, score_df: pd.DataFrame) -> pd.DataFrame: + """ + 计算抗菌潜力分数 + + Args: + score_df: 预测分数DataFrame + + Returns: + 聚合后的抗菌潜力DataFrame + """ + # 分离化合物ID和菌株名 + score_df["chem_id"] = score_df["pred_id"].str.split(":", expand=True)[0] + score_df["strain_name"] = score_df["pred_id"].str.split(":", expand=True)[1] + + # 添加革兰染色信息 + pred_df = self._gram_stain(score_df) + + # 计算抗菌潜力分数(几何平均数的对数) + apscore_total = pred_df.groupby("chem_id")["1"].apply(gmean).to_frame().rename( + columns={"1": "apscore_total"} + ) + apscore_total["apscore_total"] = np.log(apscore_total["apscore_total"]) + + # 按革兰染色分组的抗菌分数 + apscore_gram = pred_df.groupby(["chem_id", "gram_stain"])["1"].apply(gmean).unstack().rename( + columns={"negative": "apscore_gnegative", "positive": "apscore_gpositive"} + ) + apscore_gram["apscore_gnegative"] = np.log(apscore_gram["apscore_gnegative"]) + apscore_gram["apscore_gpositive"] = np.log(apscore_gram["apscore_gpositive"]) + + # 被抑制菌株数统计 + inhibted_total = pred_df.groupby("chem_id")["growth_inhibition"].sum().to_frame().rename( + columns={"growth_inhibition": "ginhib_total"} + ) + + # 按革兰染色分组的被抑制菌株数 + inhibted_gram = pred_df.groupby(["chem_id", "gram_stain"])["growth_inhibition"].sum().unstack().rename( + columns={"negative": "ginhib_gnegative", "positive": "ginhib_gpositive"} + ) + + # 合并所有结果 + agg_pred = apscore_total.join(apscore_gram).join(inhibted_total).join(inhibted_gram) + + # 填充NaN值 + agg_pred = agg_pred.fillna(0) + + return agg_pred + + +def _predict_batch_worker(batch_data: Tuple[pd.DataFrame, int], + model_path: str, + app_threshold: float) -> Tuple[int, pd.DataFrame]: + """ + 批次预测工作函数(用于多进程) + + Args: + batch_data: (特征数据, 批次ID) + model_path: XGBoost模型路径 + app_threshold: 抑制阈值 + + Returns: + (批次ID, 预测结果DataFrame) + """ + import warnings + # 忽略所有XGBoost版本相关的警告 + warnings.filterwarnings("ignore", category=UserWarning, module="xgboost") + X_input, batch_id = batch_data + + # 加载模型 + with open(model_path, "rb") as file: + model = pickle.load(file) + # 修复特征名称兼容性问题 + # 原因:模型使用旧版 XGBoost 保存时,特征列为元组格式(如 "('bacteria_name',)") + # 新版 XGBoost 严格检查特征名称匹配,导致预测失败。 + # 解决:清除 XGBoost 内部的特征名称验证,直接使用输入特征进行预测 + # 注意:此操作不改变模型权重和预测逻辑,只禁用格式检查,预测结果保持一致 + if hasattr(model, 'get_booster'): + model.get_booster().feature_names = None + + # 进行预测 + y_pred = model.predict_proba(X_input) + pred_df = pd.DataFrame(y_pred, columns=["0", "1"], index=X_input.index) + + # 二值化预测结果 + pred_df["growth_inhibition"] = pred_df["1"].apply( + lambda x: 1 if x >= app_threshold else 0 + ) + + return batch_id, pred_df + + +class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor): + """ + 并行广谱抗菌预测器 + + 继承自BroadSpectrumPredictor,添加了多进程并行处理能力, + 适用于大规模分子批量预测。 + """ + + def predict_single(self, molecule: MoleculeInput) -> BroadSpectrumResult: + """ + 预测单个分子的广谱抗菌活性 + + Args: + molecule: 分子输入数据 + + Returns: + 广谱抗菌预测结果 + """ + results = self.predict_batch([molecule]) + return results[0] + + def predict_batch(self, molecules: List[MoleculeInput]) -> List[BroadSpectrumResult]: + """ + 批量预测分子的广谱抗菌活性 + + Args: + molecules: 分子输入列表 + + Returns: + 广谱抗菌预测结果列表 + """ + if not molecules: + return [] + + # 获取MolE表示 + print(f"Processing {len(molecules)} molecules...") + mole_representation = self._get_mole_representation(molecules) + + # 添加菌株信息 + print("Preparing strain-level features...") + X_input = self._add_strains(mole_representation) + + # 分批处理 + print(f"Starting parallel prediction with {self.n_workers} workers...") + batches = [] + for i in range(0, len(X_input), self.config.batch_size): + batch = X_input.iloc[i:i+self.config.batch_size] + batches.append((batch, i // self.config.batch_size)) + + # 并行预测 + results = {} + with ProcessPoolExecutor(max_workers=self.n_workers) as executor: + futures = { + executor.submit(_predict_batch_worker, (batch_data, batch_id), + self.config.xgboost_model_path, + self.config.app_threshold): batch_id + for batch_data, batch_id in batches + } + + for future in as_completed(futures): + batch_id, pred_df = future.result() + results[batch_id] = pred_df + print(f"Batch {batch_id} completed") + + # 合并结果 + print("Merging prediction results...") + all_pred_df = pd.concat([results[i] for i in sorted(results.keys())]) + + # 计算抗菌潜力 + print("Calculating antimicrobial potential scores...") + all_pred_df = all_pred_df.reset_index() + agg_df = self._antimicrobial_potential(all_pred_df) + + # 判断广谱抗菌 + agg_df["broad_spectrum"] = agg_df["ginhib_total"].apply( + lambda x: 1 if x >= self.config.min_nkill else 0 + ) + + # 转换为结果对象 + results_list = [] + for _, row in agg_df.iterrows(): + result = BroadSpectrumResult( + chem_id=row.name, + apscore_total=row["apscore_total"], + apscore_gnegative=row["apscore_gnegative"], + apscore_gpositive=row["apscore_gpositive"], + ginhib_total=int(row["ginhib_total"]), + ginhib_gnegative=int(row["ginhib_gnegative"]), + ginhib_gpositive=int(row["ginhib_gpositive"]), + broad_spectrum=int(row["broad_spectrum"]) + ) + results_list.append(result) + + return results_list + + def predict_from_smiles(self, + smiles_list: List[str], + chem_ids: Optional[List[str]] = None) -> List[BroadSpectrumResult]: + """ + 从SMILES字符串列表预测广谱抗菌活性 + + Args: + smiles_list: SMILES字符串列表 + chem_ids: 化合物ID列表,如果为None则自动生成 + + Returns: + 广谱抗菌预测结果列表 + """ + if chem_ids is None: + chem_ids = [f"mol{i+1}" for i in range(len(smiles_list))] + + if len(smiles_list) != len(chem_ids): + raise ValueError("smiles_list and chem_ids must have the same length") + + molecules = [ + MoleculeInput(smiles=smiles, chem_id=chem_id) + for smiles, chem_id in zip(smiles_list, chem_ids) + ] + + return self.predict_batch(molecules) + + def predict_from_file(self, + file_path: str, + smiles_column: str = "smiles", + id_column: str = "chem_id") -> List[BroadSpectrumResult]: + """ + 从文件预测广谱抗菌活性 + + Args: + file_path: 输入文件路径(支持CSV/TSV) + smiles_column: SMILES列名 + id_column: 化合物ID列名 + + Returns: + 广谱抗菌预测结果列表 + """ + # 读取文件 + if file_path.endswith('.tsv'): + df = pd.read_csv(file_path, sep='\t') + else: + df = pd.read_csv(file_path) + + # 验证列存在(大小写不敏感) + columns_lower = {col.lower(): col for col in df.columns} + + smiles_col_actual = columns_lower.get(smiles_column.lower()) + if smiles_col_actual is None: + raise ValueError(f"Column '{smiles_column}' not found in file. Available columns: {list(df.columns)}") + + # 处理ID列 + id_col_actual = columns_lower.get(id_column.lower()) + if id_col_actual is None: + df[id_column] = [f"mol{i+1}" for i in range(len(df))] + id_col_actual = id_column + + # 创建分子输入 + molecules = [ + MoleculeInput(smiles=row[smiles_col_actual], chem_id=str(row[id_col_actual])) + for _, row in df.iterrows() + ] + + return self.predict_batch(molecules) + + +def create_predictor(config: Optional[PredictionConfig] = None) -> ParallelBroadSpectrumPredictor: + """ + 创建并行广谱抗菌预测器实例 + + Args: + config: 预测配置参数 + + Returns: + 预测器实例 + """ + return ParallelBroadSpectrumPredictor(config) + + +# 便捷函数 +def predict_smiles(smiles_list: List[str], + chem_ids: Optional[List[str]] = None, + config: Optional[PredictionConfig] = None) -> List[BroadSpectrumResult]: + """ + 便捷函数:直接从SMILES列表预测广谱抗菌活性 + + Args: + smiles_list: SMILES字符串列表 + chem_ids: 化合物ID列表 + config: 预测配置 + + Returns: + 预测结果列表 + """ + predictor = create_predictor(config) + return predictor.predict_from_smiles(smiles_list, chem_ids) + + +def predict_file(file_path: str, + smiles_column: str = "smiles", + id_column: str = "chem_id", + config: Optional[PredictionConfig] = None) -> List[BroadSpectrumResult]: + """ + 便捷函数:从文件预测广谱抗菌活性 + + Args: + file_path: 输入文件路径 + smiles_column: SMILES列名 + id_column: ID列名 + config: 预测配置 + + Returns: + 预测结果列表 + """ + predictor = create_predictor(config) + return predictor.predict_from_file(file_path, smiles_column, id_column) + diff --git a/models/dataset_representation.py b/models/dataset_representation.py new file mode 100644 index 0000000..a0b92a1 --- /dev/null +++ b/models/dataset_representation.py @@ -0,0 +1,179 @@ +import os +import yaml +import numpy as np +import pandas as pd + +import torch +from torch_geometric.data import Data, Dataset, Batch + +from rdkit import Chem +from rdkit.Chem.rdchem import BondType as BT +from rdkit import RDLogger + +RDLogger.DisableLog('rdApp.*') + + +ATOM_LIST = list(range(1,119)) +CHIRALITY_LIST = [ + Chem.rdchem.ChiralType.CHI_UNSPECIFIED, + Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, + Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, + Chem.rdchem.ChiralType.CHI_OTHER +] +BOND_LIST = [ + BT.SINGLE, + BT.DOUBLE, + BT.TRIPLE, + BT.AROMATIC +] +BONDDIR_LIST = [ + Chem.rdchem.BondDir.NONE, + Chem.rdchem.BondDir.ENDUPRIGHT, + Chem.rdchem.BondDir.ENDDOWNRIGHT +] + + +class MoleculeDataset(Dataset): + """ + Dataset class for creating molecular graphs. + + Attributes: + - smile_df (pandas.DataFrame): DataFrame containing SMILES data. + - smile_column (str): Name of the column containing SMILES strings. + - id_column (str): Name of the column containing molecule IDs. + """ + + def __init__(self, smile_df, smile_column, id_column): + super(Dataset, self).__init__() + + # Gather the SMILES and the corresponding IDs + self.smiles_data = smile_df[smile_column].tolist() + self.id_data = smile_df[id_column].tolist() + + def __getitem__(self, index): + # Get the molecule + mol = Chem.MolFromSmiles(self.smiles_data[index]) + mol = Chem.AddHs(mol) + + ######################### + # Get the molecule info # + ######################### + type_idx = [] + chirality_idx = [] + atomic_number = [] + + # Roberto: Might want to add more features later on. Such as atomic spin + for atom in mol.GetAtoms(): + if atom.GetAtomicNum() == 0: + print(self.id_data[index]) + + type_idx.append(ATOM_LIST.index(atom.GetAtomicNum())) + chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag())) + atomic_number.append(atom.GetAtomicNum()) + + x1 = torch.tensor(type_idx, dtype=torch.long).view(-1,1) + x2 = torch.tensor(chirality_idx, dtype=torch.long).view(-1,1) + x = torch.cat([x1, x2], dim=-1) + + row, col, edge_feat = [], [], [] + for bond in mol.GetBonds(): + start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() + row += [start, end] + col += [end, start] + edge_feat.append([ + BOND_LIST.index(bond.GetBondType()), + BONDDIR_LIST.index(bond.GetBondDir()) + ]) + edge_feat.append([ + BOND_LIST.index(bond.GetBondType()), + BONDDIR_LIST.index(bond.GetBondDir()) + ]) + + edge_index = torch.tensor([row, col], dtype=torch.long) + edge_attr = torch.tensor(np.array(edge_feat), dtype=torch.long) + + data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, + chem_id=self.id_data[index]) + + return data + + def __len__(self): + return len(self.smiles_data) + + def get(self, index): + return self.__getitem__(index) + + def len(self): + return self.__len__() + + +def batch_representation(smile_df, dl_model, column_str, id_str, batch_size=10_000, id_is_str=True, device="cuda:0"): + """ + Generate molecular representations using a Deep Learning model. + + Parameters: + - smile_df (pandas.DataFrame): DataFrame containing SMILES data. + - dl_model: Deep Learning model for molecular representation. + - column_str (str): Name of the column containing SMILES strings. + - id_str (str): Name of the column containing molecule IDs. + - batch_size (int, optional): Batch size for processing (default is 10,000). + - id_is_str (bool, optional): Whether IDs are strings (default is True). + - device (str, optional): Device for computation (default is "cuda:0"). + + Returns: + - chem_representation (pandas.DataFrame): DataFrame containing molecular representations. + """ + + # First we create a list of graphs + molecular_graph_dataset = MoleculeDataset(smile_df, column_str, id_str) + graph_list = [g for g in molecular_graph_dataset] + + # Determine number of loops to do given the batch size + n_batches = len(graph_list) // batch_size + + # Are all molecules accounted for? + remaining_molecules = len(graph_list) % batch_size + + # Starting indices + start, end = 0, batch_size + + # Determine number of iterations + if remaining_molecules == 0: + n_iter = n_batches + + elif remaining_molecules > 0: + n_iter = n_batches + 1 + + # A list to store the batch dataframes + batch_dataframes = [] + + # Iterate over the batches + for i in range(n_iter): + # Start batch object + batch_obj = Batch() + graph_batch = batch_obj.from_data_list(graph_list[start:end]) + graph_batch = graph_batch.to(device) + + # Gather the representation + with torch.no_grad(): + dl_model.eval() + h_representation, _ = dl_model(graph_batch) + chem_ids = graph_batch.chem_id + + batch_df = pd.DataFrame(h_representation.cpu().numpy(), index=chem_ids) + batch_dataframes.append(batch_df) + + # Get the next batch + ## In the final iteration we want to get all the remaining molecules + if i == n_iter - 2: + start = end + end = len(graph_list) + else: + start = end + end = end + batch_size + + # Concatenate the dataframes + chem_representation = pd.concat(batch_dataframes) + + return chem_representation + diff --git a/models/ginet_concat.py b/models/ginet_concat.py new file mode 100644 index 0000000..783648e --- /dev/null +++ b/models/ginet_concat.py @@ -0,0 +1,164 @@ +import torch +from torch import nn +import torch.nn.functional as F + +from torch_geometric.nn import MessagePassing +from torch_geometric.utils import add_self_loops +from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool + +num_atom_type = 119 # including the extra mask tokens +num_chirality_tag = 3 + +num_bond_type = 5 # including aromatic and self-loop edge +num_bond_direction = 3 + + +class GINEConv(MessagePassing): + def __init__(self, emb_dim): + super(GINEConv, self).__init__() + self.mlp = nn.Sequential( + nn.Linear(emb_dim, 2*emb_dim), + nn.BatchNorm1d(2*emb_dim), + nn.ReLU(), + nn.Linear(2*emb_dim, emb_dim), + nn.ReLU() + ) + self.edge_embedding1 = nn.Embedding(num_bond_type, emb_dim) + self.edge_embedding2 = nn.Embedding(num_bond_direction, emb_dim) + nn.init.xavier_uniform_(self.edge_embedding1.weight.data) + nn.init.xavier_uniform_(self.edge_embedding2.weight.data) + + def forward(self, x, edge_index, edge_attr): + # add self loops in the edge space + edge_index = add_self_loops(edge_index, num_nodes=x.size(0))[0] + + # add features corresponding to self-loop edges. + self_loop_attr = torch.zeros(x.size(0), 2) + self_loop_attr[:,0] = 4 #bond type for self-loop edge + self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype) + edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0) + + edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1]) + + return self.propagate(edge_index, x=x, edge_attr=edge_embeddings) + + def message(self, x_j, edge_attr): + return x_j + edge_attr + + def update(self, aggr_out): + return self.mlp(aggr_out) + + +class GINet(nn.Module): + + """ + GIN encoder from MolE. + + Args: + num_layer (int): Number of GNN layers. + emb_dim (int): Dimensionality of embeddings for each graph layer. + feat_dim (int): Dimensionality of embedding vector. + drop_ratio (float): Dropout rate. + pool (str): Pooling method for neighbor aggregation ('mean', 'max', or 'add'). + + Output: + h_global_embedding: Graph-level representation + out: Final embedding vector + """ + def __init__(self, num_layer=5, emb_dim=300, feat_dim=256, drop_ratio=0, pool='mean'): + + super(GINet, self).__init__() + self.num_layer = num_layer + self.emb_dim = emb_dim + self.feat_dim = feat_dim + self.drop_ratio = drop_ratio + + self.concat_dim = num_layer * emb_dim + + if self.concat_dim != self.feat_dim: + print(f"Representation dimension ({self.concat_dim}) - Embedding dimension ({self.feat_dim})") + + self.x_embedding1 = nn.Embedding(num_atom_type, emb_dim) + self.x_embedding2 = nn.Embedding(num_chirality_tag, emb_dim) + nn.init.xavier_uniform_(self.x_embedding1.weight.data) + nn.init.xavier_uniform_(self.x_embedding2.weight.data) + + # List of MLPs + self.gnns = nn.ModuleList() + for layer in range(num_layer): + self.gnns.append(GINEConv(emb_dim)) + + # List of batchnorms + self.batch_norms = nn.ModuleList() + for layer in range(num_layer): + self.batch_norms.append(nn.BatchNorm1d(emb_dim)) + + if pool == 'mean': + self.pool = global_mean_pool + elif pool == 'max': + self.pool = global_max_pool + elif pool == 'add': + self.pool = global_add_pool + + self.feat_lin = nn.Linear(self.concat_dim, self.feat_dim) + + self.out_lin = nn.Sequential( + nn.Linear(self.feat_dim, self.feat_dim), + nn.BatchNorm1d(self.feat_dim), + nn.ReLU(inplace=True), + + nn.Linear(self.feat_dim, self.feat_dim), # Is not reduced to half size! + nn.BatchNorm1d(self.feat_dim), + nn.ReLU(inplace=True), + + nn.Linear(self.feat_dim, self.feat_dim) + ) + def forward(self, data): + x = data.x + edge_index = data.edge_index + edge_attr = data.edge_attr + + h_init = self.x_embedding1(x[:,0]) + self.x_embedding2(x[:,1]) + + # Perform the convolutions + h_dict = {} + + for layer in range(self.num_layer): + if layer == self.num_layer - 1: + tmp_h = self.gnns[layer](h_dict[f"h_{layer - 1}"], edge_index, edge_attr) + tmp_h = self.batch_norms[layer](tmp_h) + h_dict[f"h_{layer}"] = F.dropout(tmp_h, self.drop_ratio, training=self.training) + + else: + if layer == 0: + tmp_h = self.gnns[layer](h_init, edge_index, edge_attr) + tmp_h = self.batch_norms[layer](tmp_h) + h_dict[f"h_{layer}"] = F.dropout(F.relu(tmp_h), self.drop_ratio, training=self.training) + else: + tmp_h = self.gnns[layer](h_dict[f"h_{layer - 1}"], edge_index, edge_attr) + tmp_h = self.batch_norms[layer](tmp_h) + h_dict[f"h_{layer}"] = F.dropout(F.relu(tmp_h), self.drop_ratio, training=self.training) + + # Graph representation + h_list_pooled = [self.pool(h_dict[f"h_{layer}"], data.batch) for layer in range(self.num_layer)] + h_global_embedding = torch.cat(h_list_pooled, dim=1) + + assert h_global_embedding.shape[1] == self.concat_dim + + # Projection + h_expansion = self.feat_lin(h_global_embedding) + out = self.out_lin(h_expansion) + + return h_global_embedding, out + + def load_my_state_dict(self, state_dict): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name not in own_state: + continue + if isinstance(param, nn.parameter.Parameter): + # backwards compatibility for serialized parameters + param = param.data + print(name) + own_state[name].copy_(param) + diff --git a/models/mole.yaml b/models/mole.yaml new file mode 100644 index 0000000..fc804f2 --- /dev/null +++ b/models/mole.yaml @@ -0,0 +1,26 @@ +name: mole +channels: + - pytorch + - nvidia + - rmg + - conda-forge + - rdkit + - defaults +dependencies: + - python=3.8 + - pytorch=2.2.1 + - pytorch-cuda=11.8 + - rdkit=2022.3.3 + - pip + - pip: + - xgboost==1.6.2 + - pandas==2.0.3 + - PyYAML==6.0.1 + - torch_geometric==2.5.0 + - openpyxl + - pubchempy==1.0.4 + - matplotlib==3.7.5 + - seaborn==0.13.2 + - tqdm + - scikit-learn==1.0.2 + - umap-learn==0.5.5 \ No newline at end of file diff --git a/models/mole_representation.py b/models/mole_representation.py new file mode 100644 index 0000000..254ee91 --- /dev/null +++ b/models/mole_representation.py @@ -0,0 +1,128 @@ +""" +MolE Representation Module + +This module provides functions to generate MolE molecular representations. +""" + +import os +import yaml +import torch +import pandas as pd +from rdkit import Chem +from rdkit import RDLogger + +from .dataset_representation import batch_representation +from .ginet_concat import GINet + +RDLogger.DisableLog('rdApp.*') + + +def read_smiles(data_path, smile_col="smiles", id_col="chem_id"): + """ + Read SMILES data from a file or DataFrame and remove invalid SMILES. + + Parameters: + - data_path (str or pd.DataFrame): Path to the file or a DataFrame containing SMILES data. + - smile_col (str, optional): Name of the column containing SMILES strings. + - id_col (str, optional): Name of the column containing molecule IDs. + + Returns: + - smile_df (pandas.DataFrame): DataFrame containing SMILES data with specified columns. + """ + + # Read the data + if isinstance(data_path, pd.DataFrame): + smile_df = data_path.copy() + else: + # Try to read with different separators + try: + smile_df = pd.read_csv(data_path, sep='\t') + except: + smile_df = pd.read_csv(data_path) + + # Check if columns exist, handle case-insensitive matching + columns_lower = {col.lower(): col for col in smile_df.columns} + + smile_col_actual = columns_lower.get(smile_col.lower(), smile_col) + id_col_actual = columns_lower.get(id_col.lower(), id_col) + + if smile_col_actual not in smile_df.columns: + raise ValueError(f"Column '{smile_col}' not found in data. Available columns: {list(smile_df.columns)}") + + # Select columns + if id_col_actual in smile_df.columns: + smile_df = smile_df[[smile_col_actual, id_col_actual]] + smile_df.columns = [smile_col, id_col] + else: + # Create ID column if not exists + smile_df = smile_df[[smile_col_actual]] + smile_df.columns = [smile_col] + smile_df[id_col] = [f"mol{i+1}" for i in range(len(smile_df))] + + # Make sure ID column is interpreted as str + smile_df[id_col] = smile_df[id_col].astype(str) + + # Remove NaN + smile_df = smile_df.dropna() + + # Remove invalid smiles + smile_df = smile_df[smile_df[smile_col].apply(lambda x: Chem.MolFromSmiles(x) is not None)] + + return smile_df + + +def load_pretrained_model(pretrained_model_dir, device="cuda:0"): + """ + Load a pre-trained MolE model. + + Parameters: + - pretrained_model_dir (str): Path to the pre-trained MolE model directory. + - device (str, optional): Device for computation (default is "cuda:0"). + + Returns: + - model: Loaded pre-trained model. + """ + + # Read model configuration + config = yaml.load(open(os.path.join(pretrained_model_dir, "config.yaml"), "r"), Loader=yaml.FullLoader) + model_config = config["model"] + + # Instantiate model + model = GINet(**model_config).to(device) + + # Load pre-trained weights + model_pth_path = os.path.join(pretrained_model_dir, "model.pth") + print(f"Loading model from: {model_pth_path}") + + state_dict = torch.load(model_pth_path, map_location=device) + model.load_my_state_dict(state_dict) + + return model + + +def process_representation(dataset_path, smile_column_str, id_column_str, pretrained_dir, device): + """ + Process the dataset to generate molecular representations. + + Parameters: + - dataset_path (str or pd.DataFrame): Path to the dataset file or DataFrame. + - pretrained_dir (str): Path to the pre-trained model directory. + - smile_column_str (str): Name of the column containing SMILES strings. + - id_column_str (str): Name of the column containing molecule IDs. + - device (str): Device to use for computation. Can be "cpu", "cuda:0", etc. + + Returns: + - udl_representation (pandas.DataFrame): DataFrame containing molecular representations. + """ + + # First we read the SMILES dataframe + smiles_df = read_smiles(dataset_path, smile_col=smile_column_str, id_col=id_column_str) + + # Load the pre-trained model + pmodel = load_pretrained_model(pretrained_model_dir=pretrained_dir, device=device) + + # Gather pre-trained representation + udl_representation = batch_representation(smiles_df, pmodel, smile_column_str, id_column_str, device=device) + + return udl_representation + diff --git a/utils/mole_predictor.py b/utils/mole_predictor.py new file mode 100644 index 0000000..a83cc48 --- /dev/null +++ b/utils/mole_predictor.py @@ -0,0 +1,278 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +MolE 抗菌活性预测工具 + +这个脚本提供了使用 MolE 模型预测小分子 SMILES 抗菌活性的功能。 +支持命令行和 Python API 调用两种方式。 + +命令行示例: + python mole_predictor.py input.csv output.csv --smiles-column smiles --id-column chem_id + +Python API 示例: + from utils.mole_predictor import predict_csv_file + + predict_csv_file( + input_path="input.csv", + output_path="output.csv", + smiles_column="smiles", + id_column="chem_id" + ) +""" + +import sys +import os +from pathlib import Path + +# 添加项目根目录到 Python 路径 +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +import click +import pandas as pd +from typing import Optional, List +from datetime import datetime + +from models.broad_spectrum_predictor import ( + ParallelBroadSpectrumPredictor, + PredictionConfig, + MoleculeInput, + BroadSpectrumResult +) + + +def predict_csv_file( + input_path: str, + output_path: Optional[str] = None, + smiles_column: str = "smiles", + id_column: str = "chem_id", + batch_size: int = 100, + n_workers: Optional[int] = None, + device: str = "auto", + add_suffix: bool = True +) -> pd.DataFrame: + """ + 预测 CSV 文件中的分子抗菌活性 + + Args: + input_path: 输入 CSV 文件路径 + output_path: 输出 CSV 文件路径,如果为 None 则自动生成 + smiles_column: SMILES 列名 + id_column: 化合物 ID 列名 + batch_size: 批处理大小 + n_workers: 工作进程数 + device: 计算设备 ("auto", "cpu", "cuda:0" 等) + add_suffix: 是否在输出文件名后添加预测后缀 + + Returns: + 包含预测结果的 DataFrame + """ + + print(f"开始处理文件: {input_path}") + + # 读取输入文件 + input_path_obj = Path(input_path) + if not input_path_obj.exists(): + raise FileNotFoundError(f"输入文件不存在: {input_path}") + + # 读取 CSV + try: + df_input = pd.read_csv(input_path) + except Exception as e: + raise RuntimeError(f"读取 CSV 文件失败: {e}") + + print(f"读取了 {len(df_input)} 条数据") + + # 检查列是否存在(大小写不敏感) + columns_lower = {col.lower(): col for col in df_input.columns} + + smiles_col_actual = columns_lower.get(smiles_column.lower()) + if smiles_col_actual is None: + raise ValueError( + f"未找到 SMILES 列 '{smiles_column}'。可用列: {list(df_input.columns)}" + ) + + # 处理 ID 列 + id_col_actual = columns_lower.get(id_column.lower()) + if id_col_actual is None: + print(f"未找到 ID 列 '{id_column}',将自动生成 ID") + df_input[id_column] = [f"mol{i+1}" for i in range(len(df_input))] + id_col_actual = id_column + + # 创建预测器配置 + config = PredictionConfig( + batch_size=batch_size, + n_workers=n_workers, + device=device + ) + + # 初始化预测器 + print("初始化预测器...") + predictor = ParallelBroadSpectrumPredictor(config) + + # 准备分子输入 + molecules = [ + MoleculeInput(smiles=row[smiles_col_actual], chem_id=str(row[id_col_actual])) + for _, row in df_input.iterrows() + ] + + # 执行预测 + print("开始预测...") + results = predictor.predict_batch(molecules) + + # 转换结果为 DataFrame + results_dicts = [r.to_dict() for r in results] + df_results = pd.DataFrame(results_dicts) + + # 合并原始数据和预测结果 + # 使用 chem_id 作为键进行合并 + df_input['_merge_id'] = df_input[id_col_actual].astype(str) + df_results['_merge_id'] = df_results['chem_id'].astype(str) + + df_output = df_input.merge( + df_results.drop(columns=['chem_id']), + on='_merge_id', + how='left' + ) + df_output = df_output.drop(columns=['_merge_id']) + + # 生成输出路径 + if output_path is None: + if add_suffix: + output_path = str(input_path_obj.parent / f"{input_path_obj.stem}_predicted{input_path_obj.suffix}") + else: + output_path = str(input_path_obj.parent / f"{input_path_obj.stem}_output{input_path_obj.suffix}") + elif add_suffix: + output_path_obj = Path(output_path) + output_path = str(output_path_obj.parent / f"{output_path_obj.stem}_predicted{output_path_obj.suffix}") + + # 保存结果 + print(f"保存结果到: {output_path}") + df_output.to_csv(output_path, index=False) + + print(f"完成! 预测了 {len(results)} 个分子") + print(f"其中 {sum(r.broad_spectrum for r in results)} 个分子被预测为广谱抗菌") + + return df_output + + +def predict_multiple_files( + input_paths: List[str], + output_dir: Optional[str] = None, + smiles_column: str = "smiles", + id_column: str = "chem_id", + batch_size: int = 100, + n_workers: Optional[int] = None, + device: str = "auto", + add_suffix: bool = True +) -> List[pd.DataFrame]: + """ + 批量预测多个 CSV 文件 + + Args: + input_paths: 输入 CSV 文件路径列表 + output_dir: 输出目录,如果为 None 则在原文件目录生成 + smiles_column: SMILES 列名 + id_column: 化合物 ID 列名 + batch_size: 批处理大小 + n_workers: 工作进程数 + device: 计算设备 + add_suffix: 是否在输出文件名后添加预测后缀 + + Returns: + 包含预测结果的 DataFrame 列表 + """ + + results = [] + + for input_path in input_paths: + input_path_obj = Path(input_path) + + # 确定输出路径 + if output_dir is not None: + output_dir_obj = Path(output_dir) + output_dir_obj.mkdir(parents=True, exist_ok=True) + + if add_suffix: + output_path = str(output_dir_obj / f"{input_path_obj.stem}_predicted{input_path_obj.suffix}") + else: + output_path = str(output_dir_obj / input_path_obj.name) + else: + output_path = None + + # 预测单个文件 + try: + df_result = predict_csv_file( + input_path=input_path, + output_path=output_path, + smiles_column=smiles_column, + id_column=id_column, + batch_size=batch_size, + n_workers=n_workers, + device=device, + add_suffix=add_suffix + ) + results.append(df_result) + except Exception as e: + print(f"处理文件 {input_path} 时出错: {e}") + continue + + return results + + +# ============================================================================ +# 命令行接口 +# ============================================================================ + +@click.command() +@click.argument('input_path', type=click.Path(exists=True)) +@click.argument('output_path', type=click.Path(), required=False) +@click.option('--smiles-column', '-s', default='smiles', + help='SMILES 列名 (默认: smiles)') +@click.option('--id-column', '-i', default='chem_id', + help='化合物 ID 列名 (默认: chem_id)') +@click.option('--batch-size', '-b', default=100, type=int, + help='批处理大小 (默认: 100)') +@click.option('--n-workers', '-w', default=None, type=int, + help='工作进程数 (默认: CPU 核心数)') +@click.option('--device', '-d', default='auto', + type=click.Choice(['auto', 'cpu', 'cuda:0', 'cuda:1'], case_sensitive=False), + help='计算设备 (默认: auto)') +@click.option('--add-suffix/--no-add-suffix', default=True, + help='是否在输出文件名后添加 "_predicted" 后缀 (默认: 添加)') +def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers, device, add_suffix): + """ + 使用 MolE 模型预测小分子 SMILES 的抗菌活性 + + INPUT_PATH: 输入 CSV 文件路径 + + OUTPUT_PATH: 输出 CSV 文件路径 (可选,默认在原文件目录生成) + + 示例: + + python mole_predictor.py input.csv output.csv + + python mole_predictor.py input.csv -s SMILES -i ID + + python mole_predictor.py input.csv --device cuda:0 --batch-size 200 + """ + + try: + predict_csv_file( + input_path=input_path, + output_path=output_path, + smiles_column=smiles_column, + id_column=id_column, + batch_size=batch_size, + n_workers=n_workers, + device=device, + add_suffix=add_suffix + ) + except Exception as e: + click.echo(f"错误: {e}", err=True) + sys.exit(1) + + +if __name__ == '__main__': + cli() +