add cluster func

This commit is contained in:
2025-08-15 18:14:30 +08:00
parent e58f90cd1e
commit b85b02b5c3
7 changed files with 25782 additions and 1 deletions

View File

@@ -0,0 +1,160 @@
# -*- coding: utf-8 -*-
"""
聚类并从每个簇选择 vina_scores 最优分子(支持 Butina 或 Hierarchical[scipy linkage]
- 正确解析 vina_scores字符串列表 -> list[float]),默认取最小值(更负更好)
- 稳健处理无效 SMILES只在有效子集上聚类与分组最后导出原始行
- 可选择方法:--method butina / hierarchical / both
- hierarchical 采用 scipy.cluster.hierarchy.linkage(method='average') + fcluster 按距离阈值切分
示例:
python scripts/cluster_best_vina.py \
--csv result/filtered_results/qed_values_trpe_combined_filtered.csv \
--smiles-col smiles \
--method both \
--cutoff 0.6 \
--radius 3 --n-bits 1024 \
--out result/cluster_best_vina.csv
只跑 Butina推荐起步
python scripts/cluster_best_vina.py \
--csv scripts/finally_data/qed_values_poses_fgbar_all.csv \
--smiles-col smiles \
--radius 3 --n-bits 1024 \
--method butina \
--cutoff 0.6 \
--out scripts/finally_data/fgbar_cluster_best_vina_butina.csv
python scripts/cluster_best_vina.py \
--csv scripts/finally_data/qed_values_poses_fgbar_all.csv \
--smiles-col smiles \
--radius 3 --n-bits 1024 \
--method hierarchical \
--cutoff 0.6 \
--out scripts/finally_data/fgbar_cluster_best_vina_hierarchical.csv
# 产出fgbar_cluster_best_vina_hierarchical.csv
python scripts/cluster_best_vina.py \
--csv scripts/finally_data/qed_values_poses_fgbar_all.csv \
--smiles-col smiles \
--radius 3 --n-bits 1024 \
--method both \
--cutoff 0.6 \
--out scripts/finally_data/fgbar_cluster_best_vina.csv
# 产出两份对比文件
"""
import argparse
import ast
import sys, os
from pathlib import Path
import numpy as np
import pandas as pd
# 让项目根目录的 utils.chem_cluster 可用
sys.path.append(Path(os.path.abspath(__file__)).parent.parent.as_posix())
from utils.chem_cluster import TanimotoClusterer, FPConfig
def parse_vina_scores(v):
"""'[-6.9, -6.1, ...]' 解析为 list[float];出错返回 [nan]"""
try:
seq = ast.literal_eval(v)
if isinstance(seq, (list, tuple)):
return [float(x) for x in seq]
return [float(v)]
except Exception:
return [np.nan]
def select_best_each_cluster(df_valid: pd.DataFrame, labels: np.ndarray) -> pd.DataFrame:
"""
在已经对齐的 df_valid仅有效 SMILES 行)上:
- 计算 df_valid['vina_best'] = min(vina_scores)
- 每簇取 vina_best 最小的那一行
"""
out = df_valid.copy()
out["cluster_id"] = labels
out["vina_list"] = out["vina_scores"].apply(parse_vina_scores)
# 默认对接分数越小越好(更负越好);如相反改成 max
out["vina_best"] = out["vina_list"].apply(min)
# groupby 每簇取 vina_best 最小的那条
picked = (
out.loc[out.groupby("cluster_id")["vina_best"].idxmin()]
.sort_values(["cluster_id", "vina_best"])
.reset_index(drop=True)
)
return picked
def run_one_method(df: pd.DataFrame, smiles_col: str, fp_radius: int, fp_n_bits: int,
method: str, cutoff: float, out_path: Path) -> pd.DataFrame:
"""
method: 'butina''hierarchical'
cutoff:
- butina -> sim_cutoff (相似度阈值)
- hierarchical -> 距离阈值 t= 1 - 相似度阈值),我们直接把 --cutoff 当做 t 使用
"""
smiles_all = df[smiles_col].astype(str).tolist()
api = TanimotoClusterer(fp_cfg=FPConfig(radius=fp_radius, n_bits=fp_n_bits))
if method == "butina":
res = api.fit_from_smiles(smiles_all, method="butina",
method_kwargs={"sim_cutoff": float(cutoff)})
elif method == "hierarchical":
# 用 scipy linkage + fcluster 切分average
res = api.fit_from_smiles(smiles_all, method="scipy_linkage",
method_kwargs={"method": "average",
"t": float(cutoff),
"criterion": "distance"})
else:
raise ValueError("method ∈ {'butina','hierarchical'}")
labels = res["labels"] # 仅对“有效 SMILES”产生标签
keep_idx = res["keep_idx"] # 有效 SMILES 在原 df 中的行号
print(f"[i] {method} 聚类完成:{len(np.unique(labels))} 个簇,有效分子数 {len(keep_idx)}")
# 仅在有效子集上挑代表
df_valid = df.iloc[keep_idx].copy()
picked_valid = select_best_each_cluster(df_valid, labels)
# 存盘
picked_valid.to_csv(out_path, index=False)
print(f"[+] 已保存:{out_path.as_posix()}{len(picked_valid)} 条,每簇 1 条)")
return picked_valid
def main():
parser = argparse.ArgumentParser(description="聚类并从每簇选择 vina_scores 最优分子(支持 Butina / Hierarchical")
parser.add_argument("--csv", required=True, help="输入 CSV 文件路径")
parser.add_argument("--smiles-col", default="smiles", help="SMILES 列名")
parser.add_argument("--radius", type=int, default=3, help="ECFP 半径2=ECFP4, 3=ECFP6")
parser.add_argument("--n-bits", type=int, default=1024, help="指纹位数1024/2048")
parser.add_argument("--method", choices=["butina", "hierarchical", "both"], default="butina",
help="选择聚类方法")
parser.add_argument("--cutoff", type=float, default=0.6,
help="阈值Butina: sim_cutoffHierarchical: 距离阈值 t")
parser.add_argument("--out", default="cluster_best_vina.csv", help="输出文件名(会自动加后缀 _butina/_hierarchical")
args = parser.parse_args()
df = pd.read_csv(args.csv)
if args.smiles_col not in df.columns:
raise ValueError(f"找不到 SMILES 列: {args.smiles_col}")
out_base = Path(args.out)
if args.method in ("butina", "both"):
run_one_method(
df, args.smiles_col, args.radius, args.n_bits,
method="butina", cutoff=args.cutoff,
out_path=out_base.with_name(f"{out_base.stem}_butina{out_base.suffix}")
)
if args.method in ("hierarchical", "both"):
run_one_method(
df, args.smiles_col, args.radius, args.n_bits,
method="hierarchical", cutoff=args.cutoff,
out_path=out_base.with_name(f"{out_base.stem}_hierarchical{out_base.suffix}")
)
if __name__ == "__main__":
main()