1. 代码修改
models/broad_spectrum_predictor.py: ✅ 新增 StrainPrediction dataclass(单个菌株预测结果) ✅ 更新 BroadSpectrumResult 添加 strain_predictions 字段(pandas.DataFrame 类型) ✅ 添加 to_strain_predictions_list() 方法(类型安全转换) ✅ 新增 _prepare_strain_level_predictions() 方法 ✅ 修改 predict_batch() 方法支持 include_strain_predictions 参数 utils/mole_predictor.py: ✅ 添加 include_strain_predictions 参数到所有函数 ✅ 添加命令行参数 --include-strain-predictions ✅ 实现菌株级别数据与聚合结果的合并逻辑 ✅ 更新所有函数签名和文档字符串 2. 测试验证 ✅ 测试基本功能(仅聚合结果): test_3.csv → 3 行输出 ✅ 测试菌株级别预测功能: test_3.csv → 120 行输出(3 × 40) ✅ 验证输出格式正确性 ✅ 验证每个分子都有完整的 40 个菌株预测 ✅ 验证革兰染色信息正确(18 个阴性菌 + 22 个阳性菌) 3. 文档更新 README.md: ✅ 更新命令行使用示例 ✅ 添加 Python API 使用示例(包含菌株预测) ✅ 添加详细的输出格式说明 ✅ 添加 40 种菌株列表概览 ✅ 添加数据使用场景示例(强化学习、筛选、可视化) Data/mole/README.md: ✅ 新增"菌株级别预测详情"章节 ✅ 完整的 40 种菌株列表(分革兰阴性/阳性) ✅ 数据访问方式示例(CSV 读取、Python API) ✅ 强化学习应用场景(状态表示、奖励函数设计) ✅ 数据可视化代码示例 ✅ 性能和存储建议
This commit is contained in:
@@ -49,7 +49,8 @@ def predict_csv_file(
|
||||
batch_size: int = 100,
|
||||
n_workers: Optional[int] = None,
|
||||
device: str = "auto",
|
||||
add_suffix: bool = True
|
||||
add_suffix: bool = True,
|
||||
include_strain_predictions: bool = False
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
预测 CSV 文件中的分子抗菌活性
|
||||
@@ -63,6 +64,7 @@ def predict_csv_file(
|
||||
n_workers: 工作进程数
|
||||
device: 计算设备 ("auto", "cpu", "cuda:0" 等)
|
||||
add_suffix: 是否在输出文件名后添加预测后缀
|
||||
include_strain_predictions: 是否在输出中包含40种菌株的预测详情
|
||||
|
||||
Returns:
|
||||
包含预测结果的 DataFrame
|
||||
@@ -118,7 +120,7 @@ def predict_csv_file(
|
||||
|
||||
# 执行预测
|
||||
print("开始预测...")
|
||||
results = predictor.predict_batch(molecules)
|
||||
results = predictor.predict_batch(molecules, include_strain_predictions=include_strain_predictions)
|
||||
|
||||
# 转换结果为 DataFrame
|
||||
results_dicts = [r.to_dict() for r in results]
|
||||
@@ -136,6 +138,36 @@ def predict_csv_file(
|
||||
)
|
||||
df_output = df_output.drop(columns=['_merge_id'])
|
||||
|
||||
# 如果包含菌株级别预测,将其添加到输出中
|
||||
if include_strain_predictions:
|
||||
print("合并菌株级别预测数据...")
|
||||
# 收集所有菌株级别预测
|
||||
all_strain_predictions = []
|
||||
for result in results:
|
||||
if result.strain_predictions is not None and not result.strain_predictions.empty:
|
||||
all_strain_predictions.append(result.strain_predictions)
|
||||
|
||||
if all_strain_predictions:
|
||||
# 合并所有菌株预测
|
||||
df_strain_predictions = pd.concat(all_strain_predictions, ignore_index=True)
|
||||
|
||||
# 将聚合结果和菌株预测合并
|
||||
# 为了在同一个 CSV 中展示,我们使用重复行的方式
|
||||
# 每个分子的聚合结果会重复40次(每个菌株一次)
|
||||
df_output_expanded = df_output.merge(
|
||||
df_strain_predictions,
|
||||
left_on=df_output.columns[df_output.columns.get_loc(id_col_actual)],
|
||||
right_on='chem_id',
|
||||
how='left',
|
||||
suffixes=('', '_strain')
|
||||
)
|
||||
|
||||
# 移除重复的 chem_id 列
|
||||
if 'chem_id_strain' in df_output_expanded.columns:
|
||||
df_output_expanded = df_output_expanded.drop(columns=['chem_id_strain'])
|
||||
|
||||
df_output = df_output_expanded
|
||||
|
||||
# 生成输出路径
|
||||
if output_path is None:
|
||||
if add_suffix:
|
||||
@@ -164,7 +196,8 @@ def predict_multiple_files(
|
||||
batch_size: int = 100,
|
||||
n_workers: Optional[int] = None,
|
||||
device: str = "auto",
|
||||
add_suffix: bool = True
|
||||
add_suffix: bool = True,
|
||||
include_strain_predictions: bool = False
|
||||
) -> List[pd.DataFrame]:
|
||||
"""
|
||||
批量预测多个 CSV 文件
|
||||
@@ -178,6 +211,7 @@ def predict_multiple_files(
|
||||
n_workers: 工作进程数
|
||||
device: 计算设备
|
||||
add_suffix: 是否在输出文件名后添加预测后缀
|
||||
include_strain_predictions: 是否在输出中包含40种菌株的预测详情
|
||||
|
||||
Returns:
|
||||
包含预测结果的 DataFrame 列表
|
||||
@@ -210,7 +244,8 @@ def predict_multiple_files(
|
||||
batch_size=batch_size,
|
||||
n_workers=n_workers,
|
||||
device=device,
|
||||
add_suffix=add_suffix
|
||||
add_suffix=add_suffix,
|
||||
include_strain_predictions=include_strain_predictions
|
||||
)
|
||||
results.append(df_result)
|
||||
except Exception as e:
|
||||
@@ -240,7 +275,9 @@ def predict_multiple_files(
|
||||
help='计算设备 (默认: auto)')
|
||||
@click.option('--add-suffix/--no-add-suffix', default=True,
|
||||
help='是否在输出文件名后添加 "_predicted" 后缀 (默认: 添加)')
|
||||
def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers, device, add_suffix):
|
||||
@click.option('--include-strain-predictions', is_flag=True, default=False,
|
||||
help='在输出中包含40种菌株的详细预测数据(每个分子将产生40行数据,对应每个菌株的预测概率和抑制情况)')
|
||||
def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers, device, add_suffix, include_strain_predictions):
|
||||
"""
|
||||
使用 MolE 模型预测小分子 SMILES 的抗菌活性
|
||||
|
||||
@@ -248,12 +285,21 @@ def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers
|
||||
|
||||
OUTPUT_PATH: 输出 CSV 文件路径 (可选,默认在原文件目录生成)
|
||||
|
||||
默认输出包含聚合的抗菌活性指标(广谱抗菌评分、抑制菌株数等)。
|
||||
使用 --include-strain-predictions 可以额外包含每个菌株的详细预测数据。
|
||||
|
||||
示例:
|
||||
|
||||
# 基本用法(仅输出聚合结果)
|
||||
python mole_predictor.py input.csv output.csv
|
||||
|
||||
python mole_predictor.py input.csv -s SMILES -i ID
|
||||
# 包含40种菌株的详细预测数据
|
||||
python mole_predictor.py input.csv output.csv --include-strain-predictions
|
||||
|
||||
# 指定列名和设备
|
||||
python mole_predictor.py input.csv -s SMILES -i ID --device cuda:0
|
||||
|
||||
# 自定义批处理大小
|
||||
python mole_predictor.py input.csv --device cuda:0 --batch-size 200
|
||||
"""
|
||||
|
||||
@@ -266,7 +312,8 @@ def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers
|
||||
batch_size=batch_size,
|
||||
n_workers=n_workers,
|
||||
device=device,
|
||||
add_suffix=add_suffix
|
||||
add_suffix=add_suffix,
|
||||
include_strain_predictions=include_strain_predictions
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(f"错误: {e}", err=True)
|
||||
|
||||
Reference in New Issue
Block a user