first commit

This commit is contained in:
mm644706215
2025-08-27 21:16:45 +08:00
commit c5620ad4e3
5 changed files with 173 additions and 0 deletions

60
fix_admet_load.py Normal file
View File

@@ -0,0 +1,60 @@
import torch
import sys
import os
import argparse
from pathlib import Path
def patch_torch_load():
"""
修补torch.load函数添加weights_only=False参数
这是为了解决PyTorch 2.6版本中weights_only默认值变更导致的问题
"""
original_torch_load = torch.load
def patched_load(f, map_location=None, pickle_module=None, **pickle_load_args):
# 显式添加weights_only=False参数
if 'weights_only' not in pickle_load_args:
pickle_load_args['weights_only'] = False
return original_torch_load(f, map_location, pickle_module, **pickle_load_args)
torch.load = patched_load
print("已修补torch.load函数添加weights_only=False参数")
def main():
parser = argparse.ArgumentParser(description='运行admet_predict并修复torch.load问题')
parser.add_argument('--data_path', type=str, required=True, help='输入数据路径')
parser.add_argument('--save_path', type=str, required=True, help='保存结果路径')
parser.add_argument('--smiles_column', type=str, default='smiles', help='SMILES列名')
parser.add_argument('--models_dir', type=str, help='模型目录')
parser.add_argument('--include_physchem', action='store_true', help='是否包含物理化学特性')
parser.add_argument('--drugbank_path', type=str, help='DrugBank路径')
parser.add_argument('--atc_code', type=str, help='ATC代码')
args, unknown = parser.parse_known_args()
# 应用修补
patch_torch_load()
# 构建admet_predict命令
cmd = [sys.executable, '-m', 'admet_ai.admet_predict']
cmd.extend(['--data_path', args.data_path])
cmd.extend(['--save_path', args.save_path])
cmd.extend(['--smiles_column', args.smiles_column])
if args.models_dir:
cmd.extend(['--models_dir', args.models_dir])
if args.include_physchem:
cmd.append('--include_physchem')
if args.drugbank_path:
cmd.extend(['--drugbank_path', args.drugbank_path])
if args.atc_code:
cmd.extend(['--atc_code', args.atc_code])
# 添加未知参数
cmd.extend(unknown)
# 执行命令
os.execv(sys.executable, cmd)
if __name__ == '__main__':
main()