60 lines
2.2 KiB
Python
60 lines
2.2 KiB
Python
import torch
|
||
import sys
|
||
import os
|
||
import argparse
|
||
from pathlib import Path
|
||
|
||
def patch_torch_load():
|
||
"""
|
||
修补torch.load函数,添加weights_only=False参数
|
||
这是为了解决PyTorch 2.6版本中weights_only默认值变更导致的问题
|
||
"""
|
||
original_torch_load = torch.load
|
||
|
||
def patched_load(f, map_location=None, pickle_module=None, **pickle_load_args):
|
||
# 显式添加weights_only=False参数
|
||
if 'weights_only' not in pickle_load_args:
|
||
pickle_load_args['weights_only'] = False
|
||
return original_torch_load(f, map_location, pickle_module, **pickle_load_args)
|
||
|
||
torch.load = patched_load
|
||
print("已修补torch.load函数,添加weights_only=False参数")
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='运行admet_predict并修复torch.load问题')
|
||
parser.add_argument('--data_path', type=str, required=True, help='输入数据路径')
|
||
parser.add_argument('--save_path', type=str, required=True, help='保存结果路径')
|
||
parser.add_argument('--smiles_column', type=str, default='smiles', help='SMILES列名')
|
||
parser.add_argument('--models_dir', type=str, help='模型目录')
|
||
parser.add_argument('--include_physchem', action='store_true', help='是否包含物理化学特性')
|
||
parser.add_argument('--drugbank_path', type=str, help='DrugBank路径')
|
||
parser.add_argument('--atc_code', type=str, help='ATC代码')
|
||
|
||
args, unknown = parser.parse_known_args()
|
||
|
||
# 应用修补
|
||
patch_torch_load()
|
||
|
||
# 构建admet_predict命令
|
||
cmd = [sys.executable, '-m', 'admet_ai.admet_predict']
|
||
cmd.extend(['--data_path', args.data_path])
|
||
cmd.extend(['--save_path', args.save_path])
|
||
cmd.extend(['--smiles_column', args.smiles_column])
|
||
|
||
if args.models_dir:
|
||
cmd.extend(['--models_dir', args.models_dir])
|
||
if args.include_physchem:
|
||
cmd.append('--include_physchem')
|
||
if args.drugbank_path:
|
||
cmd.extend(['--drugbank_path', args.drugbank_path])
|
||
if args.atc_code:
|
||
cmd.extend(['--atc_code', args.atc_code])
|
||
|
||
# 添加未知参数
|
||
cmd.extend(unknown)
|
||
|
||
# 执行命令
|
||
os.execv(sys.executable, cmd)
|
||
|
||
if __name__ == '__main__':
|
||
main() |