Files
admet-ai/fix_admet_load.py
mm644706215 c5620ad4e3 first commit
2025-08-27 21:16:45 +08:00

60 lines
2.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import sys
import os
import argparse
from pathlib import Path
def patch_torch_load():
"""
修补torch.load函数添加weights_only=False参数
这是为了解决PyTorch 2.6版本中weights_only默认值变更导致的问题
"""
original_torch_load = torch.load
def patched_load(f, map_location=None, pickle_module=None, **pickle_load_args):
# 显式添加weights_only=False参数
if 'weights_only' not in pickle_load_args:
pickle_load_args['weights_only'] = False
return original_torch_load(f, map_location, pickle_module, **pickle_load_args)
torch.load = patched_load
print("已修补torch.load函数添加weights_only=False参数")
def main():
parser = argparse.ArgumentParser(description='运行admet_predict并修复torch.load问题')
parser.add_argument('--data_path', type=str, required=True, help='输入数据路径')
parser.add_argument('--save_path', type=str, required=True, help='保存结果路径')
parser.add_argument('--smiles_column', type=str, default='smiles', help='SMILES列名')
parser.add_argument('--models_dir', type=str, help='模型目录')
parser.add_argument('--include_physchem', action='store_true', help='是否包含物理化学特性')
parser.add_argument('--drugbank_path', type=str, help='DrugBank路径')
parser.add_argument('--atc_code', type=str, help='ATC代码')
args, unknown = parser.parse_known_args()
# 应用修补
patch_torch_load()
# 构建admet_predict命令
cmd = [sys.executable, '-m', 'admet_ai.admet_predict']
cmd.extend(['--data_path', args.data_path])
cmd.extend(['--save_path', args.save_path])
cmd.extend(['--smiles_column', args.smiles_column])
if args.models_dir:
cmd.extend(['--models_dir', args.models_dir])
if args.include_physchem:
cmd.append('--include_physchem')
if args.drugbank_path:
cmd.extend(['--drugbank_path', args.drugbank_path])
if args.atc_code:
cmd.extend(['--atc_code', args.atc_code])
# 添加未知参数
cmd.extend(unknown)
# 执行命令
os.execv(sys.executable, cmd)
if __name__ == '__main__':
main()