first add

This commit is contained in:
mm644706215
2025-10-16 17:21:48 +08:00
commit a56e60e9a3
192 changed files with 32720 additions and 0 deletions

3
.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
*.pth
__pychche__/
*.pyc

139
README.md Normal file
View File

@@ -0,0 +1,139 @@
## MolE 广谱抗菌预测 API
测试案例: example_usage.py
## 功能特性
1. **高性能并行处理** - 支持多进程并行计算,显著提高大批量分子预测速度
2. **多种使用方式** - 提供Python API、命令行工具和Web服务三种使用方式
3. **模块化设计** - 易于集成到其他项目中
4. **灵活配置** - 支持自定义模型路径、阈值等参数
## 安装
```bash
pip install -e .
```
## 使用方式
### 1. Python API
```python
from broad_spectrum_parallel import predict_smiles, MoleculeInput
# 预测单个或多个SMILES
results = predict_smiles(["CCO", "CCN"], ["ethanol", "ethylamine"])
for result in results:
print(f"{result.chem_id}: 广谱={result.broad_spectrum}, 抑制数={result.ginhib_total}")
```
### 2. 命令行工具
```bash
# 基本用法
predict_antimicrobial input.tsv output.tsv --smiles_input --smiles_colname smiles --chemid_colname chem_id
# 聚合预测结果
predict_antimicrobial input.tsv output.tsv --smiles_input --aggregate_scores
```
### 3. Web API服务
```bash
# 启动服务
uvicorn broad_spectrum_parallel.api:app --host 0.0.0.0 --port 8000
```
然后可以通过POST请求访问`http://localhost:8000/predict`端点:
```bash
curl -X POST "http://localhost:8000/predict" \
-H "Content-Type: application/json" \
-d '{"smiles": ["CCO", "CCN"]}'
```
## 结果解读
从运行结果可以看到每个化合物返回8个关键指标
1. 抗菌潜力分数(对数尺度):
apscore_total: -11.758 - 总体抗菌分数
apscore_gnegative: -11.648 - 革兰阴性菌抗菌分数
apscore_gpositive: -11.848 - 革兰阳性菌抗菌分数
2. 抑制菌株统计:
ginhib_total: 0 - 总抑制菌株数
ginhib_gnegative: 0 - 抑制的革兰阴性菌株数
ginhib_gpositive: 0 - 抑制的革兰阳性菌株数
3. 广谱判定:
broad_spectrum: 0 - 是否广谱抗菌需抑制≥10个菌株
🧪 结果解释示例
以乙醇(CCO)为例:
抗菌分数很低 (-11.758):表明预测的抗菌活性很弱
无菌株抑制 (0):在设定阈值下不能有效抑制任何测试菌株
非广谱抗菌 (0):不满足广谱抗菌的最低标准
这个结果符合预期,因为乙醇虽有杀菌作用,但在药物发现的标准下不被认为是有效的抗菌候选化合物。
## 可以运行的菌株信息
```shell
(mole) root@DESK4090:/srv/project/mole_antimicrobial_potential/broad_spectrum_parallel# micromamba run -n mole python -c "
> import pandas as pd
> import numpy as np
>
> # 加载菌株筛选数据
> maier_screen = pd.read_csv('data/01.prepare_training_data/maier_screening_results.tsv.gz', sep='\t', index_col=0)
> print(f'总菌株数量: {len(maier_screen.columns)}')
> print(f'总化合物数量: {len(maier_screen.index)}')
> print(f'菌株列表前10个:')
> for i, strain in enumerate(maier_screen.columns[:10]):
> print(f'{i+1}. {strain}')
>
> # 加载革兰染色信息
> gram_info = pd.read_excel('raw_data/maier_microbiome/strain_info_SF2.xlsx',
> skiprows=[0, 1, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54],
> index_col='NT data base')
> print(f'\n革兰染色信息:')
> print(gram_info['Gram stain'].value_counts())
> "
总菌株数量: 40
总化合物数量: 1197
菌株列表前10个:
1. Akkermansia muciniphila (NT5021)
2. Bacteroides caccae (NT5050)
3. Bacteroides fragilis (ET) (NT5033)
4. Bacteroides fragilis (NT) (NT5003)
5. Bacteroides ovatus (NT5054)
6. Bacteroides thetaiotaomicron (NT5004)
7. Bacteroides uniformis (NT5002)
8. Bacteroides vulgatus (NT5001)
9. Bacteroides xylanisolvens (NT5064)
10. Bifidobacterium adolescentis (NT5022)
/root/micromamba/envs/mole/lib/python3.10/site-packages/openpyxl/worksheet/_reader.py:329: UserWarning: Unknown extension is not supported and will be removed
warn(msg)
革兰染色信息:
Gram stain
positive 22
negative 18
Name: count, dtype: int64
```
## 权重下载
mole
https://www.alipan.com/s/DNuDo8iEn89
提取码: mh90
下载完成放到pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001
## 原始论文与github仓库
https://www.nature.com/articles/s41467-025-58804-4
https://github.com/rolayoalarcon/mole_antimicrobial_potential

30
__init__.py Normal file
View File

@@ -0,0 +1,30 @@
"""
并行广谱抗菌预测模块
提供高性能的分子广谱抗菌活性预测功能。
"""
from .broad_spectrum_api import (
ParallelBroadSpectrumPredictor,
BroadSpectrumPredictor,
PredictionConfig,
MoleculeInput,
BroadSpectrumResult,
create_predictor,
predict_smiles,
predict_file
)
__version__ = "1.0.0"
__author__ = "Your Name"
__all__ = [
"ParallelBroadSpectrumPredictor",
"BroadSpectrumPredictor",
"PredictionConfig",
"MoleculeInput",
"BroadSpectrumResult",
"create_predictor",
"predict_smiles",
"predict_file"
]

131
api.py Normal file
View File

