first add
This commit is contained in:
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
*.pth
|
||||
__pychche__/
|
||||
*.pyc
|
||||
139
README.md
Normal file
139
README.md
Normal file
@@ -0,0 +1,139 @@
|
||||
## MolE 广谱抗菌预测 API
|
||||
|
||||
测试案例: example_usage.py
|
||||
|
||||
## 功能特性
|
||||
|
||||
1. **高性能并行处理** - 支持多进程并行计算,显著提高大批量分子预测速度
|
||||
2. **多种使用方式** - 提供Python API、命令行工具和Web服务三种使用方式
|
||||
3. **模块化设计** - 易于集成到其他项目中
|
||||
4. **灵活配置** - 支持自定义模型路径、阈值等参数
|
||||
|
||||
## 安装
|
||||
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## 使用方式
|
||||
|
||||
### 1. Python API
|
||||
|
||||
```python
|
||||
from broad_spectrum_parallel import predict_smiles, MoleculeInput
|
||||
|
||||
# 预测单个或多个SMILES
|
||||
results = predict_smiles(["CCO", "CCN"], ["ethanol", "ethylamine"])
|
||||
|
||||
for result in results:
|
||||
print(f"{result.chem_id}: 广谱={result.broad_spectrum}, 抑制数={result.ginhib_total}")
|
||||
```
|
||||
|
||||
### 2. 命令行工具
|
||||
|
||||
```bash
|
||||
# 基本用法
|
||||
predict_antimicrobial input.tsv output.tsv --smiles_input --smiles_colname smiles --chemid_colname chem_id
|
||||
|
||||
# 聚合预测结果
|
||||
predict_antimicrobial input.tsv output.tsv --smiles_input --aggregate_scores
|
||||
```
|
||||
|
||||
### 3. Web API服务
|
||||
|
||||
```bash
|
||||
# 启动服务
|
||||
uvicorn broad_spectrum_parallel.api:app --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
然后可以通过POST请求访问`http://localhost:8000/predict`端点:
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:8000/predict" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"smiles": ["CCO", "CCN"]}'
|
||||
```
|
||||
|
||||
## 结果解读
|
||||
|
||||
从运行结果可以看到,每个化合物返回8个关键指标:
|
||||
|
||||
1. 抗菌潜力分数(对数尺度):
|
||||
|
||||
apscore_total: -11.758 - 总体抗菌分数
|
||||
apscore_gnegative: -11.648 - 革兰阴性菌抗菌分数
|
||||
apscore_gpositive: -11.848 - 革兰阳性菌抗菌分数
|
||||
2. 抑制菌株统计:
|
||||
|
||||
ginhib_total: 0 - 总抑制菌株数
|
||||
ginhib_gnegative: 0 - 抑制的革兰阴性菌株数
|
||||
ginhib_gpositive: 0 - 抑制的革兰阳性菌株数
|
||||
3. 广谱判定:
|
||||
|
||||
broad_spectrum: 0 - 是否广谱抗菌(需抑制≥10个菌株)
|
||||
🧪 结果解释示例
|
||||
以乙醇(CCO)为例:
|
||||
|
||||
抗菌分数很低 (-11.758):表明预测的抗菌活性很弱
|
||||
无菌株抑制 (0):在设定阈值下不能有效抑制任何测试菌株
|
||||
非广谱抗菌 (0):不满足广谱抗菌的最低标准
|
||||
这个结果符合预期,因为乙醇虽有杀菌作用,但在药物发现的标准下不被认为是有效的抗菌候选化合物。
|
||||
|
||||
## 可以运行的菌株信息
|
||||
|
||||
```shell
|
||||
(mole) root@DESK4090:/srv/project/mole_antimicrobial_potential/broad_spectrum_parallel# micromamba run -n mole python -c "
|
||||
> import pandas as pd
|
||||
> import numpy as np
|
||||
>
|
||||
> # 加载菌株筛选数据
|
||||
> maier_screen = pd.read_csv('data/01.prepare_training_data/maier_screening_results.tsv.gz', sep='\t', index_col=0)
|
||||
> print(f'总菌株数量: {len(maier_screen.columns)}')
|
||||
> print(f'总化合物数量: {len(maier_screen.index)}')
|
||||
> print(f'菌株列表前10个:')
|
||||
> for i, strain in enumerate(maier_screen.columns[:10]):
|
||||
> print(f'{i+1}. {strain}')
|
||||
>
|
||||
> # 加载革兰染色信息
|
||||
> gram_info = pd.read_excel('raw_data/maier_microbiome/strain_info_SF2.xlsx',
|
||||
> skiprows=[0, 1, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54],
|
||||
> index_col='NT data base')
|
||||
> print(f'\n革兰染色信息:')
|
||||
> print(gram_info['Gram stain'].value_counts())
|
||||
> "
|
||||
总菌株数量: 40
|
||||
总化合物数量: 1197
|
||||
菌株列表前10个:
|
||||
1. Akkermansia muciniphila (NT5021)
|
||||
2. Bacteroides caccae (NT5050)
|
||||
3. Bacteroides fragilis (ET) (NT5033)
|
||||
4. Bacteroides fragilis (NT) (NT5003)
|
||||
5. Bacteroides ovatus (NT5054)
|
||||
6. Bacteroides thetaiotaomicron (NT5004)
|
||||
7. Bacteroides uniformis (NT5002)
|
||||
8. Bacteroides vulgatus (NT5001)
|
||||
9. Bacteroides xylanisolvens (NT5064)
|
||||
10. Bifidobacterium adolescentis (NT5022)
|
||||
/root/micromamba/envs/mole/lib/python3.10/site-packages/openpyxl/worksheet/_reader.py:329: UserWarning: Unknown extension is not supported and will be removed
|
||||
warn(msg)
|
||||
|
||||
革兰染色信息:
|
||||
Gram stain
|
||||
positive 22
|
||||
negative 18
|
||||
Name: count, dtype: int64
|
||||
```
|
||||
|
||||
## 权重下载
|
||||
|
||||
mole
|
||||
https://www.alipan.com/s/DNuDo8iEn89
|
||||
提取码: mh90
|
||||
|
||||
下载完成放到:pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001
|
||||
|
||||
## 原始论文与github仓库
|
||||
|
||||
https://www.nature.com/articles/s41467-025-58804-4
|
||||
|
||||
https://github.com/rolayoalarcon/mole_antimicrobial_potential
|
||||
30
__init__.py
Normal file
30
__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""
|
||||
并行广谱抗菌预测模块
|
||||
|
||||
提供高性能的分子广谱抗菌活性预测功能。
|
||||
"""
|
||||
|
||||
from .broad_spectrum_api import (
|
||||
ParallelBroadSpectrumPredictor,
|
||||
BroadSpectrumPredictor,
|
||||
PredictionConfig,
|
||||
MoleculeInput,
|
||||
BroadSpectrumResult,
|
||||
create_predictor,
|
||||
predict_smiles,
|
||||
predict_file
|
||||
)
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "Your Name"
|
||||
|
||||
__all__ = [
|
||||
"ParallelBroadSpectrumPredictor",
|
||||
"BroadSpectrumPredictor",
|
||||
"PredictionConfig",
|
||||
"MoleculeInput",
|
||||
"BroadSpectrumResult",
|
||||
"create_predictor",
|
||||
"predict_smiles",
|
||||
"predict_file"
|
||||
]
|
||||
131
api.py
Normal file
131
api.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
FastAPI服务,基于并行广谱抗菌预测API
|
||||
"""
|
||||
|
||||
from typing import List, Union, Optional
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel, Field
|
||||
import pandas as pd
|
||||
from .broad_spectrum_api import (
|
||||
ParallelBroadSpectrumPredictor,
|
||||
PredictionConfig,
|
||||
MoleculeInput,
|
||||
BroadSpectrumResult
|
||||
)
|
||||
|
||||
|
||||
# 数据模型
|
||||
class MoleculeInfo(BaseModel):
|
||||
smiles: str
|
||||
chem_id: Optional[str] = None
|
||||
|
||||
|
||||
class MoleculeInputRequest(BaseModel):
|
||||
molecules: Optional[List[MoleculeInfo]] = None
|
||||
smiles: Optional[Union[str, List[str]]] = None
|
||||
chem_id: Optional[Union[str, List[str]]] = None
|
||||
aggregate_scores: bool = Field(False, description="Whether to aggregate predictions")
|
||||
app_threshold: float = Field(0.04374140128493309, description="Threshold for growth inhibition")
|
||||
min_nkill: int = Field(10, description="Minimum strains for broad spectrum")
|
||||
batch_size: int = Field(100, description="Batch size for processing")
|
||||
n_workers: Optional[int] = Field(None, description="Number of worker processes")
|
||||
|
||||
|
||||
class PredictionResponse(BaseModel):
|
||||
results: List[BroadSpectrumResult]
|
||||
|
||||
|
||||
# 创建FastAPI应用
|
||||
app = FastAPI(title="Antimicrobial Prediction API",
|
||||
description="API for predicting antimicrobial activity of compounds using MolE and XGBoost",
|
||||
version="1.0.0")
|
||||
|
||||
|
||||
# 初始化预测器(在实际应用中可能需要更复杂的初始化)
|
||||
predictor = None
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""应用启动时初始化预测器"""
|
||||
global predictor
|
||||
config = PredictionConfig()
|
||||
predictor = ParallelBroadSpectrumPredictor(config)
|
||||
|
||||
|
||||
@app.post("/predict", response_model=PredictionResponse)
|
||||
async def predict_antimicrobial(input_data: MoleculeInputRequest):
|
||||
"""
|
||||
预测化合物的抗菌活性
|
||||
|
||||
Args:
|
||||
input_data: 包含分子信息的请求数据
|
||||
|
||||
Returns:
|
||||
预测结果列表
|
||||
"""
|
||||
global predictor
|
||||
|
||||
if not predictor:
|
||||
# 如果预测器未初始化,则创建一个
|
||||
config = PredictionConfig(
|
||||
batch_size=input_data.batch_size,
|
||||
n_workers=input_data.n_workers
|
||||
)
|
||||
predictor = ParallelBroadSpectrumPredictor(config)
|
||||
|
||||
# 处理输入数据
|
||||
if input_data.molecules:
|
||||
# 直接使用MoleculeInput对象列表
|
||||
molecules = [
|
||||
MoleculeInput(smiles=m.smiles, chem_id=m.chem_id or f"mol{i+1}")
|
||||
for i, m in enumerate(input_data.molecules)
|
||||
]
|
||||
elif input_data.smiles:
|
||||
# 从SMILES字符串创建分子列表
|
||||
if isinstance(input_data.smiles, str):
|
||||
smiles_list = [input_data.smiles]
|
||||
else:
|
||||
smiles_list = input_data.smiles
|
||||
|
||||
if input_data.chem_id:
|
||||
if isinstance(input_data.chem_id, str):
|
||||
chem_ids = [input_data.chem_id]
|
||||
else:
|
||||
chem_ids = input_data.chem_id
|
||||
else:
|
||||
chem_ids = [f"mol{i+1}" for i in range(len(smiles_list))]
|
||||
|
||||
molecules = [
|
||||
MoleculeInput(smiles=smiles, chem_id=chem_id)
|
||||
for smiles, chem_id in zip(smiles_list, chem_ids)
|
||||
]
|
||||
else:
|
||||
raise ValueError("Either 'molecules' or 'smiles' must be provided")
|
||||
|
||||
# 执行预测
|
||||
try:
|
||||
results = predictor.predict_batch(molecules)
|
||||
return PredictionResponse(results=results)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Prediction failed: {str(e)}")
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""
|
||||
健康检查端点
|
||||
"""
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""
|
||||
根路径,提供API信息
|
||||
"""
|
||||
return {
|
||||
"message": "Antimicrobial Prediction API",
|
||||
"version": "1.0.0",
|
||||
"description": "API for predicting antimicrobial activity of compounds"
|
||||
}
|
||||
524
broad_spectrum_api.py
Normal file
524
broad_spectrum_api.py
Normal file
@@ -0,0 +1,524 @@
|
||||
"""
|
||||
并行广谱抗菌预测API模块
|
||||
|
||||
提供高性能的分子广谱抗菌活性预测功能,支持批量处理和多进程并行计算。
|
||||
基于MolE分子表示和XGBoost模型进行预测。
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import pickle
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import multiprocessing as mp
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from typing import List, Dict, Union, Optional, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from scipy.stats.mstats import gmean
|
||||
from sklearn.preprocessing import OneHotEncoder
|
||||
from xgboost import XGBClassifier
|
||||
|
||||
try:
|
||||
from mole_representation import process_representation
|
||||
except ImportError:
|
||||
print("Warning: mole_representation module not found. Please ensure it's in your Python path.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PredictionConfig:
|
||||
"""预测配置参数"""
|
||||
xgboost_model_path: str = "data/03.model_evaluation/MolE-XGBoost-08.03.2024_14.20.pkl"
|
||||
mole_model_path: str = "pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001"
|
||||
strain_categories_path: str = "data/01.prepare_training_data/maier_screening_results.tsv.gz"
|
||||
gram_info_path: str = "raw_data/maier_microbiome/strain_info_SF2.xlsx"
|
||||
app_threshold: float = 0.04374140128493309
|
||||
min_nkill: int = 10
|
||||
batch_size: int = 100
|
||||
n_workers: Optional[int] = None
|
||||
device: str = "auto"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MoleculeInput:
|
||||
"""分子输入数据结构"""
|
||||
smiles: str
|
||||
chem_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BroadSpectrumResult:
|
||||
"""广谱抗菌预测结果"""
|
||||
chem_id: str
|
||||
apscore_total: float
|
||||
apscore_gnegative: float
|
||||
apscore_gpositive: float
|
||||
ginhib_total: int
|
||||
ginhib_gnegative: int
|
||||
ginhib_gpositive: int
|
||||
broad_spectrum: int
|
||||
|
||||
def to_dict(self) -> Dict[str, Union[str, float, int]]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
'chem_id': self.chem_id,
|
||||
'apscore_total': self.apscore_total,
|
||||
'apscore_gnegative': self.apscore_gnegative,
|
||||
'apscore_gpositive': self.apscore_gpositive,
|
||||
'ginhib_total': self.ginhib_total,
|
||||
'ginhib_gnegative': self.ginhib_gnegative,
|
||||
'ginhib_gpositive': self.ginhib_gpositive,
|
||||
'broad_spectrum': self.broad_spectrum
|
||||
}
|
||||
|
||||
|
||||
class BroadSpectrumPredictor:
|
||||
"""
|
||||
广谱抗菌预测器
|
||||
|
||||
基于MolE分子表示和XGBoost模型预测分子的广谱抗菌活性。
|
||||
支持单分子和批量预测,提供详细的抗菌潜力分析。
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[PredictionConfig] = None) -> None:
|
||||
"""
|
||||
初始化预测器
|
||||
|
||||
Args:
|
||||
config: 预测配置参数,如果为None则使用默认配置
|
||||
"""
|
||||
self.config = config or PredictionConfig()
|
||||
self.n_workers = self.config.n_workers or mp.cpu_count()
|
||||
|
||||
# 验证文件路径
|
||||
self._validate_paths()
|
||||
|
||||
# 预加载共享数据
|
||||
self._load_shared_data()
|
||||
|
||||
def _validate_paths(self) -> None:
|
||||
"""验证必要文件路径是否存在"""
|
||||
required_files = [
|
||||
self.config.xgboost_model_path,
|
||||
self.config.strain_categories_path,
|
||||
self.config.gram_info_path
|
||||
]
|
||||
|
||||
for file_path in required_files:
|
||||
if not Path(file_path).exists():
|
||||
raise FileNotFoundError(f"Required file not found: {file_path}")
|
||||
|
||||
def _load_shared_data(self) -> None:
|
||||
"""加载共享数据(菌株信息、革兰染色信息等)"""
|
||||
try:
|
||||
# 加载菌株筛选数据
|
||||
self.maier_screen: pd.DataFrame = pd.read_csv(
|
||||
self.config.strain_categories_path, sep='\t', index_col=0
|
||||
)
|
||||
|
||||
# 准备菌株独热编码
|
||||
self.strain_ohe: pd.DataFrame = self._prep_ohe(self.maier_screen.columns)
|
||||
|
||||
# 加载革兰染色信息
|
||||
self.maier_strains: pd.DataFrame = pd.read_excel(
|
||||
self.config.gram_info_path,
|
||||
skiprows=[0, 1, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54],
|
||||
index_col="NT data base"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load shared data: {str(e)}")
|
||||
|
||||
def _prep_ohe(self, categories: pd.Index) -> pd.DataFrame:
|
||||
"""
|
||||
准备菌株的独热编码
|
||||
|
||||
Args:
|
||||
categories: 菌株类别索引
|
||||
|
||||
Returns:
|
||||
独热编码后的DataFrame
|
||||
"""
|
||||
ohe = OneHotEncoder(sparse=False)
|
||||
ohe.fit(pd.DataFrame(categories))
|
||||
cat_ohe = pd.DataFrame(
|
||||
ohe.transform(pd.DataFrame(categories)),
|
||||
columns=categories,
|
||||
index=categories
|
||||
)
|
||||
return cat_ohe
|
||||
|
||||
def _get_mole_representation(self, molecules: List[MoleculeInput]) -> pd.DataFrame:
|
||||
"""
|
||||
获取分子的MolE表示
|
||||
|
||||
Args:
|
||||
molecules: 分子输入列表
|
||||
|
||||
Returns:
|
||||
MolE特征表示DataFrame
|
||||
"""
|
||||
# 准备输入数据
|
||||
df_data = []
|
||||
for i, mol in enumerate(molecules):
|
||||
chem_id = mol.chem_id or f"mol{i+1}"
|
||||
df_data.append({"smiles": mol.smiles, "chem_id": chem_id})
|
||||
|
||||
df = pd.DataFrame(df_data)
|
||||
|
||||
# 确定设备
|
||||
device = self.config.device
|
||||
if device == "auto":
|
||||
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# 获取MolE表示
|
||||
return process_representation(
|
||||
dataset_path=df,
|
||||
smile_column_str="smiles",
|
||||
id_column_str="chem_id",
|
||||
pretrained_dir=self.config.mole_model_path,
|
||||
device=device
|
||||
)
|
||||
|
||||
def _add_strains(self, chemfeats_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
添加菌株信息到化学特征(笛卡尔积)
|
||||
|
||||
Args:
|
||||
chemfeats_df: 化学特征DataFrame
|
||||
|
||||
Returns:
|
||||
包含菌株信息的特征DataFrame
|
||||
"""
|
||||
# 准备化学特征
|
||||
chemfe = chemfeats_df.reset_index().rename(columns={"index": "chem_id"})
|
||||
chemfe["chem_id"] = chemfe["chem_id"].astype(str)
|
||||
|
||||
# 准备独热编码
|
||||
sohe = self.strain_ohe.reset_index().rename(columns={"index": "strain_name"})
|
||||
|
||||
# 笛卡尔积合并
|
||||
xpred = chemfe.merge(sohe, how="cross")
|
||||
xpred["pred_id"] = xpred["chem_id"].str.cat(xpred["strain_name"], sep=":")
|
||||
|
||||
xpred = xpred.set_index("pred_id")
|
||||
xpred = xpred.drop(columns=["chem_id", "strain_name"])
|
||||
|
||||
return xpred
|
||||
|
||||
def _gram_stain(self, label_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
添加革兰染色信息
|
||||
|
||||
Args:
|
||||
label_df: 包含菌株名称的DataFrame
|
||||
|
||||
Returns:
|
||||
添加革兰染色信息后的DataFrame
|
||||
"""
|
||||
df_label = label_df.copy()
|
||||
|
||||
# 提取NT编号
|
||||
df_label["nt_number"] = df_label["strain_name"].apply(
|
||||
lambda x: re.search(r".*?\((NT\d+)\)", x).group(1) if re.search(r".*?\((NT\d+)\)", x) else None
|
||||
)
|
||||
|
||||
# 创建革兰染色字典
|
||||
gram_dict = self.maier_strains[["Gram stain"]].to_dict()["Gram stain"]
|
||||
|
||||
# 添加染色信息
|
||||
df_label["gram_stain"] = df_label["nt_number"].apply(gram_dict.get)
|
||||
|
||||
return df_label
|
||||
|
||||
def _antimicrobial_potential(self, score_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
计算抗菌潜力分数
|
||||
|
||||
Args:
|
||||
score_df: 预测分数DataFrame
|
||||
|
||||
Returns:
|
||||
聚合后的抗菌潜力DataFrame
|
||||
"""
|
||||
# 分离化合物ID和菌株名
|
||||
score_df["chem_id"] = score_df["pred_id"].str.split(":", expand=True)[0]
|
||||
score_df["strain_name"] = score_df["pred_id"].str.split(":", expand=True)[1]
|
||||
|
||||
# 添加革兰染色信息
|
||||
pred_df = self._gram_stain(score_df)
|
||||
|
||||
# 计算抗菌潜力分数(几何平均数的对数)
|
||||
apscore_total = pred_df.groupby("chem_id")["1"].apply(gmean).to_frame().rename(
|
||||
columns={"1": "apscore_total"}
|
||||
)
|
||||
apscore_total["apscore_total"] = np.log(apscore_total["apscore_total"])
|
||||
|
||||
# 按革兰染色分组的抗菌分数
|
||||
apscore_gram = pred_df.groupby(["chem_id", "gram_stain"])["1"].apply(gmean).unstack().rename(
|
||||
columns={"negative": "apscore_gnegative", "positive": "apscore_gpositive"}
|
||||
)
|
||||
apscore_gram["apscore_gnegative"] = np.log(apscore_gram["apscore_gnegative"])
|
||||
apscore_gram["apscore_gpositive"] = np.log(apscore_gram["apscore_gpositive"])
|
||||
|
||||
# 被抑制菌株数统计
|
||||
inhibted_total = pred_df.groupby("chem_id")["growth_inhibition"].sum().to_frame().rename(
|
||||
columns={"growth_inhibition": "ginhib_total"}
|
||||
)
|
||||
|
||||
# 按革兰染色分组的被抑制菌株数
|
||||
inhibted_gram = pred_df.groupby(["chem_id", "gram_stain"])["growth_inhibition"].sum().unstack().rename(
|
||||
columns={"negative": "ginhib_gnegative", "positive": "ginhib_gpositive"}
|
||||
)
|
||||
|
||||
# 合并所有结果
|
||||
agg_pred = apscore_total.join(apscore_gram).join(inhibted_total).join(inhibted_gram)
|
||||
|
||||
# 填充NaN值
|
||||
agg_pred = agg_pred.fillna(0)
|
||||
|
||||
return agg_pred
|
||||
|
||||
|
||||
def _predict_batch_worker(batch_data: Tuple[pd.DataFrame, int],
|
||||
model_path: str,
|
||||
app_threshold: float) -> Tuple[int, pd.DataFrame]:
|
||||
"""
|
||||
批次预测工作函数(用于多进程)
|
||||
|
||||
Args:
|
||||
batch_data: (特征数据, 批次ID)
|
||||
model_path: XGBoost模型路径
|
||||
app_threshold: 抑制阈值
|
||||
|
||||
Returns:
|
||||
(批次ID, 预测结果DataFrame)
|
||||
"""
|
||||
X_input, batch_id = batch_data
|
||||
|
||||
# 加载模型
|
||||
with open(model_path, "rb") as file:
|
||||
model = pickle.load(file)
|
||||
|
||||
# 进行预测
|
||||
y_pred = model.predict_proba(X_input)
|
||||
pred_df = pd.DataFrame(y_pred, columns=["0", "1"], index=X_input.index)
|
||||
|
||||
# 二值化预测结果
|
||||
pred_df["growth_inhibition"] = pred_df["1"].apply(
|
||||
lambda x: 1 if x >= app_threshold else 0
|
||||
)
|
||||
|
||||
return batch_id, pred_df
|
||||
|
||||
|
||||
class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
|
||||
"""
|
||||
并行广谱抗菌预测器
|
||||
|
||||
继承自BroadSpectrumPredictor,添加了多进程并行处理能力,
|
||||
适用于大规模分子批量预测。
|
||||
"""
|
||||
|
||||
def predict_single(self, molecule: MoleculeInput) -> BroadSpectrumResult:
|
||||
"""
|
||||
预测单个分子的广谱抗菌活性
|
||||
|
||||
Args:
|
||||
molecule: 分子输入数据
|
||||
|
||||
Returns:
|
||||
广谱抗菌预测结果
|
||||
"""
|
||||
results = self.predict_batch([molecule])
|
||||
return results[0]
|
||||
|
||||
def predict_batch(self, molecules: List[MoleculeInput]) -> List[BroadSpectrumResult]:
|
||||
"""
|
||||
批量预测分子的广谱抗菌活性
|
||||
|
||||
Args:
|
||||
molecules: 分子输入列表
|
||||
|
||||
Returns:
|
||||
广谱抗菌预测结果列表
|
||||
"""
|
||||
if not molecules:
|
||||
return []
|
||||
|
||||
# 获取MolE表示
|
||||
print(f"Processing {len(molecules)} molecules...")
|
||||
mole_representation = self._get_mole_representation(molecules)
|
||||
|
||||
# 添加菌株信息
|
||||
print("Preparing strain-level features...")
|
||||
X_input = self._add_strains(mole_representation)
|
||||
|
||||
# 分批处理
|
||||
print(f"Starting parallel prediction with {self.n_workers} workers...")
|
||||
batches = []
|
||||
for i in range(0, len(X_input), self.config.batch_size):
|
||||
batch = X_input.iloc[i:i+self.config.batch_size]
|
||||
batches.append((batch, i // self.config.batch_size))
|
||||
|
||||
# 并行预测
|
||||
results = {}
|
||||
with ProcessPoolExecutor(max_workers=self.n_workers) as executor:
|
||||
futures = {
|
||||
executor.submit(_predict_batch_worker, (batch_data, batch_id),
|
||||
self.config.xgboost_model_path,
|
||||
self.config.app_threshold): batch_id
|
||||
for batch_data, batch_id in batches
|
||||
}
|
||||
|
||||
for future in as_completed(futures):
|
||||
batch_id, pred_df = future.result()
|
||||
results[batch_id] = pred_df
|
||||
print(f"Batch {batch_id} completed")
|
||||
|
||||
# 合并结果
|
||||
print("Merging prediction results...")
|
||||
all_pred_df = pd.concat([results[i] for i in sorted(results.keys())])
|
||||
|
||||
# 计算抗菌潜力
|
||||
print("Calculating antimicrobial potential scores...")
|
||||
all_pred_df = all_pred_df.reset_index()
|
||||
agg_df = self._antimicrobial_potential(all_pred_df)
|
||||
|
||||
# 判断广谱抗菌
|
||||
agg_df["broad_spectrum"] = agg_df["ginhib_total"].apply(
|
||||
lambda x: 1 if x >= self.config.min_nkill else 0
|
||||
)
|
||||
|
||||
# 转换为结果对象
|
||||
results_list = []
|
||||
for _, row in agg_df.iterrows():
|
||||
result = BroadSpectrumResult(
|
||||
chem_id=row.name,
|
||||
apscore_total=row["apscore_total"],
|
||||
apscore_gnegative=row["apscore_gnegative"],
|
||||
apscore_gpositive=row["apscore_gpositive"],
|
||||
ginhib_total=int(row["ginhib_total"]),
|
||||
ginhib_gnegative=int(row["ginhib_gnegative"]),
|
||||
ginhib_gpositive=int(row["ginhib_gpositive"]),
|
||||
broad_spectrum=int(row["broad_spectrum"])
|
||||
)
|
||||
results_list.append(result)
|
||||
|
||||
return results_list
|
||||
|
||||
def predict_from_smiles(self,
|
||||
smiles_list: List[str],
|
||||
chem_ids: Optional[List[str]] = None) -> List[BroadSpectrumResult]:
|
||||
"""
|
||||
从SMILES字符串列表预测广谱抗菌活性
|
||||
|
||||
Args:
|
||||
smiles_list: SMILES字符串列表
|
||||
chem_ids: 化合物ID列表,如果为None则自动生成
|
||||
|
||||
Returns:
|
||||
广谱抗菌预测结果列表
|
||||
"""
|
||||
if chem_ids is None:
|
||||
chem_ids = [f"mol{i+1}" for i in range(len(smiles_list))]
|
||||
|
||||
if len(smiles_list) != len(chem_ids):
|
||||
raise ValueError("smiles_list and chem_ids must have the same length")
|
||||
|
||||
molecules = [
|
||||
MoleculeInput(smiles=smiles, chem_id=chem_id)
|
||||
for smiles, chem_id in zip(smiles_list, chem_ids)
|
||||
]
|
||||
|
||||
return self.predict_batch(molecules)
|
||||
|
||||
def predict_from_file(self,
|
||||
file_path: str,
|
||||
smiles_column: str = "smiles",
|
||||
id_column: str = "chem_id") -> List[BroadSpectrumResult]:
|
||||
"""
|
||||
从文件预测广谱抗菌活性
|
||||
|
||||
Args:
|
||||
file_path: 输入文件路径(支持CSV/TSV)
|
||||
smiles_column: SMILES列名
|
||||
id_column: 化合物ID列名
|
||||
|
||||
Returns:
|
||||
广谱抗菌预测结果列表
|
||||
"""
|
||||
# 读取文件
|
||||
if file_path.endswith('.tsv'):
|
||||
df = pd.read_csv(file_path, sep='\t')
|
||||
else:
|
||||
df = pd.read_csv(file_path)
|
||||
|
||||
# 验证列存在
|
||||
if smiles_column not in df.columns:
|
||||
raise ValueError(f"Column '{smiles_column}' not found in file")
|
||||
|
||||
# 处理ID列
|
||||
if id_column not in df.columns:
|
||||
df[id_column] = [f"mol{i+1}" for i in range(len(df))]
|
||||
|
||||
# 创建分子输入
|
||||
molecules = [
|
||||
MoleculeInput(smiles=row[smiles_column], chem_id=row[id_column])
|
||||
for _, row in df.iterrows()
|
||||
]
|
||||
|
||||
return self.predict_batch(molecules)
|
||||
|
||||
|
||||
def create_predictor(config: Optional[PredictionConfig] = None) -> ParallelBroadSpectrumPredictor:
|
||||
"""
|
||||
创建并行广谱抗菌预测器实例
|
||||
|
||||
Args:
|
||||
config: 预测配置参数
|
||||
|
||||
Returns:
|
||||
预测器实例
|
||||
"""
|
||||
return ParallelBroadSpectrumPredictor(config)
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def predict_smiles(smiles_list: List[str],
|
||||
chem_ids: Optional[List[str]] = None,
|
||||
config: Optional[PredictionConfig] = None) -> List[BroadSpectrumResult]:
|
||||
"""
|
||||
便捷函数:直接从SMILES列表预测广谱抗菌活性
|
||||
|
||||
Args:
|
||||
smiles_list: SMILES字符串列表
|
||||
chem_ids: 化合物ID列表
|
||||
config: 预测配置
|
||||
|
||||
Returns:
|
||||
预测结果列表
|
||||
"""
|
||||
predictor = create_predictor(config)
|
||||
return predictor.predict_from_smiles(smiles_list, chem_ids)
|
||||
|
||||
|
||||
def predict_file(file_path: str,
|
||||
smiles_column: str = "smiles",
|
||||
id_column: str = "chem_id",
|
||||
config: Optional[PredictionConfig] = None) -> List[BroadSpectrumResult]:
|
||||
"""
|
||||
便捷函数:从文件预测广谱抗菌活性
|
||||
|
||||
Args:
|
||||
file_path: 输入文件路径
|
||||
smiles_column: SMILES列名
|
||||
id_column: ID列名
|
||||
config: 预测配置
|
||||
|
||||
Returns:
|
||||
预测结果列表
|
||||
"""
|
||||
predictor = create_predictor(config)
|
||||
return predictor.predict_from_file(file_path, smiles_column, id_column)
|
||||
199
cli.py
Normal file
199
cli.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""
|
||||
命令行接口,基于并行广谱抗菌预测API
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pandas as pd
|
||||
from typing import List
|
||||
from .broad_spectrum_api import (
|
||||
ParallelBroadSpectrumPredictor,
|
||||
PredictionConfig,
|
||||
MoleculeInput
|
||||
)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
"""
|
||||
解析命令行参数
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="Prediction of antimicrobial activity.",
|
||||
description="This program receives a collection of molecules as input. "
|
||||
"If it receives SMILES, it first featurizes the molecules using MolE, "
|
||||
"then makes predictions of antimicrobial activity. "
|
||||
"By default, the program returns the antimicrobial predictive probabilities "
|
||||
"for each compound-strain pair. "
|
||||
"If the --aggregate_scores flag is set, then the program aggregates the predictions "
|
||||
"into an antimicrobial potential score and reports the number of strains inhibited by each compound.",
|
||||
usage="python cli.py input_filepath output_filepath [options]",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
# 输入文件
|
||||
parser.add_argument("input_filepath", help="Complete path to input file. Can be a file with SMILES "
|
||||
"(make sure to set the --smiles_input flag) or a file with MolE representation.")
|
||||
|
||||
# 输出文件
|
||||
parser.add_argument("output_filepath", help="Complete path for output file")
|
||||
|
||||
# 输入类型参数组
|
||||
inputargs = parser.add_argument_group("Input arguments", "Arguments related to the input file")
|
||||
|
||||
# 如果是SMILES输入
|
||||
inputargs.add_argument("-s", "--smiles_input",
|
||||
help="Flag variable. Indicates if the input_filepath contains SMILES "
|
||||
"that have to be first represented using a MolE pre-trained model.",
|
||||
action="store_true")
|
||||
|
||||
# SMILES列名
|
||||
inputargs.add_argument("-c", "--smiles_colname",
|
||||
help="Column name in input_filepath that contains the SMILES. Only used if --smiles_input is set.",
|
||||
default="smiles")
|
||||
|
||||
# 化合物ID列名
|
||||
inputargs.add_argument("-i", "--chemid_colname",
|
||||
help="Column name in smiles_filepath that contains the ID string of each chemical. "
|
||||
"Only used if --smiles_input is set",
|
||||
default="chem_id")
|
||||
|
||||
# 模型参数组
|
||||
modelargs = parser.add_argument_group("Model arguments", "Arguments related to the models used for prediction")
|
||||
|
||||
# XGBoost模型路径
|
||||
modelargs.add_argument("-x", "--xgboost_model",
|
||||
help="Path to the pickled XGBoost model that makes predictions (.pkl).",
|
||||
default="data/03.model_evaluation/MolE-XGBoost-08.03.2024_14.20.pkl")
|
||||
|
||||
# MolE模型路径
|
||||
modelargs.add_argument("-m", "--mole_model",
|
||||
help="Path to the directory containing the config.yaml and model.pth files "
|
||||
"of the pre-trained MolE chemical representation. Only used if smiles_input is set.",
|
||||
default="pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001")
|
||||
|
||||
# 预测参数组
|
||||
predargs = parser.add_argument_group("Prediction arguments", "Arguments related to the prediction process.")
|
||||
|
||||
# 聚合预测结果
|
||||
predargs.add_argument("-a", "--aggregate_scores",
|
||||
help="Flag variable. If not set, then the prediction for each compound-strain pair is reported. "
|
||||
"If set, then prediction scores of each compound is aggregated into the antimicrobial "
|
||||
"potential score and the total number of strains predicted to be inhibited is reported. "
|
||||
"Additionally, the broad spectrum antibiotic prediction is reported.",
|
||||
action="store_true")
|
||||
|
||||
# 抗菌评分阈值
|
||||
predargs.add_argument("-t", "--app_threshold",
|
||||
help="Threshold score applied to the antimicrobial predictive probabilities "
|
||||
"in order to binarize compound-microbe predictions of growth inhibition. "
|
||||
"Default from original publication.",
|
||||
default=0.04374140128493309, type=float)
|
||||
|
||||
# 广谱抗菌阈值
|
||||
predargs.add_argument("-k", "--min_nkill",
|
||||
help="Minimum number of microbes predicted to be inhibited "
|
||||
"in order to consider the compound a broad spectrum antibiotic.",
|
||||
default=10, type=int)
|
||||
|
||||
# 批次大小
|
||||
predargs.add_argument("--batch_size",
|
||||
help="Batch size for processing molecules.",
|
||||
default=100, type=int)
|
||||
|
||||
# 工作进程数
|
||||
predargs.add_argument("--n_workers",
|
||||
help="Number of worker processes for parallel processing. "
|
||||
"If not set, the number of CPU cores will be used.",
|
||||
default=None, type=int)
|
||||
|
||||
# 元数据参数组
|
||||
metadataargs = parser.add_argument_group("Metadata arguments", "Arguments related to the metadata used for prediction.")
|
||||
|
||||
# Maier菌株信息
|
||||
metadataargs.add_argument("-b", "--strain_categories",
|
||||
help="Path to the Maier et.al. screening results.",
|
||||
default="data/01.prepare_training_data/maier_screening_results.tsv.gz")
|
||||
|
||||
# 细菌信息
|
||||
metadataargs.add_argument("-g", "--gram_information",
|
||||
help="Path to strain metadata.",
|
||||
default="raw_data/maier_microbiome/strain_info_SF2.xlsx")
|
||||
|
||||
# 设备
|
||||
parser.add_argument("-d", "--device",
|
||||
help="Device where the pre-trained model is loaded. "
|
||||
"Can be one of ['cpu', 'cuda', 'auto']. If 'auto' (default) "
|
||||
"then cuda:0 device is selected if a GPU is detected.",
|
||||
default="auto")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 给出返回信息的提示
|
||||
if args.aggregate_scores:
|
||||
print("Aggregating predictions of antimicrobial activity.")
|
||||
else:
|
||||
print("Returning predictions of antimicrobial activity for each compound-strain pair.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
主函数
|
||||
"""
|
||||
# 解析命令行参数
|
||||
args = parse_arguments()
|
||||
|
||||
# 创建配置对象
|
||||
config = PredictionConfig(
|
||||
xgboost_model_path=args.xgboost_model,
|
||||
mole_model_path=args.mole_model,
|
||||
strain_categories_path=args.strain_categories,
|
||||
gram_info_path=args.gram_information,
|
||||
app_threshold=args.app_threshold,
|
||||
min_nkill=args.min_nkill,
|
||||
batch_size=args.batch_size,
|
||||
n_workers=args.n_workers,
|
||||
device=args.device
|
||||
)
|
||||
|
||||
# 创建预测器
|
||||
predictor = ParallelBroadSpectrumPredictor(config)
|
||||
|
||||
# 根据输入类型处理
|
||||
if args.smiles_input:
|
||||
# 从文件中读取SMILES
|
||||
results = predictor.predict_from_file(
|
||||
args.input_filepath,
|
||||
smiles_column=args.smiles_colname,
|
||||
id_column=args.chemid_colname
|
||||
)
|
||||
else:
|
||||
# 从已有表示中读取 (MolE 特征向量)
|
||||
# TODO: 实现 predict_from_mole_representation 方法或调用相应API
|
||||
raise NotImplementedError("Processing from pre-computed representations is not yet implemented in the new API")
|
||||
|
||||
# 根据是否聚合结果进行输出
|
||||
# 根据是否聚合结果进行输出
|
||||
if args.aggregate_scores:
|
||||
print("Aggregating Antimicrobial potential")
|
||||
# 聚合模式:每行一个化合物,包含汇总统计
|
||||
results_df = pd.DataFrame([r.to_dict() for r in results])
|
||||
results_df.set_index('chem_id', inplace=True)
|
||||
results_df.to_csv(args.output_filepath, sep='\t')
|
||||
else:
|
||||
# 非聚合模式:每行一个化合物-菌株对,输出预测概率
|
||||
print("Generating non-aggregated predictions (compound-strain pairs)")
|
||||
rows = []
|
||||
for result in results:
|
||||
base_row = {'chem_id': result.chem_id}
|
||||
for strain, prob in result.predictions.items():
|
||||
row = base_row.copy()
|
||||
row['strain'] = strain
|
||||
row['prediction_probability'] = prob
|
||||
rows.append(row)
|
||||
results_df = pd.DataFrame(rows)
|
||||
results_df.to_csv(args.output_filepath, sep='\t', index=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Binary file not shown.
BIN
data/01.prepare_training_data/maier_ecfp4_representation.tsv.gz
Normal file
BIN
data/01.prepare_training_data/maier_ecfp4_representation.tsv.gz
Normal file
Binary file not shown.
BIN
data/01.prepare_training_data/maier_mole_representation.tsv.gz
Normal file
BIN
data/01.prepare_training_data/maier_mole_representation.tsv.gz
Normal file
Binary file not shown.
BIN
data/01.prepare_training_data/maier_scaffold_split.tsv.gz
Normal file
BIN
data/01.prepare_training_data/maier_scaffold_split.tsv.gz
Normal file
Binary file not shown.
BIN
data/01.prepare_training_data/maier_screening_results.tsv.gz
Normal file
BIN
data/01.prepare_training_data/maier_screening_results.tsv.gz
Normal file
Binary file not shown.
BIN
data/01.prepare_training_data/prestwick_library.tsv.gz
Normal file
BIN
data/01.prepare_training_data/prestwick_library.tsv.gz
Normal file
Binary file not shown.
BIN
data/01.prepare_training_data/prestwick_library_screened.tsv.gz
Normal file
BIN
data/01.prepare_training_data/prestwick_library_screened.tsv.gz
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
data/02.model_training/strain_performance.tsv.gz
Normal file
BIN
data/02.model_training/strain_performance.tsv.gz
Normal file
Binary file not shown.
BIN
data/03.model_evaluation/MolE-XGBoost-08.03.2024_14.20.pkl
Normal file
BIN
data/03.model_evaluation/MolE-XGBoost-08.03.2024_14.20.pkl
Normal file
Binary file not shown.
BIN
data/03.model_evaluation/chemDesc-XGBoost-08.03.2024_14.20.pkl
Normal file
BIN
data/03.model_evaluation/chemDesc-XGBoost-08.03.2024_14.20.pkl
Normal file
Binary file not shown.
BIN
data/03.model_evaluation/complete_test_predictions.tsv.gz
Normal file
BIN
data/03.model_evaluation/complete_test_predictions.tsv.gz
Normal file
Binary file not shown.
BIN
data/03.model_evaluation/ecfp4-XGBoost-08.03.2024_14.20.pkl
Normal file
BIN
data/03.model_evaluation/ecfp4-XGBoost-08.03.2024_14.20.pkl
Normal file
Binary file not shown.
BIN
data/03.model_evaluation/optimal_thresholds.tsv.gz
Normal file
BIN
data/03.model_evaluation/optimal_thresholds.tsv.gz
Normal file
Binary file not shown.
BIN
data/04.new_predictions/MolE_novel_abx.tsv.gz
Normal file
BIN
data/04.new_predictions/MolE_novel_abx.tsv.gz
Normal file
Binary file not shown.
Binary file not shown.
BIN
data/04.new_predictions/chemDesc_novel_abx.tsv.gz
Normal file
BIN
data/04.new_predictions/chemDesc_novel_abx.tsv.gz
Normal file
Binary file not shown.
Binary file not shown.
BIN
data/04.new_predictions/ecfp4_mce_predictions.xlsx
Normal file
BIN
data/04.new_predictions/ecfp4_mce_predictions.xlsx
Normal file
Binary file not shown.
BIN
data/04.new_predictions/ecfp4_mce_predictions_litsearch.xlsx
Normal file
BIN
data/04.new_predictions/ecfp4_mce_predictions_litsearch.xlsx
Normal file
Binary file not shown.
BIN
data/04.new_predictions/ecfp4_novel_abx.tsv.gz
Normal file
BIN
data/04.new_predictions/ecfp4_novel_abx.tsv.gz
Normal file
Binary file not shown.
Binary file not shown.
BIN
data/04.new_predictions/medchemexpress_filtered.tsv.gz
Normal file
BIN
data/04.new_predictions/medchemexpress_filtered.tsv.gz
Normal file
Binary file not shown.
BIN
data/04.new_predictions/mole_mce_predictions.xlsx
Normal file
BIN
data/04.new_predictions/mole_mce_predictions.xlsx
Normal file
Binary file not shown.
BIN
data/04.new_predictions/mole_mce_predictions_litsearch.xlsx
Normal file
BIN
data/04.new_predictions/mole_mce_predictions_litsearch.xlsx
Normal file
Binary file not shown.
BIN
data/04.new_predictions/novel_abx_smiles.tsv.gz
Normal file
BIN
data/04.new_predictions/novel_abx_smiles.tsv.gz
Normal file
Binary file not shown.
BIN
data/05.analyze_mce_predictions/ecfp4_mce_overview.pdf
Normal file
BIN
data/05.analyze_mce_predictions/ecfp4_mce_overview.pdf
Normal file
Binary file not shown.
BIN
data/05.analyze_mce_predictions/mce_pred_comparison.pdf
Normal file
BIN
data/05.analyze_mce_predictions/mce_pred_comparison.pdf
Normal file
Binary file not shown.
BIN
data/05.analyze_mce_predictions/mole_mce_overview.pdf
Normal file
BIN
data/05.analyze_mce_predictions/mole_mce_overview.pdf
Normal file
Binary file not shown.
BIN
data/06.experimental_validation/complete_od.tsv.gz
Normal file
BIN
data/06.experimental_validation/complete_od.tsv.gz
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
data/06.experimental_validation/growth_parameters.tsv.gz
Normal file
BIN
data/06.experimental_validation/growth_parameters.tsv.gz
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
data/06.experimental_validation/sa_ebastine_gparam_boxplots.pdf
Normal file
BIN
data/06.experimental_validation/sa_ebastine_gparam_boxplots.pdf
Normal file
Binary file not shown.
BIN
data/06.experimental_validation/sa_opicapone_gparam_boxplots.pdf
Normal file
BIN
data/06.experimental_validation/sa_opicapone_gparam_boxplots.pdf
Normal file
Binary file not shown.
BIN
data/06.experimental_validation/taxtree_circle.pdf
Normal file
BIN
data/06.experimental_validation/taxtree_circle.pdf
Normal file
Binary file not shown.
BIN
data/07.pubchem_exploration/ecfp4_umap.tsv.gz
Normal file
BIN
data/07.pubchem_exploration/ecfp4_umap.tsv.gz
Normal file
Binary file not shown.
BIN
data/07.pubchem_exploration/fps_dictionary.pkl
Normal file
BIN
data/07.pubchem_exploration/fps_dictionary.pkl
Normal file
Binary file not shown.
BIN
data/07.pubchem_exploration/mole_umap.tsv.gz
Normal file
BIN
data/07.pubchem_exploration/mole_umap.tsv.gz
Normal file
Binary file not shown.
BIN
data/07.pubchem_exploration/ranking_100000.tsv.gz
Normal file
BIN
data/07.pubchem_exploration/ranking_100000.tsv.gz
Normal file
Binary file not shown.
BIN
data/07.pubchem_exploration/ranking_100003.tsv.gz
Normal file
BIN
data/07.pubchem_exploration/ranking_100003.tsv.gz
Normal file
Binary file not shown.
BIN
data/07.pubchem_exploration/ranking_28465.tsv.gz
Normal file
BIN
data/07.pubchem_exploration/ranking_28465.tsv.gz
Normal file
Binary file not shown.
1403
data/07.pubchem_exploration/similar_molecules_chem100000.svg
Normal file
1403
data/07.pubchem_exploration/similar_molecules_chem100000.svg
Normal file
File diff suppressed because it is too large
Load Diff
|
After Width: | Height: | Size: 390 KiB |
1419
data/07.pubchem_exploration/similar_molecules_chem100003.svg
Normal file
1419
data/07.pubchem_exploration/similar_molecules_chem100003.svg
Normal file
File diff suppressed because it is too large
Load Diff
|
After Width: | Height: | Size: 403 KiB |
1241
data/07.pubchem_exploration/similar_molecules_chem28465.svg
Normal file
1241
data/07.pubchem_exploration/similar_molecules_chem28465.svg
Normal file
File diff suppressed because it is too large
Load Diff
|
After Width: | Height: | Size: 364 KiB |
BIN
data/08.compare_mce_maier/mole_joint_umap.pdf
Normal file
BIN
data/08.compare_mce_maier/mole_joint_umap.pdf
Normal file
Binary file not shown.
BIN
data/08.compare_mce_maier/mole_joint_umap.tsv.gz
Normal file
BIN
data/08.compare_mce_maier/mole_joint_umap.tsv.gz
Normal file
Binary file not shown.
5743
data/08.compare_mce_maier/selected_prestwick_tanimoto.tsv
Normal file
5743
data/08.compare_mce_maier/selected_prestwick_tanimoto.tsv
Normal file
File diff suppressed because it is too large
Load Diff
BIN
data/08.compare_mce_maier/similarity_to_maier.png
Normal file
BIN
data/08.compare_mce_maier/similarity_to_maier.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 230 KiB |
3277
data/08.compare_mce_maier/umap_library_comparison.tsv
Normal file
3277
data/08.compare_mce_maier/umap_library_comparison.tsv
Normal file
File diff suppressed because it is too large
Load Diff
BIN
data/Figures Source Data.zip
Normal file
BIN
data/Figures Source Data.zip
Normal file
Binary file not shown.
BIN
data/Source Data.zip
Normal file
BIN
data/Source Data.zip
Normal file
Binary file not shown.
131
detailed_analysis.py
Normal file
131
detailed_analysis.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
详细分析广谱抗菌预测的计算过程
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from scipy.stats.mstats import gmean
|
||||
from broad_spectrum_api import ParallelBroadSpectrumPredictor, MoleculeInput
|
||||
|
||||
def analyze_prediction_process():
|
||||
"""详细分析预测过程"""
|
||||
print("=== 广谱抗菌预测详细分析 ===\n")
|
||||
|
||||
# 创建预测器
|
||||
predictor = ParallelBroadSpectrumPredictor()
|
||||
|
||||
# 测试分子
|
||||
molecule = MoleculeInput(smiles="CCO", chem_id="ethanol")
|
||||
|
||||
print("1. 基本信息:")
|
||||
print(f" - 总菌株数量: {len(predictor.maier_screen.columns)}")
|
||||
print(f" - 革兰阳性菌: {(predictor.maier_strains['Gram stain'] == 'positive').sum()}")
|
||||
print(f" - 革兰阴性菌: {(predictor.maier_strains['Gram stain'] == 'negative').sum()}")
|
||||
print(f" - 抑制阈值: {predictor.config.app_threshold}")
|
||||
print(f" - 广谱标准: 抑制≥{predictor.config.min_nkill}个菌株")
|
||||
|
||||
# 获取MolE表示
|
||||
print("\n2. 获取分子表示...")
|
||||
mole_representation = predictor._get_mole_representation([molecule])
|
||||
print(f" - MolE特征维度: {mole_representation.shape}")
|
||||
|
||||
# 添加菌株信息
|
||||
print("\n3. 构建预测特征...")
|
||||
X_input = predictor._add_strains(mole_representation)
|
||||
print(f" - 预测样本数: {len(X_input)} (1个分子 × {len(predictor.maier_screen.columns)}个菌株)")
|
||||
print(f" - 特征维度: {X_input.shape[1]}")
|
||||
|
||||
# 进行预测
|
||||
print("\n4. 模型预测...")
|
||||
import pickle
|
||||
with open(predictor.config.xgboost_model_path, "rb") as file:
|
||||
model = pickle.load(file)
|
||||
|
||||
y_pred = model.predict_proba(X_input)
|
||||
pred_df = pd.DataFrame(y_pred, columns=["0", "1"], index=X_input.index)
|
||||
|
||||
# 显示预测概率统计
|
||||
print(f" - 抑制概率范围: {pred_df['1'].min():.6f} - {pred_df['1'].max():.6f}")
|
||||
print(f" - 抑制概率均值: {pred_df['1'].mean():.6f}")
|
||||
print(f" - 抑制概率中位数: {pred_df['1'].median():.6f}")
|
||||
|
||||
# 二值化预测
|
||||
pred_df["growth_inhibition"] = pred_df["1"].apply(
|
||||
lambda x: 1 if x >= predictor.config.app_threshold else 0
|
||||
)
|
||||
|
||||
inhibited_count = pred_df["growth_inhibition"].sum()
|
||||
print(f" - 超过阈值的菌株数: {inhibited_count}")
|
||||
|
||||
# 分析抗菌分数计算
|
||||
print("\n5. 抗菌分数计算:")
|
||||
|
||||
# 计算几何平均数
|
||||
geometric_mean = gmean(pred_df["1"])
|
||||
log_geometric_mean = np.log(geometric_mean)
|
||||
|
||||
print(f" - 所有概率的几何平均数: {geometric_mean:.10f}")
|
||||
print(f" - 几何平均数的对数: {log_geometric_mean:.6f}")
|
||||
|
||||
# 显示概率分布
|
||||
print(f"\n6. 概率分布分析:")
|
||||
print(f" - 概率 < 0.001: {(pred_df['1'] < 0.001).sum()} 个菌株")
|
||||
print(f" - 概率 0.001-0.01: {((pred_df['1'] >= 0.001) & (pred_df['1'] < 0.01)).sum()} 个菌株")
|
||||
print(f" - 概率 0.01-0.1: {((pred_df['1'] >= 0.01) & (pred_df['1'] < 0.1)).sum()} 个菌株")
|
||||
print(f" - 概率 ≥ 0.1: {(pred_df['1'] >= 0.1).sum()} 个菌株")
|
||||
|
||||
# 显示最高和最低的几个预测
|
||||
print(f"\n7. 预测详情 (前5高和前5低):")
|
||||
sorted_pred = pred_df.sort_values("1", ascending=False)
|
||||
|
||||
print(" 最高抑制概率:")
|
||||
for i, (idx, row) in enumerate(sorted_pred.head().iterrows()):
|
||||
strain_name = idx.split(":")[1]
|
||||
print(f" {i+1}. {strain_name}: {row['1']:.6f}")
|
||||
|
||||
print(" 最低抑制概率:")
|
||||
for i, (idx, row) in enumerate(sorted_pred.tail().iterrows()):
|
||||
strain_name = idx.split(":")[1]
|
||||
print(f" {i+1}. {strain_name}: {row['1']:.6f}")
|
||||
|
||||
# 完整预测结果
|
||||
print(f"\n8. 最终结果:")
|
||||
result = predictor.predict_single(molecule)
|
||||
print(f" - 总抗菌分数: {result.apscore_total:.6f}")
|
||||
print(f" - 革兰阴性菌分数: {result.apscore_gnegative:.6f}")
|
||||
print(f" - 革兰阳性菌分数: {result.apscore_gpositive:.6f}")
|
||||
print(f" - 抑制菌株总数: {result.ginhib_total}")
|
||||
print(f" - 抑制革兰阴性菌: {result.ginhib_gnegative}")
|
||||
print(f" - 抑制革兰阳性菌: {result.ginhib_gpositive}")
|
||||
print(f" - 广谱抗菌: {'是' if result.broad_spectrum else '否'}")
|
||||
|
||||
def compare_different_molecules():
|
||||
"""比较不同分子的预测结果"""
|
||||
print("\n\n=== 不同分子对比分析 ===\n")
|
||||
|
||||
predictor = ParallelBroadSpectrumPredictor()
|
||||
|
||||
# 测试不同类型的分子
|
||||
molecules = [
|
||||
MoleculeInput(smiles="CCO", chem_id="ethanol"),
|
||||
MoleculeInput(smiles="CC(=O)O", chem_id="acetic_acid"),
|
||||
MoleculeInput(smiles="CCN", chem_id="ethylamine"),
|
||||
MoleculeInput(smiles="c1ccccc1", chem_id="benzene"),
|
||||
MoleculeInput(smiles="CC(C)O", chem_id="isopropanol"),
|
||||
]
|
||||
|
||||
results = predictor.predict_batch(molecules)
|
||||
|
||||
print("分子对比结果:")
|
||||
print("-" * 80)
|
||||
print(f"{'分子':<15} {'SMILES':<12} {'抗菌分数':<10} {'抑制数':<8} {'广谱':<6}")
|
||||
print("-" * 80)
|
||||
|
||||
for result in results:
|
||||
mol_info = next(m for m in molecules if m.chem_id == result.chem_id)
|
||||
print(f"{result.chem_id:<15} {mol_info.smiles:<12} {result.apscore_total:<10.3f} "
|
||||
f"{result.ginhib_total:<8} {'是' if result.broad_spectrum else '否':<6}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
analyze_prediction_process()
|
||||
compare_different_molecules()
|
||||
120
example_usage.py
Normal file
120
example_usage.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
广谱抗菌预测API使用示例
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from broad_spectrum_api import (
|
||||
ParallelBroadSpectrumPredictor,
|
||||
PredictionConfig,
|
||||
MoleculeInput,
|
||||
BroadSpectrumResult,
|
||||
predict_smiles,
|
||||
predict_file
|
||||
)
|
||||
|
||||
|
||||
def example_single_prediction():
|
||||
"""单分子预测示例"""
|
||||
print("=== 单分子预测示例 ===")
|
||||
|
||||
# 创建预测器
|
||||
predictor = ParallelBroadSpectrumPredictor()
|
||||
|
||||
# 预测单个分子
|
||||
molecule = MoleculeInput(smiles="CCO", chem_id="ethanol")
|
||||
result = predictor.predict_single(molecule)
|
||||
|
||||
print(f"化合物: {result.chem_id}")
|
||||
print(f"广谱抗菌: {'是' if result.broad_spectrum else '否'}")
|
||||
print(f"抑制菌株数: {result.ginhib_total}")
|
||||
print(f"抗菌分数: {result.apscore_total:.3f}")
|
||||
|
||||
|
||||
def example_batch_prediction():
|
||||
"""批量预测示例"""
|
||||
print("\n=== 批量预测示例 ===")
|
||||
|
||||
# 创建预测器
|
||||
config = PredictionConfig(n_workers=4, batch_size=50)
|
||||
predictor = ParallelBroadSpectrumPredictor(config)
|
||||
|
||||
# 准备多个分子
|
||||
molecules = [
|
||||
MoleculeInput(smiles="CCO", chem_id="ethanol"),
|
||||
MoleculeInput(smiles="CCN", chem_id="ethylamine"),
|
||||
MoleculeInput(smiles="CC(=O)O", chem_id="acetic_acid"),
|
||||
]
|
||||
|
||||
# 批量预测
|
||||
results = predictor.predict_batch(molecules)
|
||||
|
||||
# 输出结果
|
||||
for result in results:
|
||||
print(f"{result.chem_id}: 广谱={result.broad_spectrum}, 抑制数={result.ginhib_total}")
|
||||
|
||||
|
||||
def example_smiles_list_prediction():
|
||||
"""SMILES列表预测示例"""
|
||||
print("\n=== SMILES列表预测示例 ===")
|
||||
|
||||
smiles_list = ["CCO", "CCN", "CC(=O)O"]
|
||||
chem_ids = ["ethanol", "ethylamine", "acetic_acid"]
|
||||
|
||||
# 使用便捷函数
|
||||
results = predict_smiles(smiles_list, chem_ids)
|
||||
|
||||
# 统计广谱抗菌化合物
|
||||
broad_spectrum_count = sum(1 for r in results if r.broad_spectrum)
|
||||
print(f"广谱抗菌化合物: {broad_spectrum_count}/{len(results)}")
|
||||
|
||||
|
||||
def example_file_prediction():
|
||||
"""文件预测示例"""
|
||||
print("\n=== 文件预测示例 ===")
|
||||
|
||||
# 假设有输入文件 molecules.tsv
|
||||
try:
|
||||
results = predict_file(
|
||||
"molecules.tsv",
|
||||
smiles_column="smiles",
|
||||
id_column="compound_id"
|
||||
)
|
||||
|
||||
# 保存结果
|
||||
import pandas as pd
|
||||
results_df = pd.DataFrame([r.to_dict() for r in results])
|
||||
results_df.to_csv("broad_spectrum_results.csv", index=False)
|
||||
print(f"预测完成,结果保存到 broad_spectrum_results.csv")
|
||||
|
||||
except FileNotFoundError:
|
||||
print("输入文件不存在,跳过文件预测示例")
|
||||
|
||||
|
||||
def example_custom_config():
|
||||
"""自定义配置示例"""
|
||||
print("\n=== 自定义配置示例 ===")
|
||||
|
||||
# 自定义配置
|
||||
config = PredictionConfig(
|
||||
app_threshold=0.1, # 更严格的抑制阈值
|
||||
min_nkill=15, # 更高的广谱标准
|
||||
n_workers=8, # 更多并行进程
|
||||
batch_size=200 # 更大的批次
|
||||
)
|
||||
|
||||
predictor = ParallelBroadSpectrumPredictor(config)
|
||||
|
||||
# 预测
|
||||
molecules = [MoleculeInput(smiles="CCO", chem_id="ethanol")]
|
||||
results = predictor.predict_batch(molecules)
|
||||
|
||||
print(f"使用自定义配置预测结果: {results[0].to_dict()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行所有示例
|
||||
example_single_prediction()
|
||||
example_batch_prediction()
|
||||
example_smiles_list_prediction()
|
||||
example_file_prediction()
|
||||
example_custom_config()
|
||||
181
mole_representation.py
Normal file
181
mole_representation.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import os
|
||||
import yaml
|
||||
import argparse
|
||||
import torch
|
||||
import pandas as pd
|
||||
|
||||
from sklearn.preprocessing import OneHotEncoder
|
||||
from xgboost import XGBClassifier
|
||||
|
||||
from rdkit import Chem
|
||||
from rdkit import RDLogger
|
||||
|
||||
from workflow.dataset.dataset_representation import batch_representation
|
||||
from workflow.models.ginet_concat import GINet
|
||||
|
||||
RDLogger.DisableLog('rdApp.*')
|
||||
|
||||
|
||||
|
||||
# Function to read command line arguments
|
||||
def parse_arguments():
|
||||
|
||||
"""
|
||||
This function returns parsed command line arguments.
|
||||
|
||||
"""
|
||||
|
||||
# Instantiate parser
|
||||
parser = argparse.ArgumentParser(prog="Represent molecular structures as using MolE.",
|
||||
description="This program recieves a file with SMILES and represents them using the MolE representation.",
|
||||
usage="python mole_representation.py smiles_filepath output_filepath [options]",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
|
||||
# Input SMILES
|
||||
parser.add_argument("smiles_filepath", help="Complete path to the smiles filepath. Expects a TSV file with a column containing SMILES strings.")
|
||||
|
||||
# Output filepath
|
||||
parser.add_argument("output_filepath", help="Complete path for the output.")
|
||||
|
||||
# Column name for smiles
|
||||
parser.add_argument("-c", "--smiles_colname", help="Column name in smiles_filepath that contains the SMILES.",
|
||||
default="smiles")
|
||||
|
||||
# Column name for id
|
||||
parser.add_argument("-i", "--chemid_colname", help="Column name in smiles_filepath that contains the ID string of each chemical.",
|
||||
default="chem_id")
|
||||
|
||||
# MolE model
|
||||
parser.add_argument("-m", "--mole_model", help="Path to the directory containing the config.yaml and model.pth files of the pre-trained MolE chemical representation.",
|
||||
default="pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001")
|
||||
|
||||
# Device
|
||||
parser.add_argument("-d", "--device", help="Device where the pre-trained model is loaded. Can be one of ['cpu', 'cuda', 'auto']. If 'auto' (default) then cuda:0 device is selected if a GPU is detected.",
|
||||
default="auto")
|
||||
|
||||
# Parse arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
# Determine device for MolE model
|
||||
if args.device == "auto":
|
||||
args.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
print(f"Using {args.device}")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
# A FUNCTION TO READ SMILES from file
|
||||
def read_smiles(data_path, smile_col="rdkit_no_salt", id_col="prestwick_ID"):
|
||||
|
||||
"""
|
||||
Read SMILES data from a file or DataFrame and remove invalid SMILES.
|
||||
|
||||
Parameters:
|
||||
- data_path (str or pd.DataFrame): Path to the file or a DataFrame containing SMILES data.
|
||||
- smile_col (str, optional): Name of the column containing SMILES strings.
|
||||
- id_col (str, optional): Name of the column containing molecule IDs.
|
||||
|
||||
Returns:
|
||||
- smile_df (pandas.DataFrame): DataFrame containing SMILES data with specified columns.
|
||||
"""
|
||||
|
||||
# Read the data
|
||||
if isinstance(data_path, pd.DataFrame):
|
||||
smile_df = data_path.copy()
|
||||
else:
|
||||
smile_df = pd.read_csv(data_path, sep='\t')
|
||||
smile_df = smile_df[[smile_col, id_col]]
|
||||
|
||||
# Make sure ID column is interpreted as str
|
||||
smile_df[id_col] = smile_df[id_col].astype(str)
|
||||
|
||||
# Remove NaN
|
||||
smile_df = smile_df.dropna()
|
||||
|
||||
# Remove invalid smiles
|
||||
smile_df = smile_df[smile_df[smile_col].apply(lambda x: Chem.MolFromSmiles(x) is not None)]
|
||||
|
||||
return smile_df
|
||||
|
||||
|
||||
# Function to load a pre-trained model
|
||||
def load_pretrained_model(pretrained_model_dir, device="cuda:0"):
|
||||
|
||||
"""
|
||||
Load a pre-trained MolE model.
|
||||
|
||||
Parameters:
|
||||
- pretrained_model_dir (str): Name of the pre-trained MolE model.
|
||||
- device (str, optional): Device for computation (default is "cuda:0").
|
||||
|
||||
Returns:
|
||||
- model: Loaded pre-trained model.
|
||||
"""
|
||||
|
||||
# Read model configuration
|
||||
config = yaml.load(open(os.path.join(pretrained_model_dir, "config.yaml"), "r"), Loader=yaml.FullLoader)
|
||||
model_config = config["model"]
|
||||
|
||||
# Instantiate model
|
||||
model = GINet(**model_config).to(device)
|
||||
|
||||
# Load pre-trained weights
|
||||
model_pth_path = os.path.join(pretrained_model_dir, "model.pth")
|
||||
print(model_pth_path)
|
||||
|
||||
state_dict = torch.load(model_pth_path, map_location=device)
|
||||
model.load_my_state_dict(state_dict)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def process_representation(dataset_path, smile_column_str, id_column_str, pretrained_dir, device):
|
||||
|
||||
"""
|
||||
Process the dataset to generate molecular representations.
|
||||
|
||||
Parameters:
|
||||
- dataset_path (str): Path to the dataset file.
|
||||
- pretrained_dir (str): Name of the pre-trained model.
|
||||
- smile_column_str (str, optional): Name of the column containing SMILES strings.
|
||||
- id_column_str (str, optional): Name of the column containing molecule IDs.
|
||||
- device (str): Device to use for computation (default is "cuda:0"). Can also be "cpu".
|
||||
|
||||
Returns:
|
||||
- udl_representation (pandas.DataFrame): DataFrame containing molecular representations if split_data=False.
|
||||
"""
|
||||
|
||||
# First we read the SMILES dataframe
|
||||
smiles_df = read_smiles(dataset_path, smile_col=smile_column_str, id_col=id_column_str)
|
||||
|
||||
# Load the pre-trained model
|
||||
pmodel = load_pretrained_model(pretrained_model_dir=pretrained_dir, device=device)
|
||||
|
||||
# Gather pre-trained representation
|
||||
udl_representation = batch_representation(smiles_df, pmodel, smile_column_str, id_column_str, device=device)
|
||||
|
||||
return udl_representation
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Parse arguments
|
||||
args = parse_arguments()
|
||||
|
||||
# Obtain MolE pre-trained representation
|
||||
mole_representation = process_representation(dataset_path = args.smiles_filepath,
|
||||
smile_column_str = args.smiles_colname,
|
||||
id_column_str = args.chemid_colname,
|
||||
|
||||
|
||||
pretrained_dir = args.mole_model,
|
||||
device=args.device)
|
||||
|
||||
# Write MolE representation to output
|
||||
mole_representation.to_csv(args.output_filepath, sep='\t')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
batch_size: 1000 # batch size
|
||||
warm_up: 10 # warm-up epochs
|
||||
epochs: 1000 # total number of epochs
|
||||
|
||||
load_model: None # resume training
|
||||
eval_every_n_epochs: 1 # validation frequency
|
||||
save_every_n_epochs: 5 # automatic model saving frequecy
|
||||
|
||||
fp16_precision: False # float precision 16 (i.e. True/False)
|
||||
init_lr: 0.0005 # initial learning rate for Adam
|
||||
weight_decay: 1e-5 # weight decay for Adam
|
||||
gpu: cuda:0 # training GPU
|
||||
|
||||
model_type: gin_concat # GNN backbone (i.e., gin/gcn)
|
||||
model:
|
||||
num_layer: 5 # number of graph conv layers
|
||||
emb_dim: 200 # embedding dimension in graph conv layers
|
||||
feat_dim: 8000 # output feature dimention
|
||||
drop_ratio: 0.0 # dropout ratio
|
||||
pool: add # readout pooling (i.e., mean/max/add)
|
||||
|
||||
dataset:
|
||||
num_workers: 50 # dataloader number of workers
|
||||
valid_size: 0.1 # ratio of validation data
|
||||
data_path: data/pubchem_data/pubchem_100k_random.txt # path of pre-training data
|
||||
|
||||
loss:
|
||||
l: 0.0001 # Lambda parameter
|
||||
BIN
raw_data/experimental_validation/lib_map.xlsx
Normal file
BIN
raw_data/experimental_validation/lib_map.xlsx
Normal file
Binary file not shown.
BIN
raw_data/experimental_validation/od_files/EC_UTI_1_OD.tsv.gz
Normal file
BIN
raw_data/experimental_validation/od_files/EC_UTI_1_OD.tsv.gz
Normal file
Binary file not shown.
BIN
raw_data/experimental_validation/od_files/EC_UTI_2_OD.tsv.gz
Normal file
BIN
raw_data/experimental_validation/od_files/EC_UTI_2_OD.tsv.gz
Normal file
Binary file not shown.
BIN
raw_data/experimental_validation/od_files/EC_UTI_3_OD.tsv.gz
Normal file
BIN
raw_data/experimental_validation/od_files/EC_UTI_3_OD.tsv.gz
Normal file
Binary file not shown.
BIN
raw_data/experimental_validation/od_files/EC_iAi_1_OD.tsv.gz
Normal file
BIN
raw_data/experimental_validation/od_files/EC_iAi_1_OD.tsv.gz
Normal file
Binary file not shown.
BIN
raw_data/experimental_validation/od_files/EC_iAi_2_OD.tsv.gz
Normal file
BIN
raw_data/experimental_validation/od_files/EC_iAi_2_OD.tsv.gz
Normal file
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user