过滤结果保存

This commit is contained in:
2025-08-12 18:45:33 +08:00
parent df72a0b9c0
commit e58f90cd1e
6 changed files with 25790 additions and 151 deletions

85
scripts/filter_data.py Normal file
View File

@@ -0,0 +1,85 @@
import pandas as pd
import os
import json
from pathlib import Path
def validate_vina_scores(df):
invalid_files = []
min_length = float('inf')
for idx, row in df.iterrows():
scores = row['vina_scores']
if isinstance(scores, str):
try:
scores = json.loads(scores.replace("'", '"'))
except json.JSONDecodeError:
scores = []
elif isinstance(scores, list):
pass
else:
scores = []
length = len(scores) if isinstance(scores, list) else 0
if length < 20:
invalid_files.append(row['filename'])
if length < min_length and length > 0:
min_length = length
return invalid_files, min_length
def main():
# 创建结果目录
result_dir = Path("/Users/lingyuzeng/Downloads/211.69.141.180/202508021824/vina/result/filtered_results")
result_dir.mkdir(parents=True, exist_ok=True)
for dataset in ["trpe", "fgbar"]:
# 读取数据
input_path = f"/Users/lingyuzeng/Downloads/211.69.141.180/202508021824/vina/scripts/finally_data/qed_values_poses_{dataset}_all.csv"
df = pd.read_csv(input_path)
# 转换vina_scores列
df['vina_scores'] = df['vina_scores'].apply(
lambda x: json.loads(x.replace("'", '"')) if isinstance(x, str) else x
)
# 验证vina_scores
invalid_files, min_length = validate_vina_scores(df)
if invalid_files:
print(f"\n发现 {len(invalid_files)} 个文件的vina_scores少于20个构象 ({dataset}):")
for fname in invalid_files[:5]:
print(f"- {fname}")
if len(invalid_files) > 5:
print(f"...共 {len(invalid_files)} 个文件")
print(f"所有分子的最小构象数: {min_length}")
# 应用过滤条件
if dataset == "trpe":
# TRPE过滤条件: MW < 800, Vina < -6.5
df_filtered = df[
(df['molecular_weight'] < 800) &
(df['vina_scores'].apply(lambda x: x[0] < -6.5 if isinstance(x, list) and len(x) > 0 else False))
]
vina_threshold = -6.5
else:
# FGBAR过滤条件: QED > 0.5
df_filtered = df[
(df['qed'] > 0.5) &
(df['vina_scores'].apply(lambda x: x[0] < -5.2 if isinstance(x, list) and len(x) > 0 else False))
]
vina_threshold = -5.2
# 生成统计信息
print(f"\n{dataset.upper()} 数据统计:")
print(f"原始数据总数: {len(df)}")
print(f"仅QED过滤后数据总数: {len(df[df['qed'] > 0.5])}")
print(f"仅Vina得分过滤后数据总数: {len(df[df['vina_scores'].apply(lambda x: x[0] < vina_threshold if isinstance(x, list) and len(x) > 0 else False)])}")
print(f"同时满足QED和Vina得分条件的数据总数: {len(df_filtered)}")
# 保存过滤结果
df_filtered.to_csv(result_dir / f"qed_values_{dataset}_combined_filtered.csv", index=False)
# 生成前100个分子
df_top100 = df_filtered.sort_values('qed', ascending=False).head(100)
df_top100.to_csv(result_dir / f"qed_values_{dataset}_top100.csv", index=False)
print(f"前100个分子数据已保存")
if __name__ == "__main__":
main()