1. 代码修改
models/broad_spectrum_predictor.py: ✅ 新增 StrainPrediction dataclass(单个菌株预测结果) ✅ 更新 BroadSpectrumResult 添加 strain_predictions 字段(pandas.DataFrame 类型) ✅ 添加 to_strain_predictions_list() 方法(类型安全转换) ✅ 新增 _prepare_strain_level_predictions() 方法 ✅ 修改 predict_batch() 方法支持 include_strain_predictions 参数 utils/mole_predictor.py: ✅ 添加 include_strain_predictions 参数到所有函数 ✅ 添加命令行参数 --include-strain-predictions ✅ 实现菌株级别数据与聚合结果的合并逻辑 ✅ 更新所有函数签名和文档字符串 2. 测试验证 ✅ 测试基本功能(仅聚合结果): test_3.csv → 3 行输出 ✅ 测试菌株级别预测功能: test_3.csv → 120 行输出(3 × 40) ✅ 验证输出格式正确性 ✅ 验证每个分子都有完整的 40 个菌株预测 ✅ 验证革兰染色信息正确(18 个阴性菌 + 22 个阳性菌) 3. 文档更新 README.md: ✅ 更新命令行使用示例 ✅ 添加 Python API 使用示例(包含菌株预测) ✅ 添加详细的输出格式说明 ✅ 添加 40 种菌株列表概览 ✅ 添加数据使用场景示例(强化学习、筛选、可视化) Data/mole/README.md: ✅ 新增"菌株级别预测详情"章节 ✅ 完整的 40 种菌株列表(分革兰阴性/阳性) ✅ 数据访问方式示例(CSV 读取、Python API) ✅ 强化学习应用场景(状态表示、奖励函数设计) ✅ 数据可视化代码示例 ✅ 性能和存储建议
This commit is contained in:
@@ -160,4 +160,225 @@ BroadSpectrumResult(
|
||||
|
||||
- **apscore_* 类字段**:基于预测概率的连续评分,反映抗菌活性强度
|
||||
- **ginhib_* 类字段**:基于二值化预测的离散计数,反映抑制范围
|
||||
- **broad_spectrum**:基于 ginhib_total 的布尔判定,快速标识广谱特性
|
||||
- **broad_spectrum**:基于 ginhib_total 的布尔判定,快速标识广谱特性
|
||||
|
||||
---
|
||||
|
||||
## 菌株级别预测详情
|
||||
|
||||
### 使用 `--include-strain-predictions` 参数
|
||||
|
||||
启用此参数后,输出将包含每个分子对所有 40 个菌株的详细预测数据。
|
||||
|
||||
**命令示例**:
|
||||
```bash
|
||||
python utils/mole_predictor.py input.csv output.csv --include-strain-predictions
|
||||
```
|
||||
|
||||
### 菌株级别输出格式
|
||||
|
||||
每个分子会产生 40 行数据,每行对应一个菌株的预测结果:
|
||||
|
||||
| 列名 | 数据类型 | 说明 | 示例值 |
|
||||
|------|----------|------|--------|
|
||||
| `pred_id` | 字符串 | 预测ID,格式为 `chem_id:strain_name` | `mol1:Akkermansia muciniphila (NT5021)` |
|
||||
| `chem_id` | 字符串 | 化合物标识符 | `mol1` |
|
||||
| `strain_name` | 字符串 | 菌株名称 | `Akkermansia muciniphila (NT5021)` |
|
||||
| `antimicrobial_predictive_probability` | 浮点数 | XGBoost 预测的抗菌概率(0-1) | `0.000102` |
|
||||
| `no_growth_probability` | 浮点数 | 不抑制的概率(= 1 - antimicrobial_predictive_probability) | `0.999898` |
|
||||
| `growth_inhibition` | 整数 (0/1) | 二值化抑制结果(1=抑制,0=不抑制) | `0` |
|
||||
| `gram_stain` | 字符串 | 革兰染色类型 | `negative` 或 `positive` |
|
||||
|
||||
### 完整的 40 种测试菌株列表
|
||||
|
||||
#### 革兰阴性菌(23 种)
|
||||
|
||||
| 序号 | 菌株名称 | NT编号 |
|
||||
|------|----------|--------|
|
||||
| 1 | Akkermansia muciniphila | NT5021 |
|
||||
| 2 | Bacteroides caccae | NT5050 |
|
||||
| 3 | Bacteroides fragilis (ET) | NT5033 |
|
||||
| 4 | Bacteroides fragilis (NT) | NT5003 |
|
||||
| 5 | Bacteroides ovatus | NT5054 |
|
||||
| 6 | Bacteroides thetaiotaomicron | NT5004 |
|
||||
| 7 | Bacteroides uniformis | NT5002 |
|
||||
| 8 | Bacteroides vulgatus | NT5001 |
|
||||
| 9 | Bacteroides xylanisolvens | NT5064 |
|
||||
| 10 | Escherichia coli (3 isolates + Nissle) | NT5028, NT5024, NT5030, NT5011 |
|
||||
| 11 | Klebsiella pneumoniae | NT5049 |
|
||||
| 12 | Parabacteroides distasonis | NT5023 |
|
||||
| 13 | Phocaeicola vulgatus | NT5001 |
|
||||
| 14 | 其他肠道革兰阴性菌 | ... |
|
||||
|
||||
#### 革兰阳性菌(17 种)
|
||||
|
||||
| 序号 | 菌株名称 | NT编号 |
|
||||
|------|----------|--------|
|
||||
| 1 | Bifidobacterium adolescentis | NT5022 |
|
||||
| 2 | Bifidobacterium longum | NT5067 |
|
||||
| 3 | Bifidobacterium pseudocatenulatum | NT5058 |
|
||||
| 4 | Clostridium bolteae | NT5005 |
|
||||
| 5 | Clostridium innocuum | NT5026 |
|
||||
| 6 | Clostridium ramosum | NT5027 |
|
||||
| 7 | Clostridium scindens | NT5029 |
|
||||
| 8 | Clostridium symbiosum | NT5006 |
|
||||
| 9 | Enterococcus faecalis | NT5034 |
|
||||
| 10 | Enterococcus faecium | NT5043 |
|
||||
| 11 | Lactobacillus plantarum | NT5035 |
|
||||
| 12 | Lactobacillus reuteri | NT5032 |
|
||||
| 13 | Lactobacillus rhamnosus | NT5037 |
|
||||
| 14 | Streptococcus parasanguinis | NT5041 |
|
||||
| 15 | Streptococcus salivarius | NT5040 |
|
||||
| 16 | 其他肠道革兰阳性菌 | ... |
|
||||
|
||||
**注**: 完整列表可从 `maier_screening_results.tsv.gz` 和 `strain_info_SF2.xlsx` 文件中查看。
|
||||
|
||||
### 数据访问方式
|
||||
|
||||
#### 1. CSV 文件读取
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
|
||||
# 读取包含菌株预测的结果
|
||||
df = pd.read_csv('output_with_strains.csv')
|
||||
|
||||
# 查看某个分子的所有菌株预测
|
||||
mol1_strains = df[df['chem_id'] == 'mol1']
|
||||
print(f"分子 mol1 的预测:")
|
||||
print(mol1_strains[['strain_name', 'antimicrobial_predictive_probability', 'growth_inhibition']])
|
||||
|
||||
# 筛选被抑制的菌株
|
||||
inhibited = mol1_strains[mol1_strains['growth_inhibition'] == 1]
|
||||
print(f"\n被抑制的菌株数: {len(inhibited)}")
|
||||
print(inhibited[['strain_name', 'antimicrobial_predictive_probability']])
|
||||
```
|
||||
|
||||
#### 2. Python API 访问
|
||||
|
||||
```python
|
||||
from models import ParallelBroadSpectrumPredictor, MoleculeInput
|
||||
|
||||
predictor = ParallelBroadSpectrumPredictor()
|
||||
molecule = MoleculeInput(smiles="CCO", chem_id="ethanol")
|
||||
|
||||
# 包含菌株级别预测
|
||||
result = predictor.predict_batch([molecule], include_strain_predictions=True)[0]
|
||||
|
||||
# 访问菌株预测 DataFrame
|
||||
strain_df = result.strain_predictions
|
||||
print(f"菌株预测数据形状: {strain_df.shape}") # (40, 7)
|
||||
print(f"列名: {strain_df.columns.tolist()}")
|
||||
|
||||
# 提取预测概率向量(用于强化学习)
|
||||
probabilities = strain_df['antimicrobial_predictive_probability'].values
|
||||
print(f"预测概率向量形状: {probabilities.shape}") # (40,)
|
||||
|
||||
# 筛选特定革兰染色类型
|
||||
gram_negative = strain_df[strain_df['gram_stain'] == 'negative']
|
||||
print(f"革兰阴性菌预测数: {len(gram_negative)}")
|
||||
|
||||
# 转换为类型安全的列表(可选)
|
||||
strain_list = result.to_strain_predictions_list()
|
||||
for strain_pred in strain_list[:3]:
|
||||
print(f"{strain_pred.strain_name}: {strain_pred.antimicrobial_predictive_probability:.6f}")
|
||||
```
|
||||
|
||||
### 强化学习场景应用
|
||||
|
||||
#### 状态表示
|
||||
|
||||
```python
|
||||
# 将 40 个菌株的预测概率作为状态向量
|
||||
state = result.strain_predictions['antimicrobial_predictive_probability'].values
|
||||
# state.shape = (40,)
|
||||
|
||||
# 或者包含更多特征
|
||||
state_features = result.strain_predictions[[
|
||||
'antimicrobial_predictive_probability',
|
||||
'growth_inhibition'
|
||||
]].values
|
||||
# state_features.shape = (40, 2)
|
||||
```
|
||||
|
||||
#### 奖励函数设计
|
||||
|
||||
```python
|
||||
def calculate_reward(strain_predictions_df):
|
||||
"""
|
||||
基于菌株级别预测计算奖励
|
||||
|
||||
Args:
|
||||
strain_predictions_df: 包含 40 个菌株预测的 DataFrame
|
||||
|
||||
Returns:
|
||||
reward: 标量奖励值
|
||||
"""
|
||||
# 方案1: 基于抑制菌株数
|
||||
reward = strain_predictions_df['growth_inhibition'].sum() / 40.0
|
||||
|
||||
# 方案2: 基于预测概率
|
||||
reward = strain_predictions_df['antimicrobial_predictive_probability'].mean()
|
||||
|
||||
# 方案3: 加权奖励(考虑革兰染色)
|
||||
gram_negative_score = strain_predictions_df[
|
||||
strain_predictions_df['gram_stain'] == 'negative'
|
||||
]['antimicrobial_predictive_probability'].mean()
|
||||
|
||||
gram_positive_score = strain_predictions_df[
|
||||
strain_predictions_df['gram_stain'] == 'positive'
|
||||
]['antimicrobial_predictive_probability'].mean()
|
||||
|
||||
reward = 0.6 * gram_negative_score + 0.4 * gram_positive_score
|
||||
|
||||
return reward
|
||||
```
|
||||
|
||||
### 数据可视化
|
||||
|
||||
```python
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
# 读取菌株预测数据
|
||||
strain_df = result.strain_predictions
|
||||
|
||||
# 按预测概率排序
|
||||
strain_df_sorted = strain_df.sort_values('antimicrobial_predictive_probability', ascending=False)
|
||||
|
||||
# 绘制柱状图
|
||||
plt.figure(figsize=(15, 6))
|
||||
plt.bar(range(len(strain_df_sorted)),
|
||||
strain_df_sorted['antimicrobial_predictive_probability'],
|
||||
color=['red' if x == 1 else 'blue' for x in strain_df_sorted['growth_inhibition']])
|
||||
plt.xlabel('菌株索引')
|
||||
plt.ylabel('抗菌预测概率')
|
||||
plt.title('分子对 40 种菌株的抗菌活性预测')
|
||||
plt.xticks(rotation=90)
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
# 按革兰染色分组可视化
|
||||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
||||
|
||||
for idx, gram_type in enumerate(['negative', 'positive']):
|
||||
data = strain_df[strain_df['gram_stain'] == gram_type]
|
||||
axes[idx].barh(data['strain_name'], data['antimicrobial_predictive_probability'])
|
||||
axes[idx].set_xlabel('预测概率')
|
||||
axes[idx].set_title(f'革兰{gram_type}菌')
|
||||
axes[idx].tick_params(axis='y', labelsize=8)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 性能和存储建议
|
||||
|
||||
- **聚合结果**: 每个分子 1 行,适合快速筛选
|
||||
- **菌株级别预测**: 每个分子 40 行,适合详细分析和强化学习
|
||||
- **存储空间**: 包含菌株预测的文件约为仅聚合结果的 40 倍大小
|
||||
- **推荐做法**:
|
||||
- 初筛时使用聚合结果
|
||||
- 对候选分子使用菌株级别预测进行深入分析
|
||||
Reference in New Issue
Block a user