@@ -0,0 +1,131 @@
"""
FastAPI服务基于并行广谱抗菌预测API
"""
from typing import List, Union, Optional
from fastapi import FastAPI
from pydantic import BaseModel, Field
import pandas as pd
from .broad_spectrum_api import (
ParallelBroadSpectrumPredictor,
PredictionConfig,
MoleculeInput,
BroadSpectrumResult
)
# 数据模型
class MoleculeInfo(BaseModel):
smiles: str
chem_id: Optional[str] = None
class MoleculeInputRequest(BaseModel):
molecules: Optional[List[MoleculeInfo]] = None
smiles: Optional[Union[str, List[str]]] = None
chem_id: Optional[Union[str, List[str]]] = None
aggregate_scores: bool = Field(False, description="Whether to aggregate predictions")
app_threshold: float = Field(0.04374140128493309, description="Threshold for growth inhibition")
min_nkill: int = Field(10, description="Minimum strains for broad spectrum")
batch_size: int = Field(100, description="Batch size for processing")
n_workers: Optional[int] = Field(None, description="Number of worker processes")
class PredictionResponse(BaseModel):
results: List[BroadSpectrumResult]
# 创建FastAPI应用
app = FastAPI(title="Antimicrobial Prediction API",
description="API for predicting antimicrobial activity of compounds using MolE and XGBoost",
version="1.0.0")
# 初始化预测器(在实际应用中可能需要更复杂的初始化)
predictor = None
@app.on_event("startup")
async def startup_event():
"""应用启动时初始化预测器"""
global predictor
config = PredictionConfig()
predictor = ParallelBroadSpectrumPredictor(config)
@app.post("/predict", response_model=PredictionResponse)
async def predict_antimicrobial(input_data: MoleculeInputRequest):
"""
预测化合物的抗菌活性
Args:
input_data: 包含分子信息的请求数据
Returns:
预测结果列表
"""
global predictor
if not predictor:
# 如果预测器未初始化,则创建一个
config = PredictionConfig(
batch_size=input_data.batch_size,
n_workers=input_data.n_workers
)
predictor = ParallelBroadSpectrumPredictor(config)
# 处理输入数据
if input_data.molecules:
# 直接使用MoleculeInput对象列表
molecules = [
MoleculeInput(smiles=m.smiles, chem_id=m.chem_id or f"mol{i+1}")
for i, m in enumerate(input_data.molecules)
]
elif input_data.smiles:
# 从SMILES字符串创建分子列表
if isinstance(input_data.smiles, str):
smiles_list = [input_data.smiles]
else:
smiles_list = input_data.smiles
if input_data.chem_id:
if isinstance(input_data.chem_id, str):
chem_ids = [input_data.chem_id]
else:
chem_ids = input_data.chem_id
else:
chem_ids = [f"mol{i+1}" for i in range(len(smiles_list))]
molecules = [
MoleculeInput(smiles=smiles, chem_id=chem_id)
for smiles, chem_id in zip(smiles_list, chem_ids)
]
else:
raise ValueError("Either 'molecules' or 'smiles' must be provided")
# 执行预测
try:
results = predictor.predict_batch(molecules)
return PredictionResponse(results=results)
except Exception as e:
raise RuntimeError(f"Prediction failed: {str(e)}")
@app.get("/health")
async def health_check():
"""
健康检查端点
"""
return {"status": "healthy"}
@app.get("/")
async def root():
"""
根路径提供API信息
"""
return {
"message": "Antimicrobial Prediction API",
"version": "1.0.0",
"description": "API for predicting antimicrobial activity of compounds"
}

524
broad_spectrum_api.py Normal file
View File

