#!/usr/bin/env python3 """ 批量抗菌活性预测工具 这个工具用于大规模预测分子的抗菌活性,支持: - 单进程或多进程并行预测 - 自动处理大型 CSV 文件 - 灵活的 GPU 配置 - 临时文件管理和断点续传 技术细节: - 使用 ParallelBroadSpectrumPredictor(单进程 + XGBoost OpenMP 多线程) - 避免 CUDA fork 死锁问题 - 每个模型实例约占用 2.5GB GPU 显存 - XGBoost 自动利用所有 CPU 核心 作者: AI Assistant 日期: 2025-10-17 """ import sys import os import time import click import pandas as pd from pathlib import Path from tqdm import tqdm from contextlib import redirect_stdout, redirect_stderr import multiprocessing as mp # 添加项目根目录到 Python 路径 project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) from models.broad_spectrum_predictor import ( ParallelBroadSpectrumPredictor, MoleculeInput, PredictionConfig ) def predict_single_process( input_path: str, output_path: str, smiles_column: str, id_column: str, device: str, batch_size: int, start_from: int, max_molecules: int, temp_dir: Path, verbose: bool = True ) -> pd.DataFrame: """ 单进程预测分子抗菌活性 Args: input_path: 输入 CSV 文件路径 output_path: 输出 CSV 文件路径 smiles_column: SMILES 列名 id_column: 化合物 ID 列名 device: GPU 设备(如 'cuda:0' 或 'cpu') batch_size: 批处理大小 start_from: 从第几行开始 max_molecules: 最多处理多少个分子 temp_dir: 临时文件目录 verbose: 是否显示详细信息 Returns: 预测结果 DataFrame """ # 读取数据 df_input = pd.read_csv(input_path) # 检查列是否存在(大小写不敏感) columns_lower = {col.lower(): col for col in df_input.columns} smiles_col_actual = columns_lower.get(smiles_column.lower()) if smiles_col_actual is None: raise ValueError( f"SMILES 列 '{smiles_column}' 不存在。可用列: {list(df_input.columns)}" ) # 处理 ID 列 id_col_actual = columns_lower.get(id_column.lower()) if id_col_actual is None: if verbose: print(f"未找到 ID 列 '{id_column}',将自动生成 ID") df_input[id_column] = [f"mol{i+1}" for i in range(len(df_input))] id_col_actual = id_column # 应用限制 if start_from > 0: df_input = df_input.iloc[start_from:] if max_molecules: df_input = df_input.iloc[:max_molecules] if verbose: print(f" 处理 {len(df_input):,} 个分子") # 初始化预测器 config = PredictionConfig( batch_size=10000, device=device ) # 抑制模型加载时的输出 if not verbose: with open(os.devnull, 'w') as devnull: with redirect_stdout(devnull): predictor = ParallelBroadSpectrumPredictor(config) else: predictor = ParallelBroadSpectrumPredictor(config) # 分批处理 all_results = [] n_batches = (len(df_input) + batch_size - 1) // batch_size iterator = range(0, len(df_input), batch_size) if verbose: iterator = tqdm(iterator, desc="处理进度", unit="批") for i in iterator: batch_df = df_input.iloc[i:i+batch_size] # 准备分子输入 molecules = [ MoleculeInput( smiles=row[smiles_col_actual], chem_id=str(row[id_col_actual]) ) for _, row in batch_df.iterrows() ] # 执行预测 try: # 抑制详细输出 with open(os.devnull, 'w') as devnull: with redirect_stdout(devnull): results = predictor.predict_batch( molecules, include_strain_predictions=False ) # 转换结果 for result in results: result_dict = result.to_dict() mol_idx = int(result.chem_id.replace('mol', '')) - 1 if mol_idx < len(batch_df): result_dict['smiles'] = batch_df.iloc[mol_idx][smiles_col_actual] all_results.append(result_dict) except Exception as e: if verbose: print(f"\n❌ 批次 {i//batch_size + 1} 失败: {e}") continue # 定期保存临时结果(每 10 批) if (i // batch_size + 1) % 10 == 0: temp_df = pd.DataFrame(all_results) temp_file = temp_dir / f"batch_{i//batch_size+1}.csv" temp_df.to_csv(temp_file, index=False) if verbose: print(f"\n💾 临时保存: {temp_file}") # 转换为 DataFrame df_results = pd.DataFrame(all_results) # 重新排列列顺序 if 'smiles' in df_results.columns: cols = ['smiles', 'chem_id'] + [col for col in df_results.columns if col not in ['smiles', 'chem_id']] df_results = df_results[cols] # 保存结果 df_results.to_csv(output_path, index=False) return df_results def predict_chunk_worker(args): """ 多进程工作函数:处理单个数据块 Args: args: (chunk_data, chunk_id, output_file, params_dict) Returns: (chunk_id, output_file, success) """ chunk_data, chunk_id, output_file, params = args try: # 保存 chunk 数据到临时文件 temp_input = params['temp_dir'] / f"chunk_{chunk_id}_input.csv" chunk_data.to_csv(temp_input, index=False) # 调用单进程预测 predict_single_process( input_path=str(temp_input), output_path=str(output_file), smiles_column=params['smiles_column'], id_column=params['id_column'], device=params['device'], batch_size=params['batch_size'], start_from=0, max_molecules=None, temp_dir=params['temp_dir'], verbose=(chunk_id == 0) # 只有第一个进程显示详细信息 ) # 清理临时输入文件 temp_input.unlink() return chunk_id, output_file, True except Exception as e: print(f"❌ Chunk {chunk_id} 失败: {e}") import traceback traceback.print_exc() return chunk_id, output_file, False @click.command() @click.option( '--input', '-i', required=True, type=click.Path(exists=True), help='输入 CSV 文件路径' ) @click.option( '--output', '-o', required=True, type=click.Path(), help='输出 CSV 文件路径' ) @click.option( '--smiles-column', '-s', default='smiles', help='SMILES 列名(默认: smiles)' ) @click.option( '--id-column', '-d', default='chem_id', help='化合物 ID 列名(默认: chem_id,如不存在则自动生成)' ) @click.option( '--device', '-g', default='cuda:0', help='GPU 设备(默认: cuda:0)。可选: cuda:0, cuda:1, cpu' ) @click.option( '--n-processes', '-n', default=1, type=int, help='并行进程数(默认: 1)。建议值 = GPU显存(GB) / 2.5。例如 32GB 显存可用 ~12 个进程' ) @click.option( '--batch-size', '-b', default=1000, type=int, help='每批处理的分子数量(默认: 1000)' ) @click.option( '--start-from', default=0, type=int, help='从第几行开始处理(默认: 0,用于断点续传)' ) @click.option( '--max-molecules', '-m', default=None, type=int, help='最多处理多少个分子(默认: None,处理全部)' ) @click.option( '--temp-dir', default=None, type=click.Path(), help='临时文件目录(默认: {输入文件名}_temp)' ) @click.option( '--keep-temp/--no-keep-temp', default=True, help='是否保留临时文件(默认: 保留)' ) @click.option( '--verbose/--quiet', default=True, help='是否显示详细信息(默认: 显示)' ) def main( input: str, output: str, smiles_column: str, id_column: str, device: str, n_processes: int, batch_size: int, start_from: int, max_molecules: int, temp_dir: str, keep_temp: bool, verbose: bool ): """ 批量预测分子抗菌活性 \b 示例用法: 1. 单进程预测(最稳定): pixi run python utils/batch_predictor.py -i data.csv -o output.csv 2. 多进程并行(4个进程,适合32GB显存): pixi run python utils/batch_predictor.py -i data.csv -o output.csv -n 4 3. 指定 GPU 和列名: pixi run python utils/batch_predictor.py -i data.csv -o output.csv \\ -g cuda:1 -s SMILES -d ID -n 8 4. 断点续传(从第 100000 行开始): pixi run python utils/batch_predictor.py -i data.csv -o output.csv \\ --start-from 100000 \b 显存计算公式: - 每个模型实例约占用 2.5GB GPU 显存 - 建议并行进程数 = GPU显存(GB) / 2.5 - 例如: * 12GB 显存 → 建议 4 个进程 * 24GB 显存 → 建议 9 个进程 * 32GB 显存 → 建议 12 个进程 * 48GB 显存 → 建议 19 个进程 \b 注意事项: - 单 GPU 上的多进程会串行使用 GPU,不是真正的并行 - 预期加速比约 2-3x(而非线性加速) - 如果 GPU 内存不足,减少 n-processes - 临时文件保存在 {输入文件名}_temp/ 目录 """ print("=" * 80) print("🚀 批量抗菌活性预测") print("=" * 80) # 设置路径 input_path = Path(input) output_path = Path(output) if temp_dir is None: temp_dir = input_path.parent / f"{input_path.stem}_temp" else: temp_dir = Path(temp_dir) temp_dir.mkdir(parents=True, exist_ok=True) # 显示配置 if verbose: print(f"\n配置:") print(f" 输入文件: {input_path}") print(f" 输出文件: {output_path}") print(f" SMILES 列: {smiles_column}") print(f" ID 列: {id_column}") print(f" GPU 设备: {device}") print(f" 并行进程数: {n_processes}") print(f" 批处理大小: {batch_size}") print(f" 临时目录: {temp_dir}") if start_from > 0: print(f" 开始行: {start_from}") if max_molecules: print(f" 最多处理: {max_molecules:,} 个分子") print("=" * 80) start_time = time.time() # 单进程模式 if n_processes == 1: if verbose: print("\n📦 使用单进程模式") df_results = predict_single_process( input_path=str(input_path), output_path=str(output_path), smiles_column=smiles_column, id_column=id_column, device=device, batch_size=batch_size, start_from=start_from, max_molecules=max_molecules, temp_dir=temp_dir, verbose=verbose ) # 多进程模式 else: if verbose: print(f"\n📦 使用多进程模式({n_processes} 个进程)") # 读取数据 df_input = pd.read_csv(input_path) # 应用限制 if start_from > 0: df_input = df_input.iloc[start_from:] if max_molecules: df_input = df_input.iloc[:max_molecules] # 分割数据 chunk_size = len(df_input) // n_processes chunks = [] for i in range(n_processes): start_idx = i * chunk_size if i == n_processes - 1: end_idx = len(df_input) else: end_idx = (i + 1) * chunk_size chunk = df_input.iloc[start_idx:end_idx].copy() output_file = temp_dir / f"part_{i}.csv" params = { 'smiles_column': smiles_column, 'id_column': id_column, 'device': device, 'batch_size': batch_size, 'temp_dir': temp_dir } chunks.append((chunk, i, output_file, params)) if verbose: print(f" 数据分成 {n_processes} 块") print(f" 每块约 {chunk_size:,} 个分子") print(f"\n🔄 开始并行处理...") # 使用 spawn 模式避免 CUDA fork 问题 mp.set_start_method('spawn', force=True) # 并行处理 with mp.Pool(processes=n_processes) as pool: results = list(tqdm( pool.imap(predict_chunk_worker, chunks), total=len(chunks), desc="处理进度", disable=not verbose )) # 合并结果 if verbose: print("\n📦 合并结果...") all_dfs = [] for chunk_id, output_file, success in sorted(results, key=lambda x: x[0]): if success and output_file.exists(): df = pd.read_csv(output_file) all_dfs.append(df) if verbose: print(f" ✓ part_{chunk_id}.csv: {len(df):,} 行") else: print(f" ❌ part_{chunk_id}.csv: 失败") df_results = pd.concat(all_dfs, ignore_index=True) df_results.to_csv(output_path, index=False) # 统计信息 elapsed_time = time.time() - start_time if verbose: print("\n" + "=" * 80) print("✅ 预测完成") print("=" * 80) print(f"\n📈 统计信息:") print(f" 处理分子数: {len(df_results):,}") print(f" 总耗时: {elapsed_time:.2f} 秒 ({elapsed_time/60:.2f} 分钟)") print(f" 平均速度: {len(df_results)/elapsed_time:.2f} 分子/秒") # 广谱抗菌统计 n_broad = df_results['broad_spectrum'].sum() print(f"\n🎯 预测结果:") print(f" 广谱抗菌: {n_broad:,} 个 ({n_broad/len(df_results)*100:.2f}%)") print(f" 非广谱: {len(df_results)-n_broad:,} 个") # 抑制菌株数分布 print(f"\n📊 抑制菌株数分布:") for threshold in [0, 5, 10, 15, 20, 30]: n = (df_results['ginhib_total'] >= threshold).sum() print(f" ≥{threshold:2d} 个菌株: {n:,} ({n/len(df_results)*100:.2f}%)") print("\n" + "=" * 80) if not keep_temp: print(f"\n🗑️ 清理临时文件...") import shutil shutil.rmtree(temp_dir) print(f"✓ 已删除: {temp_dir}") else: print(f"\n📁 临时文件保留在: {temp_dir}") if __name__ == '__main__': main()