Files
vina_docking_batch/utils/chem_cluster/tanimoto_cluster_api.py

385 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
from __future__ import annotations
import math
import random
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Any, Tuple
import numpy as np
import pandas as pd
from rdkit import Chem, RDLogger
from rdkit.Chem import rdMolDescriptors
from rdkit import DataStructs
from rdkit.ML.Cluster import Butina
from sklearn.cluster import DBSCAN, AgglomerativeClustering
try:
from scipy.cluster.hierarchy import linkage, fcluster
except Exception:
linkage = None
fcluster = None
# 静音 RDKit
RDLogger.DisableLog('rdApp.*')
# ---------- 指纹 & 相似度工具 ----------
@dataclass
class FPConfig:
radius: int = 2 # ECFP 半径2=ECFP4
n_bits: int = 2048 # 位数1024 或 2048 常用)
use_count: bool = False # True: 计数 FPFalse: bit 向量(推荐用于 Tanimoto
def smiles_to_mols(smiles: List[str]) -> Tuple[List[Chem.Mol], List[int]]:
mols, idx = [], []
for i, smi in enumerate(smiles):
m = Chem.MolFromSmiles(str(smi))
if m is not None:
mols.append(m)
idx.append(i)
return mols, idx
def mols_to_ecfp(mols: List[Chem.Mol], cfg: FPConfig):
fps = []
if cfg.use_count:
gen = rdMolDescriptors.GetMorganFingerprint
for m in mols:
fps.append(gen(m, cfg.radius))
else:
for m in mols:
fps.append(rdMolDescriptors.GetMorganFingerprintAsBitVect(m, cfg.radius, nBits=cfg.n_bits))
return fps
def tanimoto(a, b) -> float:
return DataStructs.TanimotoSimilarity(a, b)
def bulk_tanimoto_to_many(fp, fps) -> List[float]:
# 对大量 pair 更快
return DataStructs.BulkTanimotoSimilarity(fp, fps)
# ---------- 稳定可扩展的聚类(首选) ----------
def cluster_butina(fps, sim_cutoff: float = 0.6) -> np.ndarray:
"""
RDKit Butina 聚类
fps: RDKit Fingerprint list
sim_cutoff: 相似度阈值(如 0.6~0.8
返回:簇标签 (0..K-1)
"""
n = len(fps)
# 构建完整的压缩距离矩阵
dists = []
for i in range(1, n):
sims = DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i])
dists.extend([1.0 - s for s in sims]) # 全部保存
# 调用 Butina 聚类
cs = Butina.ClusterData(dists, n, 1.0 - sim_cutoff, isDistData=True)
# 转换成标签数组
labels = np.full(n, -1, dtype=int)
for cid, members in enumerate(cs):
for m in members:
labels[m] = cid
# 把孤立点分配为独立簇(便于后续处理)
if np.any(labels == -1):
max_id = labels.max() if labels.max() >= 0 else -1
for i in np.where(labels == -1)[0]:
max_id += 1
labels[i] = max_id
return labels
def cluster_dbscan_threshold(fps, sim_cutoff: float = 0.6, eps: float = 0.5, min_samples: int = 5) -> np.ndarray:
"""
用相似度阈值构稀疏邻接,再交给 DBSCAN(metric='precomputed')。
"""
n = len(fps)
dm = np.full((n, n), 1.0, dtype=np.float32)
np.fill_diagonal(dm, 0.0)
for i in range(n):
sims = bulk_tanimoto_to_many(fps[i], fps)
d = 1.0 - np.array(sims, dtype=np.float32)
near = d <= (1.0 - sim_cutoff)
dm[i, near] = d[near]
clt = DBSCAN(metric='precomputed', eps=eps, min_samples=min_samples, n_jobs=-1)
labels = clt.fit_predict(dm)
# 保持 DBSCAN 的 -1 为噪声;如需改为独立簇可在外面处理
return labels
# ---------- 小数据可选:完整距离矩阵 ----------
def distance_matrix_from_fps(fps) -> np.ndarray:
n = len(fps)
dm = np.zeros((n, n), dtype=np.float32)
for i in range(n):
sims = bulk_tanimoto_to_many(fps[i], fps[i+1:])
if len(sims):
d = 1.0 - np.array(sims, dtype=np.float32)
dm[i, i+1:] = d
dm[i+1:, i] = d
return dm
def cluster_agglomerative_precomputed(fps, n_clusters=5, linkage_method='average') -> np.ndarray:
dm = distance_matrix_from_fps(fps)
clt = AgglomerativeClustering(n_clusters=n_clusters, metric='precomputed', linkage=linkage_method)
return clt.fit_predict(dm)
def cluster_scipy_linkage_cut(fps, method='average', t=0.7, criterion='distance') -> np.ndarray:
if linkage is None:
raise RuntimeError("scipy 未安装,无法使用 linkage。")
dm = distance_matrix_from_fps(fps)
n = dm.shape[0]
iu = np.triu_indices(n, k=1)
condensed = dm[iu]
Z = linkage(condensed, method=method)
labels = fcluster(Z, t=t, criterion=criterion) - 1
return labels
# ---------- 代表分子 ----------
def pick_medoid_indices(idx: List[int], fps) -> int:
"""返回簇内 medoid 的下标(全对全,适合簇不大时)"""
if len(idx) == 1:
return idx[0]
sub = idx
n = len(sub)
sums = []
for i in range(n):
sims = [tanimoto(fps[sub[i]], fps[sub[j]]) for j in range(n)]
dsum = float(np.sum(1.0 - np.array(sims)))
sums.append(dsum)
return sub[int(np.argmin(sums))]
def pick_maxmin_indices(idx: List[int], fps, k=1) -> List[int]:
"""在指定索引集合里挑 k 个多样性最大的MaxMin"""
if not idx:
return []
chosen = [idx[0]]
while len(chosen) < min(k, len(idx)):
best_j, best_min_d = None, -1.0
for j in idx:
if j in chosen:
continue
mind = 1.0
for i in chosen:
d = 1.0 - tanimoto(fps[i], fps[j])
if d < mind:
mind = d
if mind > best_min_d:
best_min_d, best_j = mind, j
chosen.append(best_j)
return chosen
def select_representatives(labels: np.ndarray, fps, per_cluster: int = 1,
strategy: str = "medoid") -> List[int]:
reps = []
for c in np.unique(labels):
members = np.where(labels == c)[0].tolist()
if not members:
continue
if strategy == "medoid":
reps.append(pick_medoid_indices(members, fps))
elif strategy == "maxmin":
reps.extend(pick_maxmin_indices(members, fps, k=per_cluster))
else:
raise ValueError("strategy ∈ {'medoid','maxmin'}")
return sorted(set(reps))
# ---------- 评估指标 ----------
@dataclass
class ClusterStats:
n_samples: int
n_clusters: int
largest_cluster_size: int
largest_cluster_ratio: float
sizes: Dict[int, int] = field(default_factory=dict)
def cluster_stats(labels: np.ndarray) -> ClusterStats:
n = len(labels)
vals, counts = np.unique(labels, return_counts=True)
largest = int(counts.max())
return ClusterStats(
n_samples=int(n),
n_clusters=int(len(vals)),
largest_cluster_size=largest,
largest_cluster_ratio=float(largest / n),
sizes={int(v): int(c) for v, c in zip(vals, counts)}
)
# ---------- 对外 API ----------
@dataclass
class TanimotoClusteringAPI:
fp_cfg: FPConfig = field(default_factory=FPConfig)
params: dict = field(default_factory=dict)
def fit_from_smiles(self, smiles: List[str], method: str = "butina",
method_kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
返回:{ 'labels': np.ndarray, 'keep_idx': List[int], 'fps': List[FP], 'stats': ClusterStats }
"""
method_kwargs = method_kwargs or {}
mols, keep_idx = smiles_to_mols(smiles)
if len(mols) == 0:
raise ValueError("没有有效的 SMILES。")
fps = mols_to_ecfp(mols, self.fp_cfg)
if method == "butina":
labels = cluster_butina(fps, **method_kwargs)
elif method == "dbscan_threshold":
labels = cluster_dbscan_threshold(fps, **method_kwargs)
elif method == "agglomerative":
labels = cluster_agglomerative_precomputed(fps, **method_kwargs)
elif method == "scipy_linkage":
labels = cluster_scipy_linkage_cut(fps, **method_kwargs)
else:
raise ValueError("未知 method")
stats = cluster_stats(labels)
return {"labels": labels, "keep_idx": keep_idx, "fps": fps, "stats": stats}
# ChemPlot 可视化(可选)
def chemplot_visualize(self, smiles_valid: List[str], labels: np.ndarray,
target: Optional[pd.Series] = None,
target_type: Optional[str] = None,
random_state: int = 42):
from chemplot import Plotter
import matplotlib.pyplot as plt
cp = Plotter.from_smiles(smiles_valid, sim_type="structural",
target=target, target_type=target_type)
cp.umap(random_state=random_state)
cp._Plotter__df_2_components["clusters"] = labels
fig = cp.visualize_plot(clusters=True)
return cp, fig
# 网格探索(大样本优先用 Butina / 稀疏 DBSCAN
def explore(self, smiles: List[str],
butina_cuts=(0.5, 0.6, 0.7, 0.75, 0.8),
dbscan_params=((0.6, 0.5, 5), (0.6, 0.4, 5), (0.7, 0.5, 5))
) -> List[Dict[str, Any]]:
"""
返回每个配置的 stats 列表,便于挑选“最大簇占比最小”的方案。
"""
results = []
mols, keep_idx = smiles_to_mols(smiles)
fps = mols_to_ecfp(mols, self.fp_cfg)
for cut in butina_cuts:
labels = cluster_butina(fps, sim_cutoff=cut)
stats = cluster_stats(labels)
results.append({"method": "butina", "params": {"sim_cutoff": cut}, "stats": stats})
for sim_cut, eps, min_samples in dbscan_params:
labels = cluster_dbscan_threshold(fps, sim_cutoff=sim_cut, eps=eps, min_samples=min_samples)
stats = cluster_stats(labels)
results.append({"method": "dbscan_threshold",
"params": {"sim_cutoff": sim_cut, "eps": eps, "min_samples": min_samples},
"stats": stats})
return results
# ---------- 自动搜索封装(对外 API 保持一致) ----------
@dataclass
class SearchConfig:
# 指纹配置
radii: List[int] = field(default_factory=lambda: [2, 3])
n_bits_list: List[int] = field(default_factory=lambda: [1024, 2048])
# 方法搜索空间
butina_cuts: List[float] = field(default_factory=lambda: [0.5, 0.6, 0.7, 0.75, 0.8])
dbscan_params: List[Tuple[float, float, int]] = field(default_factory=lambda: [(0.6, 0.5, 5), (0.6, 0.4, 5), (0.7, 0.5, 5)])
# 小样本时可选
try_agglomerative: bool = True
agglo_k_list: List[int] = field(default_factory=lambda: [5, 8, 10])
try_scipy: bool = False
scipy_t_list: List[float] = field(default_factory=lambda: [0.6, 0.7])
linkage: str = "average"
def _score(stats: ClusterStats) -> float:
# 越大越好:更均衡(最大簇比例小)+ 簇数量>1的奖励
balance = 1.0 - stats.largest_cluster_ratio
bonus = 0.1 if stats.n_clusters > 1 else 0.0
return balance + bonus
def search_best_config(smiles: List[str], cfg: Optional[SearchConfig] = None) -> Tuple[TanimotoClusteringAPI, ClusterStats, pd.DataFrame]:
"""
自动尝试不同 FP 半径/位数 + (Butina/DBSCAN/可选Agglo/SciPy)
以 1-最大簇比例 为主的评分,返回:最佳 API已设置好 fp_cfg 与 method/参数)、其 stats、以及历史对比表。
"""
cfg = cfg or SearchConfig()
trials: List[Dict[str, Any]] = []
for r in cfg.radii:
for nb in cfg.n_bits_list:
api = TanimotoClusteringAPI(fp_cfg=FPConfig(radius=r, n_bits=nb))
# 1) Butina
for cut in cfg.butina_cuts:
res = api.fit_from_smiles(smiles, method="butina", method_kwargs={"sim_cutoff": cut})
sc = _score(res["stats"])
trials.append({"fp_radius": r, "fp_n_bits": nb, "method": "butina",
"params": {"sim_cutoff": cut}, "stats": res["stats"], "score": sc})
# 2) 稀疏 DBSCAN
for (sim_cut, eps, ms) in cfg.dbscan_params:
res = api.fit_from_smiles(smiles, method="dbscan_threshold",
method_kwargs={"sim_cutoff": sim_cut, "eps": eps, "min_samples": ms})
sc = _score(res["stats"])
trials.append({"fp_radius": r, "fp_n_bits": nb, "method": "dbscan_threshold",
"params": {"sim_cutoff": sim_cut, "eps": eps, "min_samples": ms},
"stats": res["stats"], "score": sc})
# 3) 可选Agglomerative小样本或探索时
if cfg.try_agglomerative:
for k in cfg.agglo_k_list:
res = api.fit_from_smiles(smiles, method="agglomerative",
method_kwargs={"n_clusters": k, "linkage_method": cfg.linkage})
sc = _score(res["stats"])
trials.append({"fp_radius": r, "fp_n_bits": nb, "method": "agglomerative",
"params": {"n_clusters": k, "linkage_method": cfg.linkage},
"stats": res["stats"], "score": sc})
# 4) 可选SciPy linkage需小样本 & 安装 scipy
if cfg.try_scipy and linkage is not None:
for t in cfg.scipy_t_list:
res = api.fit_from_smiles(smiles, method="scipy_linkage",
method_kwargs={"method": cfg.linkage, "t": t, "criterion": "distance"})
sc = _score(res["stats"])
trials.append({"fp_radius": r, "fp_n_bits": nb, "method": "scipy_linkage",
"params": {"t": t, "method": cfg.linkage, "criterion": "distance"},
"stats": res["stats"], "score": sc})
# 排序选最优
rows = []
for t in trials:
s: ClusterStats = t["stats"]
rows.append({
"fp_radius": t["fp_radius"],
"fp_n_bits": t["fp_n_bits"],
"method": t["method"],
"params": t["params"],
"n_samples": s.n_samples,
"n_clusters": s.n_clusters,
"largest_cluster_ratio": s.largest_cluster_ratio,
"sizes": s.sizes,
"score": t["score"],
})
hist = pd.DataFrame(rows).sort_values("score", ascending=False).reset_index(drop=True)
best = hist.iloc[0]
# 组装一个“已配置好的” API 返回
api_best = TanimotoClusteringAPI(fp_cfg=FPConfig(radius=int(best["fp_radius"]), n_bits=int(best["fp_n_bits"])))
# 触发一次拟合以便调用方马上拿 labels/keep_idx
res = api_best.fit_from_smiles(smiles, method=best["method"], method_kwargs=best["params"])
return api_best, res["stats"], hist