Files
SIME/utils/mole_predictor.py
hotwa 34102cf459 1. 代码修改
models/broad_spectrum_predictor.py:
 新增 StrainPrediction dataclass(单个菌株预测结果)
 更新 BroadSpectrumResult 添加 strain_predictions 字段(pandas.DataFrame 类型)
 添加 to_strain_predictions_list() 方法(类型安全转换)
 新增 _prepare_strain_level_predictions() 方法
 修改 predict_batch() 方法支持 include_strain_predictions 参数
utils/mole_predictor.py:
 添加 include_strain_predictions 参数到所有函数
 添加命令行参数 --include-strain-predictions
 实现菌株级别数据与聚合结果的合并逻辑
 更新所有函数签名和文档字符串
2. 测试验证
 测试基本功能(仅聚合结果): test_3.csv → 3 行输出
 测试菌株级别预测功能: test_3.csv → 120 行输出(3 × 40)
 验证输出格式正确性
 验证每个分子都有完整的 40 个菌株预测
 验证革兰染色信息正确(18 个阴性菌 + 22 个阳性菌)
3. 文档更新
README.md:
 更新命令行使用示例
 添加 Python API 使用示例(包含菌株预测)
 添加详细的输出格式说明
 添加 40 种菌株列表概览
 添加数据使用场景示例(强化学习、筛选、可视化)
Data/mole/README.md:
 新增"菌株级别预测详情"章节
 完整的 40 种菌株列表(分革兰阴性/阳性)
 数据访问方式示例(CSV 读取、Python API)
 强化学习应用场景(状态表示、奖励函数设计)
 数据可视化代码示例
 性能和存储建议
2025-10-17 16:46:04 +08:00

