add cluster func
This commit is contained in:
160
scripts/cluster_best_vina.py
Normal file
160
scripts/cluster_best_vina.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
聚类并从每个簇选择 vina_scores 最优分子(支持 Butina 或 Hierarchical[scipy linkage])
|
||||
- 正确解析 vina_scores(字符串列表 -> list[float]),默认取最小值(更负更好)
|
||||
- 稳健处理无效 SMILES:只在有效子集上聚类与分组,最后导出原始行
|
||||
- 可选择方法:--method butina / hierarchical / both
|
||||
- hierarchical 采用 scipy.cluster.hierarchy.linkage(method='average') + fcluster 按距离阈值切分
|
||||
|
||||
示例:
|
||||
python scripts/cluster_best_vina.py \
|
||||
--csv result/filtered_results/qed_values_trpe_combined_filtered.csv \
|
||||
--smiles-col smiles \
|
||||
--method both \
|
||||
--cutoff 0.6 \
|
||||
--radius 3 --n-bits 1024 \
|
||||
--out result/cluster_best_vina.csv
|
||||
|
||||
只跑 Butina(推荐起步)
|
||||
|
||||
python scripts/cluster_best_vina.py \
|
||||
--csv scripts/finally_data/qed_values_poses_fgbar_all.csv \
|
||||
--smiles-col smiles \
|
||||
--radius 3 --n-bits 1024 \
|
||||
--method butina \
|
||||
--cutoff 0.6 \
|
||||
--out scripts/finally_data/fgbar_cluster_best_vina_butina.csv
|
||||
|
||||
python scripts/cluster_best_vina.py \
|
||||
--csv scripts/finally_data/qed_values_poses_fgbar_all.csv \
|
||||
--smiles-col smiles \
|
||||
--radius 3 --n-bits 1024 \
|
||||
--method hierarchical \
|
||||
--cutoff 0.6 \
|
||||
--out scripts/finally_data/fgbar_cluster_best_vina_hierarchical.csv
|
||||
# 产出:fgbar_cluster_best_vina_hierarchical.csv
|
||||
|
||||
python scripts/cluster_best_vina.py \
|
||||
--csv scripts/finally_data/qed_values_poses_fgbar_all.csv \
|
||||
--smiles-col smiles \
|
||||
--radius 3 --n-bits 1024 \
|
||||
--method both \
|
||||
--cutoff 0.6 \
|
||||
--out scripts/finally_data/fgbar_cluster_best_vina.csv
|
||||
# 产出两份对比文件
|
||||
|
||||
"""
|
||||
import argparse
|
||||
import ast
|
||||
import sys, os
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
# 让项目根目录的 utils.chem_cluster 可用
|
||||
sys.path.append(Path(os.path.abspath(__file__)).parent.parent.as_posix())
|
||||
from utils.chem_cluster import TanimotoClusterer, FPConfig
|
||||
|
||||
def parse_vina_scores(v):
|
||||
"""把 '[-6.9, -6.1, ...]' 解析为 list[float];出错返回 [nan]"""
|
||||
try:
|
||||
seq = ast.literal_eval(v)
|
||||
if isinstance(seq, (list, tuple)):
|
||||
return [float(x) for x in seq]
|
||||
return [float(v)]
|
||||
except Exception:
|
||||
return [np.nan]
|
||||
|
||||
def select_best_each_cluster(df_valid: pd.DataFrame, labels: np.ndarray) -> pd.DataFrame:
|
||||
"""
|
||||
在已经对齐的 df_valid(仅有效 SMILES 行)上:
|
||||
- 计算 df_valid['vina_best'] = min(vina_scores)
|
||||
- 每簇取 vina_best 最小的那一行
|
||||
"""
|
||||
out = df_valid.copy()
|
||||
out["cluster_id"] = labels
|
||||
out["vina_list"] = out["vina_scores"].apply(parse_vina_scores)
|
||||
# 默认对接分数越小越好(更负越好);如相反改成 max
|
||||
out["vina_best"] = out["vina_list"].apply(min)
|
||||
# groupby 每簇取 vina_best 最小的那条
|
||||
picked = (
|
||||
out.loc[out.groupby("cluster_id")["vina_best"].idxmin()]
|
||||
.sort_values(["cluster_id", "vina_best"])
|
||||
.reset_index(drop=True)
|
||||
)
|
||||
return picked
|
||||
|
||||
def run_one_method(df: pd.DataFrame, smiles_col: str, fp_radius: int, fp_n_bits: int,
|
||||
method: str, cutoff: float, out_path: Path) -> pd.DataFrame:
|
||||
"""
|
||||
method: 'butina' 或 'hierarchical'
|
||||
cutoff:
|
||||
- butina -> sim_cutoff (相似度阈值)
|
||||
- hierarchical -> 距离阈值 t(= 1 - 相似度阈值),我们直接把 --cutoff 当做 t 使用
|
||||
"""
|
||||
smiles_all = df[smiles_col].astype(str).tolist()
|
||||
|
||||
api = TanimotoClusterer(fp_cfg=FPConfig(radius=fp_radius, n_bits=fp_n_bits))
|
||||
|
||||
if method == "butina":
|
||||
res = api.fit_from_smiles(smiles_all, method="butina",
|
||||
method_kwargs={"sim_cutoff": float(cutoff)})
|
||||
elif method == "hierarchical":
|
||||
# 用 scipy linkage + fcluster 切分(average)
|
||||
res = api.fit_from_smiles(smiles_all, method="scipy_linkage",
|
||||
method_kwargs={"method": "average",
|
||||
"t": float(cutoff),
|
||||
"criterion": "distance"})
|
||||
else:
|
||||
raise ValueError("method ∈ {'butina','hierarchical'}")
|
||||
|
||||
labels = res["labels"] # 仅对“有效 SMILES”产生标签
|
||||
keep_idx = res["keep_idx"] # 有效 SMILES 在原 df 中的行号
|
||||
|
||||
print(f"[i] {method} 聚类完成:{len(np.unique(labels))} 个簇,有效分子数 {len(keep_idx)}")
|
||||
|
||||
# 仅在有效子集上挑代表
|
||||
df_valid = df.iloc[keep_idx].copy()
|
||||
picked_valid = select_best_each_cluster(df_valid, labels)
|
||||
|
||||
# 存盘
|
||||
picked_valid.to_csv(out_path, index=False)
|
||||
print(f"[+] 已保存:{out_path.as_posix()}({len(picked_valid)} 条,每簇 1 条)")
|
||||
|
||||
return picked_valid
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="聚类并从每簇选择 vina_scores 最优分子(支持 Butina / Hierarchical)")
|
||||
parser.add_argument("--csv", required=True, help="输入 CSV 文件路径")
|
||||
parser.add_argument("--smiles-col", default="smiles", help="SMILES 列名")
|
||||
parser.add_argument("--radius", type=int, default=3, help="ECFP 半径(2=ECFP4, 3=ECFP6)")
|
||||
parser.add_argument("--n-bits", type=int, default=1024, help="指纹位数(1024/2048)")
|
||||
parser.add_argument("--method", choices=["butina", "hierarchical", "both"], default="butina",
|
||||
help="选择聚类方法")
|
||||
parser.add_argument("--cutoff", type=float, default=0.6,
|
||||
help="阈值(Butina: sim_cutoff;Hierarchical: 距离阈值 t)")
|
||||
parser.add_argument("--out", default="cluster_best_vina.csv", help="输出文件名(会自动加后缀 _butina/_hierarchical)")
|
||||
args = parser.parse_args()
|
||||
|
||||
df = pd.read_csv(args.csv)
|
||||
if args.smiles_col not in df.columns:
|
||||
raise ValueError(f"找不到 SMILES 列: {args.smiles_col}")
|
||||
|
||||
out_base = Path(args.out)
|
||||
|
||||
if args.method in ("butina", "both"):
|
||||
run_one_method(
|
||||
df, args.smiles_col, args.radius, args.n_bits,
|
||||
method="butina", cutoff=args.cutoff,
|
||||
out_path=out_base.with_name(f"{out_base.stem}_butina{out_base.suffix}")
|
||||
)
|
||||
|
||||
if args.method in ("hierarchical", "both"):
|
||||
run_one_method(
|
||||
df, args.smiles_col, args.radius, args.n_bits,
|
||||
method="hierarchical", cutoff=args.cutoff,
|
||||
out_path=out_base.with_name(f"{out_base.stem}_hierarchical{out_base.suffix}")
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
103
scripts/cluster_granularity_scan.py
Normal file
103
scripts/cluster_granularity_scan.py
Normal file
@@ -0,0 +1,103 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
聚类粒度扫描脚本
|
||||
运行示例:
|
||||
python scripts/cluster_granularity_scan.py \
|
||||
--csv result/filtered_results/qed_values_trpe_combined_filtered.csv \
|
||||
--smiles-col smiles \
|
||||
--radius 3 \
|
||||
--n-bits 1024
|
||||
"""
|
||||
import sys, os
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from rdkit import Chem, DataStructs
|
||||
from rdkit.Chem import AllChem
|
||||
from sklearn.cluster import AgglomerativeClustering, KMeans, DBSCAN
|
||||
|
||||
# 把项目根目录加入 sys.path
|
||||
sys.path.append(Path(os.path.abspath(__file__)).parent.parent.as_posix())
|
||||
from utils.chem_cluster import TanimotoClusterer, FPConfig
|
||||
|
||||
def tanimoto_matrix(smiles, radius=3, n_bits=1024):
|
||||
fps = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(s), radius, nBits=n_bits) for s in smiles]
|
||||
n = len(fps)
|
||||
sim_mat = np.zeros((n, n))
|
||||
for i in range(n):
|
||||
sims = DataStructs.BulkTanimotoSimilarity(fps[i], fps)
|
||||
sim_mat[i, :] = sims
|
||||
return sim_mat
|
||||
|
||||
def avg_intra_cluster_similarity(labels, sim_mat):
|
||||
"""计算每个簇的平均内部相似度"""
|
||||
cluster_sims = []
|
||||
for lbl in set(labels):
|
||||
idx = np.where(labels == lbl)[0]
|
||||
if len(idx) > 1:
|
||||
sub_sim = sim_mat[np.ix_(idx, idx)]
|
||||
tril_idx = np.tril_indices_from(sub_sim, k=-1)
|
||||
cluster_sims.append(np.mean(sub_sim[tril_idx]))
|
||||
return np.mean(cluster_sims) if cluster_sims else 0
|
||||
|
||||
def scan(args):
|
||||
df = pd.read_csv(args.csv)
|
||||
smiles = df[args.smiles_col].astype(str).tolist()
|
||||
|
||||
# 预先计算相似度矩阵
|
||||
print("计算 Tanimoto 相似度矩阵...")
|
||||
sim_mat = tanimoto_matrix(smiles, radius=args.radius, n_bits=args.n_bits)
|
||||
dist_mat = 1 - sim_mat # 聚类使用距离矩阵
|
||||
|
||||
results = []
|
||||
|
||||
# 1. Butina 聚类
|
||||
from rdkit.ML.Cluster import Butina
|
||||
for cutoff in np.linspace(0.4, 0.8, 5):
|
||||
cluster_res = list(Butina.ClusterData(dist_mat, len(smiles), cutoff, isDistData=True))
|
||||
labels = np.zeros(len(smiles), dtype=int)
|
||||
for cid, members in enumerate(cluster_res):
|
||||
for m in members:
|
||||
labels[m] = cid
|
||||
avg_sim = avg_intra_cluster_similarity(labels, sim_mat)
|
||||
results.append(("Butina", {"cutoff": round(cutoff, 2)}, len(set(labels)), np.mean(np.bincount(labels)), avg_sim))
|
||||
|
||||
# 2. 层次聚类
|
||||
for thresh in [0.3, 0.4, 0.5]:
|
||||
model = AgglomerativeClustering(n_clusters=None, metric='precomputed', linkage='average', distance_threshold=thresh)
|
||||
labels = model.fit_predict(dist_mat)
|
||||
avg_sim = avg_intra_cluster_similarity(labels, sim_mat)
|
||||
results.append(("Hierarchical", {"threshold": thresh}, len(set(labels)), np.mean(np.bincount(labels)), avg_sim))
|
||||
|
||||
# 3. DBSCAN
|
||||
for eps in [0.2, 0.3, 0.4]:
|
||||
model = DBSCAN(eps=eps, min_samples=2, metric="precomputed")
|
||||
labels = model.fit_predict(dist_mat)
|
||||
n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
|
||||
avg_sim = avg_intra_cluster_similarity(labels[labels != -1], sim_mat)
|
||||
results.append(("DBSCAN", {"eps": eps}, n_clusters, np.mean(np.bincount(labels[labels != -1])) if n_clusters > 0 else 0, avg_sim))
|
||||
|
||||
# 4. KMeans (先降维再聚类)
|
||||
from sklearn.decomposition import PCA
|
||||
coords = PCA(n_components=10).fit_transform(sim_mat)
|
||||
for k in [10, 20, 50]:
|
||||
model = KMeans(n_clusters=k, random_state=42)
|
||||
labels = model.fit_predict(coords)
|
||||
avg_sim = avg_intra_cluster_similarity(labels, sim_mat)
|
||||
results.append(("KMeans", {"k": k}, len(set(labels)), np.mean(np.bincount(labels)), avg_sim))
|
||||
|
||||
# 输出结果表
|
||||
print(f"{'Method':<15} {'Params':<25} {'#Clusters':<10} {'AvgSize':<10} {'AvgIntraSim':<10}")
|
||||
for r in results:
|
||||
print(f"{r[0]:<15} {str(r[1]):<25} {r[2]:<10} {r[3]:<10.2f} {r[4]:<10.3f}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="聚类粒度扫描")
|
||||
parser.add_argument("--csv", type=str, required=True, help="输入 CSV 文件路径")
|
||||
parser.add_argument("--smiles-col", type=str, required=True, help="SMILES 列名")
|
||||
parser.add_argument("--radius", type=int, default=3, help="Morgan 指纹半径")
|
||||
parser.add_argument("--n-bits", type=int, default=1024, help="指纹位数")
|
||||
args = parser.parse_args()
|
||||
scan(args)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user