Files
vina_docking_batch/scripts/cluster_best_vina.py
2025-08-15 18:14:30 +08:00

161 lines
6.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
聚类并从每个簇选择 vina_scores 最优分子(支持 Butina 或 Hierarchical[scipy linkage]
- 正确解析 vina_scores字符串列表 -> list[float]),默认取最小值(更负更好)
- 稳健处理无效 SMILES只在有效子集上聚类与分组最后导出原始行
- 可选择方法:--method butina / hierarchical / both
- hierarchical 采用 scipy.cluster.hierarchy.linkage(method='average') + fcluster 按距离阈值切分
示例:
python scripts/cluster_best_vina.py \
--csv result/filtered_results/qed_values_trpe_combined_filtered.csv \
--smiles-col smiles \
--method both \
--cutoff 0.6 \
--radius 3 --n-bits 1024 \
--out result/cluster_best_vina.csv
只跑 Butina推荐起步
python scripts/cluster_best_vina.py \
--csv scripts/finally_data/qed_values_poses_fgbar_all.csv \
--smiles-col smiles \
--radius 3 --n-bits 1024 \
--method butina \
--cutoff 0.6 \
--out scripts/finally_data/fgbar_cluster_best_vina_butina.csv
python scripts/cluster_best_vina.py \
--csv scripts/finally_data/qed_values_poses_fgbar_all.csv \
--smiles-col smiles \
--radius 3 --n-bits 1024 \
--method hierarchical \
--cutoff 0.6 \
--out scripts/finally_data/fgbar_cluster_best_vina_hierarchical.csv
# 产出fgbar_cluster_best_vina_hierarchical.csv
python scripts/cluster_best_vina.py \
--csv scripts/finally_data/qed_values_poses_fgbar_all.csv \
--smiles-col smiles \
--radius 3 --n-bits 1024 \
--method both \
--cutoff 0.6 \
--out scripts/finally_data/fgbar_cluster_best_vina.csv
# 产出两份对比文件
"""
import argparse
import ast
import sys, os
from pathlib import Path
import numpy as np
import pandas as pd
# 让项目根目录的 utils.chem_cluster 可用
sys.path.append(Path(os.path.abspath(__file__)).parent.parent.as_posix())
from utils.chem_cluster import TanimotoClusterer, FPConfig
def parse_vina_scores(v):
"""'[-6.9, -6.1, ...]' 解析为 list[float];出错返回 [nan]"""
try:
seq = ast.literal_eval(v)
if isinstance(seq, (list, tuple)):
return [float(x) for x in seq]
return [float(v)]
except Exception:
return [np.nan]
def select_best_each_cluster(df_valid: pd.DataFrame, labels: np.ndarray) -> pd.DataFrame:
"""
在已经对齐的 df_valid仅有效 SMILES 行)上:
- 计算 df_valid['vina_best'] = min(vina_scores)
- 每簇取 vina_best 最小的那一行
"""
out = df_valid.copy()
out["cluster_id"] = labels
out["vina_list"] = out["vina_scores"].apply(parse_vina_scores)
# 默认对接分数越小越好(更负越好);如相反改成 max
out["vina_best"] = out["vina_list"].apply(min)
# groupby 每簇取 vina_best 最小的那条
picked = (
out.loc[out.groupby("cluster_id")["vina_best"].idxmin()]
.sort_values(["cluster_id", "vina_best"])
.reset_index(drop=True)
)
return picked
def run_one_method(df: pd.DataFrame, smiles_col: str, fp_radius: int, fp_n_bits: int,
method: str, cutoff: float, out_path: Path) -> pd.DataFrame:
"""
method: 'butina''hierarchical'
cutoff:
- butina -> sim_cutoff (相似度阈值)
- hierarchical -> 距离阈值 t= 1 - 相似度阈值),我们直接把 --cutoff 当做 t 使用
"""
smiles_all = df[smiles_col].astype(str).tolist()
api = TanimotoClusterer(fp_cfg=FPConfig(radius=fp_radius, n_bits=fp_n_bits))
if method == "butina":
res = api.fit_from_smiles(smiles_all, method="butina",
method_kwargs={"sim_cutoff": float(cutoff)})
elif method == "hierarchical":
# 用 scipy linkage + fcluster 切分average
res = api.fit_from_smiles(smiles_all, method="scipy_linkage",
method_kwargs={"method": "average",
"t": float(cutoff),
"criterion": "distance"})
else:
raise ValueError("method ∈ {'butina','hierarchical'}")
labels = res["labels"] # 仅对“有效 SMILES”产生标签
keep_idx = res["keep_idx"] # 有效 SMILES 在原 df 中的行号
print(f"[i] {method} 聚类完成:{len(np.unique(labels))} 个簇,有效分子数 {len(keep_idx)}")
# 仅在有效子集上挑代表
df_valid = df.iloc[keep_idx].copy()
picked_valid = select_best_each_cluster(df_valid, labels)
# 存盘
picked_valid.to_csv(out_path, index=False)
print(f"[+] 已保存:{out_path.as_posix()}{len(picked_valid)} 条,每簇 1 条)")
return picked_valid
def main():
parser = argparse.ArgumentParser(description="聚类并从每簇选择 vina_scores 最优分子(支持 Butina / Hierarchical")
parser.add_argument("--csv", required=True, help="输入 CSV 文件路径")
parser.add_argument("--smiles-col", default="smiles", help="SMILES 列名")
parser.add_argument("--radius", type=int, default=3, help="ECFP 半径2=ECFP4, 3=ECFP6")
parser.add_argument("--n-bits", type=int, default=1024, help="指纹位数1024/2048")
parser.add_argument("--method", choices=["butina", "hierarchical", "both"], default="butina",
help="选择聚类方法")
parser.add_argument("--cutoff", type=float, default=0.6,
help="阈值Butina: sim_cutoffHierarchical: 距离阈值 t")
parser.add_argument("--out", default="cluster_best_vina.csv", help="输出文件名(会自动加后缀 _butina/_hierarchical")
args = parser.parse_args()
df = pd.read_csv(args.csv)
if args.smiles_col not in df.columns:
raise ValueError(f"找不到 SMILES 列: {args.smiles_col}")
out_base = Path(args.out)
if args.method in ("butina", "both"):
run_one_method(
df, args.smiles_col, args.radius, args.n_bits,
method="butina", cutoff=args.cutoff,
out_path=out_base.with_name(f"{out_base.stem}_butina{out_base.suffix}")
)
if args.method in ("hierarchical", "both"):
run_one_method(
df, args.smiles_col, args.radius, args.n_bits,
method="hierarchical", cutoff=args.cutoff,
out_path=out_base.with_name(f"{out_base.stem}_hierarchical{out_base.suffix}")
)
if __name__ == "__main__":
main()