326 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
MolE 抗菌活性预测工具
这个脚本提供了使用 MolE 模型预测小分子 SMILES 抗菌活性的功能。
支持命令行和 Python API 调用两种方式。
命令行示例:
python mole_predictor.py input.csv output.csv --smiles-column smiles --id-column chem_id
Python API 示例:
from utils.mole_predictor import predict_csv_file
predict_csv_file(
input_path="input.csv",
output_path="output.csv",
smiles_column="smiles",
id_column="chem_id"
)
"""
import sys
import os
from pathlib import Path
# 添加项目根目录到 Python 路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
import click
import pandas as pd
from typing import Optional, List
from datetime import datetime
from models.broad_spectrum_predictor import (
ParallelBroadSpectrumPredictor,
PredictionConfig,
MoleculeInput,
BroadSpectrumResult
)
def predict_csv_file(
input_path: str,
output_path: Optional[str] = None,
smiles_column: str = "smiles",
id_column: str = "chem_id",
batch_size: int = 100,
n_workers: Optional[int] = None,
device: str = "auto",
add_suffix: bool = True,
include_strain_predictions: bool = False
) -> pd.DataFrame:
"""
预测 CSV 文件中的分子抗菌活性
Args:
input_path: 输入 CSV 文件路径
output_path: 输出 CSV 文件路径,如果为 None 则自动生成
smiles_column: SMILES 列名
id_column: 化合物 ID 列名
batch_size: 批处理大小
n_workers: 工作进程数
device: 计算设备 ("auto", "cpu", "cuda:0" 等)
add_suffix: 是否在输出文件名后添加预测后缀
include_strain_predictions: 是否在输出中包含40种菌株的预测详情
Returns:
包含预测结果的 DataFrame
"""
print(f"开始处理文件: {input_path}")
# 读取输入文件
input_path_obj = Path(input_path)
if not input_path_obj.exists():
raise FileNotFoundError(f"输入文件不存在: {input_path}")
# 读取 CSV
try:
df_input = pd.read_csv(input_path)
except Exception as e:
raise RuntimeError(f"读取 CSV 文件失败: {e}")
print(f"读取了 {len(df_input)} 条数据")
# 检查列是否存在(大小写不敏感)
columns_lower = {col.lower(): col for col in df_input.columns}
smiles_col_actual = columns_lower.get(smiles_column.lower())
if smiles_col_actual is None:
raise ValueError(
f"未找到 SMILES 列 '{smiles_column}'。可用列: {list(df_input.columns)}"
)
# 处理 ID 列
id_col_actual = columns_lower.get(id_column.lower())
if id_col_actual is None:
print(f"未找到 ID 列 '{id_column}',将自动生成 ID")
df_input[id_column] = [f"mol{i+1}" for i in range(len(df_input))]
id_col_actual = id_column
# 创建预测器配置
config = PredictionConfig(
batch_size=batch_size,
n_workers=n_workers,
device=device
)
# 初始化预测器
print("初始化预测器...")
predictor = ParallelBroadSpectrumPredictor(config)
# 准备分子输入
molecules = [
MoleculeInput(smiles=row[smiles_col_actual], chem_id=str(row[id_col_actual]))
for _, row in df_input.iterrows()
]
# 执行预测
print("开始预测...")
results = predictor.predict_batch(molecules, include_strain_predictions=include_strain_predictions)
# 转换结果为 DataFrame
results_dicts = [r.to_dict() for r in results]
df_results = pd.DataFrame(results_dicts)
# 合并原始数据和预测结果
# 使用 chem_id 作为键进行合并
df_input['_merge_id'] = df_input[id_col_actual].astype(str)
df_results['_merge_id'] = df_results['chem_id'].astype(str)
df_output = df_input.merge(
df_results.drop(columns=['chem_id']),
on='_merge_id',
how='left'
)
df_output = df_output.drop(columns=['_merge_id'])
# 如果包含菌株级别预测,将其添加到输出中
if include_strain_predictions:
print("合并菌株级别预测数据...")
# 收集所有菌株级别预测
all_strain_predictions = []
for result in results:
if result.strain_predictions is not None and not result.strain_predictions.empty:
all_strain_predictions.append(result.strain_predictions)
if all_strain_predictions:
# 合并所有菌株预测
df_strain_predictions = pd.concat(all_strain_predictions, ignore_index=True)
# 将聚合结果和菌株预测合并
# 为了在同一个 CSV 中展示,我们使用重复行的方式
# 每个分子的聚合结果会重复40次每个菌株一次
df_output_expanded = df_output.merge(
df_strain_predictions,
left_on=df_output.columns[df_output.columns.get_loc(id_col_actual)],
right_on='chem_id',
how='left',
suffixes=('', '_strain')
)
# 移除重复的 chem_id 列
if 'chem_id_strain' in df_output_expanded.columns:
df_output_expanded = df_output_expanded.drop(columns=['chem_id_strain'])
df_output = df_output_expanded
# 生成输出路径
if output_path is None:
if add_suffix:
output_path = str(input_path_obj.parent / f"{input_path_obj.stem}_predicted{input_path_obj.suffix}")
else:
output_path = str(input_path_obj.parent / f"{input_path_obj.stem}_output{input_path_obj.suffix}")
elif add_suffix:
output_path_obj = Path(output_path)
output_path = str(output_path_obj.parent / f"{output_path_obj.stem}_predicted{output_path_obj.suffix}")
# 保存结果
print(f"保存结果到: {output_path}")
df_output.to_csv(output_path, index=False)
print(f"完成! 预测了 {len(results)} 个分子")
print(f"其中 {sum(r.broad_spectrum for r in results)} 个分子被预测为广谱抗菌")
return df_output
def predict_multiple_files(
input_paths: List[str],
output_dir: Optional[str] = None,
smiles_column: str = "smiles",
id_column: str = "chem_id",
batch_size: int = 100,
n_workers: Optional[int] = None,
device: str = "auto",
add_suffix: bool = True,
include_strain_predictions: bool = False
) -> List[pd.DataFrame]:
"""
批量预测多个 CSV 文件
Args:
input_paths: 输入 CSV 文件路径列表
output_dir: 输出目录,如果为 None 则在原文件目录生成
smiles_column: SMILES 列名
id_column: 化合物 ID 列名
batch_size: 批处理大小
n_workers: 工作进程数
device: 计算设备
add_suffix: 是否在输出文件名后添加预测后缀
include_strain_predictions: 是否在输出中包含40种菌株的预测详情
Returns:
包含预测结果的 DataFrame 列表
"""
results = []
for input_path in input_paths:
input_path_obj = Path(input_path)
# 确定输出路径
if output_dir is not None:
output_dir_obj = Path(output_dir)
output_dir_obj.mkdir(parents=True, exist_ok=True)
if add_suffix:
output_path = str(output_dir_obj / f"{input_path_obj.stem}_predicted{input_path_obj.suffix}")
else:
output_path = str(output_dir_obj / input_path_obj.name)
else:
output_path = None
# 预测单个文件
try:
df_result = predict_csv_file(
input_path=input_path,
output_path=output_path,
smiles_column=smiles_column,
id_column=id_column,
batch_size=batch_size,
n_workers=n_workers,
device=device,
add_suffix=add_suffix,
include_strain_predictions=include_strain_predictions
)
results.append(df_result)
except Exception as e:
print(f"处理文件 {input_path} 时出错: {e}")
continue
return results
# ============================================================================
# 命令行接口
# ============================================================================
@click.command()
@click.argument('input_path', type=click.Path(exists=True))
@click.argument('output_path', type=click.Path(), required=False)
@click.option('--smiles-column', '-s', default='smiles',
help='SMILES 列名 (默认: smiles)')
@click.option('--id-column', '-i', default='chem_id',
help='化合物 ID 列名 (默认: chem_id)')
@click.option('--batch-size', '-b', default=100, type=int,
help='批处理大小 (默认: 100)')
@click.option('--n-workers', '-w', default=None, type=int,
help='工作进程数 (默认: CPU 核心数)')
@click.option('--device', '-d', default='auto',
type=click.Choice(['auto', 'cpu', 'cuda:0', 'cuda:1'], case_sensitive=False),
help='计算设备 (默认: auto)')
@click.option('--add-suffix/--no-add-suffix', default=True,
help='是否在输出文件名后添加 "_predicted" 后缀 (默认: 添加)')
@click.option('--include-strain-predictions', is_flag=True, default=False,
help='在输出中包含40种菌株的详细预测数据每个分子将产生40行数据对应每个菌株的预测概率和抑制情况')
def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers, device, add_suffix, include_strain_predictions):
"""
使用 MolE 模型预测小分子 SMILES 的抗菌活性
INPUT_PATH: 输入 CSV 文件路径
OUTPUT_PATH: 输出 CSV 文件路径 (可选,默认在原文件目录生成)
默认输出包含聚合的抗菌活性指标(广谱抗菌评分、抑制菌株数等)。
使用 --include-strain-predictions 可以额外包含每个菌株的详细预测数据。
示例:
# 基本用法(仅输出聚合结果)
python mole_predictor.py input.csv output.csv
# 包含40种菌株的详细预测数据
python mole_predictor.py input.csv output.csv --include-strain-predictions
# 指定列名和设备
python mole_predictor.py input.csv -s SMILES -i ID --device cuda:0
# 自定义批处理大小
python mole_predictor.py input.csv --device cuda:0 --batch-size 200
"""
try:
predict_csv_file(
input_path=input_path,
output_path=output_path,
smiles_column=smiles_column,
id_column=id_column,
batch_size=batch_size,
n_workers=n_workers,
device=device,
add_suffix=add_suffix,
include_strain_predictions=include_strain_predictions
)
except Exception as e:
click.echo(f"错误: {e}", err=True)
sys.exit(1)
if __name__ == '__main__':
cli()