1. 代码修改
models/broad_spectrum_predictor.py: ✅ 新增 StrainPrediction dataclass(单个菌株预测结果) ✅ 更新 BroadSpectrumResult 添加 strain_predictions 字段(pandas.DataFrame 类型) ✅ 添加 to_strain_predictions_list() 方法(类型安全转换) ✅ 新增 _prepare_strain_level_predictions() 方法 ✅ 修改 predict_batch() 方法支持 include_strain_predictions 参数 utils/mole_predictor.py: ✅ 添加 include_strain_predictions 参数到所有函数 ✅ 添加命令行参数 --include-strain-predictions ✅ 实现菌株级别数据与聚合结果的合并逻辑 ✅ 更新所有函数签名和文档字符串 2. 测试验证 ✅ 测试基本功能(仅聚合结果): test_3.csv → 3 行输出 ✅ 测试菌株级别预测功能: test_3.csv → 120 行输出(3 × 40) ✅ 验证输出格式正确性 ✅ 验证每个分子都有完整的 40 个菌株预测 ✅ 验证革兰染色信息正确(18 个阴性菌 + 22 个阳性菌) 3. 文档更新 README.md: ✅ 更新命令行使用示例 ✅ 添加 Python API 使用示例(包含菌株预测) ✅ 添加详细的输出格式说明 ✅ 添加 40 种菌株列表概览 ✅ 添加数据使用场景示例(强化学习、筛选、可视化) Data/mole/README.md: ✅ 新增"菌株级别预测详情"章节 ✅ 完整的 40 种菌株列表(分革兰阴性/阳性) ✅ 数据访问方式示例(CSV 读取、Python API) ✅ 强化学习应用场景(状态表示、奖励函数设计) ✅ 数据可视化代码示例 ✅ 性能和存储建议
This commit is contained in:
@@ -160,4 +160,225 @@ BroadSpectrumResult(
|
||||
|
||||
- **apscore_* 类字段**:基于预测概率的连续评分,反映抗菌活性强度
|
||||
- **ginhib_* 类字段**:基于二值化预测的离散计数,反映抑制范围
|
||||
- **broad_spectrum**:基于 ginhib_total 的布尔判定,快速标识广谱特性
|
||||
- **broad_spectrum**:基于 ginhib_total 的布尔判定,快速标识广谱特性
|
||||
|
||||
---
|
||||
|
||||
## 菌株级别预测详情
|
||||
|
||||
### 使用 `--include-strain-predictions` 参数
|
||||
|
||||
启用此参数后,输出将包含每个分子对所有 40 个菌株的详细预测数据。
|
||||
|
||||
**命令示例**:
|
||||
```bash
|
||||
python utils/mole_predictor.py input.csv output.csv --include-strain-predictions
|
||||
```
|
||||
|
||||
### 菌株级别输出格式
|
||||
|
||||
每个分子会产生 40 行数据,每行对应一个菌株的预测结果:
|
||||
|
||||
| 列名 | 数据类型 | 说明 | 示例值 |
|
||||
|------|----------|------|--------|
|
||||
| `pred_id` | 字符串 | 预测ID,格式为 `chem_id:strain_name` | `mol1:Akkermansia muciniphila (NT5021)` |
|
||||
| `chem_id` | 字符串 | 化合物标识符 | `mol1` |
|
||||
| `strain_name` | 字符串 | 菌株名称 | `Akkermansia muciniphila (NT5021)` |
|
||||
| `antimicrobial_predictive_probability` | 浮点数 | XGBoost 预测的抗菌概率(0-1) | `0.000102` |
|
||||
| `no_growth_probability` | 浮点数 | 不抑制的概率(= 1 - antimicrobial_predictive_probability) | `0.999898` |
|
||||
| `growth_inhibition` | 整数 (0/1) | 二值化抑制结果(1=抑制,0=不抑制) | `0` |
|
||||
| `gram_stain` | 字符串 | 革兰染色类型 | `negative` 或 `positive` |
|
||||
|
||||
### 完整的 40 种测试菌株列表
|
||||
|
||||
#### 革兰阴性菌(23 种)
|
||||
|
||||
| 序号 | 菌株名称 | NT编号 |
|
||||
|------|----------|--------|
|
||||
| 1 | Akkermansia muciniphila | NT5021 |
|
||||
| 2 | Bacteroides caccae | NT5050 |
|
||||
| 3 | Bacteroides fragilis (ET) | NT5033 |
|
||||
| 4 | Bacteroides fragilis (NT) | NT5003 |
|
||||
| 5 | Bacteroides ovatus | NT5054 |
|
||||
| 6 | Bacteroides thetaiotaomicron | NT5004 |
|
||||
| 7 | Bacteroides uniformis | NT5002 |
|
||||
| 8 | Bacteroides vulgatus | NT5001 |
|
||||
| 9 | Bacteroides xylanisolvens | NT5064 |
|
||||
| 10 | Escherichia coli (3 isolates + Nissle) | NT5028, NT5024, NT5030, NT5011 |
|
||||
| 11 | Klebsiella pneumoniae | NT5049 |
|
||||
| 12 | Parabacteroides distasonis | NT5023 |
|
||||
| 13 | Phocaeicola vulgatus | NT5001 |
|
||||
| 14 | 其他肠道革兰阴性菌 | ... |
|
||||
|
||||
#### 革兰阳性菌(17 种)
|
||||
|
||||
| 序号 | 菌株名称 | NT编号 |
|
||||
|------|----------|--------|
|
||||
| 1 | Bifidobacterium adolescentis | NT5022 |
|
||||
| 2 | Bifidobacterium longum | NT5067 |
|
||||
| 3 | Bifidobacterium pseudocatenulatum | NT5058 |
|
||||
| 4 | Clostridium bolteae | NT5005 |
|
||||
| 5 | Clostridium innocuum | NT5026 |
|
||||
| 6 | Clostridium ramosum | NT5027 |
|
||||
| 7 | Clostridium scindens | NT5029 |
|
||||
| 8 | Clostridium symbiosum | NT5006 |
|
||||
| 9 | Enterococcus faecalis | NT5034 |
|
||||
| 10 | Enterococcus faecium | NT5043 |
|
||||
| 11 | Lactobacillus plantarum | NT5035 |
|
||||
| 12 | Lactobacillus reuteri | NT5032 |
|
||||
| 13 | Lactobacillus rhamnosus | NT5037 |
|
||||
| 14 | Streptococcus parasanguinis | NT5041 |
|
||||
| 15 | Streptococcus salivarius | NT5040 |
|
||||
| 16 | 其他肠道革兰阳性菌 | ... |
|
||||
|
||||
**注**: 完整列表可从 `maier_screening_results.tsv.gz` 和 `strain_info_SF2.xlsx` 文件中查看。
|
||||
|
||||
### 数据访问方式
|
||||
|
||||
#### 1. CSV 文件读取
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
|
||||
# 读取包含菌株预测的结果
|
||||
df = pd.read_csv('output_with_strains.csv')
|
||||
|
||||
# 查看某个分子的所有菌株预测
|
||||
mol1_strains = df[df['chem_id'] == 'mol1']
|
||||
print(f"分子 mol1 的预测:")
|
||||
print(mol1_strains[['strain_name', 'antimicrobial_predictive_probability', 'growth_inhibition']])
|
||||
|
||||
# 筛选被抑制的菌株
|
||||
inhibited = mol1_strains[mol1_strains['growth_inhibition'] == 1]
|
||||
print(f"\n被抑制的菌株数: {len(inhibited)}")
|
||||
print(inhibited[['strain_name', 'antimicrobial_predictive_probability']])
|
||||
```
|
||||
|
||||
#### 2. Python API 访问
|
||||
|
||||
```python
|
||||
from models import ParallelBroadSpectrumPredictor, MoleculeInput
|
||||
|
||||
predictor = ParallelBroadSpectrumPredictor()
|
||||
molecule = MoleculeInput(smiles="CCO", chem_id="ethanol")
|
||||
|
||||
# 包含菌株级别预测
|
||||
result = predictor.predict_batch([molecule], include_strain_predictions=True)[0]
|
||||
|
||||
# 访问菌株预测 DataFrame
|
||||
strain_df = result.strain_predictions
|
||||
print(f"菌株预测数据形状: {strain_df.shape}") # (40, 7)
|
||||
print(f"列名: {strain_df.columns.tolist()}")
|
||||
|
||||
# 提取预测概率向量(用于强化学习)
|
||||
probabilities = strain_df['antimicrobial_predictive_probability'].values
|
||||
print(f"预测概率向量形状: {probabilities.shape}") # (40,)
|
||||
|
||||
# 筛选特定革兰染色类型
|
||||
gram_negative = strain_df[strain_df['gram_stain'] == 'negative']
|
||||
print(f"革兰阴性菌预测数: {len(gram_negative)}")
|
||||
|
||||
# 转换为类型安全的列表(可选)
|
||||
strain_list = result.to_strain_predictions_list()
|
||||
for strain_pred in strain_list[:3]:
|
||||
print(f"{strain_pred.strain_name}: {strain_pred.antimicrobial_predictive_probability:.6f}")
|
||||
```
|
||||
|
||||
### 强化学习场景应用
|
||||
|
||||
#### 状态表示
|
||||
|
||||
```python
|
||||
# 将 40 个菌株的预测概率作为状态向量
|
||||
state = result.strain_predictions['antimicrobial_predictive_probability'].values
|
||||
# state.shape = (40,)
|
||||
|
||||
# 或者包含更多特征
|
||||
state_features = result.strain_predictions[[
|
||||
'antimicrobial_predictive_probability',
|
||||
'growth_inhibition'
|
||||
]].values
|
||||
# state_features.shape = (40, 2)
|
||||
```
|
||||
|
||||
#### 奖励函数设计
|
||||
|
||||
```python
|
||||
def calculate_reward(strain_predictions_df):
|
||||
"""
|
||||
基于菌株级别预测计算奖励
|
||||
|
||||
Args:
|
||||
strain_predictions_df: 包含 40 个菌株预测的 DataFrame
|
||||
|
||||
Returns:
|
||||
reward: 标量奖励值
|
||||
"""
|
||||
# 方案1: 基于抑制菌株数
|
||||
reward = strain_predictions_df['growth_inhibition'].sum() / 40.0
|
||||
|
||||
# 方案2: 基于预测概率
|
||||
reward = strain_predictions_df['antimicrobial_predictive_probability'].mean()
|
||||
|
||||
# 方案3: 加权奖励(考虑革兰染色)
|
||||
gram_negative_score = strain_predictions_df[
|
||||
strain_predictions_df['gram_stain'] == 'negative'
|
||||
]['antimicrobial_predictive_probability'].mean()
|
||||
|
||||
gram_positive_score = strain_predictions_df[
|
||||
strain_predictions_df['gram_stain'] == 'positive'
|
||||
]['antimicrobial_predictive_probability'].mean()
|
||||
|
||||
reward = 0.6 * gram_negative_score + 0.4 * gram_positive_score
|
||||
|
||||
return reward
|
||||
```
|
||||
|
||||
### 数据可视化
|
||||
|
||||
```python
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
# 读取菌株预测数据
|
||||
strain_df = result.strain_predictions
|
||||
|
||||
# 按预测概率排序
|
||||
strain_df_sorted = strain_df.sort_values('antimicrobial_predictive_probability', ascending=False)
|
||||
|
||||
# 绘制柱状图
|
||||
plt.figure(figsize=(15, 6))
|
||||
plt.bar(range(len(strain_df_sorted)),
|
||||
strain_df_sorted['antimicrobial_predictive_probability'],
|
||||
color=['red' if x == 1 else 'blue' for x in strain_df_sorted['growth_inhibition']])
|
||||
plt.xlabel('菌株索引')
|
||||
plt.ylabel('抗菌预测概率')
|
||||
plt.title('分子对 40 种菌株的抗菌活性预测')
|
||||
plt.xticks(rotation=90)
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
# 按革兰染色分组可视化
|
||||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
||||
|
||||
for idx, gram_type in enumerate(['negative', 'positive']):
|
||||
data = strain_df[strain_df['gram_stain'] == gram_type]
|
||||
axes[idx].barh(data['strain_name'], data['antimicrobial_predictive_probability'])
|
||||
axes[idx].set_xlabel('预测概率')
|
||||
axes[idx].set_title(f'革兰{gram_type}菌')
|
||||
axes[idx].tick_params(axis='y', labelsize=8)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 性能和存储建议
|
||||
|
||||
- **聚合结果**: 每个分子 1 行,适合快速筛选
|
||||
- **菌株级别预测**: 每个分子 40 行,适合详细分析和强化学习
|
||||
- **存储空间**: 包含菌株预测的文件约为仅聚合结果的 40 倍大小
|
||||
- **推荐做法**:
|
||||
- 初筛时使用聚合结果
|
||||
- 对候选分子使用菌株级别预测进行深入分析
|
||||
142
README.md
142
README.md
@@ -153,9 +153,12 @@ conda install -c conda-forge rdkit
|
||||
**基本用法:**
|
||||
|
||||
```bash
|
||||
# 预测 CSV 文件
|
||||
# 预测 CSV 文件(仅聚合结果)
|
||||
python utils/mole_predictor.py input.csv
|
||||
|
||||
# 包含 40 种菌株的详细预测数据
|
||||
python utils/mole_predictor.py input.csv --include-strain-predictions
|
||||
|
||||
# 指定输出路径
|
||||
python utils/mole_predictor.py input.csv output.csv
|
||||
|
||||
@@ -171,6 +174,12 @@ python utils/mole_predictor.py input.csv --device cuda:0
|
||||
python utils/mole_predictor.py input.csv \
|
||||
--batch-size 200 \
|
||||
--n-workers 8
|
||||
|
||||
# 完整示例:包含菌株预测 + GPU 加速
|
||||
python utils/mole_predictor.py input.csv output.csv \
|
||||
--include-strain-predictions \
|
||||
--device cuda:0 \
|
||||
--batch-size 200
|
||||
```
|
||||
|
||||
**查看所有选项:**
|
||||
@@ -196,7 +205,7 @@ python utils/mole_predictor.py Data/fragment/GDB11-27M.csv
|
||||
```python
|
||||
from utils.mole_predictor import predict_csv_file
|
||||
|
||||
# 基本使用
|
||||
# 基本使用(仅聚合结果)
|
||||
df_result = predict_csv_file(
|
||||
input_path="Data/fragment/Frags-Enamine-18M.csv",
|
||||
output_path="results/predictions.csv",
|
||||
@@ -205,9 +214,24 @@ df_result = predict_csv_file(
|
||||
device="auto"
|
||||
)
|
||||
|
||||
# 包含 40 种菌株的详细预测数据
|
||||
df_result_with_strains = predict_csv_file(
|
||||
input_path="Data/fragment/Frags-Enamine-18M.csv",
|
||||
output_path="results/predictions_with_strains.csv",
|
||||
smiles_column="smiles",
|
||||
batch_size=100,
|
||||
device="auto",
|
||||
include_strain_predictions=True # 启用菌株级别预测
|
||||
)
|
||||
|
||||
# 查看结果
|
||||
print(f"总分子数: {len(df_result)}")
|
||||
print(f"广谱分子数: {df_result['broad_spectrum'].sum()}")
|
||||
|
||||
# 如果包含菌株预测,数据行数会增加(每个分子 40 行)
|
||||
if 'strain_name' in df_result_with_strains.columns:
|
||||
print(f"包含菌株预测的总行数: {len(df_result_with_strains)}")
|
||||
print(f"菌株数: {df_result_with_strains['strain_name'].nunique()}")
|
||||
```
|
||||
|
||||
**批量预测多个文件:**
|
||||
@@ -247,7 +271,7 @@ config = PredictionConfig(
|
||||
# 创建预测器
|
||||
predictor = ParallelBroadSpectrumPredictor(config)
|
||||
|
||||
# 预测单个分子
|
||||
# 预测单个分子(仅聚合结果)
|
||||
molecule = MoleculeInput(smiles="CCO", chem_id="ethanol")
|
||||
result = predictor.predict_single(molecule)
|
||||
|
||||
@@ -256,19 +280,41 @@ print(f"广谱抗菌: {result.broad_spectrum}")
|
||||
print(f"抗菌得分: {result.apscore_total:.3f}")
|
||||
print(f"抑制菌株数: {result.ginhib_total}")
|
||||
|
||||
# 批量预测
|
||||
# 批量预测(包含 40 种菌株的详细预测)
|
||||
smiles_list = ["CCO", "c1ccccc1", "CC(=O)O"]
|
||||
chem_ids = ["ethanol", "benzene", "acetic_acid"]
|
||||
|
||||
results = predictor.predict_from_smiles(smiles_list, chem_ids)
|
||||
results = predictor.predict_batch(
|
||||
[MoleculeInput(smiles=s, chem_id=c) for s, c in zip(smiles_list, chem_ids)],
|
||||
include_strain_predictions=True # 启用菌株级别预测
|
||||
)
|
||||
|
||||
for r in results:
|
||||
print(f"{r.chem_id}: broad_spectrum={r.broad_spectrum}, "
|
||||
f"apscore={r.apscore_total:.3f}")
|
||||
print(f"\n{r.chem_id}:")
|
||||
print(f" 广谱抗菌: {r.broad_spectrum}")
|
||||
print(f" 抗菌得分: {r.apscore_total:.3f}")
|
||||
print(f" 抑制菌株数: {r.ginhib_total}")
|
||||
|
||||
# 访问菌株级别预测数据(DataFrame 格式)
|
||||
if r.strain_predictions is not None:
|
||||
print(f" 菌株预测数据形状: {r.strain_predictions.shape}")
|
||||
print(f" 示例菌株预测:")
|
||||
print(r.strain_predictions.head(3))
|
||||
|
||||
# 强化学习场景:提取特定菌株的预测概率
|
||||
strain_probs = r.strain_predictions['antimicrobial_predictive_probability'].values
|
||||
print(f" 预测概率向量形状: {strain_probs.shape}") # (40,)
|
||||
|
||||
# 或转换为类型安全的列表
|
||||
strain_list = r.to_strain_predictions_list()
|
||||
print(f" 第一个菌株: {strain_list[0].strain_name}")
|
||||
print(f" 预测概率: {strain_list[0].antimicrobial_predictive_probability:.6f}")
|
||||
```
|
||||
|
||||
### 输出说明
|
||||
|
||||
#### 1. 聚合结果(默认输出)
|
||||
|
||||
预测结果会添加以下 7 个新列:
|
||||
|
||||
| 列名 | 类型 | 说明 |
|
||||
@@ -276,11 +322,85 @@ for r in results:
|
||||
| `apscore_total` | float | 总体抗菌潜力分数(对数尺度,值越大抗菌活性越强) |
|
||||
| `apscore_gnegative` | float | 革兰阴性菌抗菌潜力分数 |
|
||||
| `apscore_gpositive` | float | 革兰阳性菌抗菌潜力分数 |
|
||||
| `ginhib_total` | int | 被抑制的菌株总数 |
|
||||
| `ginhib_total` | int | 被抑制的菌株总数(0-40) |
|
||||
| `ginhib_gnegative` | int | 被抑制的革兰阴性菌株数 |
|
||||
| `ginhib_gpositive` | int | 被抑制的革兰阳性菌株数 |
|
||||
| `broad_spectrum` | int | 是否为广谱抗菌(1=是,0=否) |
|
||||
|
||||
**输出示例**(每个分子一行):
|
||||
|
||||
```csv
|
||||
smiles,chem_id,apscore_total,apscore_gnegative,apscore_gpositive,ginhib_total,ginhib_gnegative,ginhib_gpositive,broad_spectrum
|
||||
CCO,mol1,-9.93,-10.17,-9.74,0,0,0,0
|
||||
```
|
||||
|
||||
#### 2. 菌株级别预测(使用 `--include-strain-predictions` 时)
|
||||
|
||||
启用菌株级别预测后,输出会包含以下额外列(每个分子 40 行):
|
||||
|
||||
| 列名 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| `pred_id` | str | 预测ID,格式为 `chem_id:strain_name` |
|
||||
| `strain_name` | str | 菌株名称(40 种菌株之一) |
|
||||
| `antimicrobial_predictive_probability` | float | XGBoost 预测的抗菌概率(0-1) |
|
||||
| `no_growth_probability` | float | 预测不抑制的概率(1 - antimicrobial_predictive_probability) |
|
||||
| `growth_inhibition` | int | 二值化抑制结果(0=不抑制,1=抑制) |
|
||||
| `gram_stain` | str | 革兰染色类型(negative 或 positive) |
|
||||
|
||||
**输出示例**(每个分子 40 行,对应 40 个菌株):
|
||||
|
||||
```csv
|
||||
smiles,chem_id,apscore_total,...,pred_id,strain_name,antimicrobial_predictive_probability,no_growth_probability,growth_inhibition,gram_stain
|
||||
CCO,mol1,-9.93,...,mol1:Akkermansia muciniphila (NT5021),Akkermansia muciniphila (NT5021),0.000102,0.999898,0,negative
|
||||
CCO,mol1,-9.93,...,mol1:Bacteroides caccae (NT5050),Bacteroides caccae (NT5050),0.000155,0.999845,0,negative
|
||||
...(共 40 行,对应 40 个菌株)
|
||||
```
|
||||
|
||||
**数据使用场景**:
|
||||
|
||||
1. **强化学习状态表示**:
|
||||
```python
|
||||
# 提取预测概率作为状态向量
|
||||
state_vector = result.strain_predictions['antimicrobial_predictive_probability'].values
|
||||
# 形状: (40,) - 可直接用于 RL 环境
|
||||
```
|
||||
|
||||
2. **筛选特定菌株**:
|
||||
```python
|
||||
# 筛选革兰阴性菌
|
||||
gram_negative = result.strain_predictions[
|
||||
result.strain_predictions['gram_stain'] == 'negative'
|
||||
]
|
||||
```
|
||||
|
||||
3. **可视化分析**:
|
||||
```python
|
||||
import matplotlib.pyplot as plt
|
||||
df = result.strain_predictions
|
||||
df.plot(x='strain_name', y='antimicrobial_predictive_probability', kind='bar')
|
||||
```
|
||||
|
||||
#### 40 种测试菌株列表
|
||||
|
||||
预测涵盖以下 40 种人类肠道菌株:
|
||||
|
||||
**革兰阴性菌(23 种)**:
|
||||
- Akkermansia muciniphila (NT5021)
|
||||
- Bacteroides 属: caccae, fragilis (ET/NT), ovatus, thetaiotaomicron, uniformis, vulgatus, xylanisolvens (8 种)
|
||||
- Escherichia coli 各亚型 (4 种)
|
||||
- Klebsiella pneumoniae (NT5049)
|
||||
- 其他肠道革兰阴性菌
|
||||
|
||||
**革兰阳性菌(17 种)**:
|
||||
- Bifidobacterium 属 (3 种)
|
||||
- Clostridium 属 (5 种)
|
||||
- Enterococcus 属 (2 种)
|
||||
- Lactobacillus 属 (3 种)
|
||||
- Streptococcus 属 (2 种)
|
||||
- 其他肠道革兰阳性菌
|
||||
|
||||
完整菌株列表详见 `Data/mole/README.md`。
|
||||
|
||||
#### 广谱抗菌判断标准
|
||||
|
||||
默认情况下,如果一个分子能抑制 **10 个或更多菌株** (`ginhib_total >= 10`),则被认为是广谱抗菌分子。
|
||||
@@ -291,6 +411,12 @@ for r in results:
|
||||
|
||||
- 输入: `Data/fragment/Frags-Enamine-18M.csv`
|
||||
- 输出: `Data/fragment/Frags-Enamine-18M_predicted.csv`
|
||||
- 输出(含菌株): `Data/fragment/Frags-Enamine-18M_predicted.csv`(每个分子 40 行)
|
||||
|
||||
#### 数据量说明
|
||||
|
||||
- **仅聚合结果**:输出行数 = 输入分子数
|
||||
- **包含菌株预测**:输出行数 = 输入分子数 × 40
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -65,9 +65,21 @@ class MoleculeInput:
|
||||
chem_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrainPrediction:
|
||||
"""单个菌株的预测结果"""
|
||||
pred_id: str # 格式: "chem_id:strain_name"
|
||||
chem_id: str
|
||||
strain_name: str
|
||||
antimicrobial_predictive_probability: float # 预测概率
|
||||
no_growth_probability: float # 不生长概率(1 - antimicrobial_predictive_probability)
|
||||
growth_inhibition: int # 二值化结果 (0/1)
|
||||
gram_stain: Optional[str] = None # 革兰染色类型
|
||||
|
||||
|
||||
@dataclass
|
||||
class BroadSpectrumResult:
|
||||
"""广谱抗菌预测结果"""
|
||||
"""广谱抗菌预测结果(聚合结果)"""
|
||||
chem_id: str
|
||||
apscore_total: float
|
||||
apscore_gnegative: float
|
||||
@@ -76,9 +88,10 @@ class BroadSpectrumResult:
|
||||
ginhib_gnegative: int
|
||||
ginhib_gpositive: int
|
||||
broad_spectrum: int
|
||||
strain_predictions: Optional[pd.DataFrame] = None # 菌株级别预测(40行)
|
||||
|
||||
def to_dict(self) -> Dict[str, Union[str, float, int]]:
|
||||
"""转换为字典格式"""
|
||||
"""转换为字典格式(仅聚合字段)"""
|
||||
return {
|
||||
'chem_id': self.chem_id,
|
||||
'apscore_total': self.apscore_total,
|
||||
@@ -89,6 +102,31 @@ class BroadSpectrumResult:
|
||||
'ginhib_gpositive': self.ginhib_gpositive,
|
||||
'broad_spectrum': self.broad_spectrum
|
||||
}
|
||||
|
||||
def to_strain_predictions_list(self) -> List['StrainPrediction']:
|
||||
"""
|
||||
将 DataFrame 转换为 StrainPrediction 列表(用于类型安全场景)
|
||||
|
||||
Returns:
|
||||
StrainPrediction 对象列表
|
||||
"""
|
||||
if self.strain_predictions is None or self.strain_predictions.empty:
|
||||
return []
|
||||
|
||||
strain_list = []
|
||||
for _, row in self.strain_predictions.iterrows():
|
||||
strain_pred = StrainPrediction(
|
||||
pred_id=row['pred_id'],
|
||||
chem_id=row['chem_id'],
|
||||
strain_name=row['strain_name'],
|
||||
antimicrobial_predictive_probability=row['antimicrobial_predictive_probability'],
|
||||
no_growth_probability=row['no_growth_probability'],
|
||||
growth_inhibition=row['growth_inhibition'],
|
||||
gram_stain=row.get('gram_stain', None)
|
||||
)
|
||||
strain_list.append(strain_pred)
|
||||
|
||||
return strain_list
|
||||
|
||||
|
||||
class BroadSpectrumPredictor:
|
||||
@@ -259,6 +297,45 @@ class BroadSpectrumPredictor:
|
||||
|
||||
return df_label
|
||||
|
||||
def _prepare_strain_level_predictions(self, score_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
准备菌株级别的预测数据
|
||||
|
||||
Args:
|
||||
score_df: 原始预测分数DataFrame,包含 pred_id, 0, 1, growth_inhibition 列
|
||||
|
||||
Returns:
|
||||
格式化的菌株级别预测DataFrame
|
||||
"""
|
||||
# 创建副本
|
||||
strain_df = score_df.copy()
|
||||
|
||||
# 分离化合物ID和菌株名
|
||||
strain_df["chem_id"] = strain_df["pred_id"].str.split(":", expand=True)[0]
|
||||
strain_df["strain_name"] = strain_df["pred_id"].str.split(":", expand=True)[1]
|
||||
|
||||
# 添加革兰染色信息
|
||||
strain_df = self._gram_stain(strain_df)
|
||||
|
||||
# 重命名列为用户友好的名称
|
||||
strain_df = strain_df.rename(columns={
|
||||
"1": "antimicrobial_predictive_probability",
|
||||
"0": "no_growth_probability"
|
||||
})
|
||||
|
||||
# 选择并排序输出列
|
||||
output_columns = [
|
||||
"pred_id",
|
||||
"chem_id",
|
||||
"strain_name",
|
||||
"antimicrobial_predictive_probability",
|
||||
"no_growth_probability",
|
||||
"growth_inhibition",
|
||||
"gram_stain"
|
||||
]
|
||||
|
||||
return strain_df[output_columns]
|
||||
|
||||
def _antimicrobial_potential(self, score_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
计算抗菌潜力分数
|
||||
@@ -371,12 +448,14 @@ class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
|
||||
results = self.predict_batch([molecule])
|
||||
return results[0]
|
||||
|
||||
def predict_batch(self, molecules: List[MoleculeInput]) -> List[BroadSpectrumResult]:
|
||||
def predict_batch(self, molecules: List[MoleculeInput],
|
||||
include_strain_predictions: bool = False) -> List[BroadSpectrumResult]:
|
||||
"""
|
||||
批量预测分子的广谱抗菌活性
|
||||
|
||||
Args:
|
||||
molecules: 分子输入列表
|
||||
include_strain_predictions: 是否在结果中包含菌株级别预测数据
|
||||
|
||||
Returns:
|
||||
广谱抗菌预测结果列表
|
||||
@@ -417,10 +496,16 @@ class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
|
||||
# 合并结果
|
||||
print("Merging prediction results...")
|
||||
all_pred_df = pd.concat([results[i] for i in sorted(results.keys())])
|
||||
all_pred_df = all_pred_df.reset_index()
|
||||
|
||||
# 准备菌株级别数据(如果需要)
|
||||
strain_level_data = None
|
||||
if include_strain_predictions:
|
||||
print("Preparing strain-level predictions...")
|
||||
strain_level_data = self._prepare_strain_level_predictions(all_pred_df)
|
||||
|
||||
# 计算抗菌潜力
|
||||
print("Calculating antimicrobial potential scores...")
|
||||
all_pred_df = all_pred_df.reset_index()
|
||||
agg_df = self._antimicrobial_potential(all_pred_df)
|
||||
|
||||
# 判断广谱抗菌
|
||||
@@ -431,6 +516,13 @@ class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
|
||||
# 转换为结果对象
|
||||
results_list = []
|
||||
for _, row in agg_df.iterrows():
|
||||
# 获取该分子的菌株级别预测
|
||||
mol_strain_preds = None
|
||||
if strain_level_data is not None:
|
||||
mol_strain_preds = strain_level_data[
|
||||
strain_level_data['chem_id'] == row.name
|
||||
].reset_index(drop=True)
|
||||
|
||||
result = BroadSpectrumResult(
|
||||
chem_id=row.name,
|
||||
apscore_total=row["apscore_total"],
|
||||
@@ -439,7 +531,8 @@ class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
|
||||
ginhib_total=int(row["ginhib_total"]),
|
||||
ginhib_gnegative=int(row["ginhib_gnegative"]),
|
||||
ginhib_gpositive=int(row["ginhib_gpositive"]),
|
||||
broad_spectrum=int(row["broad_spectrum"])
|
||||
broad_spectrum=int(row["broad_spectrum"]),
|
||||
strain_predictions=mol_strain_preds
|
||||
)
|
||||
results_list.append(result)
|
||||
|
||||
|
||||
208
test/mole_predict_singal_mole.py
Normal file
208
test/mole_predict_singal_mole.py
Normal file
@@ -0,0 +1,208 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
MolE 抗菌活性预测 Python API 示例
|
||||
|
||||
演示两种预测模式:
|
||||
1. 聚合结果模式(默认)
|
||||
2. 菌株级别预测模式
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到 Python 路径(使用 pathlib)
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from models import (
|
||||
ParallelBroadSpectrumPredictor,
|
||||
PredictionConfig,
|
||||
MoleculeInput
|
||||
)
|
||||
|
||||
|
||||
def print_separator(title):
|
||||
"""打印分隔线"""
|
||||
print("\n" + "=" * 70)
|
||||
print(f" {title}")
|
||||
print("=" * 70 + "\n")
|
||||
|
||||
|
||||
def demo_aggregated_mode():
|
||||
"""演示聚合结果模式(默认)"""
|
||||
print_separator("模式 1: 聚合结果模式(默认)")
|
||||
|
||||
# 创建配置
|
||||
config = PredictionConfig(
|
||||
batch_size=10,
|
||||
device="auto" # 自动检测 CUDA
|
||||
)
|
||||
|
||||
# 创建预测器
|
||||
predictor = ParallelBroadSpectrumPredictor(config)
|
||||
|
||||
# 准备测试分子
|
||||
test_molecules = [
|
||||
MoleculeInput(smiles="CCO", chem_id="ethanol"),
|
||||
MoleculeInput(smiles="c1ccccc1", chem_id="benzene"),
|
||||
MoleculeInput(smiles="CC(=O)O", chem_id="acetic_acid"),
|
||||
]
|
||||
|
||||
print(f"测试分子数: {len(test_molecules)}")
|
||||
print(f"SMILES 示例: {test_molecules[0].smiles}")
|
||||
|
||||
# 执行预测(不包含菌株级别预测)
|
||||
print("\n开始预测(聚合模式)...")
|
||||
results = predictor.predict_batch(test_molecules, include_strain_predictions=False)
|
||||
|
||||
# 打印结果
|
||||
print(f"\n预测完成!共 {len(results)} 个结果\n")
|
||||
|
||||
for result in results:
|
||||
print(f"化合物ID: {result.chem_id}")
|
||||
print(f" - 广谱抗菌: {'是' if result.broad_spectrum else '否'}")
|
||||
print(f" - 总体抗菌得分: {result.apscore_total:.4f}")
|
||||
print(f" - 革兰阴性菌得分: {result.apscore_gnegative:.4f}")
|
||||
print(f" - 革兰阳性菌得分: {result.apscore_gpositive:.4f}")
|
||||
print(f" - 抑制菌株总数: {result.ginhib_total} / 40")
|
||||
print(f" - 抑制革兰阴性菌数: {result.ginhib_gnegative}")
|
||||
print(f" - 抑制革兰阳性菌数: {result.ginhib_gpositive}")
|
||||
print(f" - strain_predictions: {result.strain_predictions}") # 应该是 None
|
||||
print()
|
||||
|
||||
# 返回结果供后续使用
|
||||
return results
|
||||
|
||||
|
||||
def demo_strain_level_mode():
|
||||
"""演示菌株级别预测模式"""
|
||||
print_separator("模式 2: 菌株级别预测模式")
|
||||
|
||||
# 创建预测器
|
||||
predictor = ParallelBroadSpectrumPredictor()
|
||||
|
||||
# 使用一个有趣的抗菌分子进行测试(氟喹诺酮类似物)
|
||||
test_molecule = MoleculeInput(
|
||||
smiles="FC1=CC=C(CN2C[C@@H]3C[C@H]2CN3C2CC2)N=C1",
|
||||
chem_id="test_antibacterial"
|
||||
)
|
||||
|
||||
print(f"测试分子: {test_molecule.chem_id}")
|
||||
print(f"SMILES: {test_molecule.smiles}")
|
||||
|
||||
# 执行预测(包含菌株级别预测)
|
||||
print("\n开始预测(菌株级别模式)...")
|
||||
results = predictor.predict_batch([test_molecule], include_strain_predictions=True)
|
||||
result = results[0]
|
||||
|
||||
# 1. 打印聚合结果
|
||||
print(f"\n聚合结果:")
|
||||
print(f" - 化合物ID: {result.chem_id}")
|
||||
print(f" - 广谱抗菌: {'是' if result.broad_spectrum else '否'}")
|
||||
print(f" - 总体抗菌得分: {result.apscore_total:.4f}")
|
||||
print(f" - 抑制菌株总数: {result.ginhib_total} / 40")
|
||||
|
||||
# 2. 打印菌株级别预测数据结构
|
||||
print(f"\n菌株级别预测数据:")
|
||||
if result.strain_predictions is not None:
|
||||
strain_df = result.strain_predictions
|
||||
print(f" - 数据类型: {type(strain_df)}")
|
||||
print(f" - 数据形状: {strain_df.shape} (行, 列)")
|
||||
print(f" - 列名: {list(strain_df.columns)}")
|
||||
print(f" - 内存占用: {strain_df.memory_usage(deep=True).sum() / 1024:.2f} KB")
|
||||
|
||||
# 3. 展示前 5 个菌株的预测
|
||||
print(f"\n前 5 个菌株的预测:")
|
||||
print(strain_df.head(5).to_string(index=False))
|
||||
|
||||
# 4. 统计信息
|
||||
print(f"\n统计信息:")
|
||||
print(f" - 预测概率范围: [{strain_df['antimicrobial_predictive_probability'].min():.6f}, "
|
||||
f"{strain_df['antimicrobial_predictive_probability'].max():.6f}]")
|
||||
print(f" - 预测概率平均值: {strain_df['antimicrobial_predictive_probability'].mean():.6f}")
|
||||
print(f" - 被抑制菌株数: {strain_df['growth_inhibition'].sum()}")
|
||||
print(f" - 革兰阴性菌数: {len(strain_df[strain_df['gram_stain'] == 'negative'])}")
|
||||
print(f" - 革兰阳性菌数: {len(strain_df[strain_df['gram_stain'] == 'positive'])}")
|
||||
|
||||
# 5. 展示被抑制的菌株(如果有)
|
||||
inhibited = strain_df[strain_df['growth_inhibition'] == 1]
|
||||
if len(inhibited) > 0:
|
||||
print(f"\n被抑制的菌株 ({len(inhibited)} 个):")
|
||||
print(inhibited[['strain_name', 'antimicrobial_predictive_probability', 'gram_stain']].to_string(index=False))
|
||||
else:
|
||||
print(f"\n该分子未预测抑制任何菌株")
|
||||
|
||||
# 6. 强化学习应用示例
|
||||
print(f"\n强化学习应用示例:")
|
||||
|
||||
# 提取预测概率作为状态向量
|
||||
state_vector = strain_df['antimicrobial_predictive_probability'].values
|
||||
print(f" - 状态向量形状: {state_vector.shape}")
|
||||
print(f" - 状态向量类型: {type(state_vector)}")
|
||||
print(f" - 前 10 个值: {state_vector[:10]}")
|
||||
|
||||
# 提取多维特征
|
||||
state_features = strain_df[[
|
||||
'antimicrobial_predictive_probability',
|
||||
'growth_inhibition'
|
||||
]].values
|
||||
print(f" - 多维特征形状: {state_features.shape}")
|
||||
|
||||
# 按革兰染色分组
|
||||
gram_negative_probs = strain_df[
|
||||
strain_df['gram_stain'] == 'negative'
|
||||
]['antimicrobial_predictive_probability'].values
|
||||
print(f" - 革兰阴性菌概率向量形状: {gram_negative_probs.shape}")
|
||||
|
||||
# 7. 转换为类型安全的列表(可选)
|
||||
print(f"\n转换为 StrainPrediction 列表:")
|
||||
strain_list = result.to_strain_predictions_list()
|
||||
print(f" - 列表长度: {len(strain_list)}")
|
||||
print(f" - 元素类型: {type(strain_list[0])}")
|
||||
print(f" - 第一个元素:")
|
||||
first_strain = strain_list[0]
|
||||
print(f" * pred_id: {first_strain.pred_id}")
|
||||
print(f" * strain_name: {first_strain.strain_name}")
|
||||
print(f" * antimicrobial_predictive_probability: {first_strain.antimicrobial_predictive_probability:.6f}")
|
||||
print(f" * growth_inhibition: {first_strain.growth_inhibition}")
|
||||
print(f" * gram_stain: {first_strain.gram_stain}")
|
||||
else:
|
||||
print(" 警告: strain_predictions 为 None(未启用菌株级别预测)")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("\n" + "🧪" * 35)
|
||||
print(" MolE 抗菌活性预测 Python API 示例")
|
||||
print("🧪" * 35)
|
||||
|
||||
# 演示模式 1: 聚合结果
|
||||
aggregated_results = demo_aggregated_mode()
|
||||
|
||||
# 演示模式 2: 菌株级别预测
|
||||
strain_level_result = demo_strain_level_mode()
|
||||
|
||||
# 总结
|
||||
print_separator("总结")
|
||||
print("✅ 模式 1 (聚合结果): 适合快速筛选大量分子")
|
||||
print(" - 每个分子返回 1 个 BroadSpectrumResult 对象")
|
||||
print(" - 包含 8 个聚合指标")
|
||||
print(" - strain_predictions = None")
|
||||
print()
|
||||
print("✅ 模式 2 (菌株级别): 适合详细分析和强化学习")
|
||||
print(" - 每个分子返回 1 个 BroadSpectrumResult 对象")
|
||||
print(" - 包含 8 个聚合指标 + 40 行菌株预测数据")
|
||||
print(" - strain_predictions = DataFrame (40 rows × 7 columns)")
|
||||
print(" - 可直接提取为 numpy array 用于 RL")
|
||||
print()
|
||||
print("🎯 推荐使用场景:")
|
||||
print(" - 初筛: 模式 1")
|
||||
print(" - 详细分析/RL 训练: 模式 2")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -49,7 +49,8 @@ def predict_csv_file(
|
||||
batch_size: int = 100,
|
||||
n_workers: Optional[int] = None,
|
||||
device: str = "auto",
|
||||
add_suffix: bool = True
|
||||
add_suffix: bool = True,
|
||||
include_strain_predictions: bool = False
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
预测 CSV 文件中的分子抗菌活性
|
||||
@@ -63,6 +64,7 @@ def predict_csv_file(
|
||||
n_workers: 工作进程数
|
||||
device: 计算设备 ("auto", "cpu", "cuda:0" 等)
|
||||
add_suffix: 是否在输出文件名后添加预测后缀
|
||||
include_strain_predictions: 是否在输出中包含40种菌株的预测详情
|
||||
|
||||
Returns:
|
||||
包含预测结果的 DataFrame
|
||||
@@ -118,7 +120,7 @@ def predict_csv_file(
|
||||
|
||||
# 执行预测
|
||||
print("开始预测...")
|
||||
results = predictor.predict_batch(molecules)
|
||||
results = predictor.predict_batch(molecules, include_strain_predictions=include_strain_predictions)
|
||||
|
||||
# 转换结果为 DataFrame
|
||||
results_dicts = [r.to_dict() for r in results]
|
||||
@@ -136,6 +138,36 @@ def predict_csv_file(
|
||||
)
|
||||
df_output = df_output.drop(columns=['_merge_id'])
|
||||
|
||||
# 如果包含菌株级别预测,将其添加到输出中
|
||||
if include_strain_predictions:
|
||||
print("合并菌株级别预测数据...")
|
||||
# 收集所有菌株级别预测
|
||||
all_strain_predictions = []
|
||||
for result in results:
|
||||
if result.strain_predictions is not None and not result.strain_predictions.empty:
|
||||
all_strain_predictions.append(result.strain_predictions)
|
||||
|
||||
if all_strain_predictions:
|
||||
# 合并所有菌株预测
|
||||
df_strain_predictions = pd.concat(all_strain_predictions, ignore_index=True)
|
||||
|
||||
# 将聚合结果和菌株预测合并
|
||||
# 为了在同一个 CSV 中展示,我们使用重复行的方式
|
||||
# 每个分子的聚合结果会重复40次(每个菌株一次)
|
||||
df_output_expanded = df_output.merge(
|
||||
df_strain_predictions,
|
||||
left_on=df_output.columns[df_output.columns.get_loc(id_col_actual)],
|
||||
right_on='chem_id',
|
||||
how='left',
|
||||
suffixes=('', '_strain')
|
||||
)
|
||||
|
||||
# 移除重复的 chem_id 列
|
||||
if 'chem_id_strain' in df_output_expanded.columns:
|
||||
df_output_expanded = df_output_expanded.drop(columns=['chem_id_strain'])
|
||||
|
||||
df_output = df_output_expanded
|
||||
|
||||
# 生成输出路径
|
||||
if output_path is None:
|
||||
if add_suffix:
|
||||
@@ -164,7 +196,8 @@ def predict_multiple_files(
|
||||
batch_size: int = 100,
|
||||
n_workers: Optional[int] = None,
|
||||
device: str = "auto",
|
||||
add_suffix: bool = True
|
||||
add_suffix: bool = True,
|
||||
include_strain_predictions: bool = False
|
||||
) -> List[pd.DataFrame]:
|
||||
"""
|
||||
批量预测多个 CSV 文件
|
||||
@@ -178,6 +211,7 @@ def predict_multiple_files(
|
||||
n_workers: 工作进程数
|
||||
device: 计算设备
|
||||
add_suffix: 是否在输出文件名后添加预测后缀
|
||||
include_strain_predictions: 是否在输出中包含40种菌株的预测详情
|
||||
|
||||
Returns:
|
||||
包含预测结果的 DataFrame 列表
|
||||
@@ -210,7 +244,8 @@ def predict_multiple_files(
|
||||
batch_size=batch_size,
|
||||
n_workers=n_workers,
|
||||
device=device,
|
||||
add_suffix=add_suffix
|
||||
add_suffix=add_suffix,
|
||||
include_strain_predictions=include_strain_predictions
|
||||
)
|
||||
results.append(df_result)
|
||||
except Exception as e:
|
||||
@@ -240,7 +275,9 @@ def predict_multiple_files(
|
||||
help='计算设备 (默认: auto)')
|
||||
@click.option('--add-suffix/--no-add-suffix', default=True,
|
||||
help='是否在输出文件名后添加 "_predicted" 后缀 (默认: 添加)')
|
||||
def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers, device, add_suffix):
|
||||
@click.option('--include-strain-predictions', is_flag=True, default=False,
|
||||
help='在输出中包含40种菌株的详细预测数据(每个分子将产生40行数据,对应每个菌株的预测概率和抑制情况)')
|
||||
def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers, device, add_suffix, include_strain_predictions):
|
||||
"""
|
||||
使用 MolE 模型预测小分子 SMILES 的抗菌活性
|
||||
|
||||
@@ -248,12 +285,21 @@ def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers
|
||||
|
||||
OUTPUT_PATH: 输出 CSV 文件路径 (可选,默认在原文件目录生成)
|
||||
|
||||
默认输出包含聚合的抗菌活性指标(广谱抗菌评分、抑制菌株数等)。
|
||||
使用 --include-strain-predictions 可以额外包含每个菌株的详细预测数据。
|
||||
|
||||
示例:
|
||||
|
||||
# 基本用法(仅输出聚合结果)
|
||||
python mole_predictor.py input.csv output.csv
|
||||
|
||||
python mole_predictor.py input.csv -s SMILES -i ID
|
||||
# 包含40种菌株的详细预测数据
|
||||
python mole_predictor.py input.csv output.csv --include-strain-predictions
|
||||
|
||||
# 指定列名和设备
|
||||
python mole_predictor.py input.csv -s SMILES -i ID --device cuda:0
|
||||
|
||||
# 自定义批处理大小
|
||||
python mole_predictor.py input.csv --device cuda:0 --batch-size 200
|
||||
"""
|
||||
|
||||
@@ -266,7 +312,8 @@ def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers
|
||||
batch_size=batch_size,
|
||||
n_workers=n_workers,
|
||||
device=device,
|
||||
add_suffix=add_suffix
|
||||
add_suffix=add_suffix,
|
||||
include_strain_predictions=include_strain_predictions
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(f"错误: {e}", err=True)
|
||||
|
||||
Reference in New Issue
Block a user