first commit
This commit is contained in:
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
admet_ai
|
||||
*.sdf
|
||||
11
README.md
Normal file
11
README.md
Normal file
@@ -0,0 +1,11 @@
|
||||
## 使用admet-ai预测
|
||||
|
||||
```bash
|
||||
git clone https://github.com/swansonk14/admet_ai
|
||||
cd admet_ai
|
||||
micromamba create -n admet_ai python=3.10
|
||||
micromamba activate admet_ai
|
||||
pip install -r requirements.txt
|
||||
pip install -e .
|
||||
python fix_admet_load.py --data_path ./data/molecules.csv --save_path ./data/preds.csv --smiles_column smiles
|
||||
```
|
||||
56
data/molecules.csv
Normal file
56
data/molecules.csv
Normal file
File diff suppressed because one or more lines are too long
60
fix_admet_load.py
Normal file
60
fix_admet_load.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
def patch_torch_load():
|
||||
"""
|
||||
修补torch.load函数,添加weights_only=False参数
|
||||
这是为了解决PyTorch 2.6版本中weights_only默认值变更导致的问题
|
||||
"""
|
||||
original_torch_load = torch.load
|
||||
|
||||
def patched_load(f, map_location=None, pickle_module=None, **pickle_load_args):
|
||||
# 显式添加weights_only=False参数
|
||||
if 'weights_only' not in pickle_load_args:
|
||||
pickle_load_args['weights_only'] = False
|
||||
return original_torch_load(f, map_location, pickle_module, **pickle_load_args)
|
||||
|
||||
torch.load = patched_load
|
||||
print("已修补torch.load函数,添加weights_only=False参数")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='运行admet_predict并修复torch.load问题')
|
||||
parser.add_argument('--data_path', type=str, required=True, help='输入数据路径')
|
||||
parser.add_argument('--save_path', type=str, required=True, help='保存结果路径')
|
||||
parser.add_argument('--smiles_column', type=str, default='smiles', help='SMILES列名')
|
||||
parser.add_argument('--models_dir', type=str, help='模型目录')
|
||||
parser.add_argument('--include_physchem', action='store_true', help='是否包含物理化学特性')
|
||||
parser.add_argument('--drugbank_path', type=str, help='DrugBank路径')
|
||||
parser.add_argument('--atc_code', type=str, help='ATC代码')
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
# 应用修补
|
||||
patch_torch_load()
|
||||
|
||||
# 构建admet_predict命令
|
||||
cmd = [sys.executable, '-m', 'admet_ai.admet_predict']
|
||||
cmd.extend(['--data_path', args.data_path])
|
||||
cmd.extend(['--save_path', args.save_path])
|
||||
cmd.extend(['--smiles_column', args.smiles_column])
|
||||
|
||||
if args.models_dir:
|
||||
cmd.extend(['--models_dir', args.models_dir])
|
||||
if args.include_physchem:
|
||||
cmd.append('--include_physchem')
|
||||
if args.drugbank_path:
|
||||
cmd.extend(['--drugbank_path', args.drugbank_path])
|
||||
if args.atc_code:
|
||||
cmd.extend(['--atc_code', args.atc_code])
|
||||
|
||||
# 添加未知参数
|
||||
cmd.extend(unknown)
|
||||
|
||||
# 执行命令
|
||||
os.execv(sys.executable, cmd)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
44
scripts/sdf_to_smiles.py
Normal file
44
scripts/sdf_to_smiles.py
Normal file
@@ -0,0 +1,44 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
from rdkit import Chem
|
||||
|
||||
# 当前目录路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# 输入SDF文件路径
|
||||
sdf_file = os.path.join(current_dir, 'fgbar_vina_SP_1_pv.sdf')
|
||||
|
||||
# 输出SMILES文件路径
|
||||
smiles_file = os.path.join(current_dir, 'molecules.txt')
|
||||
|
||||
def sdf_to_smiles(sdf_path, output_path):
|
||||
"""读取SDF文件中的分子并将其转换为SMILES格式保存到文本文件"""
|
||||
# 读取SDF文件
|
||||
suppl = Chem.SDMolSupplier(sdf_path)
|
||||
|
||||
# 计数有效分子数量
|
||||
valid_mol_count = 0
|
||||
|
||||
# 打开输出文件
|
||||
with open(output_path, 'w') as f:
|
||||
# 遍历所有分子
|
||||
for i, mol in enumerate(suppl):
|
||||
if mol is not None: # 确保分子有效
|
||||
# 获取SMILES
|
||||
smiles = Chem.MolToSmiles(mol)
|
||||
# 写入文件
|
||||
f.write(f"{smiles}\n")
|
||||
valid_mol_count += 1
|
||||
|
||||
return valid_mol_count
|
||||
|
||||
def main():
|
||||
print(f"正在读取SDF文件: {sdf_file}")
|
||||
mol_count = sdf_to_smiles(sdf_file, smiles_file)
|
||||
print(f"成功处理了 {mol_count} 个分子")
|
||||
print(f"SMILES已保存到: {smiles_file}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user