update
This commit is contained in:
59
.dockerignore
Normal file
59
.dockerignore
Normal file
@@ -0,0 +1,59 @@
|
||||
# Git 相关
|
||||
.git
|
||||
.gitignore
|
||||
.gitattributes
|
||||
|
||||
# Docker 相关
|
||||
docker/
|
||||
.dockerignore
|
||||
|
||||
# Python 相关
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# 数据文件
|
||||
data/
|
||||
results/
|
||||
output/
|
||||
*.pdb
|
||||
*.sdf
|
||||
*.mol2
|
||||
*.xyz
|
||||
|
||||
# 日志文件
|
||||
*.log
|
||||
logs/
|
||||
|
||||
# 临时文件
|
||||
*.tmp
|
||||
*.temp
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# IDE 相关
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# 系统文件
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# 文档
|
||||
*.md
|
||||
docs/
|
||||
|
||||
# 测试文件
|
||||
test/
|
||||
tests/
|
||||
*_test.py
|
||||
test_*.py
|
||||
2
.gitattributes
vendored
Normal file
2
.gitattributes
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
# SCM syntax highlighting & preventing 3-way merges
|
||||
pixi.lock merge=binary linguist-language=YAML linguist-generated=true
|
||||
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
# pixi environments
|
||||
.pixi/*
|
||||
!.pixi/config.toml
|
||||
bin/
|
||||
141
README.md
Normal file
141
README.md
Normal file
@@ -0,0 +1,141 @@
|
||||
## 环境管理 (使用 pixi)
|
||||
|
||||
### 安装 pixi
|
||||
|
||||
```bash
|
||||
# 安装 pixi
|
||||
curl -fsSL https://pixi.sh/install.sh | bash
|
||||
|
||||
# 重新加载 shell 配置
|
||||
source ~/.bashrc # 或 source ~/.zshrc
|
||||
```
|
||||
|
||||
### 初始化项目环境
|
||||
|
||||
```bash
|
||||
# 在项目目录中初始化 pixi 环境
|
||||
pixi init
|
||||
|
||||
# 添加所需的包
|
||||
pixi add rdkit openbabel meeko
|
||||
|
||||
# 激活环境
|
||||
pixi shell
|
||||
```
|
||||
|
||||
### 使用环境
|
||||
|
||||
```bash
|
||||
# 激活 pixi 环境
|
||||
pixi shell
|
||||
|
||||
# 在环境中运行脚本
|
||||
pixi run python scripts/your_script.py
|
||||
|
||||
# 或者直接使用 pixi 执行命令
|
||||
pixi run vina --help
|
||||
```
|
||||
|
||||
## AutoDock Vina 安装
|
||||
|
||||
### 下载二进制文件
|
||||
|
||||
[Download](https://github.com/ccsb-scripps/AutoDock-Vina/releases/tag/v1.2.7)
|
||||
|
||||
for macos:
|
||||
|
||||
```bash
|
||||
|
||||
wget -O ./bin/vina_1.2.7_mac_aarch64 https://github.com/ccsb-scripps/AutoDock-Vina/releases/download/v1.2.7/vina_1.2.7_mac_aarch64
|
||||
wget -O ./bin/vina_split_1.2.7_mac_aarch64 https://github.com/ccsb-scripps/AutoDock-Vina/releases/download/v1.2.7/vina_split_1.2.7_mac_aarch64
|
||||
|
||||
# 或者使用 curl:
|
||||
curl -L -o ./bin/vina_1.2.7_mac_aarch64 https://github.com/ccsb-scripps/AutoDock-Vina/releases/download/v1.2.7/vina_1.2.7_mac_aarch64
|
||||
curl -L -o ./bin/vina_split_1.2.7_mac_aarch64 https://github.com/ccsb-scripps/AutoDock-Vina/releases/download/v1.2.7/vina_split_1.2.7_mac_aarch64
|
||||
|
||||
chmod +x ./bin/vina_*
|
||||
```
|
||||
|
||||
## 项目使用
|
||||
|
||||
### 快速开始
|
||||
|
||||
```bash
|
||||
# 1. 克隆项目
|
||||
git clone <your-repo-url>
|
||||
cd vinatools
|
||||
|
||||
# 2. 初始化 pixi 环境
|
||||
pixi init
|
||||
pixi add rdkit openbabel meeko
|
||||
|
||||
# 3. 下载 AutoDock Vina 二进制文件
|
||||
mkdir -p bin
|
||||
curl -L -o ./bin/vina_1.2.7_mac_aarch64 https://github.com/ccsb-scripps/AutoDock-Vina/releases/download/v1.2.7/vina_1.2.7_mac_aarch64
|
||||
curl -L -o ./bin/vina_split_1.2.7_mac_aarch64 https://github.com/ccsb-scripps/AutoDock-Vina/releases/download/v1.2.7/vina_split_1.2.7_mac_aarch64
|
||||
chmod +x ./bin/vina_*
|
||||
|
||||
# 4. 激活环境并运行
|
||||
pixi shell
|
||||
```
|
||||
|
||||
### 环境管理命令
|
||||
|
||||
```bash
|
||||
# 查看已安装的包
|
||||
pixi list
|
||||
|
||||
# 添加新包
|
||||
pixi add package_name
|
||||
|
||||
# 移除包
|
||||
pixi remove package_name
|
||||
|
||||
# 更新所有包
|
||||
pixi update
|
||||
|
||||
# 导出环境配置
|
||||
pixi export --format conda-lock
|
||||
```
|
||||
|
||||
### 项目结构
|
||||
|
||||
```
|
||||
vinatools/
|
||||
├── bin/ # AutoDock Vina 二进制文件
|
||||
│ ├── vina_1.2.7_mac_aarch64
|
||||
│ └── vina_split_1.2.7_mac_aarch64
|
||||
├── scripts/ # Python 脚本
|
||||
│ ├── batch_prepare_ligands.sh
|
||||
│ ├── batch_docking.sh
|
||||
│ ├── calculate_qed_values.py
|
||||
│ └── analyze_results.py
|
||||
├── pixi.toml # pixi 环境配置文件
|
||||
├── docker/ # Docker 配置文件
|
||||
│ ├── Dockerfile
|
||||
│ ├── docker-compose.yml
|
||||
│ └── README.md
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## Docker 环境
|
||||
|
||||
### 快速使用 Docker
|
||||
|
||||
```bash
|
||||
# 构建并运行 Docker 容器
|
||||
docker-compose -f docker/docker-compose.yml up -d
|
||||
|
||||
# 进入容器
|
||||
docker-compose -f docker/docker-compose.yml exec vinatools bash
|
||||
|
||||
# 运行脚本
|
||||
docker-compose -f docker/docker-compose.yml exec vinatools pixi run python scripts/calculate_qed_values.py
|
||||
```
|
||||
|
||||
### Docker 服务说明
|
||||
|
||||
- **vinatools**: 主服务,包含 pixi 环境和所有依赖包
|
||||
- **jupyter**: Jupyter Notebook 服务,访问 http://localhost:8888
|
||||
|
||||
详细使用说明请参考 [docker/README.md](docker/README.md)
|
||||
6
config/example.txt
Normal file
6
config/example.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
center_x = -12.7
|
||||
center_y = -9.1
|
||||
center_z = -0.3
|
||||
size_x = 49.1
|
||||
size_y = 37.6
|
||||
size_z = 35.2
|
||||
69
docker/Dockerfile
Normal file
69
docker/Dockerfile
Normal file
@@ -0,0 +1,69 @@
|
||||
# 使用 Ubuntu 22.04 作为基础镜像(使用腾讯云镜像源)
|
||||
FROM ccr.ccs.tencentyun.com/library/ubuntu:22.04
|
||||
|
||||
# 设置环境变量
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PATH="/root/.pixi/bin:$PATH"
|
||||
|
||||
# 配置 APT 镜像源(使用阿里云镜像)
|
||||
RUN sed -i 's@//.*archive.ubuntu.com@//mirrors.aliyun.com@g' /etc/apt/sources.list && \
|
||||
sed -i 's@//.*security.ubuntu.com@//mirrors.aliyun.com@g' /etc/apt/sources.list
|
||||
|
||||
# 安装系统依赖
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
wget \
|
||||
git \
|
||||
build-essential \
|
||||
ca-certificates \
|
||||
gnupg \
|
||||
lsb-release \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 配置 pip 镜像源
|
||||
RUN mkdir -p /root/.pip && \
|
||||
echo "[global]" > /root/.pip/pip.conf && \
|
||||
echo "index-url = https://pypi.tuna.tsinghua.edu.cn/simple" >> /root/.pip/pip.conf && \
|
||||
echo "trusted-host = pypi.tuna.tsinghua.edu.cn" >> /root/.pip/pip.conf
|
||||
|
||||
# 安装 pixi
|
||||
RUN curl -fsSL https://pixi.sh/install.sh | bash
|
||||
|
||||
# 设置工作目录
|
||||
WORKDIR /app
|
||||
|
||||
# 复制项目文件
|
||||
COPY . .
|
||||
|
||||
# 创建 bin 目录
|
||||
RUN mkdir -p bin
|
||||
|
||||
# 下载 AutoDock Vina 二进制文件
|
||||
ARG VINA_VERSION=1.2.7
|
||||
ARG VINA_PLATFORM=mac_aarch64
|
||||
ARG DOWNLOAD_VINA=true
|
||||
|
||||
RUN if [ "$DOWNLOAD_VINA" = "true" ]; then \
|
||||
curl -L -o ./bin/vina_${VINA_VERSION}_${VINA_PLATFORM} https://github.com/ccsb-scripps/AutoDock-Vina/releases/download/v${VINA_VERSION}/vina_${VINA_VERSION}_${VINA_PLATFORM} && \
|
||||
curl -L -o ./bin/vina_split_${VINA_VERSION}_${VINA_PLATFORM} https://github.com/ccsb-scripps/AutoDock-Vina/releases/download/v${VINA_VERSION}/vina_split_${VINA_VERSION}_${VINA_PLATFORM} && \
|
||||
chmod +x ./bin/vina_*; \
|
||||
fi
|
||||
|
||||
# 添加平台支持并安装 pixi 包
|
||||
RUN /root/.pixi/bin/pixi workspace platform add linux-aarch64 && \
|
||||
/root/.pixi/bin/pixi add rdkit openbabel meeko
|
||||
|
||||
# 设置环境变量
|
||||
ENV PATH="/root/.pixi/bin:/app/bin:$PATH"
|
||||
|
||||
# 创建启动脚本
|
||||
RUN echo '#!/bin/bash\n\
|
||||
source /root/.bashrc\n\
|
||||
exec "$@"' > /entrypoint.sh && \
|
||||
chmod +x /entrypoint.sh
|
||||
|
||||
# 设置入口点
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
|
||||
# 默认命令
|
||||
CMD ["/root/.pixi/bin/pixi", "shell"]
|
||||
201
docker/README.md
Normal file
201
docker/README.md
Normal file
@@ -0,0 +1,201 @@
|
||||
# Docker 环境使用说明
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 1. 环境变量配置
|
||||
|
||||
```bash
|
||||
# 复制环境变量模板
|
||||
cp docker/docker.env.example docker/.env
|
||||
|
||||
# 编辑环境变量
|
||||
vim docker/.env
|
||||
```
|
||||
|
||||
### 2. 构建镜像
|
||||
|
||||
```bash
|
||||
# 使用默认配置构建
|
||||
docker-compose -f docker/docker-compose.yml build
|
||||
|
||||
# 使用环境变量构建
|
||||
docker-compose -f docker/docker-compose.yml --env-file docker/.env build
|
||||
|
||||
# 或者直接使用 docker build
|
||||
docker build -f docker/Dockerfile -t vinatools:latest .
|
||||
```
|
||||
|
||||
### 3. 环境变量说明
|
||||
|
||||
| 变量名 | 默认值 | 说明 |
|
||||
|--------|--------|------|
|
||||
| `VINA_VERSION` | `1.2.7` | AutoDock Vina 版本 |
|
||||
| `VINA_PLATFORM` | `mac_aarch64` | 平台架构 |
|
||||
| `DOWNLOAD_VINA` | `true` | 是否下载 AutoDock Vina |
|
||||
|
||||
**支持的平台:**
|
||||
- `mac_aarch64` - Apple Silicon Mac
|
||||
- `mac_x86_64` - Intel Mac
|
||||
- `linux_x86_64` - Linux x86_64
|
||||
- `windows_x86_64` - Windows x86_64
|
||||
|
||||
### 4. 运行容器
|
||||
|
||||
```bash
|
||||
# 启动主服务
|
||||
docker-compose -f docker/docker-compose.yml up -d vinatools
|
||||
|
||||
# 进入容器
|
||||
docker-compose -f docker/docker-compose.yml exec vinatools bash
|
||||
|
||||
# 或者直接运行
|
||||
docker run -it --rm -v $(pwd):/app vinatools:latest bash
|
||||
```
|
||||
|
||||
## 使用示例
|
||||
|
||||
### 不同平台构建
|
||||
|
||||
```bash
|
||||
# Linux x86_64 平台
|
||||
VINA_PLATFORM=linux_x86_64 docker-compose -f docker/docker-compose.yml build
|
||||
VINA_PLATFORM=linux_aarch64 docker-compose -f docker/docker-compose.yml build
|
||||
|
||||
# Intel Mac 平台
|
||||
VINA_PLATFORM=mac_x86_64 docker-compose -f docker/docker-compose.yml build
|
||||
|
||||
# 不下载 AutoDock Vina
|
||||
DOWNLOAD_VINA=false docker-compose -f docker/docker-compose.yml build
|
||||
```
|
||||
|
||||
### 使用环境文件
|
||||
|
||||
```bash
|
||||
# 创建自定义环境文件
|
||||
cat > docker/my.env << EOF
|
||||
VINA_VERSION=1.2.6
|
||||
VINA_PLATFORM=linux_x86_64
|
||||
DOWNLOAD_VINA=true
|
||||
EOF
|
||||
|
||||
# 使用自定义环境文件构建
|
||||
docker-compose -f docker/docker-compose.yml --env-file docker/my.env build
|
||||
```
|
||||
|
||||
### 3. 使用 Jupyter Notebook
|
||||
|
||||
```bash
|
||||
# 启动 Jupyter 服务
|
||||
docker-compose -f docker/docker-compose.yml up -d jupyter
|
||||
|
||||
# 访问 http://localhost:8888
|
||||
```
|
||||
|
||||
## 环境说明
|
||||
|
||||
### 镜像源配置
|
||||
为了加速构建过程,Dockerfile 中配置了以下镜像源:
|
||||
|
||||
- **APT 源**: 阿里云镜像 (mirrors.aliyun.com)
|
||||
- **pip 源**: 清华大学镜像 (pypi.tuna.tsinghua.edu.cn)
|
||||
- **conda 源**: 清华大学镜像 (mirrors.tuna.tsinghua.edu.cn)
|
||||
|
||||
### 包含的包
|
||||
- **rdkit**: 化学信息学工具包
|
||||
- **openbabel**: 分子格式转换工具
|
||||
- **meeko**: 分子准备工具
|
||||
- **AutoDock Vina**: 分子对接工具
|
||||
|
||||
### 目录结构
|
||||
```
|
||||
/app/
|
||||
├── bin/ # AutoDock Vina 二进制文件
|
||||
├── scripts/ # Python 脚本
|
||||
├── data/ # 输入数据(挂载)
|
||||
└── results/ # 输出结果(挂载)
|
||||
```
|
||||
|
||||
## 常用命令
|
||||
|
||||
### 运行脚本
|
||||
```bash
|
||||
# 在容器中运行 Python 脚本
|
||||
pixi run python scripts/calculate_qed_values.py
|
||||
|
||||
# 运行批处理脚本
|
||||
pixi run bash scripts/batch_docking.sh
|
||||
```
|
||||
|
||||
### 数据管理
|
||||
```bash
|
||||
# 挂载数据目录
|
||||
docker run -it --rm \
|
||||
-v $(pwd)/data:/app/data \
|
||||
-v $(pwd)/results:/app/results \
|
||||
vinatools:latest bash
|
||||
```
|
||||
|
||||
### 清理
|
||||
```bash
|
||||
# 停止所有服务
|
||||
docker-compose -f docker/docker-compose.yml down
|
||||
|
||||
# 删除镜像
|
||||
docker rmi vinatools:latest
|
||||
|
||||
# 清理未使用的资源
|
||||
docker system prune -a
|
||||
```
|
||||
|
||||
## 故障排除
|
||||
|
||||
### 网络连接问题
|
||||
如果遇到网络连接问题,可以尝试以下解决方案:
|
||||
|
||||
```bash
|
||||
# 1. 使用代理构建
|
||||
docker build --build-arg HTTP_PROXY=http://proxy:port \
|
||||
--build-arg HTTPS_PROXY=http://proxy:port \
|
||||
-f docker/Dockerfile -t vinatools:latest .
|
||||
|
||||
# 2. 使用不同的镜像源
|
||||
# 编辑 Dockerfile,替换镜像源:
|
||||
# - APT: mirrors.ustc.edu.cn (中科大)
|
||||
# - pip: pypi.douban.com (豆瓣)
|
||||
# - conda: mirrors.ustc.edu.cn/anaconda/cloud/
|
||||
|
||||
# 3. 离线构建(如果网络完全不可用)
|
||||
# 预先下载所有依赖包,然后使用本地构建
|
||||
```
|
||||
|
||||
### 权限问题
|
||||
```bash
|
||||
# 修复文件权限
|
||||
sudo chown -R $USER:$USER data/ results/
|
||||
```
|
||||
|
||||
### 内存不足
|
||||
```bash
|
||||
# 增加 Docker 内存限制
|
||||
docker run -it --rm --memory=8g vinatools:latest bash
|
||||
```
|
||||
|
||||
### 网络问题
|
||||
```bash
|
||||
# 使用主机网络
|
||||
docker run -it --rm --network=host vinatools:latest bash
|
||||
```
|
||||
|
||||
### 镜像源切换
|
||||
如果需要使用其他镜像源,可以修改 Dockerfile 中的配置:
|
||||
|
||||
```dockerfile
|
||||
# APT 源切换
|
||||
RUN sed -i 's@//.*archive.ubuntu.com@//mirrors.ustc.edu.cn@g' /etc/apt/sources.list
|
||||
|
||||
# pip 源切换
|
||||
RUN echo "index-url = https://pypi.douban.com/simple" > /root/.pip/pip.conf
|
||||
|
||||
# conda 源切换
|
||||
RUN /root/.pixi/bin/pixi config set channels https://mirrors.ustc.edu.cn/anaconda/cloud/conda-forge/
|
||||
```
|
||||
70
docker/docker-compose.yml
Normal file
70
docker/docker-compose.yml
Normal file
@@ -0,0 +1,70 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
vinatools:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: docker/Dockerfile
|
||||
args:
|
||||
VINA_VERSION: ${VINA_VERSION:-1.2.7}
|
||||
VINA_PLATFORM: ${VINA_PLATFORM:-linux}
|
||||
DOWNLOAD_VINA: ${DOWNLOAD_VINA:-true}
|
||||
image: vinatools:latest
|
||||
container_name: vinatools-container
|
||||
volumes:
|
||||
# 挂载项目目录到容器
|
||||
- ..:/app
|
||||
# 挂载数据目录(用于输入输出文件)
|
||||
- ./data:/app/data
|
||||
- ./results:/app/results
|
||||
working_dir: /app
|
||||
environment:
|
||||
- PIXI_ROOT=/root/.pixi
|
||||
- PATH=/root/.pixi/bin:/app/bin:$PATH
|
||||
# 保持容器运行
|
||||
tty: true
|
||||
stdin_open: true
|
||||
# 网络模式
|
||||
network_mode: host
|
||||
# 重启策略
|
||||
restart: unless-stopped
|
||||
# 资源限制
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
memory: 4G
|
||||
cpus: '2.0'
|
||||
reservations:
|
||||
memory: 2G
|
||||
cpus: '1.0'
|
||||
|
||||
# 可选:用于 Jupyter Notebook 服务
|
||||
jupyter:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: docker/Dockerfile
|
||||
args:
|
||||
VINA_VERSION: ${VINA_VERSION:-1.2.7}
|
||||
VINA_PLATFORM: ${VINA_PLATFORM:-linux_x86_64}
|
||||
DOWNLOAD_VINA: ${DOWNLOAD_VINA:-true}
|
||||
image: vinatools:latest
|
||||
container_name: vinatools-jupyter
|
||||
ports:
|
||||
- "8888:8888"
|
||||
volumes:
|
||||
- ..:/app
|
||||
- ./data:/app/data
|
||||
- ./results:/app/results
|
||||
working_dir: /app
|
||||
environment:
|
||||
- PIXI_ROOT=/root/.pixi
|
||||
- PATH=/root/.pixi/bin:/app/bin:$PATH
|
||||
command: >
|
||||
bash -c "
|
||||
/root/.pixi/bin/pixi workspace platform add linux-aarch64 &&
|
||||
/root/.pixi/bin/pixi add jupyter notebook &&
|
||||
/root/.pixi/bin/pixi run jupyter notebook --ip=0.0.0.0 --port=8888 --no-browser --allow-root --NotebookApp.token='' --NotebookApp.password=''
|
||||
"
|
||||
restart: unless-stopped
|
||||
depends_on:
|
||||
- vinatools
|
||||
12
docker/docker.env.example
Normal file
12
docker/docker.env.example
Normal file
@@ -0,0 +1,12 @@
|
||||
# AutoDock Vina 配置
|
||||
VINA_VERSION=1.2.7
|
||||
VINA_PLATFORM=mac_aarch64
|
||||
DOWNLOAD_VINA=true
|
||||
|
||||
# 其他平台选项:
|
||||
# VINA_PLATFORM=linux_x86_64
|
||||
# VINA_PLATFORM=mac_x86_64
|
||||
# VINA_PLATFORM=windows_x86_64
|
||||
|
||||
# 禁用 AutoDock Vina 下载
|
||||
# DOWNLOAD_VINA=false
|
||||
15
pixi.toml
Normal file
15
pixi.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
[workspace]
|
||||
authors = ["hotwa <pylyzeng@gmail.com>"]
|
||||
channels = ["conda-forge"]
|
||||
name = "vinatools"
|
||||
platforms = ["osx-arm64", "linux-aarch64"]
|
||||
version = "0.1.0"
|
||||
|
||||
[tasks]
|
||||
|
||||
[dependencies]
|
||||
rdkit = ">=2025.9.1,<2026"
|
||||
openbabel = ">=3.1.1,<4"
|
||||
meeko = ">=0.5.0,<0.6"
|
||||
jupyter = ">=1.1.1,<2"
|
||||
notebook = ">=7.4.7,<8"
|
||||
551
scripts/analyze_qed_mw_distribution.py
Normal file
551
scripts/analyze_qed_mw_distribution.py
Normal file
@@ -0,0 +1,551 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
'''
|
||||
@file :analyze_qed_mw_distribution.py
|
||||
@Description :Analysis of QED and molecular weight distribution with KDE plots
|
||||
@Date :2025/08/05
|
||||
@Author :lyzeng
|
||||
'''
|
||||
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from rdkit import Chem
|
||||
import logging
|
||||
import ast
|
||||
import json
|
||||
import click
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def load_dataset(csv_file):
|
||||
"""
|
||||
Load dataset from CSV file
|
||||
|
||||
Args:
|
||||
csv_file (str): Path to the CSV file
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: Loaded dataset
|
||||
"""
|
||||
df = pd.read_csv(csv_file)
|
||||
logger.info(f"Loaded {len(df)} records from {csv_file}")
|
||||
|
||||
# Print basic statistics
|
||||
logger.info(f"Statistics for {Path(csv_file).stem}:")
|
||||
logger.info(f"QED - Min: {df['qed'].min():.3f}, Max: {df['qed'].max():.3f}, Mean: {df['qed'].mean():.3f}")
|
||||
logger.info(f"Molecular Weight - Min: {df['molecular_weight'].min():.2f}, Max: {df['molecular_weight'].max():.2f}, Mean: {df['molecular_weight'].mean():.2f}")
|
||||
|
||||
return df
|
||||
|
||||
def load_reference_molecules(dataset_name):
|
||||
"""
|
||||
Load reference molecules from CSV file
|
||||
|
||||
Args:
|
||||
dataset_name (str): Name of the dataset (fgbar or trpe)
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: Reference molecules with QED and molecular weight
|
||||
"""
|
||||
# Load reference molecules from the main CSV file
|
||||
csv_files = list(Path(".").glob(f"qed_values_{dataset_name}.csv"))
|
||||
if not csv_files:
|
||||
logger.warning(f"No CSV file found for {dataset_name}")
|
||||
return pd.DataFrame()
|
||||
|
||||
df = pd.read_csv(csv_files[0])
|
||||
|
||||
# Filter for reference molecules (those that have align_ and _out_converted.sdf in their filename)
|
||||
reference_df = df[df['filename'].str.contains('align_.*_out_converted\.sdf', na=False, regex=True)]
|
||||
|
||||
logger.info(f"Loaded {len(reference_df)} reference molecules for {dataset_name}")
|
||||
return reference_df
|
||||
|
||||
def extract_vina_scores_from_sdf(sdf_file_path):
|
||||
"""
|
||||
Extract Vina scores from all conformers in an SDF file
|
||||
|
||||
Args:
|
||||
sdf_file_path (str): Path to the SDF file
|
||||
|
||||
Returns:
|
||||
list: List of Vina scores (free_energy values) or empty list if failed
|
||||
"""
|
||||
scores = []
|
||||
try:
|
||||
supplier = Chem.SDMolSupplier(sdf_file_path, removeHs=False)
|
||||
for mol in supplier:
|
||||
if mol is None:
|
||||
continue
|
||||
|
||||
# Get the meeko property which contains docking information
|
||||
if mol.HasProp("meeko"):
|
||||
meeko_raw = mol.GetProp("meeko")
|
||||
try:
|
||||
meeko_dict = json.loads(meeko_raw)
|
||||
# Extract free energy (Vina score)
|
||||
if 'free_energy' in meeko_dict:
|
||||
scores.append(meeko_dict['free_energy'])
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse meeko JSON for {sdf_file_path}")
|
||||
else:
|
||||
logger.warning(f"No meeko property found in molecule from {sdf_file_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract Vina scores from {sdf_file_path}: {e}")
|
||||
|
||||
return scores
|
||||
|
||||
def load_vina_scores_from_csv(df, max_files=1000):
|
||||
"""
|
||||
Load Vina scores from the CSV file
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): DataFrame with vina_scores column
|
||||
max_files (int): Maximum number of files to process
|
||||
|
||||
Returns:
|
||||
list: List of all Vina scores from all molecules
|
||||
"""
|
||||
all_vina_scores = []
|
||||
|
||||
# Process only up to max_files to avoid memory issues
|
||||
processed_files = 0
|
||||
|
||||
for idx, row in df.iterrows():
|
||||
if processed_files >= max_files:
|
||||
break
|
||||
|
||||
# Skip reference molecules (those with mol2 extension)
|
||||
if '.mol2' in row['filename']:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Parse the vina_scores string back to a list
|
||||
vina_scores = ast.literal_eval(row['vina_scores'])
|
||||
all_vina_scores.extend(vina_scores)
|
||||
processed_files += 1
|
||||
except (ValueError, SyntaxError) as e:
|
||||
logger.warning(f"Failed to parse Vina scores for {row['filename']}: {e}")
|
||||
|
||||
logger.info(f"Loaded {len(all_vina_scores)} Vina scores from {processed_files} files")
|
||||
return all_vina_scores
|
||||
|
||||
def get_min_vina_scores_length(df):
|
||||
"""
|
||||
Get the minimum length of vina_scores lists in the dataframe
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): DataFrame with vina_scores column
|
||||
|
||||
Returns:
|
||||
int: Minimum length of vina_scores lists
|
||||
"""
|
||||
min_length = float('inf')
|
||||
|
||||
for idx, row in df.iterrows():
|
||||
# Skip reference molecules (those with mol2 extension)
|
||||
if '.mol2' in row['filename']:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Parse the vina_scores string back to a list
|
||||
vina_scores = ast.literal_eval(row['vina_scores'])
|
||||
min_length = min(min_length, len(vina_scores))
|
||||
except (ValueError, SyntaxError) as e:
|
||||
logger.warning(f"Failed to parse Vina scores for {row['filename']}: {e}")
|
||||
|
||||
return min_length if min_length != float('inf') else 0
|
||||
|
||||
def get_reference_vina_scores(dataset_name, rank=0):
|
||||
"""
|
||||
Get Vina scores for reference molecules
|
||||
|
||||
Args:
|
||||
dataset_name (str): Name of the dataset (fgbar or trpe)
|
||||
rank (int): Rank of the conformation to use (0 for best/first)
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with reference molecule identifiers and their Vina scores
|
||||
"""
|
||||
reference_scores = {}
|
||||
|
||||
# 使用原始目录名称 "refence"
|
||||
reference_dir = Path("result") / "refence" / dataset_name
|
||||
|
||||
if not reference_dir.exists():
|
||||
logger.warning(f"Reference directory {reference_dir} does not exist")
|
||||
return reference_scores
|
||||
|
||||
# Find reference SDF files
|
||||
reference_sdf_files = list(reference_dir.glob("*_converted.sdf"))
|
||||
logger.info(f"Processing {len(reference_sdf_files)} reference SDF files in {reference_dir}")
|
||||
|
||||
for sdf_file in reference_sdf_files:
|
||||
vina_scores = extract_vina_scores_from_sdf(str(sdf_file))
|
||||
if vina_scores:
|
||||
# Check if rank is valid
|
||||
if rank >= len(vina_scores):
|
||||
raise ValueError(f"Rank {rank} is out of range. The minimum number of conformers across all molecules is {len(vina_scores)}. Please choose a rank less than {len(vina_scores)}.")
|
||||
|
||||
# Get the score at the specified rank
|
||||
reference_score = vina_scores[rank]
|
||||
|
||||
# Extract identifier from filename
|
||||
filename_stem = sdf_file.stem
|
||||
if '_out_converted' in filename_stem:
|
||||
filename_stem = filename_stem.replace('_out_converted', '')
|
||||
if '_addH' in filename_stem:
|
||||
filename_stem = filename_stem.replace('_addH', '')
|
||||
if 'align_' in filename_stem:
|
||||
filename_stem = filename_stem.split('_')[-1] # Get the last part (e.g., 9NY or 0GA)
|
||||
|
||||
# Use filename_stem as key for reference_scores
|
||||
reference_scores[filename_stem] = reference_score
|
||||
logger.info(f"Reference Vina score for {filename_stem} (rank {rank}): {reference_score}")
|
||||
|
||||
return reference_scores
|
||||
|
||||
def plot_combined_kde_distribution_normalized(df, dataset_name, reference_df=None, reference_scores=None, vina_scores=None, reference_vina_scores=None):
|
||||
"""
|
||||
Plot combined KDE distribution for QED, molecular weight, and Vina scores (normalized)
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): Main dataset
|
||||
dataset_name (str): Name of the dataset (fgbar or trpe)
|
||||
reference_df (pd.DataFrame): Reference molecules dataset (optional)
|
||||
reference_scores (dict): Reference molecule scores (optional)
|
||||
vina_scores (list): Vina scores for all molecules (optional)
|
||||
reference_vina_scores (dict): Reference molecule Vina scores (optional)
|
||||
"""
|
||||
# Create figure
|
||||
plt.figure(figsize=(15, 8))
|
||||
|
||||
# Normalize the data to make them comparable on the same scale
|
||||
qed_normalized = (df['qed'] - df['qed'].min()) / (df['qed'].max() - df['qed'].min())
|
||||
mw_normalized = (df['molecular_weight'] - df['molecular_weight'].min()) / (df['molecular_weight'].max() - df['molecular_weight'].min())
|
||||
|
||||
# Plot KDE for normalized QED
|
||||
sns.kdeplot(qed_normalized, label='QED (normalized)', fill=True, alpha=0.5, color='blue')
|
||||
|
||||
# Plot KDE for normalized molecular weight
|
||||
sns.kdeplot(mw_normalized, label='Molecular Weight (normalized)', fill=True, alpha=0.5, color='red')
|
||||
|
||||
# Plot KDE for normalized Vina scores if available
|
||||
if vina_scores and len(vina_scores) > 0:
|
||||
# Normalize Vina scores (note: lower scores are better, so we negate for visualization)
|
||||
vina_series = pd.Series(vina_scores)
|
||||
vina_normalized = (vina_series - vina_series.min()) / (vina_series.max() - vina_series.min())
|
||||
sns.kdeplot(vina_normalized, label='Vina Score (normalized)', fill=True, alpha=0.5, color='green')
|
||||
|
||||
# Mark reference molecules if provided
|
||||
if reference_df is not None and len(reference_df) > 0:
|
||||
# Normalize reference data using the same scale as main dataset
|
||||
ref_qed_normalized = (reference_df['qed'] - df['qed'].min()) / (df['qed'].max() - df['qed'].min())
|
||||
ref_mw_normalized = (reference_df['molecular_weight'] - df['molecular_weight'].min()) / (df['molecular_weight'].max() - df['molecular_weight'].min())
|
||||
|
||||
# Dictionary to store reference positions for legend
|
||||
legend_handles = []
|
||||
|
||||
# Mark reference molecules for QED, MW, and Vina scores
|
||||
for i, (idx, row) in enumerate(reference_df.iterrows()):
|
||||
filename_stem = Path(row['filename']).stem
|
||||
# Extract actual identifier from filename
|
||||
if '_addH' in filename_stem:
|
||||
filename_stem = filename_stem.replace('_addH', '')
|
||||
if 'align_' in filename_stem:
|
||||
filename_stem = filename_stem.split('_')[-1] # Get the last part (e.g., 9NY or 0GA)
|
||||
|
||||
# Get values from the reference dataframe
|
||||
qed_value = row['qed']
|
||||
mw_value = row['molecular_weight']
|
||||
|
||||
# Build score text
|
||||
score_text = f"{filename_stem}\n(QED: {qed_value:.2f}, MW: {mw_value:.2f}"
|
||||
|
||||
# Add Vina score if available
|
||||
if reference_vina_scores and filename_stem in reference_vina_scores:
|
||||
vina_score = reference_vina_scores[filename_stem]
|
||||
score_text += f", Vina: {vina_score:.2f}"
|
||||
score_text += ")"
|
||||
|
||||
# QED marker
|
||||
x_pos = ref_qed_normalized.iloc[i]
|
||||
plt.scatter(x_pos, 0, color='darkblue', s=100, marker='v', zorder=5)
|
||||
|
||||
# Molecular weight marker
|
||||
x_pos = ref_mw_normalized.iloc[i]
|
||||
plt.scatter(x_pos, 0, color='darkred', s=100, marker='^', zorder=5)
|
||||
|
||||
# Vina score marker if available
|
||||
if reference_vina_scores and filename_stem in reference_vina_scores:
|
||||
# Normalize reference Vina score
|
||||
vina_min = min(vina_scores)
|
||||
vina_max = max(vina_scores)
|
||||
ref_vina_normalized = (reference_vina_scores[filename_stem] - vina_min) / (vina_max - vina_min)
|
||||
plt.scatter(ref_vina_normalized, 0, color='darkgreen', s=100, marker='o', zorder=5)
|
||||
|
||||
# Annotate with combined information
|
||||
plt.annotate(score_text,
|
||||
(x_pos, 0),
|
||||
xytext=(10, 30),
|
||||
textcoords='offset points',
|
||||
bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7),
|
||||
arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'),
|
||||
fontsize=8)
|
||||
|
||||
# Add to legend
|
||||
legend_handles.append(plt.Line2D([0], [0], marker='v', color='darkblue', label=f"{filename_stem} - QED",
|
||||
markerfacecolor='darkblue', markersize=8, linestyle=''))
|
||||
legend_handles.append(plt.Line2D([0], [0], marker='^', color='darkred', label=f"{filename_stem} - MW",
|
||||
markerfacecolor='darkred', markersize=8, linestyle=''))
|
||||
if reference_vina_scores and filename_stem in reference_vina_scores:
|
||||
legend_handles.append(plt.Line2D([0], [0], marker='o', color='darkgreen', label=f"{filename_stem} - Vina",
|
||||
markerfacecolor='darkgreen', markersize=8, linestyle=''))
|
||||
|
||||
# Add combined legend
|
||||
plt.legend(handles=legend_handles, loc='upper right', fontsize=10)
|
||||
|
||||
plt.title(f'Combined KDE Distribution (Normalized) - {dataset_name.upper()}', fontsize=16)
|
||||
plt.xlabel('Normalized Values (0-1)')
|
||||
plt.ylabel('Density')
|
||||
plt.legend()
|
||||
plt.grid(True, alpha=0.3)
|
||||
|
||||
# Adjust layout and save figure
|
||||
plt.tight_layout()
|
||||
plt.savefig(f'kde_distribution_{dataset_name}_normalized.png', dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
logger.info(f"Saved combined KDE distribution plot (normalized) for {dataset_name} as kde_distribution_{dataset_name}_normalized.png")
|
||||
|
||||
def plot_combined_kde_distribution_actual(df, dataset_name, reference_df=None, reference_scores=None, vina_scores=None, reference_vina_scores=None):
|
||||
"""
|
||||
Plot combined KDE distribution for QED, molecular weight, and Vina scores (actual values)
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): Main dataset
|
||||
dataset_name (str): Name of the dataset (fgbar or trpe)
|
||||
reference_df (pd.DataFrame): Reference molecules dataset (optional)
|
||||
reference_scores (dict): Reference molecule scores (optional)
|
||||
vina_scores (list): Vina scores for all molecules (optional)
|
||||
reference_vina_scores (dict): Reference molecule Vina scores (optional)
|
||||
"""
|
||||
# Create figure with subplots
|
||||
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
|
||||
fig.suptitle(f'KDE Distribution (Actual Values) - {dataset_name.upper()}', fontsize=16)
|
||||
|
||||
# Get reference molecule identifier (stem of the SDF filename)
|
||||
reference_filename_stem = None
|
||||
if reference_df is not None and len(reference_df) > 0:
|
||||
reference_filename_stem = Path(reference_df.iloc[0]['filename']).stem
|
||||
if '_out_converted' in reference_filename_stem:
|
||||
reference_filename_stem = reference_filename_stem.replace('_out_converted', '')
|
||||
if '_addH' in reference_filename_stem:
|
||||
reference_filename_stem = reference_filename_stem.replace('_addH', '')
|
||||
if 'align_' in reference_filename_stem:
|
||||
reference_filename_stem = reference_filename_stem.split('_')[-1] # Get the last part (e.g., 9NY or 0GA)
|
||||
|
||||
# Plot 1: QED distribution
|
||||
sns.kdeplot(df['qed'], ax=axes[0], fill=True, alpha=0.5, color='blue')
|
||||
axes[0].set_title('QED Distribution')
|
||||
axes[0].set_xlabel('QED Value')
|
||||
axes[0].set_ylabel('Density')
|
||||
axes[0].grid(True, alpha=0.3)
|
||||
|
||||
# Mark reference molecules for QED
|
||||
if reference_df is not None and len(reference_df) > 0:
|
||||
for i, (idx, row) in enumerate(reference_df.iterrows()):
|
||||
filename_stem = Path(row['filename']).stem
|
||||
# Extract actual identifier from filename
|
||||
if '_addH' in filename_stem:
|
||||
filename_stem = filename_stem.replace('_addH', '')
|
||||
if 'align_' in filename_stem:
|
||||
filename_stem = filename_stem.split('_')[-1] # Get the last part (e.g., 9NY or 0GA)
|
||||
|
||||
# Get QED value from the reference dataframe
|
||||
qed_value = row['qed']
|
||||
score_text = f"{reference_filename_stem}\n({qed_value:.2f})"
|
||||
|
||||
axes[0].scatter(row['qed'], 0, color='darkblue', s=100, marker='v', zorder=5)
|
||||
axes[0].annotate(score_text,
|
||||
(row['qed'], 0),
|
||||
xytext=(10, 20),
|
||||
textcoords='offset points',
|
||||
bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7),
|
||||
arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'),
|
||||
fontsize=8)
|
||||
|
||||
# Plot 2: Molecular weight distribution
|
||||
sns.kdeplot(df['molecular_weight'], ax=axes[1], fill=True, alpha=0.5, color='red')
|
||||
axes[1].set_title('Molecular Weight Distribution')
|
||||
axes[1].set_xlabel('Molecular Weight (Daltons)')
|
||||
axes[1].set_ylabel('Density')
|
||||
axes[1].grid(True, alpha=0.3)
|
||||
|
||||
# Mark reference molecules for molecular weight
|
||||
if reference_df is not None and len(reference_df) > 0:
|
||||
for i, (idx, row) in enumerate(reference_df.iterrows()):
|
||||
filename_stem = Path(row['filename']).stem
|
||||
# Extract actual identifier from filename
|
||||
if '_addH' in filename_stem:
|
||||
filename_stem = filename_stem.replace('_addH', '')
|
||||
if 'align_' in filename_stem:
|
||||
filename_stem = filename_stem.split('_')[-1] # Get the last part (e.g., 9NY or 0GA)
|
||||
|
||||
# Get MW value from the reference dataframe
|
||||
mw_value = row['molecular_weight']
|
||||
score_text = f"{reference_filename_stem}\n({mw_value:.2f})"
|
||||
|
||||
axes[1].scatter(row['molecular_weight'], 0, color='darkred', s=100, marker='^', zorder=5)
|
||||
axes[1].annotate(score_text,
|
||||
(row['molecular_weight'], 0),
|
||||
xytext=(10, -30),
|
||||
textcoords='offset points',
|
||||
bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7),
|
||||
arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'),
|
||||
fontsize=8)
|
||||
|
||||
# Plot 3: Vina scores distribution
|
||||
if vina_scores and len(vina_scores) > 0:
|
||||
vina_series = pd.Series(vina_scores)
|
||||
sns.kdeplot(vina_series, ax=axes[2], fill=True, alpha=0.5, color='green')
|
||||
axes[2].set_title('Vina Score Distribution')
|
||||
axes[2].set_xlabel('Vina Score (kcal/mol)')
|
||||
axes[2].set_ylabel('Density')
|
||||
axes[2].grid(True, alpha=0.3)
|
||||
|
||||
# Mark reference molecules for Vina scores
|
||||
if reference_vina_scores:
|
||||
for filename_stem, vina_score in reference_vina_scores.items():
|
||||
score_text = f"{reference_filename_stem}\n({vina_score:.2f})"
|
||||
|
||||
axes[2].scatter(vina_score, 0, color='darkgreen', s=100, marker='o', zorder=5)
|
||||
axes[2].annotate(score_text,
|
||||
(vina_score, 0),
|
||||
xytext=(10, -60),
|
||||
textcoords='offset points',
|
||||
bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7),
|
||||
arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'),
|
||||
fontsize=8)
|
||||
|
||||
# Adjust layout and save figure
|
||||
plt.tight_layout()
|
||||
plt.savefig(f'kde_distribution_{dataset_name}_actual.png', dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
logger.info(f"Saved combined KDE distribution plot (actual values) for {dataset_name} as kde_distribution_{dataset_name}_actual.png")
|
||||
|
||||
def analyze_dataset(csv_file, dataset_name, reference_scores=None, rank=0):
|
||||
"""
|
||||
Analyze a dataset and generate KDE plots
|
||||
|
||||
Args:
|
||||
csv_file (str): Path to the CSV file
|
||||
dataset_name (str): Name of the dataset (fgbar or trpe)
|
||||
reference_scores (dict): Reference scores for each dataset
|
||||
rank (int): Rank of the conformation to use for reference Vina scores (0 for best/first)
|
||||
"""
|
||||
# Load main dataset
|
||||
df = load_dataset(csv_file)
|
||||
|
||||
# Check minimum vina scores length
|
||||
min_vina_length = get_min_vina_scores_length(df)
|
||||
if rank >= min_vina_length:
|
||||
raise ValueError(f"Rank {rank} is out of range. The minimum number of conformers across all molecules is {min_vina_length}. Please choose a rank less than {min_vina_length}.")
|
||||
|
||||
# Load reference molecules
|
||||
reference_df = load_reference_molecules(dataset_name)
|
||||
|
||||
# Load Vina scores from CSV
|
||||
vina_scores = load_vina_scores_from_csv(df)
|
||||
|
||||
# Get reference Vina scores
|
||||
reference_vina_scores = get_reference_vina_scores(dataset_name, rank)
|
||||
|
||||
# Plot combined KDE distributions (normalized)
|
||||
plot_combined_kde_distribution_normalized(df, dataset_name, reference_df, reference_scores, vina_scores, reference_vina_scores)
|
||||
|
||||
# Plot combined KDE distributions (actual values)
|
||||
plot_combined_kde_distribution_actual(df, dataset_name, reference_df, reference_scores, vina_scores, reference_vina_scores)
|
||||
|
||||
def get_default_reference_scores():
|
||||
"""
|
||||
Get default reference scores from README.md
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with default reference scores
|
||||
"""
|
||||
# Default reference scores from README.md
|
||||
return {
|
||||
'fgbar': {
|
||||
'9NY': -5.268 # From README.md
|
||||
},
|
||||
'trpe': {
|
||||
'0GA': -6.531 # From README.md
|
||||
}
|
||||
}
|
||||
|
||||
@click.command()
|
||||
@click.argument('csv_files', type=click.Path(exists=True), nargs=-1)
|
||||
@click.option('--dataset-names', '-d', multiple=True, help='Names of the datasets corresponding to CSV files')
|
||||
@click.option('--reference-scores', '-r', type=str, help='Reference scores in JSON format')
|
||||
@click.option('--rank', '-k', default=0, type=int, help='Rank of conformation to use for reference Vina scores (default: 0 for best/first)')
|
||||
def main_cli(csv_files, dataset_names, reference_scores, rank):
|
||||
"""
|
||||
Analyze QED and molecular weight distributions and generate KDE plots
|
||||
|
||||
CSV_FILES: Paths to the CSV files with QED and molecular weight data
|
||||
"""
|
||||
if not csv_files:
|
||||
logger.error("At least one CSV file must be provided")
|
||||
return
|
||||
|
||||
# Convert dataset names to list
|
||||
dataset_names_list = list(dataset_names) if dataset_names else None
|
||||
|
||||
# Parse reference scores if provided
|
||||
if reference_scores:
|
||||
try:
|
||||
reference_scores_dict = json.loads(reference_scores)
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Invalid JSON format for reference scores")
|
||||
return
|
||||
else:
|
||||
reference_scores_dict = get_default_reference_scores()
|
||||
|
||||
# Run main analysis
|
||||
try:
|
||||
main_api(csv_files, dataset_names_list, reference_scores_dict, rank)
|
||||
except Exception as e:
|
||||
logger.error(f"Analysis failed: {e}")
|
||||
raise
|
||||
|
||||
def main_api(csv_files, dataset_names=None, reference_scores=None, rank=0):
|
||||
"""
|
||||
Main function for API usage
|
||||
|
||||
Args:
|
||||
csv_files (list): List of CSV files to analyze
|
||||
dataset_names (list): List of dataset names corresponding to CSV files
|
||||
reference_scores (dict): Reference scores for each dataset
|
||||
rank (int): Rank of the conformation to use for reference Vina scores (0 for best/first)
|
||||
"""
|
||||
if dataset_names is None:
|
||||
dataset_names = [Path(f).stem.replace('qed_values_', '') for f in csv_files]
|
||||
|
||||
if reference_scores is None:
|
||||
reference_scores = get_default_reference_scores()
|
||||
|
||||
for csv_file, dataset_name in zip(csv_files, dataset_names):
|
||||
try:
|
||||
logger.info(f"Analyzing dataset: {dataset_name}")
|
||||
analyze_dataset(csv_file, dataset_name, reference_scores.get(dataset_name, {}), rank)
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing {dataset_name}: {e}")
|
||||
raise
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_cli()
|
||||
151
scripts/analyze_results.py
Normal file
151
scripts/analyze_results.py
Normal file
@@ -0,0 +1,151 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
分析AutoDock Vina对接结果脚本
|
||||
用法: python analyze_results.py poses_directory output_csv
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import pandas as pd
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
def parse_vina_output(pdbqt_file):
|
||||
"""解析Vina输出文件,提取能量信息"""
|
||||
results = []
|
||||
|
||||
try:
|
||||
with open(pdbqt_file, 'r') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
ligand_name = Path(pdbqt_file).stem.replace('_out', '')
|
||||
model_num = 0
|
||||
|
||||
for line in lines:
|
||||
if line.startswith('REMARK VINA RESULT:'):
|
||||
# 解析能量信息
|
||||
# 格式: REMARK VINA RESULT: -8.5 0.000 0.000
|
||||
parts = line.strip().split()
|
||||
if len(parts) >= 4:
|
||||
binding_energy = float(parts[3])
|
||||
rmsd_lb = float(parts[4]) if len(parts) > 4 else 0.0
|
||||
rmsd_ub = float(parts[5]) if len(parts) > 5 else 0.0
|
||||
|
||||
model_num += 1
|
||||
results.append({
|
||||
'ligand_name': ligand_name,
|
||||
'model': model_num,
|
||||
'binding_energy': binding_energy,
|
||||
'rmsd_lb': rmsd_lb,
|
||||
'rmsd_ub': rmsd_ub,
|
||||
'file_path': pdbqt_file
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析文件失败 {pdbqt_file}: {e}")
|
||||
|
||||
return results
|
||||
|
||||
def analyze_poses_directory(poses_dir):
|
||||
"""分析整个poses目录"""
|
||||
all_results = []
|
||||
|
||||
poses_path = Path(poses_dir)
|
||||
if not poses_path.exists():
|
||||
print(f"错误: 目录不存在 {poses_dir}")
|
||||
return []
|
||||
|
||||
# 查找所有_out.pdbqt文件
|
||||
pdbqt_files = list(poses_path.glob("*_out.pdbqt"))
|
||||
|
||||
print(f"找到 {len(pdbqt_files)} 个结果文件")
|
||||
|
||||
for pdbqt_file in pdbqt_files:
|
||||
results = parse_vina_output(str(pdbqt_file))
|
||||
all_results.extend(results)
|
||||
|
||||
return all_results
|
||||
|
||||
def generate_summary_stats(df):
|
||||
"""生成汇总统计信息"""
|
||||
if df.empty:
|
||||
return {}
|
||||
|
||||
# 获取每个配体的最佳结合能
|
||||
best_poses = df.loc[df.groupby('ligand_name')['binding_energy'].idxmin()]
|
||||
|
||||
stats = {
|
||||
'total_ligands': df['ligand_name'].nunique(),
|
||||
'total_poses': len(df),
|
||||
'best_binding_energy': df['binding_energy'].min(),
|
||||
'worst_binding_energy': df['binding_energy'].max(),
|
||||
'mean_binding_energy': df['binding_energy'].mean(),
|
||||
'median_binding_energy': df['binding_energy'].median(),
|
||||
'std_binding_energy': df['binding_energy'].std(),
|
||||
'ligands_better_than_minus_7': (best_poses['binding_energy'] < -7.0).sum(),
|
||||
'ligands_better_than_minus_8': (best_poses['binding_energy'] < -8.0).sum(),
|
||||
'ligands_better_than_minus_9': (best_poses['binding_energy'] < -9.0).sum(),
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
def main():
|
||||
if len(sys.argv) != 3:
|
||||
print("用法: python analyze_results.py <poses目录> <输出CSV文件>")
|
||||
sys.exit(1)
|
||||
|
||||
poses_dir = sys.argv[1]
|
||||
output_csv = sys.argv[2]
|
||||
|
||||
print("开始分析对接结果...")
|
||||
|
||||
# 分析所有结果
|
||||
results = analyze_poses_directory(poses_dir)
|
||||
|
||||
if not results:
|
||||
print("没有找到有效的结果文件")
|
||||
sys.exit(1)
|
||||
|
||||
# 转换为DataFrame
|
||||
df = pd.DataFrame(results)
|
||||
|
||||
# 保存详细结果
|
||||
df.to_csv(output_csv, index=False)
|
||||
print(f"详细结果已保存到: {output_csv}")
|
||||
|
||||
# 生成汇总统计
|
||||
stats = generate_summary_stats(df)
|
||||
|
||||
# 打印汇总信息
|
||||
print("\n=== 对接结果汇总 ===")
|
||||
print(f"总配体数量: {stats['total_ligands']}")
|
||||
print(f"总构象数量: {stats['total_poses']}")
|
||||
print(f"最佳结合能: {stats['best_binding_energy']:.2f} kcal/mol")
|
||||
print(f"最差结合能: {stats['worst_binding_energy']:.2f} kcal/mol")
|
||||
print(f"平均结合能: {stats['mean_binding_energy']:.2f} kcal/mol")
|
||||
print(f"中位结合能: {stats['median_binding_energy']:.2f} kcal/mol")
|
||||
print(f"标准差: {stats['std_binding_energy']:.2f} kcal/mol")
|
||||
print(f"结合能 < -7.0 的配体: {stats['ligands_better_than_minus_7']}")
|
||||
print(f"结合能 < -8.0 的配体: {stats['ligands_better_than_minus_8']}")
|
||||
print(f"结合能 < -9.0 的配体: {stats['ligands_better_than_minus_9']}")
|
||||
|
||||
# 显示前10个最佳结果
|
||||
best_results = df.loc[df.groupby('ligand_name')['binding_energy'].idxmin()]
|
||||
best_results = best_results.sort_values('binding_energy').head(10)
|
||||
|
||||
print("\n=== 前10个最佳结果 ===")
|
||||
for _, row in best_results.iterrows():
|
||||
print(f"{row['ligand_name']}: {row['binding_energy']:.2f} kcal/mol")
|
||||
|
||||
# 保存汇总统计
|
||||
summary_file = output_csv.replace('.csv', '_summary.txt')
|
||||
with open(summary_file, 'w') as f:
|
||||
f.write("AutoDock Vina 对接结果汇总\n")
|
||||
f.write("=" * 30 + "\n\n")
|
||||
for key, value in stats.items():
|
||||
f.write(f"{key}: {value}\n")
|
||||
|
||||
print(f"\n汇总统计已保存到: {summary_file}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
228
scripts/calculate_qed_values.py
Normal file
228
scripts/calculate_qed_values.py
Normal file
@@ -0,0 +1,228 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
'''
|
||||
@file :calculate_qed_values.py
|
||||
@Description :Calculate QED values for molecules in poses_all directories and reference molecules
|
||||
@Date :2025/08/04
|
||||
@Author :lyzeng
|
||||
'''
|
||||
|
||||
import pandas as pd
|
||||
from rdkit import Chem
|
||||
from rdkit.Chem import QED
|
||||
from rdkit.Chem.Descriptors import MolWt
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import json
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_smiles_from_sdf(sdf_file_path):
|
||||
"""
|
||||
Extract SMILES from the first molecule in an SDF file
|
||||
|
||||
Args:
|
||||
sdf_file_path (Path): Path to the SDF file
|
||||
|
||||
Returns:
|
||||
str: SMILES representation of the first molecule or None if failed
|
||||
"""
|
||||
try:
|
||||
supplier = Chem.SDMolSupplier(str(sdf_file_path))
|
||||
mol = next(supplier) # Get the first molecule
|
||||
if mol is not None:
|
||||
smiles = Chem.MolToSmiles(mol)
|
||||
return smiles
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract SMILES from {sdf_file_path}: {e}")
|
||||
return None
|
||||
|
||||
def get_smiles_from_mol2(mol2_file_path):
|
||||
"""
|
||||
Extract SMILES from a mol2 file
|
||||
|
||||
Args:
|
||||
mol2_file_path (Path): Path to the mol2 file
|
||||
|
||||
Returns:
|
||||
str: SMILES representation of the molecule or None if failed
|
||||
"""
|
||||
try:
|
||||
mol = Chem.MolFromMol2File(str(mol2_file_path))
|
||||
if mol is not None:
|
||||
smiles = Chem.MolToSmiles(mol)
|
||||
return smiles
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract SMILES from {mol2_file_path}: {e}")
|
||||
return None
|
||||
|
||||
def extract_vina_scores_from_sdf(sdf_file_path):
|
||||
"""
|
||||
Extract Vina scores from all conformers in an SDF file
|
||||
|
||||
Args:
|
||||
sdf_file_path (Path): Path to the SDF file
|
||||
|
||||
Returns:
|
||||
list: List of Vina scores (free_energy values) or empty list if failed
|
||||
"""
|
||||
scores = []
|
||||
try:
|
||||
supplier = Chem.SDMolSupplier(str(sdf_file_path), removeHs=False)
|
||||
for mol in supplier:
|
||||
if mol is None:
|
||||
continue
|
||||
|
||||
# Get the meeko property which contains docking information
|
||||
if mol.HasProp("meeko"):
|
||||
meeko_raw = mol.GetProp("meeko")
|
||||
try:
|
||||
meeko_dict = json.loads(meeko_raw)
|
||||
# Extract free energy (Vina score)
|
||||
if 'free_energy' in meeko_dict:
|
||||
scores.append(meeko_dict['free_energy'])
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse meeko JSON for {sdf_file_path}")
|
||||
else:
|
||||
logger.warning(f"No meeko property found in molecule from {sdf_file_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract Vina scores from {sdf_file_path}: {e}")
|
||||
|
||||
return scores
|
||||
|
||||
def calculate_qed_for_poses_all(base_dir, dataset_name):
|
||||
"""
|
||||
Calculate QED values for all SDF files in poses_all directory
|
||||
|
||||
Args:
|
||||
base_dir (Path): Base directory containing the poses_all folder
|
||||
dataset_name (str): Name of the dataset (fgbar or trpe)
|
||||
|
||||
Returns:
|
||||
list: List of dictionaries with smiles, filename, qed, and molecular weight values
|
||||
"""
|
||||
results = []
|
||||
poses_all_dir = base_dir / dataset_name / "poses_all"
|
||||
|
||||
if not poses_all_dir.exists():
|
||||
logger.warning(f"Directory {poses_all_dir} does not exist")
|
||||
return results
|
||||
|
||||
sdf_files = list(poses_all_dir.glob("*.sdf"))
|
||||
logger.info(f"Processing {len(sdf_files)} SDF files in {poses_all_dir}")
|
||||
|
||||
for sdf_file in sdf_files:
|
||||
smiles = get_smiles_from_sdf(sdf_file)
|
||||
vina_scores = extract_vina_scores_from_sdf(sdf_file)
|
||||
|
||||
if smiles is not None:
|
||||
try:
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
if mol is not None:
|
||||
qed_value = QED.qed(mol)
|
||||
mol_weight = MolWt(mol)
|
||||
results.append({
|
||||
'smiles': smiles,
|
||||
'filename': sdf_file.name,
|
||||
'qed': qed_value,
|
||||
'molecular_weight': mol_weight,
|
||||
'vina_scores': vina_scores # 添加Vina得分列表
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to calculate QED for {sdf_file}: {e}")
|
||||
|
||||
return results
|
||||
|
||||
def calculate_qed_for_reference(base_dir, dataset_name):
|
||||
"""
|
||||
Calculate QED values for reference molecules in SDF format
|
||||
|
||||
Args:
|
||||
base_dir (Path): Base directory containing the reference folder
|
||||
dataset_name (str): Name of the dataset (fgbar or trpe)
|
||||
|
||||
Returns:
|
||||
list: List of dictionaries with smiles, filename, qed, molecular weight, and vina scores values
|
||||
"""
|
||||
results = []
|
||||
# 使用原始目录名称 "refence"
|
||||
reference_dir = base_dir / "refence" / dataset_name
|
||||
|
||||
if not reference_dir.exists():
|
||||
logger.warning(f"Directory {reference_dir} does not exist")
|
||||
return results
|
||||
|
||||
# 查找参考分子的SDF文件
|
||||
reference_sdf_files = list(reference_dir.glob("*_out_converted.sdf"))
|
||||
logger.info(f"Processing {len(reference_sdf_files)} reference SDF files in {reference_dir}")
|
||||
|
||||
for sdf_file in reference_sdf_files:
|
||||
smiles = get_smiles_from_sdf(sdf_file)
|
||||
vina_scores = extract_vina_scores_from_sdf(sdf_file)
|
||||
|
||||
if smiles is not None:
|
||||
try:
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
if mol is not None:
|
||||
qed_value = QED.qed(mol)
|
||||
mol_weight = MolWt(mol)
|
||||
results.append({
|
||||
'smiles': smiles,
|
||||
'filename': sdf_file.name,
|
||||
'qed': qed_value,
|
||||
'molecular_weight': mol_weight,
|
||||
'vina_scores': str(vina_scores) # 添加Vina得分列表
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to calculate QED for {sdf_file}: {e}")
|
||||
|
||||
return results
|
||||
|
||||
def process_dataset(result_dir, dataset_name):
|
||||
"""
|
||||
Process a single dataset (fgbar or trpe) and save to a separate CSV file
|
||||
|
||||
Args:
|
||||
result_dir (Path): Base result directory
|
||||
dataset_name (str): Name of the dataset (fgbar or trpe)
|
||||
"""
|
||||
# Process poses_all SDF files
|
||||
poses_results = calculate_qed_for_poses_all(result_dir, dataset_name)
|
||||
|
||||
# Process reference SDF files
|
||||
reference_results = calculate_qed_for_reference(result_dir, dataset_name)
|
||||
|
||||
# Combine results
|
||||
all_results = poses_results + reference_results
|
||||
|
||||
# Create DataFrame and save to CSV
|
||||
if all_results:
|
||||
df = pd.DataFrame(all_results)
|
||||
csv_filename = f"qed_values_{dataset_name}.csv"
|
||||
df.to_csv(csv_filename, index=False)
|
||||
logger.info(f"Saved {len(df)} QED values to {csv_filename}")
|
||||
print(f"First few rows of {csv_filename}:")
|
||||
print(df.head())
|
||||
else:
|
||||
logger.warning(f"No QED values were calculated for {dataset_name}")
|
||||
# Create empty CSV with headers
|
||||
df = pd.DataFrame(columns=['smiles', 'filename', 'qed', 'molecular_weight', 'vina_scores'])
|
||||
csv_filename = f"qed_values_{dataset_name}.csv"
|
||||
df.to_csv(csv_filename, index=False)
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main function to calculate QED values for all molecules
|
||||
"""
|
||||
# Define base directories
|
||||
result_dir = Path("result")
|
||||
|
||||
# Process both datasets (fgbar and trpe) separately
|
||||
datasets = ["fgbar", "trpe"]
|
||||
for dataset in datasets:
|
||||
process_dataset(result_dir, dataset)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
30
scripts/example_api_usage.py
Normal file
30
scripts/example_api_usage.py
Normal file
@@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Example usage of the analyze_qed_mw_distribution API
|
||||
"""
|
||||
|
||||
from analyze_qed_mw_distribution import main_api
|
||||
|
||||
print("Running analysis examples...")
|
||||
|
||||
# Example 1: Basic usage
|
||||
print("\nExample 1: Basic usage")
|
||||
main_api(['qed_values_fgbar.csv', 'qed_values_trpe.csv'], ['fgbar', 'trpe'])
|
||||
|
||||
# Example 2: With custom reference scores
|
||||
print("\nExample 2: With custom reference scores")
|
||||
main_api(['qed_values_fgbar.csv', 'qed_values_trpe.csv'], ['fgbar', 'trpe'],
|
||||
reference_scores={'fgbar': {'9NY': -5.268}, 'trpe': {'0GA': -6.531}})
|
||||
|
||||
# Example 3: With specific conformation rank
|
||||
print("\nExample 3: With specific conformation rank")
|
||||
main_api(['qed_values_fgbar.csv', 'qed_values_trpe.csv'], ['fgbar', 'trpe'], rank=0)
|
||||
|
||||
# Example 4: With both custom reference scores and specific conformation rank
|
||||
print("\nExample 4: With both custom reference scores and specific conformation rank")
|
||||
main_api(['qed_values_fgbar.csv', 'qed_values_trpe.csv'], ['fgbar', 'trpe'],
|
||||
reference_scores={'fgbar': {'9NY': -5.268}, 'trpe': {'0GA': -6.531}}, rank=0)
|
||||
|
||||
print("\nAnalysis complete! Check the generated PNG files.")
|
||||
201
scripts/qed.py
Normal file
201
scripts/qed.py
Normal file
@@ -0,0 +1,201 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
'''
|
||||
@file :qed_calculator.py
|
||||
@Description :QED calculator with joblib parallel support
|
||||
@Date :2025/08/04
|
||||
@Author :lyzeng
|
||||
'''
|
||||
|
||||
from rdkit import Chem
|
||||
from rdkit.Chem import QED
|
||||
import pandas as pd
|
||||
from typing import List, Union, Tuple, Optional
|
||||
import joblib
|
||||
from joblib import Parallel, delayed
|
||||
import logging
|
||||
|
||||
# 设置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def calculate_single_qed(smiles: str) -> Optional[Tuple[str, float]]:
|
||||
"""
|
||||
Calculate QED value for a single molecule.
|
||||
|
||||
Args:
|
||||
smiles (str): SMILES representation of the molecule
|
||||
|
||||
Returns:
|
||||
tuple: (smiles, qed_value) or None if calculation fails
|
||||
"""
|
||||
try:
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
if mol is not None:
|
||||
qed_value = QED.qed(mol)
|
||||
return (smiles, qed_value)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to calculate QED for {smiles}: {e}")
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def parallel_qed_calculation(
|
||||
smiles_list: List[str],
|
||||
n_jobs: int = -1,
|
||||
batch_size: Union[int, str] = "auto",
|
||||
backend: str = "loky"
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Calculate QED values for a list of SMILES in parallel using joblib.
|
||||
|
||||
Args:
|
||||
smiles_list (List[str]): List of SMILES strings
|
||||
n_jobs (int): Number of parallel jobs. -1 means using all processors
|
||||
batch_size (int or str): Batch size for parallel processing
|
||||
backend (str): Joblib backend to use ('loky', 'threading', 'multiprocessing')
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: DataFrame with 'smiles' and 'qed' columns
|
||||
"""
|
||||
logger.info(f"Calculating QED values for {len(smiles_list)} molecules...")
|
||||
|
||||
# 并行计算QED值
|
||||
results = Parallel(
|
||||
n_jobs=n_jobs,
|
||||
batch_size=batch_size,
|
||||
backend=backend
|
||||
)(delayed(calculate_single_qed)(smiles) for smiles in smiles_list)
|
||||
|
||||
# 过滤掉None结果
|
||||
valid_results = [r for r in results if r is not None]
|
||||
|
||||
if not valid_results:
|
||||
logger.warning("No valid QED values calculated")
|
||||
return pd.DataFrame(columns=['smiles', 'qed'])
|
||||
|
||||
# 分离SMILES和QED值
|
||||
smiles_values, qed_values = zip(*valid_results)
|
||||
|
||||
# 创建DataFrame
|
||||
df = pd.DataFrame({
|
||||
'smiles': smiles_values,
|
||||
'qed': qed_values
|
||||
})
|
||||
|
||||
logger.info(f"Successfully calculated QED values for {len(df)} molecules")
|
||||
return df
|
||||
|
||||
|
||||
def calculate_qed_series(
|
||||
smiles_series: Union[List[str], pd.Series],
|
||||
n_jobs: int = -1,
|
||||
batch_size: Union[int, str] = "auto",
|
||||
backend: str = "loky"
|
||||
) -> pd.Series:
|
||||
"""
|
||||
Calculate QED values for a pandas Series or list of SMILES and return as Series.
|
||||
|
||||
Args:
|
||||
smiles_series: Series or list of SMILES strings
|
||||
n_jobs (int): Number of parallel jobs
|
||||
batch_size (int or str): Batch size for parallel processing
|
||||
backend (str): Joblib backend to use
|
||||
|
||||
Returns:
|
||||
pd.Series: Series of QED values with the same index as input (if Series)
|
||||
"""
|
||||
if isinstance(smiles_series, pd.Series):
|
||||
smiles_list = smiles_series.tolist()
|
||||
original_index = smiles_series.index
|
||||
else:
|
||||
smiles_list = smiles_series
|
||||
original_index = None
|
||||
|
||||
# 计算QED值
|
||||
results = Parallel(
|
||||
n_jobs=n_jobs,
|
||||
batch_size=batch_size,
|
||||
backend=backend
|
||||
)(delayed(calculate_single_qed)(smiles) for smiles in smiles_list)
|
||||
|
||||
# 提取QED值(失败的计算返回None)
|
||||
qed_values = [r[1] if r is not None else None for r in results]
|
||||
|
||||
# 创建Series
|
||||
if original_index is not None:
|
||||
return pd.Series(qed_values, index=original_index, name='qed')
|
||||
else:
|
||||
return pd.Series(qed_values, name='qed')
|
||||
|
||||
|
||||
class QEDCalculator:
|
||||
"""
|
||||
A class for calculating QED values with support for parallel processing and caching.
|
||||
"""
|
||||
|
||||
def __init__(self, n_jobs: int = -1, batch_size: Union[int, str] = "auto"):
|
||||
"""
|
||||
Initialize the QEDCalculator.
|
||||
|
||||
Args:
|
||||
n_jobs (int): Number of parallel jobs. -1 means using all processors
|
||||
batch_size (int or str): Batch size for parallel processing
|
||||
"""
|
||||
self.n_jobs = n_jobs
|
||||
self.batch_size = batch_size
|
||||
self.backend = "loky"
|
||||
|
||||
def calculate(self, smiles_list: List[str]) -> pd.DataFrame:
|
||||
"""
|
||||
Calculate QED values for a list of SMILES.
|
||||
|
||||
Args:
|
||||
smiles_list (List[str]): List of SMILES strings
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: DataFrame with 'smiles' and 'qed' columns
|
||||
"""
|
||||
return parallel_qed_calculation(
|
||||
smiles_list,
|
||||
n_jobs=self.n_jobs,
|
||||
batch_size=self.batch_size,
|
||||
backend=self.backend
|
||||
)
|
||||
|
||||
def calculate_series(self, smiles_series: Union[List[str], pd.Series]) -> pd.Series:
|
||||
"""
|
||||
Calculate QED values for a pandas Series or list of SMILES and return as Series.
|
||||
|
||||
Args:
|
||||
smiles_series: Series or list of SMILES strings
|
||||
|
||||
Returns:
|
||||
pd.Series: Series of QED values
|
||||
"""
|
||||
return calculate_qed_series(
|
||||
smiles_series,
|
||||
n_jobs=self.n_jobs,
|
||||
batch_size=self.batch_size,
|
||||
backend=self.backend
|
||||
)
|
||||
|
||||
"""
|
||||
usage
|
||||
|
||||
from utils.qed_calculator import parallel_qed_calculation, QEDCalculator
|
||||
|
||||
# 方式1:直接使用函数
|
||||
smiles_list = ['CCO', 'CCN', 'CCC']
|
||||
qed_df = parallel_qed_calculation(smiles_list, n_jobs=-1)
|
||||
|
||||
# 方式2:使用类
|
||||
calculator = QEDCalculator(n_jobs=-1)
|
||||
qed_df = calculator.calculate(smiles_list)
|
||||
|
||||
# 方式3:处理pandas Series
|
||||
import pandas as pd
|
||||
smiles_series = pd.Series(['CCO', 'CCN', 'CCC'], name='smiles')
|
||||
qed_series = calculator.calculate_series(smiles_series)
|
||||
"""
|
||||
Reference in New Issue
Block a user