models/broad_spectrum_predictor.py: ✅ 新增 StrainPrediction dataclass(单个菌株预测结果) ✅ 更新 BroadSpectrumResult 添加 strain_predictions 字段(pandas.DataFrame 类型) ✅ 添加 to_strain_predictions_list() 方法(类型安全转换) ✅ 新增 _prepare_strain_level_predictions() 方法 ✅ 修改 predict_batch() 方法支持 include_strain_predictions 参数 utils/mole_predictor.py: ✅ 添加 include_strain_predictions 参数到所有函数 ✅ 添加命令行参数 --include-strain-predictions ✅ 实现菌株级别数据与聚合结果的合并逻辑 ✅ 更新所有函数签名和文档字符串 2. 测试验证 ✅ 测试基本功能(仅聚合结果): test_3.csv → 3 行输出 ✅ 测试菌株级别预测功能: test_3.csv → 120 行输出(3 × 40) ✅ 验证输出格式正确性 ✅ 验证每个分子都有完整的 40 个菌株预测 ✅ 验证革兰染色信息正确(18 个阴性菌 + 22 个阳性菌) 3. 文档更新 README.md: ✅ 更新命令行使用示例 ✅ 添加 Python API 使用示例(包含菌株预测) ✅ 添加详细的输出格式说明 ✅ 添加 40 种菌株列表概览 ✅ 添加数据使用场景示例(强化学习、筛选、可视化) Data/mole/README.md: ✅ 新增"菌株级别预测详情"章节 ✅ 完整的 40 种菌株列表(分革兰阴性/阳性) ✅ 数据访问方式示例(CSV 读取、Python API) ✅ 强化学习应用场景(状态表示、奖励函数设计) ✅ 数据可视化代码示例 ✅ 性能和存储建议
326 lines
11 KiB
Python
326 lines
11 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
MolE 抗菌活性预测工具
|
||
|
||
这个脚本提供了使用 MolE 模型预测小分子 SMILES 抗菌活性的功能。
|
||
支持命令行和 Python API 调用两种方式。
|
||
|
||
命令行示例:
|
||
python mole_predictor.py input.csv output.csv --smiles-column smiles --id-column chem_id
|
||
|
||
Python API 示例:
|
||
from utils.mole_predictor import predict_csv_file
|
||
|
||
predict_csv_file(
|
||
input_path="input.csv",
|
||
output_path="output.csv",
|
||
smiles_column="smiles",
|
||
id_column="chem_id"
|
||
)
|
||
"""
|
||
|
||
import sys
|
||
import os
|
||
from pathlib import Path
|
||
|
||
# 添加项目根目录到 Python 路径
|
||
project_root = Path(__file__).parent.parent
|
||
sys.path.insert(0, str(project_root))
|
||
|
||
import click
|
||
import pandas as pd
|
||
from typing import Optional, List
|
||
from datetime import datetime
|
||
|
||
from models.broad_spectrum_predictor import (
|
||
ParallelBroadSpectrumPredictor,
|
||
PredictionConfig,
|
||
MoleculeInput,
|
||
BroadSpectrumResult
|
||
)
|
||
|
||
|
||
def predict_csv_file(
|
||
input_path: str,
|
||
output_path: Optional[str] = None,
|
||
smiles_column: str = "smiles",
|
||
id_column: str = "chem_id",
|
||
batch_size: int = 100,
|
||
n_workers: Optional[int] = None,
|
||
device: str = "auto",
|
||
add_suffix: bool = True,
|
||
include_strain_predictions: bool = False
|
||
) -> pd.DataFrame:
|
||
"""
|
||
预测 CSV 文件中的分子抗菌活性
|
||
|
||
Args:
|
||
input_path: 输入 CSV 文件路径
|
||
output_path: 输出 CSV 文件路径,如果为 None 则自动生成
|
||
smiles_column: SMILES 列名
|
||
id_column: 化合物 ID 列名
|
||
batch_size: 批处理大小
|
||
n_workers: 工作进程数
|
||
device: 计算设备 ("auto", "cpu", "cuda:0" 等)
|
||
add_suffix: 是否在输出文件名后添加预测后缀
|
||
include_strain_predictions: 是否在输出中包含40种菌株的预测详情
|
||
|
||
Returns:
|
||
包含预测结果的 DataFrame
|
||
"""
|
||
|
||
print(f"开始处理文件: {input_path}")
|
||
|
||
# 读取输入文件
|
||
input_path_obj = Path(input_path)
|
||
if not input_path_obj.exists():
|
||
raise FileNotFoundError(f"输入文件不存在: {input_path}")
|
||
|
||
# 读取 CSV
|
||
try:
|
||
df_input = pd.read_csv(input_path)
|
||
except Exception as e:
|
||
raise RuntimeError(f"读取 CSV 文件失败: {e}")
|
||
|
||
print(f"读取了 {len(df_input)} 条数据")
|
||
|
||
# 检查列是否存在(大小写不敏感)
|
||
columns_lower = {col.lower(): col for col in df_input.columns}
|
||
|
||
smiles_col_actual = columns_lower.get(smiles_column.lower())
|
||
if smiles_col_actual is None:
|
||
raise ValueError(
|
||
f"未找到 SMILES 列 '{smiles_column}'。可用列: {list(df_input.columns)}"
|
||
)
|
||
|
||
# 处理 ID 列
|
||
id_col_actual = columns_lower.get(id_column.lower())
|
||
if id_col_actual is None:
|
||
print(f"未找到 ID 列 '{id_column}',将自动生成 ID")
|
||
df_input[id_column] = [f"mol{i+1}" for i in range(len(df_input))]
|
||
id_col_actual = id_column
|
||
|
||
# 创建预测器配置
|
||
config = PredictionConfig(
|
||
batch_size=batch_size,
|
||
n_workers=n_workers,
|
||
device=device
|
||
)
|
||
|
||
# 初始化预测器
|
||
print("初始化预测器...")
|
||
predictor = ParallelBroadSpectrumPredictor(config)
|
||
|
||
# 准备分子输入
|
||
molecules = [
|
||
MoleculeInput(smiles=row[smiles_col_actual], chem_id=str(row[id_col_actual]))
|
||
for _, row in df_input.iterrows()
|
||
]
|
||
|
||
# 执行预测
|
||
print("开始预测...")
|
||
results = predictor.predict_batch(molecules, include_strain_predictions=include_strain_predictions)
|
||
|
||
# 转换结果为 DataFrame
|
||
results_dicts = [r.to_dict() for r in results]
|
||
df_results = pd.DataFrame(results_dicts)
|
||
|
||
# 合并原始数据和预测结果
|
||
# 使用 chem_id 作为键进行合并
|
||
df_input['_merge_id'] = df_input[id_col_actual].astype(str)
|
||
df_results['_merge_id'] = df_results['chem_id'].astype(str)
|
||
|
||
df_output = df_input.merge(
|
||
df_results.drop(columns=['chem_id']),
|
||
on='_merge_id',
|
||
how='left'
|
||
)
|
||
df_output = df_output.drop(columns=['_merge_id'])
|
||
|
||
# 如果包含菌株级别预测,将其添加到输出中
|
||
if include_strain_predictions:
|
||
print("合并菌株级别预测数据...")
|
||
# 收集所有菌株级别预测
|
||
all_strain_predictions = []
|
||
for result in results:
|
||
if result.strain_predictions is not None and not result.strain_predictions.empty:
|
||
all_strain_predictions.append(result.strain_predictions)
|
||
|
||
if all_strain_predictions:
|
||
# 合并所有菌株预测
|
||
df_strain_predictions = pd.concat(all_strain_predictions, ignore_index=True)
|
||
|
||
# 将聚合结果和菌株预测合并
|
||
# 为了在同一个 CSV 中展示,我们使用重复行的方式
|
||
# 每个分子的聚合结果会重复40次(每个菌株一次)
|
||
df_output_expanded = df_output.merge(
|
||
df_strain_predictions,
|
||
left_on=df_output.columns[df_output.columns.get_loc(id_col_actual)],
|
||
right_on='chem_id',
|
||
how='left',
|
||
suffixes=('', '_strain')
|
||
)
|
||
|
||
# 移除重复的 chem_id 列
|
||
if 'chem_id_strain' in df_output_expanded.columns:
|
||
df_output_expanded = df_output_expanded.drop(columns=['chem_id_strain'])
|
||
|
||
df_output = df_output_expanded
|
||
|
||
# 生成输出路径
|
||
if output_path is None:
|
||
if add_suffix:
|
||
output_path = str(input_path_obj.parent / f"{input_path_obj.stem}_predicted{input_path_obj.suffix}")
|
||
else:
|
||
output_path = str(input_path_obj.parent / f"{input_path_obj.stem}_output{input_path_obj.suffix}")
|
||
elif add_suffix:
|
||
output_path_obj = Path(output_path)
|
||
output_path = str(output_path_obj.parent / f"{output_path_obj.stem}_predicted{output_path_obj.suffix}")
|
||
|
||
# 保存结果
|
||
print(f"保存结果到: {output_path}")
|
||
df_output.to_csv(output_path, index=False)
|
||
|
||
print(f"完成! 预测了 {len(results)} 个分子")
|
||
print(f"其中 {sum(r.broad_spectrum for r in results)} 个分子被预测为广谱抗菌")
|
||
|
||
return df_output
|
||
|
||
|
||
def predict_multiple_files(
|
||
input_paths: List[str],
|
||
output_dir: Optional[str] = None,
|
||
smiles_column: str = "smiles",
|
||
id_column: str = "chem_id",
|
||
batch_size: int = 100,
|
||
n_workers: Optional[int] = None,
|
||
device: str = "auto",
|
||
add_suffix: bool = True,
|
||
include_strain_predictions: bool = False
|
||
) -> List[pd.DataFrame]:
|
||
"""
|
||
批量预测多个 CSV 文件
|
||
|
||
Args:
|
||
input_paths: 输入 CSV 文件路径列表
|
||
output_dir: 输出目录,如果为 None 则在原文件目录生成
|
||
smiles_column: SMILES 列名
|
||
id_column: 化合物 ID 列名
|
||
batch_size: 批处理大小
|
||
n_workers: 工作进程数
|
||
device: 计算设备
|
||
add_suffix: 是否在输出文件名后添加预测后缀
|
||
include_strain_predictions: 是否在输出中包含40种菌株的预测详情
|
||
|
||
Returns:
|
||
包含预测结果的 DataFrame 列表
|
||
"""
|
||
|
||
results = []
|
||
|
||
for input_path in input_paths:
|
||
input_path_obj = Path(input_path)
|
||
|
||
# 确定输出路径
|
||
if output_dir is not None:
|
||
output_dir_obj = Path(output_dir)
|
||
output_dir_obj.mkdir(parents=True, exist_ok=True)
|
||
|
||
if add_suffix:
|
||
output_path = str(output_dir_obj / f"{input_path_obj.stem}_predicted{input_path_obj.suffix}")
|
||
else:
|
||
output_path = str(output_dir_obj / input_path_obj.name)
|
||
else:
|
||
output_path = None
|
||
|
||
# 预测单个文件
|
||
try:
|
||
df_result = predict_csv_file(
|
||
input_path=input_path,
|
||
output_path=output_path,
|
||
smiles_column=smiles_column,
|
||
id_column=id_column,
|
||
batch_size=batch_size,
|
||
n_workers=n_workers,
|
||
device=device,
|
||
add_suffix=add_suffix,
|
||
include_strain_predictions=include_strain_predictions
|
||
)
|
||
results.append(df_result)
|
||
except Exception as e:
|
||
print(f"处理文件 {input_path} 时出错: {e}")
|
||
continue
|
||
|
||
return results
|
||
|
||
|
||
# ============================================================================
|
||
# 命令行接口
|
||
# ============================================================================
|
||
|
||
@click.command()
|
||
@click.argument('input_path', type=click.Path(exists=True))
|
||
@click.argument('output_path', type=click.Path(), required=False)
|
||
@click.option('--smiles-column', '-s', default='smiles',
|
||
help='SMILES 列名 (默认: smiles)')
|
||
@click.option('--id-column', '-i', default='chem_id',
|
||
help='化合物 ID 列名 (默认: chem_id)')
|
||
@click.option('--batch-size', '-b', default=100, type=int,
|
||
help='批处理大小 (默认: 100)')
|
||
@click.option('--n-workers', '-w', default=None, type=int,
|
||
help='工作进程数 (默认: CPU 核心数)')
|
||
@click.option('--device', '-d', default='auto',
|
||
type=click.Choice(['auto', 'cpu', 'cuda:0', 'cuda:1'], case_sensitive=False),
|
||
help='计算设备 (默认: auto)')
|
||
@click.option('--add-suffix/--no-add-suffix', default=True,
|
||
help='是否在输出文件名后添加 "_predicted" 后缀 (默认: 添加)')
|
||
@click.option('--include-strain-predictions', is_flag=True, default=False,
|
||
help='在输出中包含40种菌株的详细预测数据(每个分子将产生40行数据,对应每个菌株的预测概率和抑制情况)')
|
||
def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers, device, add_suffix, include_strain_predictions):
|
||
"""
|
||
使用 MolE 模型预测小分子 SMILES 的抗菌活性
|
||
|
||
INPUT_PATH: 输入 CSV 文件路径
|
||
|
||
OUTPUT_PATH: 输出 CSV 文件路径 (可选,默认在原文件目录生成)
|
||
|
||
默认输出包含聚合的抗菌活性指标(广谱抗菌评分、抑制菌株数等)。
|
||
使用 --include-strain-predictions 可以额外包含每个菌株的详细预测数据。
|
||
|
||
示例:
|
||
|
||
# 基本用法(仅输出聚合结果)
|
||
python mole_predictor.py input.csv output.csv
|
||
|
||
# 包含40种菌株的详细预测数据
|
||
python mole_predictor.py input.csv output.csv --include-strain-predictions
|
||
|
||
# 指定列名和设备
|
||
python mole_predictor.py input.csv -s SMILES -i ID --device cuda:0
|
||
|
||
# 自定义批处理大小
|
||
python mole_predictor.py input.csv --device cuda:0 --batch-size 200
|
||
"""
|
||
|
||
try:
|
||
predict_csv_file(
|
||
input_path=input_path,
|
||
output_path=output_path,
|
||
smiles_column=smiles_column,
|
||
id_column=id_column,
|
||
batch_size=batch_size,
|
||
n_workers=n_workers,
|
||
device=device,
|
||
add_suffix=add_suffix,
|
||
include_strain_predictions=include_strain_predictions
|
||
)
|
||
except Exception as e:
|
||
click.echo(f"错误: {e}", err=True)
|
||
sys.exit(1)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
cli()
|
||
|