过滤结果保存
This commit is contained in:
85
scripts/filter_data.py
Normal file
85
scripts/filter_data.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import pandas as pd
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
def validate_vina_scores(df):
|
||||
invalid_files = []
|
||||
min_length = float('inf')
|
||||
for idx, row in df.iterrows():
|
||||
scores = row['vina_scores']
|
||||
if isinstance(scores, str):
|
||||
try:
|
||||
scores = json.loads(scores.replace("'", '"'))
|
||||
except json.JSONDecodeError:
|
||||
scores = []
|
||||
elif isinstance(scores, list):
|
||||
pass
|
||||
else:
|
||||
scores = []
|
||||
|
||||
length = len(scores) if isinstance(scores, list) else 0
|
||||
if length < 20:
|
||||
invalid_files.append(row['filename'])
|
||||
if length < min_length and length > 0:
|
||||
min_length = length
|
||||
return invalid_files, min_length
|
||||
|
||||
def main():
|
||||
# 创建结果目录
|
||||
result_dir = Path("/Users/lingyuzeng/Downloads/211.69.141.180/202508021824/vina/result/filtered_results")
|
||||
result_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for dataset in ["trpe", "fgbar"]:
|
||||
# 读取数据
|
||||
input_path = f"/Users/lingyuzeng/Downloads/211.69.141.180/202508021824/vina/scripts/finally_data/qed_values_poses_{dataset}_all.csv"
|
||||
df = pd.read_csv(input_path)
|
||||
|
||||
# 转换vina_scores列
|
||||
df['vina_scores'] = df['vina_scores'].apply(
|
||||
lambda x: json.loads(x.replace("'", '"')) if isinstance(x, str) else x
|
||||
)
|
||||
|
||||
# 验证vina_scores
|
||||
invalid_files, min_length = validate_vina_scores(df)
|
||||
if invalid_files:
|
||||
print(f"\n发现 {len(invalid_files)} 个文件的vina_scores少于20个构象 ({dataset}):")
|
||||
for fname in invalid_files[:5]:
|
||||
print(f"- {fname}")
|
||||
if len(invalid_files) > 5:
|
||||
print(f"...共 {len(invalid_files)} 个文件")
|
||||
print(f"所有分子的最小构象数: {min_length}")
|
||||
|
||||
# 应用过滤条件
|
||||
if dataset == "trpe":
|
||||
# TRPE过滤条件: MW < 800, Vina < -6.5
|
||||
df_filtered = df[
|
||||
(df['molecular_weight'] < 800) &
|
||||
(df['vina_scores'].apply(lambda x: x[0] < -6.5 if isinstance(x, list) and len(x) > 0 else False))
|
||||
]
|
||||
vina_threshold = -6.5
|
||||
else:
|
||||
# FGBAR过滤条件: QED > 0.5
|
||||
df_filtered = df[
|
||||
(df['qed'] > 0.5) &
|
||||
(df['vina_scores'].apply(lambda x: x[0] < -5.2 if isinstance(x, list) and len(x) > 0 else False))
|
||||
]
|
||||
vina_threshold = -5.2
|
||||
|
||||
# 生成统计信息
|
||||
print(f"\n{dataset.upper()} 数据统计:")
|
||||
print(f"原始数据总数: {len(df)}")
|
||||
print(f"仅QED过滤后数据总数: {len(df[df['qed'] > 0.5])}")
|
||||
print(f"仅Vina得分过滤后数据总数: {len(df[df['vina_scores'].apply(lambda x: x[0] < vina_threshold if isinstance(x, list) and len(x) > 0 else False)])}")
|
||||
print(f"同时满足QED和Vina得分条件的数据总数: {len(df_filtered)}")
|
||||
|
||||
# 保存过滤结果
|
||||
df_filtered.to_csv(result_dir / f"qed_values_{dataset}_combined_filtered.csv", index=False)
|
||||
|
||||
# 生成前100个分子
|
||||
df_top100 = df_filtered.sort_values('qed', ascending=False).head(100)
|
||||
df_top100.to_csv(result_dir / f"qed_values_{dataset}_top100.csv", index=False)
|
||||
print(f"前100个分子数据已保存")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user