@@ -0,0 +1,524 @@
"""
并行广谱抗菌预测API模块
提供高性能的分子广谱抗菌活性预测功能,支持批量处理和多进程并行计算。
基于MolE分子表示和XGBoost模型进行预测。
"""
import os
import re
import pickle
import torch
import numpy as np
import pandas as pd
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import List, Dict, Union, Optional, Tuple, Any
from dataclasses import dataclass
from pathlib import Path
from scipy.stats.mstats import gmean
from sklearn.preprocessing import OneHotEncoder
from xgboost import XGBClassifier
try:
from mole_representation import process_representation
except ImportError:
print("Warning: mole_representation module not found. Please ensure it's in your Python path.")
@dataclass
class PredictionConfig:
"""预测配置参数"""
xgboost_model_path: str = "data/03.model_evaluation/MolE-XGBoost-08.03.2024_14.20.pkl"
mole_model_path: str = "pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001"
strain_categories_path: str = "data/01.prepare_training_data/maier_screening_results.tsv.gz"
gram_info_path: str = "raw_data/maier_microbiome/strain_info_SF2.xlsx"
app_threshold: float = 0.04374140128493309
min_nkill: int = 10
batch_size: int = 100
n_workers: Optional[int] = None
device: str = "auto"
@dataclass
class MoleculeInput:
"""分子输入数据结构"""
smiles: str
chem_id: Optional[str] = None
@dataclass
class BroadSpectrumResult:
"""广谱抗菌预测结果"""
chem_id: str
apscore_total: float
apscore_gnegative: float
apscore_gpositive: float
ginhib_total: int
ginhib_gnegative: int
ginhib_gpositive: int
broad_spectrum: int
def to_dict(self) -> Dict[str, Union[str, float, int]]:
"""转换为字典格式"""
return {
'chem_id': self.chem_id,
'apscore_total': self.apscore_total,
'apscore_gnegative': self.apscore_gnegative,
'apscore_gpositive': self.apscore_gpositive,
'ginhib_total': self.ginhib_total,
'ginhib_gnegative': self.ginhib_gnegative,
'ginhib_gpositive': self.ginhib_gpositive,
'broad_spectrum': self.broad_spectrum
}
class BroadSpectrumPredictor:
"""
广谱抗菌预测器
基于MolE分子表示和XGBoost模型预测分子的广谱抗菌活性。
支持单分子和批量预测,提供详细的抗菌潜力分析。
"""
def __init__(self, config: Optional[PredictionConfig] = None) -> None:
"""
初始化预测器
Args:
config: 预测配置参数如果为None则使用默认配置
"""
self.config = config or PredictionConfig()
self.n_workers = self.config.n_workers or mp.cpu_count()
# 验证文件路径
self._validate_paths()
# 预加载共享数据
self._load_shared_data()
def _validate_paths(self) -> None:
"""验证必要文件路径是否存在"""
required_files = [
self.config.xgboost_model_path,
self.config.strain_categories_path,
self.config.gram_info_path
]
for file_path in required_files:
if not Path(file_path).exists():
raise FileNotFoundError(f"Required file not found: {file_path}")
def _load_shared_data(self) -> None:
"""加载共享数据(菌株信息、革兰染色信息等)"""
try:
# 加载菌株筛选数据
self.maier_screen: pd.DataFrame = pd.read_csv(
self.config.strain_categories_path, sep='\t', index_col=0
)
# 准备菌株独热编码
self.strain_ohe: pd.DataFrame = self._prep_ohe(self.maier_screen.columns)
# 加载革兰染色信息
self.maier_strains: pd.DataFrame = pd.read_excel(
self.config.gram_info_path,
skiprows=[0, 1, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54],
index_col="NT data base"
)
except Exception as e:
raise RuntimeError(f"Failed to load shared data: {str(e)}")
def _prep_ohe(self, categories: pd.Index) -> pd.DataFrame:
"""
准备菌株的独热编码
Args:
categories: 菌株类别索引
Returns:
独热编码后的DataFrame
"""
ohe = OneHotEncoder(sparse=False)
ohe.fit(pd.DataFrame(categories))
cat_ohe = pd.DataFrame(
ohe.transform(pd.DataFrame(categories)),
columns=categories,
index=categories
)
return cat_ohe
def _get_mole_representation(self, molecules: List[MoleculeInput]) -> pd.DataFrame:
"""
获取分子的MolE表示
Args:
molecules: 分子输入列表
Returns:
MolE特征表示DataFrame
"""
# 准备输入数据
df_data = []
for i, mol in enumerate(molecules):
chem_id = mol.chem_id or f"mol{i+1}"
df_data.append({"smiles": mol.smiles, "chem_id": chem_id})
df = pd.DataFrame(df_data)
# 确定设备
device = self.config.device
if device == "auto":
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# 获取MolE表示
return process_representation(
dataset_path=df,
smile_column_str="smiles",
id_column_str="chem_id",
pretrained_dir=self.config.mole_model_path,
device=device
)
def _add_strains(self, chemfeats_df: pd.DataFrame) -> pd.DataFrame:
"""
添加菌株信息到化学特征(笛卡尔积)
Args:
chemfeats_df: 化学特征DataFrame
Returns:
包含菌株信息的特征DataFrame
"""
# 准备化学特征
chemfe = chemfeats_df.reset_index().rename(columns={"index": "chem_id"})
chemfe["chem_id"] = chemfe["chem_id"].astype(str)
# 准备独热编码
sohe = self.strain_ohe.reset_index().rename(columns={"index": "strain_name"})
# 笛卡尔积合并
xpred = chemfe.merge(sohe, how="cross")
xpred["pred_id"] = xpred["chem_id"].str.cat(xpred["strain_name"], sep=":")
xpred = xpred.set_index("pred_id")
xpred = xpred.drop(columns=["chem_id", "strain_name"])
return xpred
def _gram_stain(self, label_df: pd.DataFrame) -> pd.DataFrame:
"""
添加革兰染色信息
Args:
label_df: 包含菌株名称的DataFrame
Returns:
添加革兰染色信息后的DataFrame
"""
df_label = label_df.copy()
# 提取NT编号
df_label["nt_number"] = df_label["strain_name"].apply(
lambda x: re.search(r".*?\((NT\d+)\)", x).group(1) if re.search(r".*?\((NT\d+)\)", x) else None
)
# 创建革兰染色字典
gram_dict = self.maier_strains[["Gram stain"]].to_dict()["Gram stain"]
# 添加染色信息
df_label["gram_stain"] = df_label["nt_number"].apply(gram_dict.get)
return df_label
def _antimicrobial_potential(self, score_df: pd.DataFrame) -> pd.DataFrame:
"""
计算抗菌潜力分数
Args:
score_df: 预测分数DataFrame
Returns:
聚合后的抗菌潜力DataFrame
"""
# 分离化合物ID和菌株名
score_df["chem_id"] = score_df["pred_id"].str.split(":", expand=True)[0]
score_df["strain_name"] = score_df["pred_id"].str.split(":", expand=True)[1]
# 添加革兰染色信息
pred_df = self._gram_stain(score_df)
# 计算抗菌潜力分数(几何平均数的对数)
apscore_total = pred_df.groupby("chem_id")["1"].apply(gmean).to_frame().rename(
columns={"1": "apscore_total"}
)
apscore_total["apscore_total"] = np.log(apscore_total["apscore_total"])
# 按革兰染色分组的抗菌分数
apscore_gram = pred_df.groupby(["chem_id", "gram_stain"])["1"].apply(gmean).unstack().rename(
columns={"negative": "apscore_gnegative", "positive": "apscore_gpositive"}
)
apscore_gram["apscore_gnegative"] = np.log(apscore_gram["apscore_gnegative"])
apscore_gram["apscore_gpositive"] = np.log(apscore_gram["apscore_gpositive"])
# 被抑制菌株数统计
inhibted_total = pred_df.groupby("chem_id")["growth_inhibition"].sum().to_frame().rename(
columns={"growth_inhibition": "ginhib_total"}
)
# 按革兰染色分组的被抑制菌株数
inhibted_gram = pred_df.groupby(["chem_id", "gram_stain"])["growth_inhibition"].sum().unstack().rename(
columns={"negative": "ginhib_gnegative", "positive": "ginhib_gpositive"}
)
# 合并所有结果
agg_pred = apscore_total.join(apscore_gram).join(inhibted_total).join(inhibted_gram)
# 填充NaN值
agg_pred = agg_pred.fillna(0)
return agg_pred
def _predict_batch_worker(batch_data: Tuple[pd.DataFrame, int],
model_path: str,
app_threshold: float) -> Tuple[int, pd.DataFrame]:
"""
批次预测工作函数(用于多进程)
Args:
batch_data: (特征数据, 批次ID)
model_path: XGBoost模型路径
app_threshold: 抑制阈值
Returns:
(批次ID, 预测结果DataFrame)
"""
X_input, batch_id = batch_data
# 加载模型
with open(model_path, "rb") as file:
model = pickle.load(file)
# 进行预测
y_pred = model.predict_proba(X_input)
pred_df = pd.DataFrame(y_pred, columns=["0", "1"], index=X_input.index)
# 二值化预测结果
pred_df["growth_inhibition"] = pred_df["1"].apply(
lambda x: 1 if x >= app_threshold else 0
)
return batch_id, pred_df
class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
"""
并行广谱抗菌预测器
继承自BroadSpectrumPredictor添加了多进程并行处理能力
适用于大规模分子批量预测。
"""
def predict_single(self, molecule: MoleculeInput) -> BroadSpectrumResult:
"""
预测单个分子的广谱抗菌活性
Args:
molecule: 分子输入数据
Returns:
广谱抗菌预测结果
"""
results = self.predict_batch([molecule])
return results[0]
def predict_batch(self, molecules: List[MoleculeInput]) -> List[BroadSpectrumResult]:
"""
批量预测分子的广谱抗菌活性
Args:
molecules: 分子输入列表
Returns:
广谱抗菌预测结果列表
"""
if not molecules:
return []
# 获取MolE表示
print(f"Processing {len(molecules)} molecules...")
mole_representation = self._get_mole_representation(molecules)
# 添加菌株信息
print("Preparing strain-level features...")
X_input = self._add_strains(mole_representation)
# 分批处理
print(f"Starting parallel prediction with {self.n_workers} workers...")
batches = []
for i in range(0, len(X_input), self.config.batch_size):
batch = X_input.iloc[i:i+self.config.batch_size]
batches.append((batch, i // self.config.batch_size))
# 并行预测
results = {}
with ProcessPoolExecutor(max_workers=self.n_workers) as executor:
futures = {
executor.submit(_predict_batch_worker, (batch_data, batch_id),
self.config.xgboost_model_path,
self.config.app_threshold): batch_id
for batch_data, batch_id in batches
}
for future in as_completed(futures):
batch_id, pred_df = future.result()
results[batch_id] = pred_df
print(f"Batch {batch_id} completed")
# 合并结果
print("Merging prediction results...")
all_pred_df = pd.concat([results[i] for i in sorted(results.keys())])
# 计算抗菌潜力
print("Calculating antimicrobial potential scores...")
all_pred_df = all_pred_df.reset_index()
agg_df = self._antimicrobial_potential(all_pred_df)
# 判断广谱抗菌
agg_df["broad_spectrum"] = agg_df["ginhib_total"].apply(
lambda x: 1 if x >= self.config.min_nkill else 0
)
# 转换为结果对象
results_list = []
for _, row in agg_df.iterrows():
result = BroadSpectrumResult(
chem_id=row.name,
apscore_total=row["apscore_total"],
apscore_gnegative=row["apscore_gnegative"],
apscore_gpositive=row["apscore_gpositive"],
ginhib_total=int(row["ginhib_total"]),
ginhib_gnegative=int(row["ginhib_gnegative"]),
ginhib_gpositive=int(row["ginhib_gpositive"]),
broad_spectrum=int(row["broad_spectrum"])
)
results_list.append(result)
return results_list
def predict_from_smiles(self,
smiles_list: List[str],
chem_ids: Optional[List[str]] = None) -> List[BroadSpectrumResult]:
"""
从SMILES字符串列表预测广谱抗菌活性
Args:
smiles_list: SMILES字符串列表
chem_ids: 化合物ID列表如果为None则自动生成
Returns:
广谱抗菌预测结果列表
"""
if chem_ids is None:
chem_ids = [f"mol{i+1}" for i in range(len(smiles_list))]
if len(smiles_list) != len(chem_ids):
raise ValueError("smiles_list and chem_ids must have the same length")
molecules = [
MoleculeInput(smiles=smiles, chem_id=chem_id)
for smiles, chem_id in zip(smiles_list, chem_ids)
]
return self.predict_batch(molecules)
def predict_from_file(self,
file_path: str,
smiles_column: str = "smiles",
id_column: str = "chem_id") -> List[BroadSpectrumResult]:
"""
从文件预测广谱抗菌活性
Args:
file_path: 输入文件路径支持CSV/TSV
smiles_column: SMILES列名
id_column: 化合物ID列名
Returns:
广谱抗菌预测结果列表
"""
# 读取文件
if file_path.endswith('.tsv'):
df = pd.read_csv(file_path, sep='\t')
else:
df = pd.read_csv(file_path)
# 验证列存在
if smiles_column not in df.columns:
raise ValueError(f"Column '{smiles_column}' not found in file")
# 处理ID列
if id_column not in df.columns:
df[id_column] = [f"mol{i+1}" for i in range(len(df))]
# 创建分子输入
molecules = [
MoleculeInput(smiles=row[smiles_column], chem_id=row[id_column])
for _, row in df.iterrows()
]
return self.predict_batch(molecules)
def create_predictor(config: Optional[PredictionConfig] = None) -> ParallelBroadSpectrumPredictor:
"""
创建并行广谱抗菌预测器实例
Args:
config: 预测配置参数
Returns:
预测器实例
"""
return ParallelBroadSpectrumPredictor(config)
# 便捷函数
def predict_smiles(smiles_list: List[str],
chem_ids: Optional[List[str]] = None,
config: Optional[PredictionConfig] = None) -> List[BroadSpectrumResult]:
"""
便捷函数直接从SMILES列表预测广谱抗菌活性
Args:
smiles_list: SMILES字符串列表
chem_ids: 化合物ID列表
config: 预测配置
Returns:
预测结果列表
"""
predictor = create_predictor(config)
return predictor.predict_from_smiles(smiles_list, chem_ids)
def predict_file(file_path: str,
smiles_column: str = "smiles",
id_column: str = "chem_id",
config: Optional[PredictionConfig] = None) -> List[BroadSpectrumResult]:
"""
便捷函数:从文件预测广谱抗菌活性
Args:
file_path: 输入文件路径
smiles_column: SMILES列名
id_column: ID列名
config: 预测配置
Returns:
预测结果列表
"""
predictor = create_predictor(config)
return predictor.predict_from_file(file_path, smiles_column, id_column)

199
cli.py Normal file
View File

@@ -0,0 +1,199 @@
"""
命令行接口基于并行广谱抗菌预测API
"""
import argparse
import pandas as pd
from typing import List
from .broad_spectrum_api import (
ParallelBroadSpectrumPredictor,
PredictionConfig,
MoleculeInput
)
def parse_arguments():
"""
解析命令行参数
"""
parser = argparse.ArgumentParser(
prog="Prediction of antimicrobial activity.",
description="This program receives a collection of molecules as input. "
"If it receives SMILES, it first featurizes the molecules using MolE, "
"then makes predictions of antimicrobial activity. "
"By default, the program returns the antimicrobial predictive probabilities "
"for each compound-strain pair. "
"If the --aggregate_scores flag is set, then the program aggregates the predictions "
"into an antimicrobial potential score and reports the number of strains inhibited by each compound.",
usage="python cli.py input_filepath output_filepath [options]",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
# 输入文件
parser.add_argument("input_filepath", help="Complete path to input file. Can be a file with SMILES "
"(make sure to set the --smiles_input flag) or a file with MolE representation.")
# 输出文件
parser.add_argument("output_filepath", help="Complete path for output file")
# 输入类型参数组
inputargs = parser.add_argument_group("Input arguments", "Arguments related to the input file")
# 如果是SMILES输入
inputargs.add_argument("-s", "--smiles_input",
help="Flag variable. Indicates if the input_filepath contains SMILES "
"that have to be first represented using a MolE pre-trained model.",
action="store_true")
# SMILES列名
inputargs.add_argument("-c", "--smiles_colname",
help="Column name in input_filepath that contains the SMILES. Only used if --smiles_input is set.",
default="smiles")
# 化合物ID列名
inputargs.add_argument("-i", "--chemid_colname",
help="Column name in smiles_filepath that contains the ID string of each chemical. "
"Only used if --smiles_input is set",
default="chem_id")
# 模型参数组
modelargs = parser.add_argument_group("Model arguments", "Arguments related to the models used for prediction")
# XGBoost模型路径
modelargs.add_argument("-x", "--xgboost_model",
help="Path to the pickled XGBoost model that makes predictions (.pkl).",
default="data/03.model_evaluation/MolE-XGBoost-08.03.2024_14.20.pkl")
# MolE模型路径
modelargs.add_argument("-m", "--mole_model",
help="Path to the directory containing the config.yaml and model.pth files "
"of the pre-trained MolE chemical representation. Only used if smiles_input is set.",
default="pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001")
# 预测参数组
predargs = parser.add_argument_group("Prediction arguments", "Arguments related to the prediction process.")
# 聚合预测结果
predargs.add_argument("-a", "--aggregate_scores",
help="Flag variable. If not set, then the prediction for each compound-strain pair is reported. "
"If set, then prediction scores of each compound is aggregated into the antimicrobial "
"potential score and the total number of strains predicted to be inhibited is reported. "
"Additionally, the broad spectrum antibiotic prediction is reported.",
action="store_true")
# 抗菌评分阈值
predargs.add_argument("-t", "--app_threshold",
help="Threshold score applied to the antimicrobial predictive probabilities "
"in order to binarize compound-microbe predictions of growth inhibition. "
"Default from original publication.",
default=0.04374140128493309, type=float)
# 广谱抗菌阈值
predargs.add_argument("-k", "--min_nkill",
help="Minimum number of microbes predicted to be inhibited "
"in order to consider the compound a broad spectrum antibiotic.",
default=10, type=int)
# 批次大小
predargs.add_argument("--batch_size",
help="Batch size for processing molecules.",
default=100, type=int)
# 工作进程数
predargs.add_argument("--n_workers",
help="Number of worker processes for parallel processing. "
"If not set, the number of CPU cores will be used.",
default=None, type=int)
# 元数据参数组
metadataargs = parser.add_argument_group("Metadata arguments", "Arguments related to the metadata used for prediction.")
# Maier菌株信息
metadataargs.add_argument("-b", "--strain_categories",
help="Path to the Maier et.al. screening results.",
default="data/01.prepare_training_data/maier_screening_results.tsv.gz")
# 细菌信息
metadataargs.add_argument("-g", "--gram_information",
help="Path to strain metadata.",
default="raw_data/maier_microbiome/strain_info_SF2.xlsx")
# 设备
parser.add_argument("-d", "--device",
help="Device where the pre-trained model is loaded. "
"Can be one of ['cpu', 'cuda', 'auto']. If 'auto' (default) "
"then cuda:0 device is selected if a GPU is detected.",
default="auto")
args = parser.parse_args()
# 给出返回信息的提示
if args.aggregate_scores:
print("Aggregating predictions of antimicrobial activity.")
else:
print("Returning predictions of antimicrobial activity for each compound-strain pair.")
return args
def main():
"""
主函数
"""
# 解析命令行参数
args = parse_arguments()
# 创建配置对象
config = PredictionConfig(
xgboost_model_path=args.xgboost_model,
mole_model_path=args.mole_model,
strain_categories_path=args.strain_categories,
gram_info_path=args.gram_information,
app_threshold=args.app_threshold,
min_nkill=args.min_nkill,
batch_size=args.batch_size,
n_workers=args.n_workers,
device=args.device
)
# 创建预测器
predictor = ParallelBroadSpectrumPredictor(config)
# 根据输入类型处理
if args.smiles_input:
# 从文件中读取SMILES
results = predictor.predict_from_file(
args.input_filepath,
smiles_column=args.smiles_colname,
id_column=args.chemid_colname
)
else:
# 从已有表示中读取 (MolE 特征向量)
# TODO: 实现 predict_from_mole_representation 方法或调用相应API
raise NotImplementedError("Processing from pre-computed representations is not yet implemented in the new API")
# 根据是否聚合结果进行输出
# 根据是否聚合结果进行输出
if args.aggregate_scores:
print("Aggregating Antimicrobial potential")
# 聚合模式:每行一个化合物,包含汇总统计
results_df = pd.DataFrame([r.to_dict() for r in results])
results_df.set_index('chem_id', inplace=True)
results_df.to_csv(args.output_filepath, sep='\t')
else:
# 非聚合模式:每行一个化合物-菌株对,输出预测概率
print("Generating non-aggregated predictions (compound-strain pairs)")
rows = []
for result in results:
base_row = {'chem_id': result.chem_id}
for strain, prob in result.predictions.items():
row = base_row.copy()
row['strain'] = strain
row['prediction_probability'] = prob
rows.append(row)
results_df = pd.DataFrame(rows)
results_df.to_csv(args.output_filepath, sep='\t', index=False)
if __name__ == "__main__":
main()

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 390 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 403 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 364 KiB

Binary file not shown.

Binary file not shown.

File diff suppressed because it is too large Load Diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 230 KiB

File diff suppressed because it is too large Load Diff

Binary file not shown.

BIN
data/Source Data.zip Normal file

Binary file not shown.

131
detailed_analysis.py Normal file
View File

@@ -0,0 +1,131 @@
"""
详细分析广谱抗菌预测的计算过程
"""
import numpy as np
import pandas as pd
from scipy.stats.mstats import gmean
from broad_spectrum_api import ParallelBroadSpectrumPredictor, MoleculeInput
def analyze_prediction_process():
"""详细分析预测过程"""
print("=== 广谱抗菌预测详细分析 ===\n")
# 创建预测器
predictor = ParallelBroadSpectrumPredictor()
# 测试分子
molecule = MoleculeInput(smiles="CCO", chem_id="ethanol")
print("1. 基本信息:")
print(f" - 总菌株数量: {len(predictor.maier_screen.columns)}")
print(f" - 革兰阳性菌: {(predictor.maier_strains['Gram stain'] == 'positive').sum()}")
print(f" - 革兰阴性菌: {(predictor.maier_strains['Gram stain'] == 'negative').sum()}")
print(f" - 抑制阈值: {predictor.config.app_threshold}")
print(f" - 广谱标准: 抑制≥{predictor.config.min_nkill}个菌株")
# 获取MolE表示
print("\n2. 获取分子表示...")
mole_representation = predictor._get_mole_representation([molecule])
print(f" - MolE特征维度: {mole_representation.shape}")
# 添加菌株信息
print("\n3. 构建预测特征...")
X_input = predictor._add_strains(mole_representation)
print(f" - 预测样本数: {len(X_input)} (1个分子 × {len(predictor.maier_screen.columns)}个菌株)")
print(f" - 特征维度: {X_input.shape[1]}")
# 进行预测
print("\n4. 模型预测...")
import pickle
with open(predictor.config.xgboost_model_path, "rb") as file:
model = pickle.load(file)
y_pred = model.predict_proba(X_input)
pred_df = pd.DataFrame(y_pred, columns=["0", "1"], index=X_input.index)
# 显示预测概率统计
print(f" - 抑制概率范围: {pred_df['1'].min():.6f} - {pred_df['1'].max():.6f}")
print(f" - 抑制概率均值: {pred_df['1'].mean():.6f}")
print(f" - 抑制概率中位数: {pred_df['1'].median():.6f}")
# 二值化预测
pred_df["growth_inhibition"] = pred_df["1"].apply(
lambda x: 1 if x >= predictor.config.app_threshold else 0
)
inhibited_count = pred_df["growth_inhibition"].sum()
print(f" - 超过阈值的菌株数: {inhibited_count}")
# 分析抗菌分数计算
print("\n5. 抗菌分数计算:")
# 计算几何平均数
geometric_mean = gmean(pred_df["1"])
log_geometric_mean = np.log(geometric_mean)
print(f" - 所有概率的几何平均数: {geometric_mean:.10f}")
print(f" - 几何平均数的对数: {log_geometric_mean:.6f}")
# 显示概率分布
print(f"\n6. 概率分布分析:")
print(f" - 概率 < 0.001: {(pred_df['1'] < 0.001).sum()} 个菌株")
print(f" - 概率 0.001-0.01: {((pred_df['1'] >= 0.001) & (pred_df['1'] < 0.01)).sum()} 个菌株")
print(f" - 概率 0.01-0.1: {((pred_df['1'] >= 0.01) & (pred_df['1'] < 0.1)).sum()} 个菌株")
print(f" - 概率 ≥ 0.1: {(pred_df['1'] >= 0.1).sum()} 个菌株")
# 显示最高和最低的几个预测
print(f"\n7. 预测详情 (前5高和前5低):")
sorted_pred = pred_df.sort_values("1", ascending=False)
print(" 最高抑制概率:")
for i, (idx, row) in enumerate(sorted_pred.head().iterrows()):
strain_name = idx.split(":")[1]
print(f" {i+1}. {strain_name}: {row['1']:.6f}")
print(" 最低抑制概率:")
for i, (idx, row) in enumerate(sorted_pred.tail().iterrows()):
strain_name = idx.split(":")[1]
print(f" {i+1}. {strain_name}: {row['1']:.6f}")
# 完整预测结果
print(f"\n8. 最终结果:")
result = predictor.predict_single(molecule)
print(f" - 总抗菌分数: {result.apscore_total:.6f}")
print(f" - 革兰阴性菌分数: {result.apscore_gnegative:.6f}")
print(f" - 革兰阳性菌分数: {result.apscore_gpositive:.6f}")
print(f" - 抑制菌株总数: {result.ginhib_total}")
print(f" - 抑制革兰阴性菌: {result.ginhib_gnegative}")
print(f" - 抑制革兰阳性菌: {result.ginhib_gpositive}")
print(f" - 广谱抗菌: {'' if result.broad_spectrum else ''}")
def compare_different_molecules():
"""比较不同分子的预测结果"""
print("\n\n=== 不同分子对比分析 ===\n")
predictor = ParallelBroadSpectrumPredictor()
# 测试不同类型的分子
molecules = [
MoleculeInput(smiles="CCO", chem_id="ethanol"),
MoleculeInput(smiles="CC(=O)O", chem_id="acetic_acid"),
MoleculeInput(smiles="CCN", chem_id="ethylamine"),
MoleculeInput(smiles="c1ccccc1", chem_id="benzene"),
MoleculeInput(smiles="CC(C)O", chem_id="isopropanol"),
]
results = predictor.predict_batch(molecules)
print("分子对比结果:")
print("-" * 80)
print(f"{'分子':<15} {'SMILES':<12} {'抗菌分数':<10} {'抑制数':<8} {'广谱':<6}")
print("-" * 80)
for result in results:
mol_info = next(m for m in molecules if m.chem_id == result.chem_id)
print(f"{result.chem_id:<15} {mol_info.smiles:<12} {result.apscore_total:<10.3f} "
f"{result.ginhib_total:<8} {'' if result.broad_spectrum else '':<6}")
if __name__ == "__main__":
analyze_prediction_process()
compare_different_molecules()

120
example_usage.py Normal file
View File

@@ -0,0 +1,120 @@
"""
广谱抗菌预测API使用示例
"""
from typing import List
from broad_spectrum_api import (
ParallelBroadSpectrumPredictor,
PredictionConfig,
MoleculeInput,
BroadSpectrumResult,
predict_smiles,
predict_file
)
def example_single_prediction():
"""单分子预测示例"""
print("=== 单分子预测示例 ===")
# 创建预测器
predictor = ParallelBroadSpectrumPredictor()
# 预测单个分子
molecule = MoleculeInput(smiles="CCO", chem_id="ethanol")
result = predictor.predict_single(molecule)
print(f"化合物: {result.chem_id}")
print(f"广谱抗菌: {'' if result.broad_spectrum else ''}")
print(f"抑制菌株数: {result.ginhib_total}")
print(f"抗菌分数: {result.apscore_total:.3f}")
def example_batch_prediction():
"""批量预测示例"""
print("\n=== 批量预测示例 ===")
# 创建预测器
config = PredictionConfig(n_workers=4, batch_size=50)
predictor = ParallelBroadSpectrumPredictor(config)
# 准备多个分子
molecules = [
MoleculeInput(smiles="CCO", chem_id="ethanol"),
MoleculeInput(smiles="CCN", chem_id="ethylamine"),
MoleculeInput(smiles="CC(=O)O", chem_id="acetic_acid"),
]
# 批量预测
results = predictor.predict_batch(molecules)
# 输出结果
for result in results:
print(f"{result.chem_id}: 广谱={result.broad_spectrum}, 抑制数={result.ginhib_total}")
def example_smiles_list_prediction():
"""SMILES列表预测示例"""
print("\n=== SMILES列表预测示例 ===")
smiles_list = ["CCO", "CCN", "CC(=O)O"]
chem_ids = ["ethanol", "ethylamine", "acetic_acid"]
# 使用便捷函数
results = predict_smiles(smiles_list, chem_ids)
# 统计广谱抗菌化合物
broad_spectrum_count = sum(1 for r in results if r.broad_spectrum)
print(f"广谱抗菌化合物: {broad_spectrum_count}/{len(results)}")
def example_file_prediction():
"""文件预测示例"""
print("\n=== 文件预测示例 ===")
# 假设有输入文件 molecules.tsv
try:
results = predict_file(
"molecules.tsv",
smiles_column="smiles",
id_column="compound_id"
)
# 保存结果
import pandas as pd
results_df = pd.DataFrame([r.to_dict() for r in results])
results_df.to_csv("broad_spectrum_results.csv", index=False)
print(f"预测完成,结果保存到 broad_spectrum_results.csv")
except FileNotFoundError:
print("输入文件不存在,跳过文件预测示例")
def example_custom_config():
"""自定义配置示例"""
print("\n=== 自定义配置示例 ===")
# 自定义配置
config = PredictionConfig(
app_threshold=0.1, # 更严格的抑制阈值
min_nkill=15, # 更高的广谱标准
n_workers=8, # 更多并行进程
batch_size=200 # 更大的批次
)
predictor = ParallelBroadSpectrumPredictor(config)
# 预测
molecules = [MoleculeInput(smiles="CCO", chem_id="ethanol")]
results = predictor.predict_batch(molecules)
print(f"使用自定义配置预测结果: {results[0].to_dict()}")
if __name__ == "__main__":
# 运行所有示例
example_single_prediction()
example_batch_prediction()
example_smiles_list_prediction()
example_file_prediction()
example_custom_config()

181
mole_representation.py Normal file
View File

@@ -0,0 +1,181 @@
import os
import yaml
import argparse
import torch
import pandas as pd
from sklearn.preprocessing import OneHotEncoder
from xgboost import XGBClassifier
from rdkit import Chem
from rdkit import RDLogger
from workflow.dataset.dataset_representation import batch_representation
from workflow.models.ginet_concat import GINet
RDLogger.DisableLog('rdApp.*')
# Function to read command line arguments
def parse_arguments():
"""
This function returns parsed command line arguments.
"""
# Instantiate parser
parser = argparse.ArgumentParser(prog="Represent molecular structures as using MolE.",
description="This program recieves a file with SMILES and represents them using the MolE representation.",
usage="python mole_representation.py smiles_filepath output_filepath [options]",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# Input SMILES
parser.add_argument("smiles_filepath", help="Complete path to the smiles filepath. Expects a TSV file with a column containing SMILES strings.")
# Output filepath
parser.add_argument("output_filepath", help="Complete path for the output.")
# Column name for smiles
parser.add_argument("-c", "--smiles_colname", help="Column name in smiles_filepath that contains the SMILES.",
default="smiles")
# Column name for id
parser.add_argument("-i", "--chemid_colname", help="Column name in smiles_filepath that contains the ID string of each chemical.",
default="chem_id")
# MolE model
parser.add_argument("-m", "--mole_model", help="Path to the directory containing the config.yaml and model.pth files of the pre-trained MolE chemical representation.",
default="pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001")
# Device
parser.add_argument("-d", "--device", help="Device where the pre-trained model is loaded. Can be one of ['cpu', 'cuda', 'auto']. If 'auto' (default) then cuda:0 device is selected if a GPU is detected.",
default="auto")
# Parse arguments
args = parser.parse_args()
# Determine device for MolE model
if args.device == "auto":
args.device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Using {args.device}")
return args
# A FUNCTION TO READ SMILES from file
def read_smiles(data_path, smile_col="rdkit_no_salt", id_col="prestwick_ID"):
"""
Read SMILES data from a file or DataFrame and remove invalid SMILES.
Parameters:
- data_path (str or pd.DataFrame): Path to the file or a DataFrame containing SMILES data.
- smile_col (str, optional): Name of the column containing SMILES strings.
- id_col (str, optional): Name of the column containing molecule IDs.
Returns:
- smile_df (pandas.DataFrame): DataFrame containing SMILES data with specified columns.
"""
# Read the data
if isinstance(data_path, pd.DataFrame):
smile_df = data_path.copy()
else:
smile_df = pd.read_csv(data_path, sep='\t')
smile_df = smile_df[[smile_col, id_col]]
# Make sure ID column is interpreted as str
smile_df[id_col] = smile_df[id_col].astype(str)
# Remove NaN
smile_df = smile_df.dropna()
# Remove invalid smiles
smile_df = smile_df[smile_df[smile_col].apply(lambda x: Chem.MolFromSmiles(x) is not None)]
return smile_df
# Function to load a pre-trained model
def load_pretrained_model(pretrained_model_dir, device="cuda:0"):
"""
Load a pre-trained MolE model.
Parameters:
- pretrained_model_dir (str): Name of the pre-trained MolE model.
- device (str, optional): Device for computation (default is "cuda:0").
Returns:
- model: Loaded pre-trained model.
"""
# Read model configuration
config = yaml.load(open(os.path.join(pretrained_model_dir, "config.yaml"), "r"), Loader=yaml.FullLoader)
model_config = config["model"]
# Instantiate model
model = GINet(**model_config).to(device)
# Load pre-trained weights
model_pth_path = os.path.join(pretrained_model_dir, "model.pth")
print(model_pth_path)
state_dict = torch.load(model_pth_path, map_location=device)
model.load_my_state_dict(state_dict)
return model
def process_representation(dataset_path, smile_column_str, id_column_str, pretrained_dir, device):
"""
Process the dataset to generate molecular representations.
Parameters:
- dataset_path (str): Path to the dataset file.
- pretrained_dir (str): Name of the pre-trained model.
- smile_column_str (str, optional): Name of the column containing SMILES strings.
- id_column_str (str, optional): Name of the column containing molecule IDs.
- device (str): Device to use for computation (default is "cuda:0"). Can also be "cpu".
Returns:
- udl_representation (pandas.DataFrame): DataFrame containing molecular representations if split_data=False.
"""
# First we read the SMILES dataframe
smiles_df = read_smiles(dataset_path, smile_col=smile_column_str, id_col=id_column_str)
# Load the pre-trained model
pmodel = load_pretrained_model(pretrained_model_dir=pretrained_dir, device=device)
# Gather pre-trained representation
udl_representation = batch_representation(smiles_df, pmodel, smile_column_str, id_column_str, device=device)
return udl_representation
def main():
# Parse arguments
args = parse_arguments()
# Obtain MolE pre-trained representation
mole_representation = process_representation(dataset_path = args.smiles_filepath,
smile_column_str = args.smiles_colname,
id_column_str = args.chemid_colname,
pretrained_dir = args.mole_model,
device=args.device)
# Write MolE representation to output
mole_representation.to_csv(args.output_filepath, sep='\t')
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,28 @@
batch_size: 1000 # batch size
warm_up: 10 # warm-up epochs
epochs: 1000 # total number of epochs
load_model: None # resume training
eval_every_n_epochs: 1 # validation frequency
save_every_n_epochs: 5 # automatic model saving frequecy
fp16_precision: False # float precision 16 (i.e. True/False)
init_lr: 0.0005 # initial learning rate for Adam
weight_decay: 1e-5 # weight decay for Adam
gpu: cuda:0 # training GPU
model_type: gin_concat # GNN backbone (i.e., gin/gcn)
model:
num_layer: 5 # number of graph conv layers
emb_dim: 200 # embedding dimension in graph conv layers
feat_dim: 8000 # output feature dimention
drop_ratio: 0.0 # dropout ratio
pool: add # readout pooling (i.e., mean/max/add)
dataset:
num_workers: 50 # dataloader number of workers
valid_size: 0.1 # ratio of validation data
data_path: data/pubchem_data/pubchem_100k_random.txt # path of pre-training data
loss:
l: 0.0001 # Lambda parameter

Binary file not shown.

Some files were not shown because too many files have changed in this diff Show More