add cluster func
This commit is contained in:
160
scripts/cluster_best_vina.py
Normal file
160
scripts/cluster_best_vina.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
聚类并从每个簇选择 vina_scores 最优分子(支持 Butina 或 Hierarchical[scipy linkage])
|
||||
- 正确解析 vina_scores(字符串列表 -> list[float]),默认取最小值(更负更好)
|
||||
- 稳健处理无效 SMILES:只在有效子集上聚类与分组,最后导出原始行
|
||||
- 可选择方法:--method butina / hierarchical / both
|
||||
- hierarchical 采用 scipy.cluster.hierarchy.linkage(method='average') + fcluster 按距离阈值切分
|
||||
|
||||
示例:
|
||||
python scripts/cluster_best_vina.py \
|
||||
--csv result/filtered_results/qed_values_trpe_combined_filtered.csv \
|
||||
--smiles-col smiles \
|
||||
--method both \
|
||||
--cutoff 0.6 \
|
||||
--radius 3 --n-bits 1024 \
|
||||
--out result/cluster_best_vina.csv
|
||||
|
||||
只跑 Butina(推荐起步)
|
||||
|
||||
python scripts/cluster_best_vina.py \
|
||||
--csv scripts/finally_data/qed_values_poses_fgbar_all.csv \
|
||||
--smiles-col smiles \
|
||||
--radius 3 --n-bits 1024 \
|
||||
--method butina \
|
||||
--cutoff 0.6 \
|
||||
--out scripts/finally_data/fgbar_cluster_best_vina_butina.csv
|
||||
|
||||
python scripts/cluster_best_vina.py \
|
||||
--csv scripts/finally_data/qed_values_poses_fgbar_all.csv \
|
||||
--smiles-col smiles \
|
||||
--radius 3 --n-bits 1024 \
|
||||
--method hierarchical \
|
||||
--cutoff 0.6 \
|
||||
--out scripts/finally_data/fgbar_cluster_best_vina_hierarchical.csv
|
||||
# 产出:fgbar_cluster_best_vina_hierarchical.csv
|
||||
|
||||
python scripts/cluster_best_vina.py \
|
||||
--csv scripts/finally_data/qed_values_poses_fgbar_all.csv \
|
||||
--smiles-col smiles \
|
||||
--radius 3 --n-bits 1024 \
|
||||
--method both \
|
||||
--cutoff 0.6 \
|
||||
--out scripts/finally_data/fgbar_cluster_best_vina.csv
|
||||
# 产出两份对比文件
|
||||
|
||||
"""
|
||||
import argparse
|
||||
import ast
|
||||
import sys, os
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
# 让项目根目录的 utils.chem_cluster 可用
|
||||
sys.path.append(Path(os.path.abspath(__file__)).parent.parent.as_posix())
|
||||
from utils.chem_cluster import TanimotoClusterer, FPConfig
|
||||
|
||||
def parse_vina_scores(v):
|
||||
"""把 '[-6.9, -6.1, ...]' 解析为 list[float];出错返回 [nan]"""
|
||||
try:
|
||||
seq = ast.literal_eval(v)
|
||||
if isinstance(seq, (list, tuple)):
|
||||
return [float(x) for x in seq]
|
||||
return [float(v)]
|
||||
except Exception:
|
||||
return [np.nan]
|
||||
|
||||
def select_best_each_cluster(df_valid: pd.DataFrame, labels: np.ndarray) -> pd.DataFrame:
|
||||
"""
|
||||
在已经对齐的 df_valid(仅有效 SMILES 行)上:
|
||||
- 计算 df_valid['vina_best'] = min(vina_scores)
|
||||
- 每簇取 vina_best 最小的那一行
|
||||
"""
|
||||
out = df_valid.copy()
|
||||
out["cluster_id"] = labels
|
||||
out["vina_list"] = out["vina_scores"].apply(parse_vina_scores)
|
||||
# 默认对接分数越小越好(更负越好);如相反改成 max
|
||||
out["vina_best"] = out["vina_list"].apply(min)
|
||||
# groupby 每簇取 vina_best 最小的那条
|
||||
picked = (
|
||||
out.loc[out.groupby("cluster_id")["vina_best"].idxmin()]
|
||||
.sort_values(["cluster_id", "vina_best"])
|
||||
.reset_index(drop=True)
|
||||
)
|
||||
return picked
|
||||
|
||||
def run_one_method(df: pd.DataFrame, smiles_col: str, fp_radius: int, fp_n_bits: int,
|
||||
method: str, cutoff: float, out_path: Path) -> pd.DataFrame:
|
||||
"""
|
||||
method: 'butina' 或 'hierarchical'
|
||||
cutoff:
|
||||
- butina -> sim_cutoff (相似度阈值)
|
||||
- hierarchical -> 距离阈值 t(= 1 - 相似度阈值),我们直接把 --cutoff 当做 t 使用
|
||||
"""
|
||||
smiles_all = df[smiles_col].astype(str).tolist()
|
||||
|
||||
api = TanimotoClusterer(fp_cfg=FPConfig(radius=fp_radius, n_bits=fp_n_bits))
|
||||
|
||||
if method == "butina":
|
||||
res = api.fit_from_smiles(smiles_all, method="butina",
|
||||
method_kwargs={"sim_cutoff": float(cutoff)})
|
||||
elif method == "hierarchical":
|
||||
# 用 scipy linkage + fcluster 切分(average)
|
||||
res = api.fit_from_smiles(smiles_all, method="scipy_linkage",
|
||||
method_kwargs={"method": "average",
|
||||
"t": float(cutoff),
|
||||
"criterion": "distance"})
|
||||
else:
|
||||
raise ValueError("method ∈ {'butina','hierarchical'}")
|
||||
|
||||
labels = res["labels"] # 仅对“有效 SMILES”产生标签
|
||||
keep_idx = res["keep_idx"] # 有效 SMILES 在原 df 中的行号
|
||||
|
||||
print(f"[i] {method} 聚类完成:{len(np.unique(labels))} 个簇,有效分子数 {len(keep_idx)}")
|
||||
|
||||
# 仅在有效子集上挑代表
|
||||
df_valid = df.iloc[keep_idx].copy()
|
||||
picked_valid = select_best_each_cluster(df_valid, labels)
|
||||
|
||||
# 存盘
|
||||
picked_valid.to_csv(out_path, index=False)
|
||||
print(f"[+] 已保存:{out_path.as_posix()}({len(picked_valid)} 条,每簇 1 条)")
|
||||
|
||||
return picked_valid
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="聚类并从每簇选择 vina_scores 最优分子(支持 Butina / Hierarchical)")
|
||||
parser.add_argument("--csv", required=True, help="输入 CSV 文件路径")
|
||||
parser.add_argument("--smiles-col", default="smiles", help="SMILES 列名")
|
||||
parser.add_argument("--radius", type=int, default=3, help="ECFP 半径(2=ECFP4, 3=ECFP6)")
|
||||
parser.add_argument("--n-bits", type=int, default=1024, help="指纹位数(1024/2048)")
|
||||
parser.add_argument("--method", choices=["butina", "hierarchical", "both"], default="butina",
|
||||
help="选择聚类方法")
|
||||
parser.add_argument("--cutoff", type=float, default=0.6,
|
||||
help="阈值(Butina: sim_cutoff;Hierarchical: 距离阈值 t)")
|
||||
parser.add_argument("--out", default="cluster_best_vina.csv", help="输出文件名(会自动加后缀 _butina/_hierarchical)")
|
||||
args = parser.parse_args()
|
||||
|
||||
df = pd.read_csv(args.csv)
|
||||
if args.smiles_col not in df.columns:
|
||||
raise ValueError(f"找不到 SMILES 列: {args.smiles_col}")
|
||||
|
||||
out_base = Path(args.out)
|
||||
|
||||
if args.method in ("butina", "both"):
|
||||
run_one_method(
|
||||
df, args.smiles_col, args.radius, args.n_bits,
|
||||
method="butina", cutoff=args.cutoff,
|
||||
out_path=out_base.with_name(f"{out_base.stem}_butina{out_base.suffix}")
|
||||
)
|
||||
|
||||
if args.method in ("hierarchical", "both"):
|
||||
run_one_method(
|
||||
df, args.smiles_col, args.radius, args.n_bits,
|
||||
method="hierarchical", cutoff=args.cutoff,
|
||||
out_path=out_base.with_name(f"{out_base.stem}_hierarchical{out_base.suffix}")
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user