# -*- coding: utf-8 -*- from __future__ import annotations import math import random from dataclasses import dataclass, field from typing import List, Optional, Dict, Any, Tuple import numpy as np import pandas as pd from rdkit import Chem, RDLogger from rdkit.Chem import rdMolDescriptors from rdkit import DataStructs from rdkit.ML.Cluster import Butina from sklearn.cluster import DBSCAN, AgglomerativeClustering try: from scipy.cluster.hierarchy import linkage, fcluster except Exception: linkage = None fcluster = None # 静音 RDKit RDLogger.DisableLog('rdApp.*') # ---------- 指纹 & 相似度工具 ---------- @dataclass class FPConfig: radius: int = 2 # ECFP 半径(2=ECFP4) n_bits: int = 2048 # 位数(1024 或 2048 常用) use_count: bool = False # True: 计数 FP;False: bit 向量(推荐用于 Tanimoto) def smiles_to_mols(smiles: List[str]) -> Tuple[List[Chem.Mol], List[int]]: mols, idx = [], [] for i, smi in enumerate(smiles): m = Chem.MolFromSmiles(str(smi)) if m is not None: mols.append(m) idx.append(i) return mols, idx def mols_to_ecfp(mols: List[Chem.Mol], cfg: FPConfig): fps = [] if cfg.use_count: gen = rdMolDescriptors.GetMorganFingerprint for m in mols: fps.append(gen(m, cfg.radius)) else: for m in mols: fps.append(rdMolDescriptors.GetMorganFingerprintAsBitVect(m, cfg.radius, nBits=cfg.n_bits)) return fps def tanimoto(a, b) -> float: return DataStructs.TanimotoSimilarity(a, b) def bulk_tanimoto_to_many(fp, fps) -> List[float]: # 对大量 pair 更快 return DataStructs.BulkTanimotoSimilarity(fp, fps) # ---------- 稳定可扩展的聚类(首选) ---------- def cluster_butina(fps, sim_cutoff: float = 0.6) -> np.ndarray: """ RDKit Butina 聚类 fps: RDKit Fingerprint list sim_cutoff: 相似度阈值(如 0.6~0.8) 返回:簇标签 (0..K-1) """ n = len(fps) # 构建完整的压缩距离矩阵 dists = [] for i in range(1, n): sims = DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i]) dists.extend([1.0 - s for s in sims]) # 全部保存 # 调用 Butina 聚类 cs = Butina.ClusterData(dists, n, 1.0 - sim_cutoff, isDistData=True) # 转换成标签数组 labels = np.full(n, -1, dtype=int) for cid, members in enumerate(cs): for m in members: labels[m] = cid # 把孤立点分配为独立簇(便于后续处理) if np.any(labels == -1): max_id = labels.max() if labels.max() >= 0 else -1 for i in np.where(labels == -1)[0]: max_id += 1 labels[i] = max_id return labels def cluster_dbscan_threshold(fps, sim_cutoff: float = 0.6, eps: float = 0.5, min_samples: int = 5) -> np.ndarray: """ 用相似度阈值构稀疏邻接,再交给 DBSCAN(metric='precomputed')。 """ n = len(fps) dm = np.full((n, n), 1.0, dtype=np.float32) np.fill_diagonal(dm, 0.0) for i in range(n): sims = bulk_tanimoto_to_many(fps[i], fps) d = 1.0 - np.array(sims, dtype=np.float32) near = d <= (1.0 - sim_cutoff) dm[i, near] = d[near] clt = DBSCAN(metric='precomputed', eps=eps, min_samples=min_samples, n_jobs=-1) labels = clt.fit_predict(dm) # 保持 DBSCAN 的 -1 为噪声;如需改为独立簇可在外面处理 return labels # ---------- 小数据可选:完整距离矩阵 ---------- def distance_matrix_from_fps(fps) -> np.ndarray: n = len(fps) dm = np.zeros((n, n), dtype=np.float32) for i in range(n): sims = bulk_tanimoto_to_many(fps[i], fps[i+1:]) if len(sims): d = 1.0 - np.array(sims, dtype=np.float32) dm[i, i+1:] = d dm[i+1:, i] = d return dm def cluster_agglomerative_precomputed(fps, n_clusters=5, linkage_method='average') -> np.ndarray: dm = distance_matrix_from_fps(fps) clt = AgglomerativeClustering(n_clusters=n_clusters, metric='precomputed', linkage=linkage_method) return clt.fit_predict(dm) def cluster_scipy_linkage_cut(fps, method='average', t=0.7, criterion='distance') -> np.ndarray: if linkage is None: raise RuntimeError("scipy 未安装,无法使用 linkage。") dm = distance_matrix_from_fps(fps) n = dm.shape[0] iu = np.triu_indices(n, k=1) condensed = dm[iu] Z = linkage(condensed, method=method) labels = fcluster(Z, t=t, criterion=criterion) - 1 return labels # ---------- 代表分子 ---------- def pick_medoid_indices(idx: List[int], fps) -> int: """返回簇内 medoid 的下标(全对全,适合簇不大时)""" if len(idx) == 1: return idx[0] sub = idx n = len(sub) sums = [] for i in range(n): sims = [tanimoto(fps[sub[i]], fps[sub[j]]) for j in range(n)] dsum = float(np.sum(1.0 - np.array(sims))) sums.append(dsum) return sub[int(np.argmin(sums))] def pick_maxmin_indices(idx: List[int], fps, k=1) -> List[int]: """在指定索引集合里挑 k 个多样性最大的(MaxMin)""" if not idx: return [] chosen = [idx[0]] while len(chosen) < min(k, len(idx)): best_j, best_min_d = None, -1.0 for j in idx: if j in chosen: continue mind = 1.0 for i in chosen: d = 1.0 - tanimoto(fps[i], fps[j]) if d < mind: mind = d if mind > best_min_d: best_min_d, best_j = mind, j chosen.append(best_j) return chosen def select_representatives(labels: np.ndarray, fps, per_cluster: int = 1, strategy: str = "medoid") -> List[int]: reps = [] for c in np.unique(labels): members = np.where(labels == c)[0].tolist() if not members: continue if strategy == "medoid": reps.append(pick_medoid_indices(members, fps)) elif strategy == "maxmin": reps.extend(pick_maxmin_indices(members, fps, k=per_cluster)) else: raise ValueError("strategy ∈ {'medoid','maxmin'}") return sorted(set(reps)) # ---------- 评估指标 ---------- @dataclass class ClusterStats: n_samples: int n_clusters: int largest_cluster_size: int largest_cluster_ratio: float sizes: Dict[int, int] = field(default_factory=dict) def cluster_stats(labels: np.ndarray) -> ClusterStats: n = len(labels) vals, counts = np.unique(labels, return_counts=True) largest = int(counts.max()) return ClusterStats( n_samples=int(n), n_clusters=int(len(vals)), largest_cluster_size=largest, largest_cluster_ratio=float(largest / n), sizes={int(v): int(c) for v, c in zip(vals, counts)} ) # ---------- 对外 API ---------- @dataclass class TanimotoClusteringAPI: fp_cfg: FPConfig = field(default_factory=FPConfig) params: dict = field(default_factory=dict) def fit_from_smiles(self, smiles: List[str], method: str = "butina", method_kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """ 返回:{ 'labels': np.ndarray, 'keep_idx': List[int], 'fps': List[FP], 'stats': ClusterStats } """ method_kwargs = method_kwargs or {} mols, keep_idx = smiles_to_mols(smiles) if len(mols) == 0: raise ValueError("没有有效的 SMILES。") fps = mols_to_ecfp(mols, self.fp_cfg) if method == "butina": labels = cluster_butina(fps, **method_kwargs) elif method == "dbscan_threshold": labels = cluster_dbscan_threshold(fps, **method_kwargs) elif method == "agglomerative": labels = cluster_agglomerative_precomputed(fps, **method_kwargs) elif method == "scipy_linkage": labels = cluster_scipy_linkage_cut(fps, **method_kwargs) else: raise ValueError("未知 method") stats = cluster_stats(labels) return {"labels": labels, "keep_idx": keep_idx, "fps": fps, "stats": stats} # ChemPlot 可视化(可选) def chemplot_visualize(self, smiles_valid: List[str], labels: np.ndarray, target: Optional[pd.Series] = None, target_type: Optional[str] = None, random_state: int = 42): from chemplot import Plotter import matplotlib.pyplot as plt cp = Plotter.from_smiles(smiles_valid, sim_type="structural", target=target, target_type=target_type) cp.umap(random_state=random_state) cp._Plotter__df_2_components["clusters"] = labels fig = cp.visualize_plot(clusters=True) return cp, fig # 网格探索(大样本优先用 Butina / 稀疏 DBSCAN) def explore(self, smiles: List[str], butina_cuts=(0.5, 0.6, 0.7, 0.75, 0.8), dbscan_params=((0.6, 0.5, 5), (0.6, 0.4, 5), (0.7, 0.5, 5)) ) -> List[Dict[str, Any]]: """ 返回每个配置的 stats 列表,便于挑选“最大簇占比最小”的方案。 """ results = [] mols, keep_idx = smiles_to_mols(smiles) fps = mols_to_ecfp(mols, self.fp_cfg) for cut in butina_cuts: labels = cluster_butina(fps, sim_cutoff=cut) stats = cluster_stats(labels) results.append({"method": "butina", "params": {"sim_cutoff": cut}, "stats": stats}) for sim_cut, eps, min_samples in dbscan_params: labels = cluster_dbscan_threshold(fps, sim_cutoff=sim_cut, eps=eps, min_samples=min_samples) stats = cluster_stats(labels) results.append({"method": "dbscan_threshold", "params": {"sim_cutoff": sim_cut, "eps": eps, "min_samples": min_samples}, "stats": stats}) return results # ---------- 自动搜索封装(对外 API 保持一致) ---------- @dataclass class SearchConfig: # 指纹配置 radii: List[int] = field(default_factory=lambda: [2, 3]) n_bits_list: List[int] = field(default_factory=lambda: [1024, 2048]) # 方法搜索空间 butina_cuts: List[float] = field(default_factory=lambda: [0.5, 0.6, 0.7, 0.75, 0.8]) dbscan_params: List[Tuple[float, float, int]] = field(default_factory=lambda: [(0.6, 0.5, 5), (0.6, 0.4, 5), (0.7, 0.5, 5)]) # 小样本时可选 try_agglomerative: bool = True agglo_k_list: List[int] = field(default_factory=lambda: [5, 8, 10]) try_scipy: bool = False scipy_t_list: List[float] = field(default_factory=lambda: [0.6, 0.7]) linkage: str = "average" def _score(stats: ClusterStats) -> float: # 越大越好:更均衡(最大簇比例小)+ 簇数量>1的奖励 balance = 1.0 - stats.largest_cluster_ratio bonus = 0.1 if stats.n_clusters > 1 else 0.0 return balance + bonus def search_best_config(smiles: List[str], cfg: Optional[SearchConfig] = None) -> Tuple[TanimotoClusteringAPI, ClusterStats, pd.DataFrame]: """ 自动尝试不同 FP 半径/位数 + (Butina/DBSCAN/可选Agglo/SciPy), 以 1-最大簇比例 为主的评分,返回:最佳 API(已设置好 fp_cfg 与 method/参数)、其 stats、以及历史对比表。 """ cfg = cfg or SearchConfig() trials: List[Dict[str, Any]] = [] for r in cfg.radii: for nb in cfg.n_bits_list: api = TanimotoClusteringAPI(fp_cfg=FPConfig(radius=r, n_bits=nb)) # 1) Butina for cut in cfg.butina_cuts: res = api.fit_from_smiles(smiles, method="butina", method_kwargs={"sim_cutoff": cut}) sc = _score(res["stats"]) trials.append({"fp_radius": r, "fp_n_bits": nb, "method": "butina", "params": {"sim_cutoff": cut}, "stats": res["stats"], "score": sc}) # 2) 稀疏 DBSCAN for (sim_cut, eps, ms) in cfg.dbscan_params: res = api.fit_from_smiles(smiles, method="dbscan_threshold", method_kwargs={"sim_cutoff": sim_cut, "eps": eps, "min_samples": ms}) sc = _score(res["stats"]) trials.append({"fp_radius": r, "fp_n_bits": nb, "method": "dbscan_threshold", "params": {"sim_cutoff": sim_cut, "eps": eps, "min_samples": ms}, "stats": res["stats"], "score": sc}) # 3) (可选)Agglomerative(小样本或探索时) if cfg.try_agglomerative: for k in cfg.agglo_k_list: res = api.fit_from_smiles(smiles, method="agglomerative", method_kwargs={"n_clusters": k, "linkage_method": cfg.linkage}) sc = _score(res["stats"]) trials.append({"fp_radius": r, "fp_n_bits": nb, "method": "agglomerative", "params": {"n_clusters": k, "linkage_method": cfg.linkage}, "stats": res["stats"], "score": sc}) # 4) (可选)SciPy linkage(需小样本 & 安装 scipy) if cfg.try_scipy and linkage is not None: for t in cfg.scipy_t_list: res = api.fit_from_smiles(smiles, method="scipy_linkage", method_kwargs={"method": cfg.linkage, "t": t, "criterion": "distance"}) sc = _score(res["stats"]) trials.append({"fp_radius": r, "fp_n_bits": nb, "method": "scipy_linkage", "params": {"t": t, "method": cfg.linkage, "criterion": "distance"}, "stats": res["stats"], "score": sc}) # 排序选最优 rows = [] for t in trials: s: ClusterStats = t["stats"] rows.append({ "fp_radius": t["fp_radius"], "fp_n_bits": t["fp_n_bits"], "method": t["method"], "params": t["params"], "n_samples": s.n_samples, "n_clusters": s.n_clusters, "largest_cluster_ratio": s.largest_cluster_ratio, "sizes": s.sizes, "score": t["score"], }) hist = pd.DataFrame(rows).sort_values("score", ascending=False).reset_index(drop=True) best = hist.iloc[0] # 组装一个“已配置好的” API 返回 api_best = TanimotoClusteringAPI(fp_cfg=FPConfig(radius=int(best["fp_radius"]), n_bits=int(best["fp_n_bits"]))) # 触发一次拟合以便调用方马上拿 labels/keep_idx res = api_best.fit_from_smiles(smiles, method=best["method"], method_kwargs=best["params"]) return api_best, res["stats"], hist