import torch import sys import os import argparse from pathlib import Path def patch_torch_load(): """ 修补torch.load函数,添加weights_only=False参数 这是为了解决PyTorch 2.6版本中weights_only默认值变更导致的问题 """ original_torch_load = torch.load def patched_load(f, map_location=None, pickle_module=None, **pickle_load_args): # 显式添加weights_only=False参数 if 'weights_only' not in pickle_load_args: pickle_load_args['weights_only'] = False return original_torch_load(f, map_location, pickle_module, **pickle_load_args) torch.load = patched_load print("已修补torch.load函数,添加weights_only=False参数") def main(): parser = argparse.ArgumentParser(description='运行admet_predict并修复torch.load问题') parser.add_argument('--data_path', type=str, required=True, help='输入数据路径') parser.add_argument('--save_path', type=str, required=True, help='保存结果路径') parser.add_argument('--smiles_column', type=str, default='smiles', help='SMILES列名') parser.add_argument('--models_dir', type=str, help='模型目录') parser.add_argument('--include_physchem', action='store_true', help='是否包含物理化学特性') parser.add_argument('--drugbank_path', type=str, help='DrugBank路径') parser.add_argument('--atc_code', type=str, help='ATC代码') args, unknown = parser.parse_known_args() # 应用修补 patch_torch_load() # 构建admet_predict命令 cmd = [sys.executable, '-m', 'admet_ai.admet_predict'] cmd.extend(['--data_path', args.data_path]) cmd.extend(['--save_path', args.save_path]) cmd.extend(['--smiles_column', args.smiles_column]) if args.models_dir: cmd.extend(['--models_dir', args.models_dir]) if args.include_physchem: cmd.append('--include_physchem') if args.drugbank_path: cmd.extend(['--drugbank_path', args.drugbank_path]) if args.atc_code: cmd.extend(['--atc_code', args.atc_code]) # 添加未知参数 cmd.extend(unknown) # 执行命令 os.execv(sys.executable, cmd) if __name__ == '__main__': main()