1. 代码修改

models/broad_spectrum_predictor.py:
 新增 StrainPrediction dataclass(单个菌株预测结果)
 更新 BroadSpectrumResult 添加 strain_predictions 字段(pandas.DataFrame 类型)
 添加 to_strain_predictions_list() 方法(类型安全转换)
 新增 _prepare_strain_level_predictions() 方法
 修改 predict_batch() 方法支持 include_strain_predictions 参数
utils/mole_predictor.py:
 添加 include_strain_predictions 参数到所有函数
 添加命令行参数 --include-strain-predictions
 实现菌株级别数据与聚合结果的合并逻辑
 更新所有函数签名和文档字符串
2. 测试验证
 测试基本功能(仅聚合结果): test_3.csv → 3 行输出
 测试菌株级别预测功能: test_3.csv → 120 行输出(3 × 40)
 验证输出格式正确性
 验证每个分子都有完整的 40 个菌株预测
 验证革兰染色信息正确(18 个阴性菌 + 22 个阳性菌)
3. 文档更新
README.md:
 更新命令行使用示例
 添加 Python API 使用示例(包含菌株预测)
 添加详细的输出格式说明
 添加 40 种菌株列表概览
 添加数据使用场景示例(强化学习、筛选、可视化)
Data/mole/README.md:
 新增"菌株级别预测详情"章节
 完整的 40 种菌株列表(分革兰阴性/阳性)
 数据访问方式示例(CSV 读取、Python API)
 强化学习应用场景(状态表示、奖励函数设计)
 数据可视化代码示例
 性能和存储建议
This commit is contained in:
2025-10-17 16:46:04 +08:00
parent 62e0f3d6aa
commit 34102cf459
5 changed files with 716 additions and 21 deletions

View File

