1. 代码修改

models/broad_spectrum_predictor.py:
 新增 StrainPrediction dataclass(单个菌株预测结果)
 更新 BroadSpectrumResult 添加 strain_predictions 字段(pandas.DataFrame 类型)
 添加 to_strain_predictions_list() 方法(类型安全转换)
 新增 _prepare_strain_level_predictions() 方法
 修改 predict_batch() 方法支持 include_strain_predictions 参数
utils/mole_predictor.py:
 添加 include_strain_predictions 参数到所有函数
 添加命令行参数 --include-strain-predictions
 实现菌株级别数据与聚合结果的合并逻辑
 更新所有函数签名和文档字符串
2. 测试验证
 测试基本功能(仅聚合结果): test_3.csv → 3 行输出
 测试菌株级别预测功能: test_3.csv → 120 行输出(3 × 40)
 验证输出格式正确性
 验证每个分子都有完整的 40 个菌株预测
 验证革兰染色信息正确(18 个阴性菌 + 22 个阳性菌)
3. 文档更新
README.md:
 更新命令行使用示例
 添加 Python API 使用示例(包含菌株预测)
 添加详细的输出格式说明
 添加 40 种菌株列表概览
 添加数据使用场景示例(强化学习、筛选、可视化)
Data/mole/README.md:
 新增"菌株级别预测详情"章节
 完整的 40 种菌株列表(分革兰阴性/阳性)
 数据访问方式示例(CSV 读取、Python API)
 强化学习应用场景(状态表示、奖励函数设计)
 数据可视化代码示例
 性能和存储建议
This commit is contained in:
2025-10-17 16:46:04 +08:00
parent 62e0f3d6aa
commit 34102cf459
5 changed files with 716 additions and 21 deletions

View File

@@ -49,7 +49,8 @@ def predict_csv_file(
batch_size: int = 100,
n_workers: Optional[int] = None,
device: str = "auto",
add_suffix: bool = True
add_suffix: bool = True,
include_strain_predictions: bool = False
) -> pd.DataFrame:
"""
预测 CSV 文件中的分子抗菌活性
@@ -63,6 +64,7 @@ def predict_csv_file(
n_workers: 工作进程数
device: 计算设备 ("auto", "cpu", "cuda:0" 等)
add_suffix: 是否在输出文件名后添加预测后缀
include_strain_predictions: 是否在输出中包含40种菌株的预测详情
Returns:
包含预测结果的 DataFrame
@@ -118,7 +120,7 @@ def predict_csv_file(
# 执行预测
print("开始预测...")
results = predictor.predict_batch(molecules)
results = predictor.predict_batch(molecules, include_strain_predictions=include_strain_predictions)
# 转换结果为 DataFrame
results_dicts = [r.to_dict() for r in results]
@@ -136,6 +138,36 @@ def predict_csv_file(
)
df_output = df_output.drop(columns=['_merge_id'])
# 如果包含菌株级别预测,将其添加到输出中
if include_strain_predictions:
print("合并菌株级别预测数据...")
# 收集所有菌株级别预测
all_strain_predictions = []
for result in results:
if result.strain_predictions is not None and not result.strain_predictions.empty:
all_strain_predictions.append(result.strain_predictions)
if all_strain_predictions:
# 合并所有菌株预测
df_strain_predictions = pd.concat(all_strain_predictions, ignore_index=True)
# 将聚合结果和菌株预测合并
# 为了在同一个 CSV 中展示,我们使用重复行的方式
# 每个分子的聚合结果会重复40次每个菌株一次
df_output_expanded = df_output.merge(
df_strain_predictions,
left_on=df_output.columns[df_output.columns.get_loc(id_col_actual)],
right_on='chem_id',
how='left',
suffixes=('', '_strain')
)
# 移除重复的 chem_id 列
if 'chem_id_strain' in df_output_expanded.columns:
df_output_expanded = df_output_expanded.drop(columns=['chem_id_strain'])
df_output = df_output_expanded
# 生成输出路径
if output_path is None:
if add_suffix:
@@ -164,7 +196,8 @@ def predict_multiple_files(
batch_size: int = 100,
n_workers: Optional[int] = None,
device: str = "auto",
add_suffix: bool = True
add_suffix: bool = True,
include_strain_predictions: bool = False
) -> List[pd.DataFrame]:
"""
批量预测多个 CSV 文件
@@ -178,6 +211,7 @@ def predict_multiple_files(
n_workers: 工作进程数
device: 计算设备
add_suffix: 是否在输出文件名后添加预测后缀
include_strain_predictions: 是否在输出中包含40种菌株的预测详情
Returns:
包含预测结果的 DataFrame 列表
@@ -210,7 +244,8 @@ def predict_multiple_files(
batch_size=batch_size,
n_workers=n_workers,
device=device,
add_suffix=add_suffix
add_suffix=add_suffix,
include_strain_predictions=include_strain_predictions
)
results.append(df_result)
except Exception as e:
@@ -240,7 +275,9 @@ def predict_multiple_files(
help='计算设备 (默认: auto)')
@click.option('--add-suffix/--no-add-suffix', default=True,
help='是否在输出文件名后添加 "_predicted" 后缀 (默认: 添加)')
def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers, device, add_suffix):
@click.option('--include-strain-predictions', is_flag=True, default=False,
help='在输出中包含40种菌株的详细预测数据每个分子将产生40行数据对应每个菌株的预测概率和抑制情况')
def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers, device, add_suffix, include_strain_predictions):
"""
使用 MolE 模型预测小分子 SMILES 的抗菌活性
@@ -248,12 +285,21 @@ def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers
OUTPUT_PATH: 输出 CSV 文件路径 (可选,默认在原文件目录生成)
默认输出包含聚合的抗菌活性指标(广谱抗菌评分、抑制菌株数等)。
使用 --include-strain-predictions 可以额外包含每个菌株的详细预测数据。
示例:
# 基本用法(仅输出聚合结果)
python mole_predictor.py input.csv output.csv
python mole_predictor.py input.csv -s SMILES -i ID
# 包含40种菌株的详细预测数据
python mole_predictor.py input.csv output.csv --include-strain-predictions
# 指定列名和设备
python mole_predictor.py input.csv -s SMILES -i ID --device cuda:0
# 自定义批处理大小
python mole_predictor.py input.csv --device cuda:0 --batch-size 200
"""
@@ -266,7 +312,8 @@ def cli(input_path, output_path, smiles_column, id_column, batch_size, n_workers
batch_size=batch_size,
n_workers=n_workers,
device=device,
add_suffix=add_suffix
add_suffix=add_suffix,
include_strain_predictions=include_strain_predictions
)
except Exception as e:
click.echo(f"错误: {e}", err=True)