161 lines
6.3 KiB
Python
161 lines
6.3 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
聚类并从每个簇选择 vina_scores 最优分子(支持 Butina 或 Hierarchical[scipy linkage])
|
||
- 正确解析 vina_scores(字符串列表 -> list[float]),默认取最小值(更负更好)
|
||
- 稳健处理无效 SMILES:只在有效子集上聚类与分组,最后导出原始行
|
||
- 可选择方法:--method butina / hierarchical / both
|
||
- hierarchical 采用 scipy.cluster.hierarchy.linkage(method='average') + fcluster 按距离阈值切分
|
||
|
||
示例:
|
||
python scripts/cluster_best_vina.py \
|
||
--csv result/filtered_results/qed_values_trpe_combined_filtered.csv \
|
||
--smiles-col smiles \
|
||
--method both \
|
||
--cutoff 0.6 \
|
||
--radius 3 --n-bits 1024 \
|
||
--out result/cluster_best_vina.csv
|
||
|
||
只跑 Butina(推荐起步)
|
||
|
||
python scripts/cluster_best_vina.py \
|
||
--csv scripts/finally_data/qed_values_poses_fgbar_all.csv \
|
||
--smiles-col smiles \
|
||
--radius 3 --n-bits 1024 \
|
||
--method butina \
|
||
--cutoff 0.6 \
|
||
--out scripts/finally_data/fgbar_cluster_best_vina_butina.csv
|
||
|
||
python scripts/cluster_best_vina.py \
|
||
--csv scripts/finally_data/qed_values_poses_fgbar_all.csv \
|
||
--smiles-col smiles \
|
||
--radius 3 --n-bits 1024 \
|
||
--method hierarchical \
|
||
--cutoff 0.6 \
|
||
--out scripts/finally_data/fgbar_cluster_best_vina_hierarchical.csv
|
||
# 产出:fgbar_cluster_best_vina_hierarchical.csv
|
||
|
||
python scripts/cluster_best_vina.py \
|
||
--csv scripts/finally_data/qed_values_poses_fgbar_all.csv \
|
||
--smiles-col smiles \
|
||
--radius 3 --n-bits 1024 \
|
||
--method both \
|
||
--cutoff 0.6 \
|
||
--out scripts/finally_data/fgbar_cluster_best_vina.csv
|
||
# 产出两份对比文件
|
||
|
||
"""
|
||
import argparse
|
||
import ast
|
||
import sys, os
|
||
from pathlib import Path
|
||
import numpy as np
|
||
import pandas as pd
|
||
|
||
# 让项目根目录的 utils.chem_cluster 可用
|
||
sys.path.append(Path(os.path.abspath(__file__)).parent.parent.as_posix())
|
||
from utils.chem_cluster import TanimotoClusterer, FPConfig
|
||
|
||
def parse_vina_scores(v):
|
||
"""把 '[-6.9, -6.1, ...]' 解析为 list[float];出错返回 [nan]"""
|
||
try:
|
||
seq = ast.literal_eval(v)
|
||
if isinstance(seq, (list, tuple)):
|
||
return [float(x) for x in seq]
|
||
return [float(v)]
|
||
except Exception:
|
||
return [np.nan]
|
||
|
||
def select_best_each_cluster(df_valid: pd.DataFrame, labels: np.ndarray) -> pd.DataFrame:
|
||
"""
|
||
在已经对齐的 df_valid(仅有效 SMILES 行)上:
|
||
- 计算 df_valid['vina_best'] = min(vina_scores)
|
||
- 每簇取 vina_best 最小的那一行
|
||
"""
|
||
out = df_valid.copy()
|
||
out["cluster_id"] = labels
|
||
out["vina_list"] = out["vina_scores"].apply(parse_vina_scores)
|
||
# 默认对接分数越小越好(更负越好);如相反改成 max
|
||
out["vina_best"] = out["vina_list"].apply(min)
|
||
# groupby 每簇取 vina_best 最小的那条
|
||
picked = (
|
||
out.loc[out.groupby("cluster_id")["vina_best"].idxmin()]
|
||
.sort_values(["cluster_id", "vina_best"])
|
||
.reset_index(drop=True)
|
||
)
|
||
return picked
|
||
|
||
def run_one_method(df: pd.DataFrame, smiles_col: str, fp_radius: int, fp_n_bits: int,
|
||
method: str, cutoff: float, out_path: Path) -> pd.DataFrame:
|
||
"""
|
||
method: 'butina' 或 'hierarchical'
|
||
cutoff:
|
||
- butina -> sim_cutoff (相似度阈值)
|
||
- hierarchical -> 距离阈值 t(= 1 - 相似度阈值),我们直接把 --cutoff 当做 t 使用
|
||
"""
|
||
smiles_all = df[smiles_col].astype(str).tolist()
|
||
|
||
api = TanimotoClusterer(fp_cfg=FPConfig(radius=fp_radius, n_bits=fp_n_bits))
|
||
|
||
if method == "butina":
|
||
res = api.fit_from_smiles(smiles_all, method="butina",
|
||
method_kwargs={"sim_cutoff": float(cutoff)})
|
||
elif method == "hierarchical":
|
||
# 用 scipy linkage + fcluster 切分(average)
|
||
res = api.fit_from_smiles(smiles_all, method="scipy_linkage",
|
||
method_kwargs={"method": "average",
|
||
"t": float(cutoff),
|
||
"criterion": "distance"})
|
||
else:
|
||
raise ValueError("method ∈ {'butina','hierarchical'}")
|
||
|
||
labels = res["labels"] # 仅对“有效 SMILES”产生标签
|
||
keep_idx = res["keep_idx"] # 有效 SMILES 在原 df 中的行号
|
||
|
||
print(f"[i] {method} 聚类完成:{len(np.unique(labels))} 个簇,有效分子数 {len(keep_idx)}")
|
||
|
||
# 仅在有效子集上挑代表
|
||
df_valid = df.iloc[keep_idx].copy()
|
||
picked_valid = select_best_each_cluster(df_valid, labels)
|
||
|
||
# 存盘
|
||
picked_valid.to_csv(out_path, index=False)
|
||
print(f"[+] 已保存:{out_path.as_posix()}({len(picked_valid)} 条,每簇 1 条)")
|
||
|
||
return picked_valid
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="聚类并从每簇选择 vina_scores 最优分子(支持 Butina / Hierarchical)")
|
||
parser.add_argument("--csv", required=True, help="输入 CSV 文件路径")
|
||
parser.add_argument("--smiles-col", default="smiles", help="SMILES 列名")
|
||
parser.add_argument("--radius", type=int, default=3, help="ECFP 半径(2=ECFP4, 3=ECFP6)")
|
||
parser.add_argument("--n-bits", type=int, default=1024, help="指纹位数(1024/2048)")
|
||
parser.add_argument("--method", choices=["butina", "hierarchical", "both"], default="butina",
|
||
help="选择聚类方法")
|
||
parser.add_argument("--cutoff", type=float, default=0.6,
|
||
help="阈值(Butina: sim_cutoff;Hierarchical: 距离阈值 t)")
|
||
parser.add_argument("--out", default="cluster_best_vina.csv", help="输出文件名(会自动加后缀 _butina/_hierarchical)")
|
||
args = parser.parse_args()
|
||
|
||
df = pd.read_csv(args.csv)
|
||
if args.smiles_col not in df.columns:
|
||
raise ValueError(f"找不到 SMILES 列: {args.smiles_col}")
|
||
|
||
out_base = Path(args.out)
|
||
|
||
if args.method in ("butina", "both"):
|
||
run_one_method(
|
||
df, args.smiles_col, args.radius, args.n_bits,
|
||
method="butina", cutoff=args.cutoff,
|
||
out_path=out_base.with_name(f"{out_base.stem}_butina{out_base.suffix}")
|
||
)
|
||
|
||
if args.method in ("hierarchical", "both"):
|
||
run_one_method(
|
||
df, args.smiles_col, args.radius, args.n_bits,
|
||
method="hierarchical", cutoff=args.cutoff,
|
||
out_path=out_base.with_name(f"{out_base.stem}_hierarchical{out_base.suffix}")
|
||
)
|
||
|
||
if __name__ == "__main__":
|
||
main()
|