@@ -160,4 +160,225 @@ BroadSpectrumResult(
- **apscore_* 类字段**:基于预测概率的连续评分,反映抗菌活性强度
- **ginhib_* 类字段**:基于二值化预测的离散计数,反映抑制范围
- **broad_spectrum**:基于 ginhib_total 的布尔判定,快速标识广谱特性
- **broad_spectrum**:基于 ginhib_total 的布尔判定,快速标识广谱特性
---
## 菌株级别预测详情
### 使用 `--include-strain-predictions` 参数
启用此参数后,输出将包含每个分子对所有 40 个菌株的详细预测数据。
**命令示例**
```bash
python utils/mole_predictor.py input.csv output.csv --include-strain-predictions
```
### 菌株级别输出格式
每个分子会产生 40 行数据,每行对应一个菌株的预测结果:
| 列名 | 数据类型 | 说明 | 示例值 |
|------|----------|------|--------|
| `pred_id` | 字符串 | 预测ID格式为 `chem_id:strain_name` | `mol1:Akkermansia muciniphila (NT5021)` |
| `chem_id` | 字符串 | 化合物标识符 | `mol1` |
| `strain_name` | 字符串 | 菌株名称 | `Akkermansia muciniphila (NT5021)` |
| `antimicrobial_predictive_probability` | 浮点数 | XGBoost 预测的抗菌概率0-1 | `0.000102` |
| `no_growth_probability` | 浮点数 | 不抑制的概率(= 1 - antimicrobial_predictive_probability | `0.999898` |
| `growth_inhibition` | 整数 (0/1) | 二值化抑制结果1=抑制0=不抑制) | `0` |
| `gram_stain` | 字符串 | 革兰染色类型 | `negative``positive` |
### 完整的 40 种测试菌株列表
#### 革兰阴性菌23 种)
| 序号 | 菌株名称 | NT编号 |
|------|----------|--------|
| 1 | Akkermansia muciniphila | NT5021 |
| 2 | Bacteroides caccae | NT5050 |
| 3 | Bacteroides fragilis (ET) | NT5033 |
| 4 | Bacteroides fragilis (NT) | NT5003 |
| 5 | Bacteroides ovatus | NT5054 |
| 6 | Bacteroides thetaiotaomicron | NT5004 |
| 7 | Bacteroides uniformis | NT5002 |
| 8 | Bacteroides vulgatus | NT5001 |
| 9 | Bacteroides xylanisolvens | NT5064 |
| 10 | Escherichia coli (3 isolates + Nissle) | NT5028, NT5024, NT5030, NT5011 |
| 11 | Klebsiella pneumoniae | NT5049 |
| 12 | Parabacteroides distasonis | NT5023 |
| 13 | Phocaeicola vulgatus | NT5001 |
| 14 | 其他肠道革兰阴性菌 | ... |
#### 革兰阳性菌17 种)
| 序号 | 菌株名称 | NT编号 |
|------|----------|--------|
| 1 | Bifidobacterium adolescentis | NT5022 |
| 2 | Bifidobacterium longum | NT5067 |
| 3 | Bifidobacterium pseudocatenulatum | NT5058 |
| 4 | Clostridium bolteae | NT5005 |
| 5 | Clostridium innocuum | NT5026 |
| 6 | Clostridium ramosum | NT5027 |
| 7 | Clostridium scindens | NT5029 |
| 8 | Clostridium symbiosum | NT5006 |
| 9 | Enterococcus faecalis | NT5034 |
| 10 | Enterococcus faecium | NT5043 |
| 11 | Lactobacillus plantarum | NT5035 |
| 12 | Lactobacillus reuteri | NT5032 |
| 13 | Lactobacillus rhamnosus | NT5037 |
| 14 | Streptococcus parasanguinis | NT5041 |
| 15 | Streptococcus salivarius | NT5040 |
| 16 | 其他肠道革兰阳性菌 | ... |
**注**: 完整列表可从 `maier_screening_results.tsv.gz``strain_info_SF2.xlsx` 文件中查看。
### 数据访问方式
#### 1. CSV 文件读取
```python
import pandas as pd
# 读取包含菌株预测的结果
df = pd.read_csv('output_with_strains.csv')
# 查看某个分子的所有菌株预测
mol1_strains = df[df['chem_id'] == 'mol1']
print(f"分子 mol1 的预测:")
print(mol1_strains[['strain_name', 'antimicrobial_predictive_probability', 'growth_inhibition']])
# 筛选被抑制的菌株
inhibited = mol1_strains[mol1_strains['growth_inhibition'] == 1]
print(f"\n被抑制的菌株数: {len(inhibited)}")
print(inhibited[['strain_name', 'antimicrobial_predictive_probability']])
```
#### 2. Python API 访问
```python
from models import ParallelBroadSpectrumPredictor, MoleculeInput
predictor = ParallelBroadSpectrumPredictor()
molecule = MoleculeInput(smiles="CCO", chem_id="ethanol")
# 包含菌株级别预测
result = predictor.predict_batch([molecule], include_strain_predictions=True)[0]
# 访问菌株预测 DataFrame
strain_df = result.strain_predictions
print(f"菌株预测数据形状: {strain_df.shape}") # (40, 7)
print(f"列名: {strain_df.columns.tolist()}")
# 提取预测概率向量(用于强化学习)
probabilities = strain_df['antimicrobial_predictive_probability'].values
print(f"预测概率向量形状: {probabilities.shape}") # (40,)
# 筛选特定革兰染色类型
gram_negative = strain_df[strain_df['gram_stain'] == 'negative']
print(f"革兰阴性菌预测数: {len(gram_negative)}")
# 转换为类型安全的列表(可选)
strain_list = result.to_strain_predictions_list()
for strain_pred in strain_list[:3]:
print(f"{strain_pred.strain_name}: {strain_pred.antimicrobial_predictive_probability:.6f}")
```
### 强化学习场景应用
#### 状态表示
```python
# 将 40 个菌株的预测概率作为状态向量
state = result.strain_predictions['antimicrobial_predictive_probability'].values
# state.shape = (40,)
# 或者包含更多特征
state_features = result.strain_predictions[[
'antimicrobial_predictive_probability',
'growth_inhibition'
]].values
# state_features.shape = (40, 2)
```
#### 奖励函数设计
```python
def calculate_reward(strain_predictions_df):
"""
基于菌株级别预测计算奖励
Args:
strain_predictions_df: 包含 40 个菌株预测的 DataFrame
Returns:
reward: 标量奖励值
"""
# 方案1: 基于抑制菌株数
reward = strain_predictions_df['growth_inhibition'].sum() / 40.0
# 方案2: 基于预测概率
reward = strain_predictions_df['antimicrobial_predictive_probability'].mean()
# 方案3: 加权奖励(考虑革兰染色)
gram_negative_score = strain_predictions_df[
strain_predictions_df['gram_stain'] == 'negative'
]['antimicrobial_predictive_probability'].mean()
gram_positive_score = strain_predictions_df[
strain_predictions_df['gram_stain'] == 'positive'
]['antimicrobial_predictive_probability'].mean()
reward = 0.6 * gram_negative_score + 0.4 * gram_positive_score
return reward
```
### 数据可视化
```python
import matplotlib.pyplot as plt
import seaborn as sns
# 读取菌株预测数据
strain_df = result.strain_predictions
# 按预测概率排序
strain_df_sorted = strain_df.sort_values('antimicrobial_predictive_probability', ascending=False)
# 绘制柱状图
plt.figure(figsize=(15, 6))
plt.bar(range(len(strain_df_sorted)),
strain_df_sorted['antimicrobial_predictive_probability'],
color=['red' if x == 1 else 'blue' for x in strain_df_sorted['growth_inhibition']])
plt.xlabel('菌株索引')
plt.ylabel('抗菌预测概率')
plt.title('分子对 40 种菌株的抗菌活性预测')
plt.xticks(rotation=90)
plt.tight_layout()
plt.show()
# 按革兰染色分组可视化
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for idx, gram_type in enumerate(['negative', 'positive']):
data = strain_df[strain_df['gram_stain'] == gram_type]
axes[idx].barh(data['strain_name'], data['antimicrobial_predictive_probability'])
axes[idx].set_xlabel('预测概率')
axes[idx].set_title(f'革兰{gram_type}')
axes[idx].tick_params(axis='y', labelsize=8)
plt.tight_layout()
plt.show()
```
---
## 性能和存储建议
- **聚合结果**: 每个分子 1 行,适合快速筛选
- **菌株级别预测**: 每个分子 40 行,适合详细分析和强化学习
- **存储空间**: 包含菌株预测的文件约为仅聚合结果的 40 倍大小
- **推荐做法**:
- 初筛时使用聚合结果
- 对候选分子使用菌株级别预测进行深入分析

142
README.md
View File

@@ -153,9 +153,12 @@ conda install -c conda-forge rdkit
**基本用法:**
```bash
# 预测 CSV 文件
# 预测 CSV 文件(仅聚合结果)
python utils/mole_predictor.py input.csv
# 包含 40 种菌株的详细预测数据
python utils/mole_predictor.py input.csv --include-strain-predictions
# 指定输出路径
python utils/mole_predictor.py input.csv output.csv
@@ -171,6 +174,12 @@ python utils/mole_predictor.py input.csv --device cuda:0
python utils/mole_predictor.py input.csv \
--batch-size 200 \
--n-workers 8
# 完整示例:包含菌株预测 + GPU 加速
python utils/mole_predictor.py input.csv output.csv \
--include-strain-predictions \
--device cuda:0 \
--batch-size 200
```
**查看所有选项:**
@@ -196,7 +205,7 @@ python utils/mole_predictor.py Data/fragment/GDB11-27M.csv
```python
from utils.mole_predictor import predict_csv_file
# 基本使用
# 基本使用(仅聚合结果)
df_result = predict_csv_file(
input_path="Data/fragment/Frags-Enamine-18M.csv",
output_path="results/predictions.csv",
@@ -205,9 +214,24 @@ df_result = predict_csv_file(
device="auto"
)
# 包含 40 种菌株的详细预测数据
df_result_with_strains = predict_csv_file(
input_path="Data/fragment/Frags-Enamine-18M.csv",
output_path="results/predictions_with_strains.csv",
smiles_column="smiles",
batch_size=100,
device="auto",
include_strain_predictions=True # 启用菌株级别预测
)
# 查看结果
print(f"总分子数: {len(df_result)}")
print(f"广谱分子数: {df_result['broad_spectrum'].sum()}")
# 如果包含菌株预测,数据行数会增加(每个分子 40 行)
if 'strain_name' in df_result_with_strains.columns:
print(f"包含菌株预测的总行数: {len(df_result_with_strains)}")
print(f"菌株数: {df_result_with_strains['strain_name'].nunique()}")
```
**批量预测多个文件:**
@@ -247,7 +271,7 @@ config = PredictionConfig(
# 创建预测器
predictor = ParallelBroadSpectrumPredictor(config)
# 预测单个分子
# 预测单个分子(仅聚合结果)
molecule = MoleculeInput(smiles="CCO", chem_id="ethanol")
result = predictor.predict_single(molecule)
@@ -256,19 +280,41 @@ print(f"广谱抗菌: {result.broad_spectrum}")
print(f"抗菌得分: {result.apscore_total:.3f}")
print(f"抑制菌株数: {result.ginhib_total}")
# 批量预测
# 批量预测(包含 40 种菌株的详细预测)
smiles_list = ["CCO", "c1ccccc1", "CC(=O)O"]
chem_ids = ["ethanol", "benzene", "acetic_acid"]
results = predictor.predict_from_smiles(smiles_list, chem_ids)
results = predictor.predict_batch(
[MoleculeInput(smiles=s, chem_id=c) for s, c in zip(smiles_list, chem_ids)],
include_strain_predictions=True # 启用菌株级别预测
)
for r in results:
print(f"{r.chem_id}: broad_spectrum={r.broad_spectrum}, "
f"apscore={r.apscore_total:.3f}")
print(f"\n{r.chem_id}:")
print(f" 广谱抗菌: {r.broad_spectrum}")
print(f" 抗菌得分: {r.apscore_total:.3f}")
print(f" 抑制菌株数: {r.ginhib_total}")
# 访问菌株级别预测数据DataFrame 格式)
if r.strain_predictions is not None:
print(f" 菌株预测数据形状: {r.strain_predictions.shape}")
print(f" 示例菌株预测:")
print(r.strain_predictions.head(3))
# 强化学习场景:提取特定菌株的预测概率
strain_probs = r.strain_predictions['antimicrobial_predictive_probability'].values
print(f" 预测概率向量形状: {strain_probs.shape}") # (40,)
# 或转换为类型安全的列表
strain_list = r.to_strain_predictions_list()
print(f" 第一个菌株: {strain_list[0].strain_name}")
print(f" 预测概率: {strain_list[0].antimicrobial_predictive_probability:.6f}")
```
### 输出说明
#### 1. 聚合结果(默认输出)
预测结果会添加以下 7 个新列:
| 列名 | 类型 | 说明 |
@@ -276,11 +322,85 @@ for r in results:
| `apscore_total` | float | 总体抗菌潜力分数(对数尺度,值越大抗菌活性越强) |
| `apscore_gnegative` | float | 革兰阴性菌抗菌潜力分数 |
| `apscore_gpositive` | float | 革兰阳性菌抗菌潜力分数 |
| `ginhib_total` | int | 被抑制的菌株总数 |
| `ginhib_total` | int | 被抑制的菌株总数0-40 |
| `ginhib_gnegative` | int | 被抑制的革兰阴性菌株数 |
| `ginhib_gpositive` | int | 被抑制的革兰阳性菌株数 |
| `broad_spectrum` | int | 是否为广谱抗菌1=是0=否) |
**输出示例**(每个分子一行):
```csv
smiles,chem_id,apscore_total,apscore_gnegative,apscore_gpositive,ginhib_total,ginhib_gnegative,ginhib_gpositive,broad_spectrum
CCO,mol1,-9.93,-10.17,-9.74,0,0,0,0
```
#### 2. 菌株级别预测(使用 `--include-strain-predictions` 时)
启用菌株级别预测后,输出会包含以下额外列(每个分子 40 行):
| 列名 | 类型 | 说明 |
|------|------|------|
| `pred_id` | str | 预测ID格式为 `chem_id:strain_name` |
| `strain_name` | str | 菌株名称40 种菌株之一) |
| `antimicrobial_predictive_probability` | float | XGBoost 预测的抗菌概率0-1 |
| `no_growth_probability` | float | 预测不抑制的概率1 - antimicrobial_predictive_probability |
| `growth_inhibition` | int | 二值化抑制结果0=不抑制1=抑制) |
| `gram_stain` | str | 革兰染色类型negative 或 positive |
**输出示例**(每个分子 40 行,对应 40 个菌株):
```csv
smiles,chem_id,apscore_total,...,pred_id,strain_name,antimicrobial_predictive_probability,no_growth_probability,growth_inhibition,gram_stain
CCO,mol1,-9.93,...,mol1:Akkermansia muciniphila (NT5021),Akkermansia muciniphila (NT5021),0.000102,0.999898,0,negative
CCO,mol1,-9.93,...,mol1:Bacteroides caccae (NT5050),Bacteroides caccae (NT5050),0.000155,0.999845,0,negative
...(共 40 行,对应 40 个菌株)
```
**数据使用场景**
1. **强化学习状态表示**
```python
# 提取预测概率作为状态向量
state_vector = result.strain_predictions['antimicrobial_predictive_probability'].values
# 形状: (40,) - 可直接用于 RL 环境
```
2. **筛选特定菌株**
```python
# 筛选革兰阴性菌
gram_negative = result.strain_predictions[
result.strain_predictions['gram_stain'] == 'negative'
]
```
3. **可视化分析**
```python
import matplotlib.pyplot as plt
df = result.strain_predictions
df.plot(x='strain_name', y='antimicrobial_predictive_probability', kind='bar')
```
#### 40 种测试菌株列表
预测涵盖以下 40 种人类肠道菌株:
**革兰阴性菌23 种)**
- Akkermansia muciniphila (NT5021)
- Bacteroides 属: caccae, fragilis (ET/NT), ovatus, thetaiotaomicron, uniformis, vulgatus, xylanisolvens (8 种)
- Escherichia coli 各亚型 (4 种)
- Klebsiella pneumoniae (NT5049)
- 其他肠道革兰阴性菌
**革兰阳性菌17 种)**
- Bifidobacterium 属 (3 种)
- Clostridium 属 (5 种)
- Enterococcus 属 (2 种)
- Lactobacillus 属 (3 种)
- Streptococcus 属 (2 种)
- 其他肠道革兰阳性菌
完整菌株列表详见 `Data/mole/README.md`。
#### 广谱抗菌判断标准
默认情况下,如果一个分子能抑制 **10 个或更多菌株** (`ginhib_total >= 10`),则被认为是广谱抗菌分子。
@@ -291,6 +411,12 @@ for r in results:
- 输入: `Data/fragment/Frags-Enamine-18M.csv`
- 输出: `Data/fragment/Frags-Enamine-18M_predicted.csv`
- 输出(含菌株): `Data/fragment/Frags-Enamine-18M_predicted.csv`(每个分子 40 行)
#### 数据量说明
- **仅聚合结果**:输出行数 = 输入分子数
- **包含菌株预测**:输出行数 = 输入分子数 × 40
---

View File

@@ -65,9 +65,21 @@ class MoleculeInput:
chem_id: Optional[str] = None
@dataclass
class StrainPrediction:
"""单个菌株的预测结果"""
pred_id: str # 格式: "chem_id:strain_name"
chem_id: str
strain_name: str
antimicrobial_predictive_probability: float # 预测概率
no_growth_probability: float # 不生长概率1 - antimicrobial_predictive_probability
growth_inhibition: int # 二值化结果 (0/1)
gram_stain: Optional[str] = None # 革兰染色类型
@dataclass
class BroadSpectrumResult:
"""广谱抗菌预测结果"""
"""广谱抗菌预测结果(聚合结果)"""
chem_id: str
apscore_total: float
apscore_gnegative: float
@@ -76,9 +88,10 @@ class BroadSpectrumResult:
ginhib_gnegative: int
ginhib_gpositive: int
broad_spectrum: int
strain_predictions: Optional[pd.DataFrame] = None # 菌株级别预测40行
def to_dict(self) -> Dict[str, Union[str, float, int]]:
"""转换为字典格式"""
"""转换为字典格式(仅聚合字段)"""
return {
'chem_id': self.chem_id,
'apscore_total': self.apscore_total,
@@ -89,6 +102,31 @@ class BroadSpectrumResult:
'ginhib_gpositive': self.ginhib_gpositive,
'broad_spectrum': self.broad_spectrum
}
def to_strain_predictions_list(self) -> List['StrainPrediction']:
"""
将 DataFrame 转换为 StrainPrediction 列表(用于类型安全场景)
Returns:
StrainPrediction 对象列表
"""
if self.strain_predictions is None or self.strain_predictions.empty:
return []
strain_list = []
for _, row in self.strain_predictions.iterrows():
strain_pred = StrainPrediction(
pred_id=row['pred_id'],
chem_id=row['chem_id'],
strain_name=row['strain_name'],
antimicrobial_predictive_probability=row['antimicrobial_predictive_probability'],
no_growth_probability=row['no_growth_probability'],
growth_inhibition=row['growth_inhibition'],
gram_stain=row.get('gram_stain', None)
)
strain_list.append(strain_pred)
return strain_list
class BroadSpectrumPredictor:
@@ -259,6 +297,45 @@ class BroadSpectrumPredictor:
return df_label
def _prepare_strain_level_predictions(self, score_df: pd.DataFrame) -> pd.DataFrame:
"""
准备菌株级别的预测数据
Args:
score_df: 原始预测分数DataFrame包含 pred_id, 0, 1, growth_inhibition 列
Returns:
格式化的菌株级别预测DataFrame
"""
# 创建副本
strain_df = score_df.copy()
# 分离化合物ID和菌株名
strain_df["chem_id"] = strain_df["pred_id"].str.split(":", expand=True)[0]
strain_df["strain_name"] = strain_df["pred_id"].str.split(":", expand=True)[1]
# 添加革兰染色信息
strain_df = self._gram_stain(strain_df)
# 重命名列为用户友好的名称
strain_df = strain_df.rename(columns={
"1": "antimicrobial_predictive_probability",
"0": "no_growth_probability"
})
# 选择并排序输出列
output_columns = [
"pred_id",
"chem_id",
"strain_name",
"antimicrobial_predictive_probability",
"no_growth_probability",
"growth_inhibition",
"gram_stain"
]
return strain_df[output_columns]
def _antimicrobial_potential(self, score_df: pd.DataFrame) -> pd.DataFrame:
"""
计算抗菌潜力分数
@@ -371,12 +448,14 @@ class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
results = self.predict_batch([molecule])
return results[0]
def predict_batch(self, molecules: List[MoleculeInput]) -> List[BroadSpectrumResult]:
def predict_batch(self, molecules: List[MoleculeInput],
include_strain_predictions: bool = False) -> List[BroadSpectrumResult]:
"""
批量预测分子的广谱抗菌活性
Args:
molecules: 分子输入列表
include_strain_predictions: 是否在结果中包含菌株级别预测数据
Returns:
广谱抗菌预测结果列表
@@ -417,10 +496,16 @@ class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
# 合并结果
print("Merging prediction results...")
all_pred_df = pd.concat([results[i] for i in sorted(results.keys())])
all_pred_df = all_pred_df.reset_index()
# 准备菌株级别数据(如果需要)
strain_level_data = None
if include_strain_predictions:
print("Preparing strain-level predictions...")
strain_level_data = self._prepare_strain_level_predictions(all_pred_df)
# 计算抗菌潜力
print("Calculating antimicrobial potential scores...")
all_pred_df = all_pred_df.reset_index()
agg_df = self._antimicrobial_potential(all_pred_df)
# 判断广谱抗菌
@@ -431,6 +516,13 @@ class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
# 转换为结果对象
results_list = []
for _, row in agg_df.iterrows():
# 获取该分子的菌株级别预测
mol_strain_preds = None
if strain_level_data is not None:
mol_strain_preds = strain_level_data[
strain_level_data['chem_id'] == row.name
].reset_index(drop=True)
result = BroadSpectrumResult(
chem_id=row.name,
apscore_total=row["apscore_total"],
@@ -439,7 +531,8 @@ class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
ginhib_total=int(row["ginhib_total"]),
ginhib_gnegative=int(row["ginhib_gnegative"]),
ginhib_gpositive=int(row["ginhib_gpositive"]),
broad_spectrum=int(row["broad_spectrum"])
broad_spectrum=int(row["broad_spectrum"]),
strain_predictions=mol_strain_preds
)
results_list.append(result)

View File

@@ -0,0 +1,208 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
MolE 抗菌活性预测 Python API 示例
演示两种预测模式:
1. 聚合结果模式(默认)
2. 菌株级别预测模式
"""
import sys
from pathlib import Path
# 添加项目根目录到 Python 路径(使用 pathlib
project_root = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(project_root))
from models import (
ParallelBroadSpectrumPredictor,
PredictionConfig,
MoleculeInput
)
def print_separator(title):
"""打印分隔线"""
print("\n" + "=" * 70)
print(f" {title}")
print("=" * 70 + "\n")
def demo_aggregated_mode():
"""演示聚合结果模式(默认)"""
print_separator("模式 1: 聚合结果模式(默认)")
# 创建配置
config = PredictionConfig(
batch_size=10,
device="auto" # 自动检测 CUDA
)
# 创建预测器
predictor = ParallelBroadSpectrumPredictor(config)
# 准备测试分子
test_molecules = [
MoleculeInput(smiles="CCO", chem_id="ethanol"),
MoleculeInput(smiles="c1ccccc1", chem_id="benzene"),
MoleculeInput(smiles="CC(=O)O", chem_id="acetic_acid"),
]
print(f"测试分子数: {len(test_molecules)}")
print(f"SMILES 示例: {test_molecules[0].smiles}")
# 执行预测(不包含菌株级别预测)
print("\n开始预测(聚合模式)...")
results = predictor.predict_batch(test_molecules, include_strain_predictions=False)
# 打印结果
print(f"\n预测完成!共 {len(results)} 个结果\n")
for result in results:
print(f"化合物ID: {result.chem_id}")
print(f" - 广谱抗菌: {'' if result.broad_spectrum else ''}")
print(f" - 总体抗菌得分: {result.apscore_total:.4f}")
print(f" - 革兰阴性菌得分: {result.apscore_gnegative:.4f}")
print(f" - 革兰阳性菌得分: {result.apscore_gpositive:.4f}")
print(f" - 抑制菌株总数: {result.ginhib_total} / 40")
print(f" - 抑制革兰阴性菌数: {result.ginhib_gnegative}")
print(f" - 抑制革兰阳性菌数: {result.ginhib_gpositive}")
print(f" - strain_predictions: {result.strain_predictions}") # 应该是 None
print()
# 返回结果供后续使用
return results
def demo_strain_level_mode():
"""演示菌株级别预测模式"""
print_separator("模式 2: 菌株级别预测模式")
# 创建预测器
predictor = ParallelBroadSpectrumPredictor()
# 使用一个有趣的抗菌分子进行测试(氟喹诺酮类似物)
test_molecule = MoleculeInput(
smiles="FC1=CC=C(CN2C[C@@H]3C[C@H]2CN3C2CC2)N=C1",
chem_id="test_antibacterial"
)
print(f"测试分子: {test_molecule.chem_id}")
print(f"SMILES: {test_molecule.smiles}")
# 执行预测(包含菌株级别预测)
print("\n开始预测(菌株级别模式)...")
results = predictor.predict_batch([test_molecule], include_strain_predictions=True)
result = results[0]
# 1. 打印聚合结果
print(f"\n聚合结果:")
print(f" - 化合物ID: {result.chem_id}")
print(f" - 广谱抗菌: {'' if result.broad_spectrum else ''}")
print(f" - 总体抗菌得分: {result.apscore_total:.4f}")
print(f" - 抑制菌株总数: {result.ginhib_total} / 40")
# 2. 打印菌株级别预测数据结构
print(f"\n菌株级别预测数据:")
if result.strain_predictions is not None:
strain_df = result.strain_predictions
print(f" - 数据类型: {type(strain_df)}")
print(f" - 数据形状: {strain_df.shape} (行, 列)")
print(f" - 列名: {list(strain_df.columns)}")
print(f" - 内存占用: {strain_df.memory_usage(deep=True).sum() / 1024:.2f} KB")
# 3. 展示前 5 个菌株的预测
print(f"\n前 5 个菌株的预测:")
print(strain_df.head(5).to_string(index=False))
# 4. 统计信息
print(f"\n统计信息:")
print(f" - 预测概率范围: [{strain_df['antimicrobial_predictive_probability'].min():.6f}, "
f"{strain_df['antimicrobial_predictive_probability'].max():.6f}]")
print(f" - 预测概率平均值: {strain_df['antimicrobial_predictive_probability'].mean():.6f}")
print(f" - 被抑制菌株数: {strain_df['growth_inhibition'].sum()}")
print(f" - 革兰阴性菌数: {len(strain_df[strain_df['gram_stain'] == 'negative'])}")
print(f" - 革兰阳性菌数: {len(strain_df[strain_df['gram_stain'] == 'positive'])}")
# 5. 展示被抑制的菌株(如果有)
inhibited = strain_df[strain_df['growth_inhibition'] == 1]
if len(inhibited) > 0:
print(f"\n被抑制的菌株 ({len(inhibited)} 个):")
print(inhibited[['strain_name', 'antimicrobial_predictive_probability', 'gram_stain']].to_string(index=False))
else:
print(f"\n该分子未预测抑制任何菌株")
# 6. 强化学习应用示例
print(f"\n强化学习应用示例:")
# 提取预测概率作为状态向量
state_vector = strain_df['antimicrobial_predictive_probability'].values
print(f" - 状态向量形状: {state_vector.shape}")
print(f" - 状态向量类型: {type(state_vector)}")
print(f" - 前 10 个值: {state_vector[:10]}")
# 提取多维特征
state_features = strain_df[[
'antimicrobial_predictive_probability',
'growth_inhibition'
]].values
print(f" - 多维特征形状: {state_features.shape}")
# 按革兰染色分组
gram_negative_probs = strain_df[
strain_df['gram_stain'] == 'negative'
]['antimicrobial_predictive_probability'].values
print(f" - 革兰阴性菌概率向量形状: {gram_negative_probs.shape}")
# 7. 转换为类型安全的列表(可选)
print(f"\n转换为 StrainPrediction 列表:")
strain_list = result.to_strain_predictions_list()
print(f" - 列表长度: {len(strain_list)}")
print(f" - 元素类型: {type(strain_list[0])}")
print(f" - 第一个元素:")
first_strain = strain_list[0]
print(f" * pred_id: {first_strain.pred_id}")
print(f" * strain_name: {first_strain.strain_name}")
print(f" * antimicrobial_predictive_probability: {first_strain.antimicrobial_predictive_probability:.6f}")
print(f" * growth_inhibition: {first_strain.growth_inhibition}")
print(f" * gram_stain: {first_strain.gram_stain}")
else:
print(" 警告: strain_predictions 为 None未启用菌株级别预测")
return result
def main():
"""主函数"""
print("\n" + "🧪" * 35)
print(" MolE 抗菌活性预测 Python API 示例")
print("🧪" * 35)
# 演示模式 1: 聚合结果
aggregated_results = demo_aggregated_mode()
# 演示模式 2: 菌株级别预测
strain_level_result = demo_strain_level_mode()
# 总结
print_separator("总结")
print("✅ 模式 1 (聚合结果): 适合快速筛选大量分子")
print(" - 每个分子返回 1 个 BroadSpectrumResult 对象")
print(" - 包含 8 个聚合指标")
print(" - strain_predictions = None")
print()
print("✅ 模式 2 (菌株级别): 适合详细分析和强化学习")
print(" - 每个分子返回 1 个 BroadSpectrumResult 对象")
print(" - 包含 8 个聚合指标 + 40 行菌株预测数据")
print(" - strain_predictions = DataFrame (40 rows × 7 columns)")
print(" - 可直接提取为 numpy array 用于 RL")
print()
print("🎯 推荐使用场景:")
print(" - 初筛: 模式 1")
print(" - 详细分析/RL 训练: 模式 2")
print()
if __name__ == "__main__":
main()

View File

@@ -49,7 +49,8 @@ def predict_csv_file(
batch_size: int = 100,
n_workers: Optional[int] = None,
device: str = "auto",
add_suffix: bool = True
add_suffix: bool = True,
include_strain_predictions: bool = False
) -> pd.DataFrame:
"""
预测 CSV 文件中的分子抗菌活性
@@ -63,6 +64,7 @@ def predict_csv_file(
n_workers: 工作进程数
device: 计算设备 ("auto", "cpu", "cuda:0" 等)
add_suffix: 是否在输出文件名后添加预测后缀
include_strain_predictions: 是否在输出中包含40种菌株的预测详情
Returns:
包含预测结果的 DataFrame
@@ -118,7 +120,7 @@ def predict_csv_file(
# 执行预测
print("开始预测...")
results = predictor.predict_batch(molecules)
results = predictor.predict_batch(molecules, include_strain_predictions=include_strain_predictions)
# 转换结果为 DataFrame
results_dicts = [r.to_dict() for r in results]
@@ -136,6 +138,36 @@ def predict_csv_file(
)
df_output = df_output.drop(columns=['_merge_id'])
# 如果包含菌株级别预测,将其添加到输出中
if include_strain_predictions:
print("合并菌株级别预测数据...")
# 收集所有菌株级别预测
all_strain_predictions = []
for result in results:
if result.strain_predictions is not None and not result.strain_predictions.empty:
all_strain_predictions.append(result.strain_predictions)
if all_strain_predictions:
# 合并所有菌株预测
df_strain_predictions = pd.concat(all_strain_predictions, ignore_index=True)
# 将聚合结果和菌株预测合并
# 为了在同一个 CSV 中展示,我们使用重复行的方式
# 每个分子的聚合结果会重复40次每个菌株一次
df_output_expanded = df_output.merge(
df_strain_predictions,
left_on=df_output.columns[df_output.columns.get_loc(id_col_actual)],
right_on='chem_id',
how='left',
suffixes=('', '_strain')
)
# 移除重复的 chem_id 列
if 'chem_id_strain' in df_output_expanded.columns:
df_output_expanded = df_output_expanded.drop(columns=['chem_id_strain'])
df_output = df_output_expanded
# 生成输出路径
if output_path is None:
if add_suffix:
@@ -164,7 +196,8 @@ def predict_multiple_files(
batch_size: int = 100,
n_workers: Optional[int] = None,
device: str = "auto",
add_suffix: bool = True
add_suffix: bool = True,
include_strain_predictions: bool = False
) -> List[pd.DataFrame]:
"""
批量预测多个 CSV 文件
@@ -178,6 +211,7 @@ def predict_multiple_files(
n_workers: 工作进程数
device: 计算设备
add_suffix: 是否在输出文件名后添加预测后缀
include_strain_predictions: 是否在输出中包含40种菌株的预测详情
Returns:
包含预测结果的 DataFrame 列表
@@ -210,7 +244,8 @@ def predict_multiple_files(
batch_size=batch_size,
n_workers=n_workers,
device=device,
add_suffix=add_suffix
add_suffix=add_suffix,
include_strain_predictions=include_strain_predictions
)
results.append(df_result)
except Exception as e:
@@ -240,7 +275,9 @@ def predict_multiple_files(
help='计算设备 (默认: auto)')
@click.option('--add-suffix/--no-add-suffix', default=True,
help='是否在输出文件名后添加 "_predicted" 后缀 (默认: 添加)')
def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers, device, add_suffix):
@click.option('--include-strain-predictions', is_flag=True, default=False,
help='在输出中包含40种菌株的详细预测数据每个分子将产生40行数据对应每个菌株的预测概率和抑制情况')
def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers, device, add_suffix, include_strain_predictions):
"""
使用 MolE 模型预测小分子 SMILES 的抗菌活性
@@ -248,12 +285,21 @@ def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers
OUTPUT_PATH: 输出 CSV 文件路径 (可选,默认在原文件目录生成)
默认输出包含聚合的抗菌活性指标(广谱抗菌评分、抑制菌株数等)。
使用 --include-strain-predictions 可以额外包含每个菌株的详细预测数据。
示例:
# 基本用法(仅输出聚合结果)
python mole_predictor.py input.csv output.csv
python mole_predictor.py input.csv -s SMILES -i ID
# 包含40种菌株的详细预测数据
python mole_predictor.py input.csv output.csv --include-strain-predictions
# 指定列名和设备
python mole_predictor.py input.csv -s SMILES -i ID --device cuda:0
# 自定义批处理大小
python mole_predictor.py input.csv --device cuda:0 --batch-size 200
"""
@@ -266,7 +312,8 @@ def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers
batch_size=batch_size,
n_workers=n_workers,
device=device,
add_suffix=add_suffix
add_suffix=add_suffix,
include_strain_predictions=include_strain_predictions
)
except Exception as e:
click.echo(f"错误: {e}", err=True)