聚类方法,聚类后选择打分最高那个分子,并对 karamadock 的结果求交集
This commit is contained in:
27
utils/chem_cluster/__init__.py
Normal file
27
utils/chem_cluster/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# utils/chem_cluster/__init__.py
|
||||
|
||||
from .tanimoto_cluster_api import (
|
||||
TanimotoClusteringAPI as TanimotoClusterer,
|
||||
SearchConfig,
|
||||
search_best_config,
|
||||
ClusterStats,
|
||||
FPConfig,
|
||||
cluster_butina,
|
||||
cluster_dbscan_threshold,
|
||||
cluster_agglomerative_precomputed,
|
||||
cluster_scipy_linkage_cut,
|
||||
select_representatives,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TanimotoClusterer",
|
||||
"SearchConfig",
|
||||
"search_best_config",
|
||||
"ClusterStats",
|
||||
"FPConfig",
|
||||
"cluster_butina",
|
||||
"cluster_dbscan_threshold",
|
||||
"cluster_agglomerative_precomputed",
|
||||
"cluster_scipy_linkage_cut",
|
||||
"select_representatives",
|
||||
]
|
||||
384
utils/chem_cluster/tanimoto_cluster_api.py
Normal file
384
utils/chem_cluster/tanimoto_cluster_api.py
Normal file
@@ -0,0 +1,384 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
import math
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from rdkit import Chem, RDLogger
|
||||
from rdkit.Chem import rdMolDescriptors
|
||||
from rdkit import DataStructs
|
||||
from rdkit.ML.Cluster import Butina
|
||||
|
||||
from sklearn.cluster import DBSCAN, AgglomerativeClustering
|
||||
|
||||
try:
|
||||
from scipy.cluster.hierarchy import linkage, fcluster
|
||||
except Exception:
|
||||
linkage = None
|
||||
fcluster = None
|
||||
|
||||
# 静音 RDKit
|
||||
RDLogger.DisableLog('rdApp.*')
|
||||
|
||||
|
||||
# ---------- 指纹 & 相似度工具 ----------
|
||||
@dataclass
|
||||
class FPConfig:
|
||||
radius: int = 2 # ECFP 半径(2=ECFP4)
|
||||
n_bits: int = 2048 # 位数(1024 或 2048 常用)
|
||||
use_count: bool = False # True: 计数 FP;False: bit 向量(推荐用于 Tanimoto)
|
||||
|
||||
|
||||
def smiles_to_mols(smiles: List[str]) -> Tuple[List[Chem.Mol], List[int]]:
|
||||
mols, idx = [], []
|
||||
for i, smi in enumerate(smiles):
|
||||
m = Chem.MolFromSmiles(str(smi))
|
||||
if m is not None:
|
||||
mols.append(m)
|
||||
idx.append(i)
|
||||
return mols, idx
|
||||
|
||||
|
||||
def mols_to_ecfp(mols: List[Chem.Mol], cfg: FPConfig):
|
||||
fps = []
|
||||
if cfg.use_count:
|
||||
gen = rdMolDescriptors.GetMorganFingerprint
|
||||
for m in mols:
|
||||
fps.append(gen(m, cfg.radius))
|
||||
else:
|
||||
for m in mols:
|
||||
fps.append(rdMolDescriptors.GetMorganFingerprintAsBitVect(m, cfg.radius, nBits=cfg.n_bits))
|
||||
return fps
|
||||
|
||||
|
||||
def tanimoto(a, b) -> float:
|
||||
return DataStructs.TanimotoSimilarity(a, b)
|
||||
|
||||
|
||||
def bulk_tanimoto_to_many(fp, fps) -> List[float]:
|
||||
# 对大量 pair 更快
|
||||
return DataStructs.BulkTanimotoSimilarity(fp, fps)
|
||||
|
||||
|
||||
# ---------- 稳定可扩展的聚类(首选) ----------
|
||||
def cluster_butina(fps, sim_cutoff: float = 0.6) -> np.ndarray:
|
||||
"""
|
||||
RDKit Butina 聚类
|
||||
fps: RDKit Fingerprint list
|
||||
sim_cutoff: 相似度阈值(如 0.6~0.8)
|
||||
返回:簇标签 (0..K-1)
|
||||
"""
|
||||
n = len(fps)
|
||||
# 构建完整的压缩距离矩阵
|
||||
dists = []
|
||||
for i in range(1, n):
|
||||
sims = DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i])
|
||||
dists.extend([1.0 - s for s in sims]) # 全部保存
|
||||
|
||||
# 调用 Butina 聚类
|
||||
cs = Butina.ClusterData(dists, n, 1.0 - sim_cutoff, isDistData=True)
|
||||
|
||||
# 转换成标签数组
|
||||
labels = np.full(n, -1, dtype=int)
|
||||
for cid, members in enumerate(cs):
|
||||
for m in members:
|
||||
labels[m] = cid
|
||||
|
||||
# 把孤立点分配为独立簇(便于后续处理)
|
||||
if np.any(labels == -1):
|
||||
max_id = labels.max() if labels.max() >= 0 else -1
|
||||
for i in np.where(labels == -1)[0]:
|
||||
max_id += 1
|
||||
labels[i] = max_id
|
||||
|
||||
return labels
|
||||
|
||||
|
||||
def cluster_dbscan_threshold(fps, sim_cutoff: float = 0.6, eps: float = 0.5, min_samples: int = 5) -> np.ndarray:
|
||||
"""
|
||||
用相似度阈值构稀疏邻接,再交给 DBSCAN(metric='precomputed')。
|
||||
"""
|
||||
n = len(fps)
|
||||
dm = np.full((n, n), 1.0, dtype=np.float32)
|
||||
np.fill_diagonal(dm, 0.0)
|
||||
for i in range(n):
|
||||
sims = bulk_tanimoto_to_many(fps[i], fps)
|
||||
d = 1.0 - np.array(sims, dtype=np.float32)
|
||||
near = d <= (1.0 - sim_cutoff)
|
||||
dm[i, near] = d[near]
|
||||
clt = DBSCAN(metric='precomputed', eps=eps, min_samples=min_samples, n_jobs=-1)
|
||||
labels = clt.fit_predict(dm)
|
||||
# 保持 DBSCAN 的 -1 为噪声;如需改为独立簇可在外面处理
|
||||
return labels
|
||||
|
||||
|
||||
# ---------- 小数据可选:完整距离矩阵 ----------
|
||||
def distance_matrix_from_fps(fps) -> np.ndarray:
|
||||
n = len(fps)
|
||||
dm = np.zeros((n, n), dtype=np.float32)
|
||||
for i in range(n):
|
||||
sims = bulk_tanimoto_to_many(fps[i], fps[i+1:])
|
||||
if len(sims):
|
||||
d = 1.0 - np.array(sims, dtype=np.float32)
|
||||
dm[i, i+1:] = d
|
||||
dm[i+1:, i] = d
|
||||
return dm
|
||||
|
||||
|
||||
def cluster_agglomerative_precomputed(fps, n_clusters=5, linkage_method='average') -> np.ndarray:
|
||||
dm = distance_matrix_from_fps(fps)
|
||||
clt = AgglomerativeClustering(n_clusters=n_clusters, metric='precomputed', linkage=linkage_method)
|
||||
return clt.fit_predict(dm)
|
||||
|
||||
|
||||
def cluster_scipy_linkage_cut(fps, method='average', t=0.7, criterion='distance') -> np.ndarray:
|
||||
if linkage is None:
|
||||
raise RuntimeError("scipy 未安装,无法使用 linkage。")
|
||||
dm = distance_matrix_from_fps(fps)
|
||||
n = dm.shape[0]
|
||||
iu = np.triu_indices(n, k=1)
|
||||
condensed = dm[iu]
|
||||
Z = linkage(condensed, method=method)
|
||||
labels = fcluster(Z, t=t, criterion=criterion) - 1
|
||||
return labels
|
||||
|
||||
|
||||
# ---------- 代表分子 ----------
|
||||
def pick_medoid_indices(idx: List[int], fps) -> int:
|
||||
"""返回簇内 medoid 的下标(全对全,适合簇不大时)"""
|
||||
if len(idx) == 1:
|
||||
return idx[0]
|
||||
sub = idx
|
||||
n = len(sub)
|
||||
sums = []
|
||||
for i in range(n):
|
||||
sims = [tanimoto(fps[sub[i]], fps[sub[j]]) for j in range(n)]
|
||||
dsum = float(np.sum(1.0 - np.array(sims)))
|
||||
sums.append(dsum)
|
||||
return sub[int(np.argmin(sums))]
|
||||
|
||||
|
||||
def pick_maxmin_indices(idx: List[int], fps, k=1) -> List[int]:
|
||||
"""在指定索引集合里挑 k 个多样性最大的(MaxMin)"""
|
||||
if not idx:
|
||||
return []
|
||||
chosen = [idx[0]]
|
||||
while len(chosen) < min(k, len(idx)):
|
||||
best_j, best_min_d = None, -1.0
|
||||
for j in idx:
|
||||
if j in chosen:
|
||||
continue
|
||||
mind = 1.0
|
||||
for i in chosen:
|
||||
d = 1.0 - tanimoto(fps[i], fps[j])
|
||||
if d < mind:
|
||||
mind = d
|
||||
if mind > best_min_d:
|
||||
best_min_d, best_j = mind, j
|
||||
chosen.append(best_j)
|
||||
return chosen
|
||||
|
||||
|
||||
def select_representatives(labels: np.ndarray, fps, per_cluster: int = 1,
|
||||
strategy: str = "medoid") -> List[int]:
|
||||
reps = []
|
||||
for c in np.unique(labels):
|
||||
members = np.where(labels == c)[0].tolist()
|
||||
if not members:
|
||||
continue
|
||||
if strategy == "medoid":
|
||||
reps.append(pick_medoid_indices(members, fps))
|
||||
elif strategy == "maxmin":
|
||||
reps.extend(pick_maxmin_indices(members, fps, k=per_cluster))
|
||||
else:
|
||||
raise ValueError("strategy ∈ {'medoid','maxmin'}")
|
||||
return sorted(set(reps))
|
||||
|
||||
|
||||
# ---------- 评估指标 ----------
|
||||
@dataclass
|
||||
class ClusterStats:
|
||||
n_samples: int
|
||||
n_clusters: int
|
||||
largest_cluster_size: int
|
||||
largest_cluster_ratio: float
|
||||
sizes: Dict[int, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
def cluster_stats(labels: np.ndarray) -> ClusterStats:
|
||||
n = len(labels)
|
||||
vals, counts = np.unique(labels, return_counts=True)
|
||||
largest = int(counts.max())
|
||||
return ClusterStats(
|
||||
n_samples=int(n),
|
||||
n_clusters=int(len(vals)),
|
||||
largest_cluster_size=largest,
|
||||
largest_cluster_ratio=float(largest / n),
|
||||
sizes={int(v): int(c) for v, c in zip(vals, counts)}
|
||||
)
|
||||
|
||||
|
||||
# ---------- 对外 API ----------
|
||||
@dataclass
|
||||
class TanimotoClusteringAPI:
|
||||
fp_cfg: FPConfig = field(default_factory=FPConfig)
|
||||
params: dict = field(default_factory=dict)
|
||||
|
||||
def fit_from_smiles(self, smiles: List[str], method: str = "butina",
|
||||
method_kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
返回:{ 'labels': np.ndarray, 'keep_idx': List[int], 'fps': List[FP], 'stats': ClusterStats }
|
||||
"""
|
||||
method_kwargs = method_kwargs or {}
|
||||
mols, keep_idx = smiles_to_mols(smiles)
|
||||
if len(mols) == 0:
|
||||
raise ValueError("没有有效的 SMILES。")
|
||||
fps = mols_to_ecfp(mols, self.fp_cfg)
|
||||
|
||||
if method == "butina":
|
||||
labels = cluster_butina(fps, **method_kwargs)
|
||||
elif method == "dbscan_threshold":
|
||||
labels = cluster_dbscan_threshold(fps, **method_kwargs)
|
||||
elif method == "agglomerative":
|
||||
labels = cluster_agglomerative_precomputed(fps, **method_kwargs)
|
||||
elif method == "scipy_linkage":
|
||||
labels = cluster_scipy_linkage_cut(fps, **method_kwargs)
|
||||
else:
|
||||
raise ValueError("未知 method")
|
||||
|
||||
stats = cluster_stats(labels)
|
||||
return {"labels": labels, "keep_idx": keep_idx, "fps": fps, "stats": stats}
|
||||
|
||||
# ChemPlot 可视化(可选)
|
||||
def chemplot_visualize(self, smiles_valid: List[str], labels: np.ndarray,
|
||||
target: Optional[pd.Series] = None,
|
||||
target_type: Optional[str] = None,
|
||||
random_state: int = 42):
|
||||
from chemplot import Plotter
|
||||
import matplotlib.pyplot as plt
|
||||
cp = Plotter.from_smiles(smiles_valid, sim_type="structural",
|
||||
target=target, target_type=target_type)
|
||||
cp.umap(random_state=random_state)
|
||||
cp._Plotter__df_2_components["clusters"] = labels
|
||||
fig = cp.visualize_plot(clusters=True)
|
||||
return cp, fig
|
||||
|
||||
# 网格探索(大样本优先用 Butina / 稀疏 DBSCAN)
|
||||
def explore(self, smiles: List[str],
|
||||
butina_cuts=(0.5, 0.6, 0.7, 0.75, 0.8),
|
||||
dbscan_params=((0.6, 0.5, 5), (0.6, 0.4, 5), (0.7, 0.5, 5))
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
返回每个配置的 stats 列表,便于挑选“最大簇占比最小”的方案。
|
||||
"""
|
||||
results = []
|
||||
mols, keep_idx = smiles_to_mols(smiles)
|
||||
fps = mols_to_ecfp(mols, self.fp_cfg)
|
||||
|
||||
for cut in butina_cuts:
|
||||
labels = cluster_butina(fps, sim_cutoff=cut)
|
||||
stats = cluster_stats(labels)
|
||||
results.append({"method": "butina", "params": {"sim_cutoff": cut}, "stats": stats})
|
||||
|
||||
for sim_cut, eps, min_samples in dbscan_params:
|
||||
labels = cluster_dbscan_threshold(fps, sim_cutoff=sim_cut, eps=eps, min_samples=min_samples)
|
||||
stats = cluster_stats(labels)
|
||||
results.append({"method": "dbscan_threshold",
|
||||
"params": {"sim_cutoff": sim_cut, "eps": eps, "min_samples": min_samples},
|
||||
"stats": stats})
|
||||
return results
|
||||
|
||||
|
||||
# ---------- 自动搜索封装(对外 API 保持一致) ----------
|
||||
@dataclass
|
||||
class SearchConfig:
|
||||
# 指纹配置
|
||||
radii: List[int] = field(default_factory=lambda: [2, 3])
|
||||
n_bits_list: List[int] = field(default_factory=lambda: [1024, 2048])
|
||||
# 方法搜索空间
|
||||
butina_cuts: List[float] = field(default_factory=lambda: [0.5, 0.6, 0.7, 0.75, 0.8])
|
||||
dbscan_params: List[Tuple[float, float, int]] = field(default_factory=lambda: [(0.6, 0.5, 5), (0.6, 0.4, 5), (0.7, 0.5, 5)])
|
||||
# 小样本时可选
|
||||
try_agglomerative: bool = True
|
||||
agglo_k_list: List[int] = field(default_factory=lambda: [5, 8, 10])
|
||||
try_scipy: bool = False
|
||||
scipy_t_list: List[float] = field(default_factory=lambda: [0.6, 0.7])
|
||||
linkage: str = "average"
|
||||
|
||||
def _score(stats: ClusterStats) -> float:
|
||||
# 越大越好:更均衡(最大簇比例小)+ 簇数量>1的奖励
|
||||
balance = 1.0 - stats.largest_cluster_ratio
|
||||
bonus = 0.1 if stats.n_clusters > 1 else 0.0
|
||||
return balance + bonus
|
||||
|
||||
def search_best_config(smiles: List[str], cfg: Optional[SearchConfig] = None) -> Tuple[TanimotoClusteringAPI, ClusterStats, pd.DataFrame]:
|
||||
"""
|
||||
自动尝试不同 FP 半径/位数 + (Butina/DBSCAN/可选Agglo/SciPy),
|
||||
以 1-最大簇比例 为主的评分,返回:最佳 API(已设置好 fp_cfg 与 method/参数)、其 stats、以及历史对比表。
|
||||
"""
|
||||
cfg = cfg or SearchConfig()
|
||||
trials: List[Dict[str, Any]] = []
|
||||
|
||||
for r in cfg.radii:
|
||||
for nb in cfg.n_bits_list:
|
||||
api = TanimotoClusteringAPI(fp_cfg=FPConfig(radius=r, n_bits=nb))
|
||||
# 1) Butina
|
||||
for cut in cfg.butina_cuts:
|
||||
res = api.fit_from_smiles(smiles, method="butina", method_kwargs={"sim_cutoff": cut})
|
||||
sc = _score(res["stats"])
|
||||
trials.append({"fp_radius": r, "fp_n_bits": nb, "method": "butina",
|
||||
"params": {"sim_cutoff": cut}, "stats": res["stats"], "score": sc})
|
||||
# 2) 稀疏 DBSCAN
|
||||
for (sim_cut, eps, ms) in cfg.dbscan_params:
|
||||
res = api.fit_from_smiles(smiles, method="dbscan_threshold",
|
||||
method_kwargs={"sim_cutoff": sim_cut, "eps": eps, "min_samples": ms})
|
||||
sc = _score(res["stats"])
|
||||
trials.append({"fp_radius": r, "fp_n_bits": nb, "method": "dbscan_threshold",
|
||||
"params": {"sim_cutoff": sim_cut, "eps": eps, "min_samples": ms},
|
||||
"stats": res["stats"], "score": sc})
|
||||
# 3) (可选)Agglomerative(小样本或探索时)
|
||||
if cfg.try_agglomerative:
|
||||
for k in cfg.agglo_k_list:
|
||||
res = api.fit_from_smiles(smiles, method="agglomerative",
|
||||
method_kwargs={"n_clusters": k, "linkage_method": cfg.linkage})
|
||||
sc = _score(res["stats"])
|
||||
trials.append({"fp_radius": r, "fp_n_bits": nb, "method": "agglomerative",
|
||||
"params": {"n_clusters": k, "linkage_method": cfg.linkage},
|
||||
"stats": res["stats"], "score": sc})
|
||||
# 4) (可选)SciPy linkage(需小样本 & 安装 scipy)
|
||||
if cfg.try_scipy and linkage is not None:
|
||||
for t in cfg.scipy_t_list:
|
||||
res = api.fit_from_smiles(smiles, method="scipy_linkage",
|
||||
method_kwargs={"method": cfg.linkage, "t": t, "criterion": "distance"})
|
||||
sc = _score(res["stats"])
|
||||
trials.append({"fp_radius": r, "fp_n_bits": nb, "method": "scipy_linkage",
|
||||
"params": {"t": t, "method": cfg.linkage, "criterion": "distance"},
|
||||
"stats": res["stats"], "score": sc})
|
||||
|
||||
# 排序选最优
|
||||
rows = []
|
||||
for t in trials:
|
||||
s: ClusterStats = t["stats"]
|
||||
rows.append({
|
||||
"fp_radius": t["fp_radius"],
|
||||
"fp_n_bits": t["fp_n_bits"],
|
||||
"method": t["method"],
|
||||
"params": t["params"],
|
||||
"n_samples": s.n_samples,
|
||||
"n_clusters": s.n_clusters,
|
||||
"largest_cluster_ratio": s.largest_cluster_ratio,
|
||||
"sizes": s.sizes,
|
||||
"score": t["score"],
|
||||
})
|
||||
hist = pd.DataFrame(rows).sort_values("score", ascending=False).reset_index(drop=True)
|
||||
|
||||
best = hist.iloc[0]
|
||||
# 组装一个“已配置好的” API 返回
|
||||
api_best = TanimotoClusteringAPI(fp_cfg=FPConfig(radius=int(best["fp_radius"]), n_bits=int(best["fp_n_bits"])))
|
||||
# 触发一次拟合以便调用方马上拿 labels/keep_idx
|
||||
res = api_best.fit_from_smiles(smiles, method=best["method"], method_kwargs=best["params"])
|
||||
return api_best, res["stats"], hist
|
||||
Reference in New Issue
Block a user