diff --git a/Data/mole/README.md b/Data/mole/README.md index 36d8097..a490033 100644 --- a/Data/mole/README.md +++ b/Data/mole/README.md @@ -160,4 +160,225 @@ BroadSpectrumResult( - **apscore_* 类字段**:基于预测概率的连续评分,反映抗菌活性强度 - **ginhib_* 类字段**:基于二值化预测的离散计数,反映抑制范围 -- **broad_spectrum**:基于 ginhib_total 的布尔判定,快速标识广谱特性 \ No newline at end of file +- **broad_spectrum**:基于 ginhib_total 的布尔判定,快速标识广谱特性 + +--- + +## 菌株级别预测详情 + +### 使用 `--include-strain-predictions` 参数 + +启用此参数后,输出将包含每个分子对所有 40 个菌株的详细预测数据。 + +**命令示例**: +```bash +python utils/mole_predictor.py input.csv output.csv --include-strain-predictions +``` + +### 菌株级别输出格式 + +每个分子会产生 40 行数据,每行对应一个菌株的预测结果: + +| 列名 | 数据类型 | 说明 | 示例值 | +|------|----------|------|--------| +| `pred_id` | 字符串 | 预测ID,格式为 `chem_id:strain_name` | `mol1:Akkermansia muciniphila (NT5021)` | +| `chem_id` | 字符串 | 化合物标识符 | `mol1` | +| `strain_name` | 字符串 | 菌株名称 | `Akkermansia muciniphila (NT5021)` | +| `antimicrobial_predictive_probability` | 浮点数 | XGBoost 预测的抗菌概率(0-1) | `0.000102` | +| `no_growth_probability` | 浮点数 | 不抑制的概率(= 1 - antimicrobial_predictive_probability) | `0.999898` | +| `growth_inhibition` | 整数 (0/1) | 二值化抑制结果(1=抑制,0=不抑制) | `0` | +| `gram_stain` | 字符串 | 革兰染色类型 | `negative` 或 `positive` | + +### 完整的 40 种测试菌株列表 + +#### 革兰阴性菌(23 种) + +| 序号 | 菌株名称 | NT编号 | +|------|----------|--------| +| 1 | Akkermansia muciniphila | NT5021 | +| 2 | Bacteroides caccae | NT5050 | +| 3 | Bacteroides fragilis (ET) | NT5033 | +| 4 | Bacteroides fragilis (NT) | NT5003 | +| 5 | Bacteroides ovatus | NT5054 | +| 6 | Bacteroides thetaiotaomicron | NT5004 | +| 7 | Bacteroides uniformis | NT5002 | +| 8 | Bacteroides vulgatus | NT5001 | +| 9 | Bacteroides xylanisolvens | NT5064 | +| 10 | Escherichia coli (3 isolates + Nissle) | NT5028, NT5024, NT5030, NT5011 | +| 11 | Klebsiella pneumoniae | NT5049 | +| 12 | Parabacteroides distasonis | NT5023 | +| 13 | Phocaeicola vulgatus | NT5001 | +| 14 | 其他肠道革兰阴性菌 | ... | + +#### 革兰阳性菌(17 种) + +| 序号 | 菌株名称 | NT编号 | +|------|----------|--------| +| 1 | Bifidobacterium adolescentis | NT5022 | +| 2 | Bifidobacterium longum | NT5067 | +| 3 | Bifidobacterium pseudocatenulatum | NT5058 | +| 4 | Clostridium bolteae | NT5005 | +| 5 | Clostridium innocuum | NT5026 | +| 6 | Clostridium ramosum | NT5027 | +| 7 | Clostridium scindens | NT5029 | +| 8 | Clostridium symbiosum | NT5006 | +| 9 | Enterococcus faecalis | NT5034 | +| 10 | Enterococcus faecium | NT5043 | +| 11 | Lactobacillus plantarum | NT5035 | +| 12 | Lactobacillus reuteri | NT5032 | +| 13 | Lactobacillus rhamnosus | NT5037 | +| 14 | Streptococcus parasanguinis | NT5041 | +| 15 | Streptococcus salivarius | NT5040 | +| 16 | 其他肠道革兰阳性菌 | ... | + +**注**: 完整列表可从 `maier_screening_results.tsv.gz` 和 `strain_info_SF2.xlsx` 文件中查看。 + +### 数据访问方式 + +#### 1. CSV 文件读取 + +```python +import pandas as pd + +# 读取包含菌株预测的结果 +df = pd.read_csv('output_with_strains.csv') + +# 查看某个分子的所有菌株预测 +mol1_strains = df[df['chem_id'] == 'mol1'] +print(f"分子 mol1 的预测:") +print(mol1_strains[['strain_name', 'antimicrobial_predictive_probability', 'growth_inhibition']]) + +# 筛选被抑制的菌株 +inhibited = mol1_strains[mol1_strains['growth_inhibition'] == 1] +print(f"\n被抑制的菌株数: {len(inhibited)}") +print(inhibited[['strain_name', 'antimicrobial_predictive_probability']]) +``` + +#### 2. Python API 访问 + +```python +from models import ParallelBroadSpectrumPredictor, MoleculeInput + +predictor = ParallelBroadSpectrumPredictor() +molecule = MoleculeInput(smiles="CCO", chem_id="ethanol") + +# 包含菌株级别预测 +result = predictor.predict_batch([molecule], include_strain_predictions=True)[0] + +# 访问菌株预测 DataFrame +strain_df = result.strain_predictions +print(f"菌株预测数据形状: {strain_df.shape}") # (40, 7) +print(f"列名: {strain_df.columns.tolist()}") + +# 提取预测概率向量(用于强化学习) +probabilities = strain_df['antimicrobial_predictive_probability'].values +print(f"预测概率向量形状: {probabilities.shape}") # (40,) + +# 筛选特定革兰染色类型 +gram_negative = strain_df[strain_df['gram_stain'] == 'negative'] +print(f"革兰阴性菌预测数: {len(gram_negative)}") + +# 转换为类型安全的列表(可选) +strain_list = result.to_strain_predictions_list() +for strain_pred in strain_list[:3]: + print(f"{strain_pred.strain_name}: {strain_pred.antimicrobial_predictive_probability:.6f}") +``` + +### 强化学习场景应用 + +#### 状态表示 + +```python +# 将 40 个菌株的预测概率作为状态向量 +state = result.strain_predictions['antimicrobial_predictive_probability'].values +# state.shape = (40,) + +# 或者包含更多特征 +state_features = result.strain_predictions[[ + 'antimicrobial_predictive_probability', + 'growth_inhibition' +]].values +# state_features.shape = (40, 2) +``` + +#### 奖励函数设计 + +```python +def calculate_reward(strain_predictions_df): + """ + 基于菌株级别预测计算奖励 + + Args: + strain_predictions_df: 包含 40 个菌株预测的 DataFrame + + Returns: + reward: 标量奖励值 + """ + # 方案1: 基于抑制菌株数 + reward = strain_predictions_df['growth_inhibition'].sum() / 40.0 + + # 方案2: 基于预测概率 + reward = strain_predictions_df['antimicrobial_predictive_probability'].mean() + + # 方案3: 加权奖励(考虑革兰染色) + gram_negative_score = strain_predictions_df[ + strain_predictions_df['gram_stain'] == 'negative' + ]['antimicrobial_predictive_probability'].mean() + + gram_positive_score = strain_predictions_df[ + strain_predictions_df['gram_stain'] == 'positive' + ]['antimicrobial_predictive_probability'].mean() + + reward = 0.6 * gram_negative_score + 0.4 * gram_positive_score + + return reward +``` + +### 数据可视化 + +```python +import matplotlib.pyplot as plt +import seaborn as sns + +# 读取菌株预测数据 +strain_df = result.strain_predictions + +# 按预测概率排序 +strain_df_sorted = strain_df.sort_values('antimicrobial_predictive_probability', ascending=False) + +# 绘制柱状图 +plt.figure(figsize=(15, 6)) +plt.bar(range(len(strain_df_sorted)), + strain_df_sorted['antimicrobial_predictive_probability'], + color=['red' if x == 1 else 'blue' for x in strain_df_sorted['growth_inhibition']]) +plt.xlabel('菌株索引') +plt.ylabel('抗菌预测概率') +plt.title('分子对 40 种菌株的抗菌活性预测') +plt.xticks(rotation=90) +plt.tight_layout() +plt.show() + +# 按革兰染色分组可视化 +fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + +for idx, gram_type in enumerate(['negative', 'positive']): + data = strain_df[strain_df['gram_stain'] == gram_type] + axes[idx].barh(data['strain_name'], data['antimicrobial_predictive_probability']) + axes[idx].set_xlabel('预测概率') + axes[idx].set_title(f'革兰{gram_type}菌') + axes[idx].tick_params(axis='y', labelsize=8) + +plt.tight_layout() +plt.show() +``` + +--- + +## 性能和存储建议 + +- **聚合结果**: 每个分子 1 行,适合快速筛选 +- **菌株级别预测**: 每个分子 40 行,适合详细分析和强化学习 +- **存储空间**: 包含菌株预测的文件约为仅聚合结果的 40 倍大小 +- **推荐做法**: + - 初筛时使用聚合结果 + - 对候选分子使用菌株级别预测进行深入分析 \ No newline at end of file diff --git a/README.md b/README.md index 937adee..c8edcb0 100755 --- a/README.md +++ b/README.md @@ -153,9 +153,12 @@ conda install -c conda-forge rdkit **基本用法:** ```bash -# 预测 CSV 文件 +# 预测 CSV 文件(仅聚合结果) python utils/mole_predictor.py input.csv +# 包含 40 种菌株的详细预测数据 +python utils/mole_predictor.py input.csv --include-strain-predictions + # 指定输出路径 python utils/mole_predictor.py input.csv output.csv @@ -171,6 +174,12 @@ python utils/mole_predictor.py input.csv --device cuda:0 python utils/mole_predictor.py input.csv \ --batch-size 200 \ --n-workers 8 + +# 完整示例:包含菌株预测 + GPU 加速 +python utils/mole_predictor.py input.csv output.csv \ + --include-strain-predictions \ + --device cuda:0 \ + --batch-size 200 ``` **查看所有选项:** @@ -196,7 +205,7 @@ python utils/mole_predictor.py Data/fragment/GDB11-27M.csv ```python from utils.mole_predictor import predict_csv_file -# 基本使用 +# 基本使用(仅聚合结果) df_result = predict_csv_file( input_path="Data/fragment/Frags-Enamine-18M.csv", output_path="results/predictions.csv", @@ -205,9 +214,24 @@ df_result = predict_csv_file( device="auto" ) +# 包含 40 种菌株的详细预测数据 +df_result_with_strains = predict_csv_file( + input_path="Data/fragment/Frags-Enamine-18M.csv", + output_path="results/predictions_with_strains.csv", + smiles_column="smiles", + batch_size=100, + device="auto", + include_strain_predictions=True # 启用菌株级别预测 +) + # 查看结果 print(f"总分子数: {len(df_result)}") print(f"广谱分子数: {df_result['broad_spectrum'].sum()}") + +# 如果包含菌株预测,数据行数会增加(每个分子 40 行) +if 'strain_name' in df_result_with_strains.columns: + print(f"包含菌株预测的总行数: {len(df_result_with_strains)}") + print(f"菌株数: {df_result_with_strains['strain_name'].nunique()}") ``` **批量预测多个文件:** @@ -247,7 +271,7 @@ config = PredictionConfig( # 创建预测器 predictor = ParallelBroadSpectrumPredictor(config) -# 预测单个分子 +# 预测单个分子(仅聚合结果) molecule = MoleculeInput(smiles="CCO", chem_id="ethanol") result = predictor.predict_single(molecule) @@ -256,19 +280,41 @@ print(f"广谱抗菌: {result.broad_spectrum}") print(f"抗菌得分: {result.apscore_total:.3f}") print(f"抑制菌株数: {result.ginhib_total}") -# 批量预测 +# 批量预测(包含 40 种菌株的详细预测) smiles_list = ["CCO", "c1ccccc1", "CC(=O)O"] chem_ids = ["ethanol", "benzene", "acetic_acid"] -results = predictor.predict_from_smiles(smiles_list, chem_ids) +results = predictor.predict_batch( + [MoleculeInput(smiles=s, chem_id=c) for s, c in zip(smiles_list, chem_ids)], + include_strain_predictions=True # 启用菌株级别预测 +) for r in results: - print(f"{r.chem_id}: broad_spectrum={r.broad_spectrum}, " - f"apscore={r.apscore_total:.3f}") + print(f"\n{r.chem_id}:") + print(f" 广谱抗菌: {r.broad_spectrum}") + print(f" 抗菌得分: {r.apscore_total:.3f}") + print(f" 抑制菌株数: {r.ginhib_total}") + + # 访问菌株级别预测数据(DataFrame 格式) + if r.strain_predictions is not None: + print(f" 菌株预测数据形状: {r.strain_predictions.shape}") + print(f" 示例菌株预测:") + print(r.strain_predictions.head(3)) + + # 强化学习场景:提取特定菌株的预测概率 + strain_probs = r.strain_predictions['antimicrobial_predictive_probability'].values + print(f" 预测概率向量形状: {strain_probs.shape}") # (40,) + + # 或转换为类型安全的列表 + strain_list = r.to_strain_predictions_list() + print(f" 第一个菌株: {strain_list[0].strain_name}") + print(f" 预测概率: {strain_list[0].antimicrobial_predictive_probability:.6f}") ``` ### 输出说明 +#### 1. 聚合结果(默认输出) + 预测结果会添加以下 7 个新列: | 列名 | 类型 | 说明 | @@ -276,11 +322,85 @@ for r in results: | `apscore_total` | float | 总体抗菌潜力分数(对数尺度,值越大抗菌活性越强) | | `apscore_gnegative` | float | 革兰阴性菌抗菌潜力分数 | | `apscore_gpositive` | float | 革兰阳性菌抗菌潜力分数 | -| `ginhib_total` | int | 被抑制的菌株总数 | +| `ginhib_total` | int | 被抑制的菌株总数(0-40) | | `ginhib_gnegative` | int | 被抑制的革兰阴性菌株数 | | `ginhib_gpositive` | int | 被抑制的革兰阳性菌株数 | | `broad_spectrum` | int | 是否为广谱抗菌(1=是,0=否) | +**输出示例**(每个分子一行): + +```csv +smiles,chem_id,apscore_total,apscore_gnegative,apscore_gpositive,ginhib_total,ginhib_gnegative,ginhib_gpositive,broad_spectrum +CCO,mol1,-9.93,-10.17,-9.74,0,0,0,0 +``` + +#### 2. 菌株级别预测(使用 `--include-strain-predictions` 时) + +启用菌株级别预测后,输出会包含以下额外列(每个分子 40 行): + +| 列名 | 类型 | 说明 | +|------|------|------| +| `pred_id` | str | 预测ID,格式为 `chem_id:strain_name` | +| `strain_name` | str | 菌株名称(40 种菌株之一) | +| `antimicrobial_predictive_probability` | float | XGBoost 预测的抗菌概率(0-1) | +| `no_growth_probability` | float | 预测不抑制的概率(1 - antimicrobial_predictive_probability) | +| `growth_inhibition` | int | 二值化抑制结果(0=不抑制,1=抑制) | +| `gram_stain` | str | 革兰染色类型(negative 或 positive) | + +**输出示例**(每个分子 40 行,对应 40 个菌株): + +```csv +smiles,chem_id,apscore_total,...,pred_id,strain_name,antimicrobial_predictive_probability,no_growth_probability,growth_inhibition,gram_stain +CCO,mol1,-9.93,...,mol1:Akkermansia muciniphila (NT5021),Akkermansia muciniphila (NT5021),0.000102,0.999898,0,negative +CCO,mol1,-9.93,...,mol1:Bacteroides caccae (NT5050),Bacteroides caccae (NT5050),0.000155,0.999845,0,negative +...(共 40 行,对应 40 个菌株) +``` + +**数据使用场景**: + +1. **强化学习状态表示**: + ```python + # 提取预测概率作为状态向量 + state_vector = result.strain_predictions['antimicrobial_predictive_probability'].values + # 形状: (40,) - 可直接用于 RL 环境 + ``` + +2. **筛选特定菌株**: + ```python + # 筛选革兰阴性菌 + gram_negative = result.strain_predictions[ + result.strain_predictions['gram_stain'] == 'negative' + ] + ``` + +3. **可视化分析**: + ```python + import matplotlib.pyplot as plt + df = result.strain_predictions + df.plot(x='strain_name', y='antimicrobial_predictive_probability', kind='bar') + ``` + +#### 40 种测试菌株列表 + +预测涵盖以下 40 种人类肠道菌株: + +**革兰阴性菌(23 种)**: +- Akkermansia muciniphila (NT5021) +- Bacteroides 属: caccae, fragilis (ET/NT), ovatus, thetaiotaomicron, uniformis, vulgatus, xylanisolvens (8 种) +- Escherichia coli 各亚型 (4 种) +- Klebsiella pneumoniae (NT5049) +- 其他肠道革兰阴性菌 + +**革兰阳性菌(17 种)**: +- Bifidobacterium 属 (3 种) +- Clostridium 属 (5 种) +- Enterococcus 属 (2 种) +- Lactobacillus 属 (3 种) +- Streptococcus 属 (2 种) +- 其他肠道革兰阳性菌 + +完整菌株列表详见 `Data/mole/README.md`。 + #### 广谱抗菌判断标准 默认情况下,如果一个分子能抑制 **10 个或更多菌株** (`ginhib_total >= 10`),则被认为是广谱抗菌分子。 @@ -291,6 +411,12 @@ for r in results: - 输入: `Data/fragment/Frags-Enamine-18M.csv` - 输出: `Data/fragment/Frags-Enamine-18M_predicted.csv` +- 输出(含菌株): `Data/fragment/Frags-Enamine-18M_predicted.csv`(每个分子 40 行) + +#### 数据量说明 + +- **仅聚合结果**:输出行数 = 输入分子数 +- **包含菌株预测**:输出行数 = 输入分子数 × 40 --- diff --git a/models/broad_spectrum_predictor.py b/models/broad_spectrum_predictor.py index fa86043..111be71 100644 --- a/models/broad_spectrum_predictor.py +++ b/models/broad_spectrum_predictor.py @@ -65,9 +65,21 @@ class MoleculeInput: chem_id: Optional[str] = None +@dataclass +class StrainPrediction: + """单个菌株的预测结果""" + pred_id: str # 格式: "chem_id:strain_name" + chem_id: str + strain_name: str + antimicrobial_predictive_probability: float # 预测概率 + no_growth_probability: float # 不生长概率(1 - antimicrobial_predictive_probability) + growth_inhibition: int # 二值化结果 (0/1) + gram_stain: Optional[str] = None # 革兰染色类型 + + @dataclass class BroadSpectrumResult: - """广谱抗菌预测结果""" + """广谱抗菌预测结果(聚合结果)""" chem_id: str apscore_total: float apscore_gnegative: float @@ -76,9 +88,10 @@ class BroadSpectrumResult: ginhib_gnegative: int ginhib_gpositive: int broad_spectrum: int + strain_predictions: Optional[pd.DataFrame] = None # 菌株级别预测(40行) def to_dict(self) -> Dict[str, Union[str, float, int]]: - """转换为字典格式""" + """转换为字典格式(仅聚合字段)""" return { 'chem_id': self.chem_id, 'apscore_total': self.apscore_total, @@ -89,6 +102,31 @@ class BroadSpectrumResult: 'ginhib_gpositive': self.ginhib_gpositive, 'broad_spectrum': self.broad_spectrum } + + def to_strain_predictions_list(self) -> List['StrainPrediction']: + """ + 将 DataFrame 转换为 StrainPrediction 列表(用于类型安全场景) + + Returns: + StrainPrediction 对象列表 + """ + if self.strain_predictions is None or self.strain_predictions.empty: + return [] + + strain_list = [] + for _, row in self.strain_predictions.iterrows(): + strain_pred = StrainPrediction( + pred_id=row['pred_id'], + chem_id=row['chem_id'], + strain_name=row['strain_name'], + antimicrobial_predictive_probability=row['antimicrobial_predictive_probability'], + no_growth_probability=row['no_growth_probability'], + growth_inhibition=row['growth_inhibition'], + gram_stain=row.get('gram_stain', None) + ) + strain_list.append(strain_pred) + + return strain_list class BroadSpectrumPredictor: @@ -259,6 +297,45 @@ class BroadSpectrumPredictor: return df_label + def _prepare_strain_level_predictions(self, score_df: pd.DataFrame) -> pd.DataFrame: + """ + 准备菌株级别的预测数据 + + Args: + score_df: 原始预测分数DataFrame,包含 pred_id, 0, 1, growth_inhibition 列 + + Returns: + 格式化的菌株级别预测DataFrame + """ + # 创建副本 + strain_df = score_df.copy() + + # 分离化合物ID和菌株名 + strain_df["chem_id"] = strain_df["pred_id"].str.split(":", expand=True)[0] + strain_df["strain_name"] = strain_df["pred_id"].str.split(":", expand=True)[1] + + # 添加革兰染色信息 + strain_df = self._gram_stain(strain_df) + + # 重命名列为用户友好的名称 + strain_df = strain_df.rename(columns={ + "1": "antimicrobial_predictive_probability", + "0": "no_growth_probability" + }) + + # 选择并排序输出列 + output_columns = [ + "pred_id", + "chem_id", + "strain_name", + "antimicrobial_predictive_probability", + "no_growth_probability", + "growth_inhibition", + "gram_stain" + ] + + return strain_df[output_columns] + def _antimicrobial_potential(self, score_df: pd.DataFrame) -> pd.DataFrame: """ 计算抗菌潜力分数 @@ -371,12 +448,14 @@ class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor): results = self.predict_batch([molecule]) return results[0] - def predict_batch(self, molecules: List[MoleculeInput]) -> List[BroadSpectrumResult]: + def predict_batch(self, molecules: List[MoleculeInput], + include_strain_predictions: bool = False) -> List[BroadSpectrumResult]: """ 批量预测分子的广谱抗菌活性 Args: molecules: 分子输入列表 + include_strain_predictions: 是否在结果中包含菌株级别预测数据 Returns: 广谱抗菌预测结果列表 @@ -417,10 +496,16 @@ class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor): # 合并结果 print("Merging prediction results...") all_pred_df = pd.concat([results[i] for i in sorted(results.keys())]) + all_pred_df = all_pred_df.reset_index() + + # 准备菌株级别数据(如果需要) + strain_level_data = None + if include_strain_predictions: + print("Preparing strain-level predictions...") + strain_level_data = self._prepare_strain_level_predictions(all_pred_df) # 计算抗菌潜力 print("Calculating antimicrobial potential scores...") - all_pred_df = all_pred_df.reset_index() agg_df = self._antimicrobial_potential(all_pred_df) # 判断广谱抗菌 @@ -431,6 +516,13 @@ class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor): # 转换为结果对象 results_list = [] for _, row in agg_df.iterrows(): + # 获取该分子的菌株级别预测 + mol_strain_preds = None + if strain_level_data is not None: + mol_strain_preds = strain_level_data[ + strain_level_data['chem_id'] == row.name + ].reset_index(drop=True) + result = BroadSpectrumResult( chem_id=row.name, apscore_total=row["apscore_total"], @@ -439,7 +531,8 @@ class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor): ginhib_total=int(row["ginhib_total"]), ginhib_gnegative=int(row["ginhib_gnegative"]), ginhib_gpositive=int(row["ginhib_gpositive"]), - broad_spectrum=int(row["broad_spectrum"]) + broad_spectrum=int(row["broad_spectrum"]), + strain_predictions=mol_strain_preds ) results_list.append(result) diff --git a/test/mole_predict_singal_mole.py b/test/mole_predict_singal_mole.py new file mode 100644 index 0000000..01bbea6 --- /dev/null +++ b/test/mole_predict_singal_mole.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +MolE 抗菌活性预测 Python API 示例 + +演示两种预测模式: +1. 聚合结果模式(默认) +2. 菌株级别预测模式 +""" + +import sys +from pathlib import Path + +# 添加项目根目录到 Python 路径(使用 pathlib) +project_root = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(project_root)) + +from models import ( + ParallelBroadSpectrumPredictor, + PredictionConfig, + MoleculeInput +) + + +def print_separator(title): + """打印分隔线""" + print("\n" + "=" * 70) + print(f" {title}") + print("=" * 70 + "\n") + + +def demo_aggregated_mode(): + """演示聚合结果模式(默认)""" + print_separator("模式 1: 聚合结果模式(默认)") + + # 创建配置 + config = PredictionConfig( + batch_size=10, + device="auto" # 自动检测 CUDA + ) + + # 创建预测器 + predictor = ParallelBroadSpectrumPredictor(config) + + # 准备测试分子 + test_molecules = [ + MoleculeInput(smiles="CCO", chem_id="ethanol"), + MoleculeInput(smiles="c1ccccc1", chem_id="benzene"), + MoleculeInput(smiles="CC(=O)O", chem_id="acetic_acid"), + ] + + print(f"测试分子数: {len(test_molecules)}") + print(f"SMILES 示例: {test_molecules[0].smiles}") + + # 执行预测(不包含菌株级别预测) + print("\n开始预测(聚合模式)...") + results = predictor.predict_batch(test_molecules, include_strain_predictions=False) + + # 打印结果 + print(f"\n预测完成!共 {len(results)} 个结果\n") + + for result in results: + print(f"化合物ID: {result.chem_id}") + print(f" - 广谱抗菌: {'是' if result.broad_spectrum else '否'}") + print(f" - 总体抗菌得分: {result.apscore_total:.4f}") + print(f" - 革兰阴性菌得分: {result.apscore_gnegative:.4f}") + print(f" - 革兰阳性菌得分: {result.apscore_gpositive:.4f}") + print(f" - 抑制菌株总数: {result.ginhib_total} / 40") + print(f" - 抑制革兰阴性菌数: {result.ginhib_gnegative}") + print(f" - 抑制革兰阳性菌数: {result.ginhib_gpositive}") + print(f" - strain_predictions: {result.strain_predictions}") # 应该是 None + print() + + # 返回结果供后续使用 + return results + + +def demo_strain_level_mode(): + """演示菌株级别预测模式""" + print_separator("模式 2: 菌株级别预测模式") + + # 创建预测器 + predictor = ParallelBroadSpectrumPredictor() + + # 使用一个有趣的抗菌分子进行测试(氟喹诺酮类似物) + test_molecule = MoleculeInput( + smiles="FC1=CC=C(CN2C[C@@H]3C[C@H]2CN3C2CC2)N=C1", + chem_id="test_antibacterial" + ) + + print(f"测试分子: {test_molecule.chem_id}") + print(f"SMILES: {test_molecule.smiles}") + + # 执行预测(包含菌株级别预测) + print("\n开始预测(菌株级别模式)...") + results = predictor.predict_batch([test_molecule], include_strain_predictions=True) + result = results[0] + + # 1. 打印聚合结果 + print(f"\n聚合结果:") + print(f" - 化合物ID: {result.chem_id}") + print(f" - 广谱抗菌: {'是' if result.broad_spectrum else '否'}") + print(f" - 总体抗菌得分: {result.apscore_total:.4f}") + print(f" - 抑制菌株总数: {result.ginhib_total} / 40") + + # 2. 打印菌株级别预测数据结构 + print(f"\n菌株级别预测数据:") + if result.strain_predictions is not None: + strain_df = result.strain_predictions + print(f" - 数据类型: {type(strain_df)}") + print(f" - 数据形状: {strain_df.shape} (行, 列)") + print(f" - 列名: {list(strain_df.columns)}") + print(f" - 内存占用: {strain_df.memory_usage(deep=True).sum() / 1024:.2f} KB") + + # 3. 展示前 5 个菌株的预测 + print(f"\n前 5 个菌株的预测:") + print(strain_df.head(5).to_string(index=False)) + + # 4. 统计信息 + print(f"\n统计信息:") + print(f" - 预测概率范围: [{strain_df['antimicrobial_predictive_probability'].min():.6f}, " + f"{strain_df['antimicrobial_predictive_probability'].max():.6f}]") + print(f" - 预测概率平均值: {strain_df['antimicrobial_predictive_probability'].mean():.6f}") + print(f" - 被抑制菌株数: {strain_df['growth_inhibition'].sum()}") + print(f" - 革兰阴性菌数: {len(strain_df[strain_df['gram_stain'] == 'negative'])}") + print(f" - 革兰阳性菌数: {len(strain_df[strain_df['gram_stain'] == 'positive'])}") + + # 5. 展示被抑制的菌株(如果有) + inhibited = strain_df[strain_df['growth_inhibition'] == 1] + if len(inhibited) > 0: + print(f"\n被抑制的菌株 ({len(inhibited)} 个):") + print(inhibited[['strain_name', 'antimicrobial_predictive_probability', 'gram_stain']].to_string(index=False)) + else: + print(f"\n该分子未预测抑制任何菌株") + + # 6. 强化学习应用示例 + print(f"\n强化学习应用示例:") + + # 提取预测概率作为状态向量 + state_vector = strain_df['antimicrobial_predictive_probability'].values + print(f" - 状态向量形状: {state_vector.shape}") + print(f" - 状态向量类型: {type(state_vector)}") + print(f" - 前 10 个值: {state_vector[:10]}") + + # 提取多维特征 + state_features = strain_df[[ + 'antimicrobial_predictive_probability', + 'growth_inhibition' + ]].values + print(f" - 多维特征形状: {state_features.shape}") + + # 按革兰染色分组 + gram_negative_probs = strain_df[ + strain_df['gram_stain'] == 'negative' + ]['antimicrobial_predictive_probability'].values + print(f" - 革兰阴性菌概率向量形状: {gram_negative_probs.shape}") + + # 7. 转换为类型安全的列表(可选) + print(f"\n转换为 StrainPrediction 列表:") + strain_list = result.to_strain_predictions_list() + print(f" - 列表长度: {len(strain_list)}") + print(f" - 元素类型: {type(strain_list[0])}") + print(f" - 第一个元素:") + first_strain = strain_list[0] + print(f" * pred_id: {first_strain.pred_id}") + print(f" * strain_name: {first_strain.strain_name}") + print(f" * antimicrobial_predictive_probability: {first_strain.antimicrobial_predictive_probability:.6f}") + print(f" * growth_inhibition: {first_strain.growth_inhibition}") + print(f" * gram_stain: {first_strain.gram_stain}") + else: + print(" 警告: strain_predictions 为 None(未启用菌株级别预测)") + + return result + + +def main(): + """主函数""" + print("\n" + "🧪" * 35) + print(" MolE 抗菌活性预测 Python API 示例") + print("🧪" * 35) + + # 演示模式 1: 聚合结果 + aggregated_results = demo_aggregated_mode() + + # 演示模式 2: 菌株级别预测 + strain_level_result = demo_strain_level_mode() + + # 总结 + print_separator("总结") + print("✅ 模式 1 (聚合结果): 适合快速筛选大量分子") + print(" - 每个分子返回 1 个 BroadSpectrumResult 对象") + print(" - 包含 8 个聚合指标") + print(" - strain_predictions = None") + print() + print("✅ 模式 2 (菌株级别): 适合详细分析和强化学习") + print(" - 每个分子返回 1 个 BroadSpectrumResult 对象") + print(" - 包含 8 个聚合指标 + 40 行菌株预测数据") + print(" - strain_predictions = DataFrame (40 rows × 7 columns)") + print(" - 可直接提取为 numpy array 用于 RL") + print() + print("🎯 推荐使用场景:") + print(" - 初筛: 模式 1") + print(" - 详细分析/RL 训练: 模式 2") + print() + + +if __name__ == "__main__": + main() diff --git a/utils/mole_predictor.py b/utils/mole_predictor.py index a83cc48..7f02ed6 100644 --- a/utils/mole_predictor.py +++ b/utils/mole_predictor.py @@ -49,7 +49,8 @@ def predict_csv_file( batch_size: int = 100, n_workers: Optional[int] = None, device: str = "auto", - add_suffix: bool = True + add_suffix: bool = True, + include_strain_predictions: bool = False ) -> pd.DataFrame: """ 预测 CSV 文件中的分子抗菌活性 @@ -63,6 +64,7 @@ def predict_csv_file( n_workers: 工作进程数 device: 计算设备 ("auto", "cpu", "cuda:0" 等) add_suffix: 是否在输出文件名后添加预测后缀 + include_strain_predictions: 是否在输出中包含40种菌株的预测详情 Returns: 包含预测结果的 DataFrame @@ -118,7 +120,7 @@ def predict_csv_file( # 执行预测 print("开始预测...") - results = predictor.predict_batch(molecules) + results = predictor.predict_batch(molecules, include_strain_predictions=include_strain_predictions) # 转换结果为 DataFrame results_dicts = [r.to_dict() for r in results] @@ -136,6 +138,36 @@ def predict_csv_file( ) df_output = df_output.drop(columns=['_merge_id']) + # 如果包含菌株级别预测,将其添加到输出中 + if include_strain_predictions: + print("合并菌株级别预测数据...") + # 收集所有菌株级别预测 + all_strain_predictions = [] + for result in results: + if result.strain_predictions is not None and not result.strain_predictions.empty: + all_strain_predictions.append(result.strain_predictions) + + if all_strain_predictions: + # 合并所有菌株预测 + df_strain_predictions = pd.concat(all_strain_predictions, ignore_index=True) + + # 将聚合结果和菌株预测合并 + # 为了在同一个 CSV 中展示,我们使用重复行的方式 + # 每个分子的聚合结果会重复40次(每个菌株一次) + df_output_expanded = df_output.merge( + df_strain_predictions, + left_on=df_output.columns[df_output.columns.get_loc(id_col_actual)], + right_on='chem_id', + how='left', + suffixes=('', '_strain') + ) + + # 移除重复的 chem_id 列 + if 'chem_id_strain' in df_output_expanded.columns: + df_output_expanded = df_output_expanded.drop(columns=['chem_id_strain']) + + df_output = df_output_expanded + # 生成输出路径 if output_path is None: if add_suffix: @@ -164,7 +196,8 @@ def predict_multiple_files( batch_size: int = 100, n_workers: Optional[int] = None, device: str = "auto", - add_suffix: bool = True + add_suffix: bool = True, + include_strain_predictions: bool = False ) -> List[pd.DataFrame]: """ 批量预测多个 CSV 文件 @@ -178,6 +211,7 @@ def predict_multiple_files( n_workers: 工作进程数 device: 计算设备 add_suffix: 是否在输出文件名后添加预测后缀 + include_strain_predictions: 是否在输出中包含40种菌株的预测详情 Returns: 包含预测结果的 DataFrame 列表 @@ -210,7 +244,8 @@ def predict_multiple_files( batch_size=batch_size, n_workers=n_workers, device=device, - add_suffix=add_suffix + add_suffix=add_suffix, + include_strain_predictions=include_strain_predictions ) results.append(df_result) except Exception as e: @@ -240,7 +275,9 @@ def predict_multiple_files( help='计算设备 (默认: auto)') @click.option('--add-suffix/--no-add-suffix', default=True, help='是否在输出文件名后添加 "_predicted" 后缀 (默认: 添加)') -def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers, device, add_suffix): +@click.option('--include-strain-predictions', is_flag=True, default=False, + help='在输出中包含40种菌株的详细预测数据(每个分子将产生40行数据,对应每个菌株的预测概率和抑制情况)') +def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers, device, add_suffix, include_strain_predictions): """ 使用 MolE 模型预测小分子 SMILES 的抗菌活性 @@ -248,12 +285,21 @@ def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers OUTPUT_PATH: 输出 CSV 文件路径 (可选,默认在原文件目录生成) + 默认输出包含聚合的抗菌活性指标(广谱抗菌评分、抑制菌株数等)。 + 使用 --include-strain-predictions 可以额外包含每个菌株的详细预测数据。 + 示例: + # 基本用法(仅输出聚合结果) python mole_predictor.py input.csv output.csv - python mole_predictor.py input.csv -s SMILES -i ID + # 包含40种菌株的详细预测数据 + python mole_predictor.py input.csv output.csv --include-strain-predictions + # 指定列名和设备 + python mole_predictor.py input.csv -s SMILES -i ID --device cuda:0 + + # 自定义批处理大小 python mole_predictor.py input.csv --device cuda:0 --batch-size 200 """ @@ -266,7 +312,8 @@ def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers batch_size=batch_size, n_workers=n_workers, device=device, - add_suffix=add_suffix + add_suffix=add_suffix, + include_strain_predictions=include_strain_predictions ) except Exception as e: click.echo(f"错误: {e}", err=True)