代码移动位置

This commit is contained in:
2025-08-05 20:37:33 +08:00
parent f6c182f38e
commit aef322a86d
4 changed files with 60 additions and 22 deletions

View File

@@ -17,11 +17,6 @@ import logging
import ast
import json
import click
import sys
import os
# Add the parent directory to the path to import modules
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Setup logging
logging.basicConfig(level=logging.INFO)
@@ -125,8 +120,8 @@ def load_vina_scores_from_csv(df, max_files=1000):
if processed_files >= max_files:
break
# Skip reference molecules (those with align_ and _out_converted.sdf in their filename)
if 'align_' in row['filename'] and '_out_converted.sdf' in row['filename']:
# Skip reference molecules (those with mol2 extension)
if '.mol2' in row['filename']:
continue
try:
@@ -179,8 +174,8 @@ def get_reference_vina_scores(dataset_name, rank=0):
"""
reference_scores = {}
# 使用更新后的路径以适应新目录结构
reference_dir = Path("../result") / "refence" / dataset_name
# 使用原始目录名称 "refence"
reference_dir = Path("result") / "refence" / dataset_name
if not reference_dir.exists():
logger.warning(f"Reference directory {reference_dir} does not exist")
@@ -207,7 +202,7 @@ def get_reference_vina_scores(dataset_name, rank=0):
if '_addH' in filename_stem:
filename_stem = filename_stem.replace('_addH', '')
if 'align_' in filename_stem:
filename_stem = filename_stem.split('_')[-1]
filename_stem = filename_stem.split('_')[-1] # Get the last part (e.g., 9NY or 0GA)
# Use filename_stem as key for reference_scores
reference_scores[filename_stem] = reference_score

View File

@@ -14,11 +14,6 @@ from rdkit.Chem.Descriptors import MolWt
from pathlib import Path
import logging
import json
import sys
import os
# Add the parent directory to the path to import modules
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Setup logging
logging.basicConfig(level=logging.INFO)
@@ -133,7 +128,7 @@ def calculate_qed_for_poses_all(base_dir, dataset_name):
'filename': sdf_file.name,
'qed': qed_value,
'molecular_weight': mol_weight,
'vina_scores': str(vina_scores) # 添加Vina得分列表
'vina_scores': vina_scores # 添加Vina得分列表
})
except Exception as e:
logger.warning(f"Failed to calculate QED for {sdf_file}: {e}")
@@ -222,7 +217,7 @@ def main():
Main function to calculate QED values for all molecules
"""
# Define base directories
result_dir = Path("../result")
result_dir = Path("result")
# Process both datasets (fgbar and trpe) separately
datasets = ["fgbar", "trpe"]

View File

@@ -5,11 +5,6 @@
Example usage of the analyze_qed_mw_distribution API
"""
import sys
import os
# Add the scripts directory to the path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from analyze_qed_mw_distribution import main_api
print("Running analysis examples...")

View File

@@ -0,0 +1,53 @@
#!/usr/bin/env python3
"""
过滤 TRPE 分子数据脚本
根据分子量 < 800 和 Vina 得分 < 6.5 进行过滤,并按得分升序排列
然后根据 QED 值排名选择前 100 个分子
"""
import pandas as pd
import ast
import sys
import os
def filter_trpe_data(input_file, output_file, top_n=100):
"""
过滤 TRPE 数据
:param input_file: 输入 CSV 文件路径
:param output_file: 输出 CSV 文件路径
:param top_n: 选取前 N 个分子(按 QED 排名)
"""
# 读取数据
df = pd.read_csv(input_file)
# 解析 vina_scores 字符串为列表
df['vina_scores'] = df['vina_scores'].apply(ast.literal_eval)
# 获取每个分子的最小 Vina 得分(最负值)
df['min_vina_score'] = df['vina_scores'].apply(min)
# 应用过滤条件:分子量 < 800 且最小 Vina 得分 < -6.5
# 注意Vina 得分为负值,所以小于 -6.5 实际上是更好的结合能
filtered_df = df[(df['molecular_weight'] < 800) & (df['min_vina_score'] < -6.5)]
# 按照 QED 值降序排列并选择前 top_n 个分子
top_qed_df = filtered_df.sort_values('qed', ascending=False).head(top_n)
# 再按照最小 Vina 得分升序排列
final_df = top_qed_df.sort_values('min_vina_score', ascending=True)
# 保存结果到新的 CSV 文件
final_df.to_csv(output_file, index=False)
print(f"过滤完成:")
print(f" 原始数据: {len(df)} 条记录")
print(f" 分子量<800且Vina得分<-6.5: {len(filtered_df)} 条记录")
print(f" 按QED排名前{top_n}并按Vina得分排序: {len(final_df)} 条记录")
print(f" 输出文件: {output_file}")
if __name__ == "__main__":
# 设置输入和输出文件路径
input_csv = os.path.join(os.path.dirname(__file__), "qed_values_trpe.csv")
output_csv = os.path.join(os.path.dirname(__file__), "filtered_qed_trpe_top100.csv")
filter_trpe_data(input_csv, output_csv, top_n=100)