Files
vina_docking_batch/scripts/extract_top_molecules.py

64 lines
2.5 KiB
Python

import pandas as pd
import os
import ast
import argparse
def parse_vina_scores(vina_scores_str):
"""解析vina_scores字符串为浮点数列表"""
try:
scores = ast.literal_eval(vina_scores_str)
if isinstance(scores, list) and len(scores) > 0:
return scores[0] # 取第一个值作为vina_score
return None
except:
return None
def extract_top_molecules(file_path, output_dir, dataset_name):
"""从CSV文件中提取karma_score_aligned和vina_score前1000的分子"""
# 读取数据
df = pd.read_csv(file_path)
# 解析vina_scores列
df['vina_score'] = df['vina_scores'].apply(parse_vina_scores)
# 按karma_score_aligned排序并提取前1000
df_karma_top = df.sort_values('karma_score_aligned', ascending=False).head(1000)
# 按vina_score排序并提取前1000
df_vina_top = df.sort_values('vina_score', ascending=False).head(1000)
# 保存结果
karma_output_file = os.path.join(output_dir, f"{dataset_name}_karma_score_aligned_top1000.csv")
vina_output_file = os.path.join(output_dir, f"{dataset_name}_vina_score_top1000.csv")
df_karma_top.to_csv(karma_output_file, index=False)
df_vina_top.to_csv(vina_output_file, index=False)
print(f"{dataset_name} - karma_score_aligned前1000分子保存到: {karma_output_file}")
print(f"{dataset_name} - vina_score前1000分子保存到: {vina_output_file}")
print(f"{dataset_name} - karma_score_aligned前1000分子数量: {len(df_karma_top)}")
print(f"{dataset_name} - vina_score前1000分子数量: {len(df_vina_top)}")
return df_karma_top, df_vina_top
def main():
parser = argparse.ArgumentParser(description='从CSV文件中提取karma_score_aligned和vina_score前1000的分子')
parser.add_argument('--input', nargs='+', required=True,
help='输入CSV文件路径列表')
parser.add_argument('--dataset-names', nargs='+', required=True,
help='数据集名称列表,与输入文件一一对应')
parser.add_argument('--output', required=True,
help='输出目录')
args = parser.parse_args()
# 确保输出目录存在
os.makedirs(args.output, exist_ok=True)
# 处理每个文件
for file_path, dataset_name in zip(args.input, args.dataset_names):
print(f"Processing {dataset_name}...")
extract_top_molecules(file_path, args.output, dataset_name)
if __name__ == "__main__":
main()