385 lines
15 KiB
Python
385 lines
15 KiB
Python
# -*- coding: utf-8 -*-
|
||
from __future__ import annotations
|
||
import math
|
||
import random
|
||
from dataclasses import dataclass, field
|
||
from typing import List, Optional, Dict, Any, Tuple
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
|
||
from rdkit import Chem, RDLogger
|
||
from rdkit.Chem import rdMolDescriptors
|
||
from rdkit import DataStructs
|
||
from rdkit.ML.Cluster import Butina
|
||
|
||
from sklearn.cluster import DBSCAN, AgglomerativeClustering
|
||
|
||
try:
|
||
from scipy.cluster.hierarchy import linkage, fcluster
|
||
except Exception:
|
||
linkage = None
|
||
fcluster = None
|
||
|
||
# 静音 RDKit
|
||
RDLogger.DisableLog('rdApp.*')
|
||
|
||
|
||
# ---------- 指纹 & 相似度工具 ----------
|
||
@dataclass
|
||
class FPConfig:
|
||
radius: int = 2 # ECFP 半径(2=ECFP4)
|
||
n_bits: int = 2048 # 位数(1024 或 2048 常用)
|
||
use_count: bool = False # True: 计数 FP;False: bit 向量(推荐用于 Tanimoto)
|
||
|
||
|
||
def smiles_to_mols(smiles: List[str]) -> Tuple[List[Chem.Mol], List[int]]:
|
||
mols, idx = [], []
|
||
for i, smi in enumerate(smiles):
|
||
m = Chem.MolFromSmiles(str(smi))
|
||
if m is not None:
|
||
mols.append(m)
|
||
idx.append(i)
|
||
return mols, idx
|
||
|
||
|
||
def mols_to_ecfp(mols: List[Chem.Mol], cfg: FPConfig):
|
||
fps = []
|
||
if cfg.use_count:
|
||
gen = rdMolDescriptors.GetMorganFingerprint
|
||
for m in mols:
|
||
fps.append(gen(m, cfg.radius))
|
||
else:
|
||
for m in mols:
|
||
fps.append(rdMolDescriptors.GetMorganFingerprintAsBitVect(m, cfg.radius, nBits=cfg.n_bits))
|
||
return fps
|
||
|
||
|
||
def tanimoto(a, b) -> float:
|
||
return DataStructs.TanimotoSimilarity(a, b)
|
||
|
||
|
||
def bulk_tanimoto_to_many(fp, fps) -> List[float]:
|
||
# 对大量 pair 更快
|
||
return DataStructs.BulkTanimotoSimilarity(fp, fps)
|
||
|
||
|
||
# ---------- 稳定可扩展的聚类(首选) ----------
|
||
def cluster_butina(fps, sim_cutoff: float = 0.6) -> np.ndarray:
|
||
"""
|
||
RDKit Butina 聚类
|
||
fps: RDKit Fingerprint list
|
||
sim_cutoff: 相似度阈值(如 0.6~0.8)
|
||
返回:簇标签 (0..K-1)
|
||
"""
|
||
n = len(fps)
|
||
# 构建完整的压缩距离矩阵
|
||
dists = []
|
||
for i in range(1, n):
|
||
sims = DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i])
|
||
dists.extend([1.0 - s for s in sims]) # 全部保存
|
||
|
||
# 调用 Butina 聚类
|
||
cs = Butina.ClusterData(dists, n, 1.0 - sim_cutoff, isDistData=True)
|
||
|
||
# 转换成标签数组
|
||
labels = np.full(n, -1, dtype=int)
|
||
for cid, members in enumerate(cs):
|
||
for m in members:
|
||
labels[m] = cid
|
||
|
||
# 把孤立点分配为独立簇(便于后续处理)
|
||
if np.any(labels == -1):
|
||
max_id = labels.max() if labels.max() >= 0 else -1
|
||
for i in np.where(labels == -1)[0]:
|
||
max_id += 1
|
||
labels[i] = max_id
|
||
|
||
return labels
|
||
|
||
|
||
def cluster_dbscan_threshold(fps, sim_cutoff: float = 0.6, eps: float = 0.5, min_samples: int = 5) -> np.ndarray:
|
||
"""
|
||
用相似度阈值构稀疏邻接,再交给 DBSCAN(metric='precomputed')。
|
||
"""
|
||
n = len(fps)
|
||
dm = np.full((n, n), 1.0, dtype=np.float32)
|
||
np.fill_diagonal(dm, 0.0)
|
||
for i in range(n):
|
||
sims = bulk_tanimoto_to_many(fps[i], fps)
|
||
d = 1.0 - np.array(sims, dtype=np.float32)
|
||
near = d <= (1.0 - sim_cutoff)
|
||
dm[i, near] = d[near]
|
||
clt = DBSCAN(metric='precomputed', eps=eps, min_samples=min_samples, n_jobs=-1)
|
||
labels = clt.fit_predict(dm)
|
||
# 保持 DBSCAN 的 -1 为噪声;如需改为独立簇可在外面处理
|
||
return labels
|
||
|
||
|
||
# ---------- 小数据可选:完整距离矩阵 ----------
|
||
def distance_matrix_from_fps(fps) -> np.ndarray:
|
||
n = len(fps)
|
||
dm = np.zeros((n, n), dtype=np.float32)
|
||
for i in range(n):
|
||
sims = bulk_tanimoto_to_many(fps[i], fps[i+1:])
|
||
if len(sims):
|
||
d = 1.0 - np.array(sims, dtype=np.float32)
|
||
dm[i, i+1:] = d
|
||
dm[i+1:, i] = d
|
||
return dm
|
||
|
||
|
||
def cluster_agglomerative_precomputed(fps, n_clusters=5, linkage_method='average') -> np.ndarray:
|
||
dm = distance_matrix_from_fps(fps)
|
||
clt = AgglomerativeClustering(n_clusters=n_clusters, metric='precomputed', linkage=linkage_method)
|
||
return clt.fit_predict(dm)
|
||
|
||
|
||
def cluster_scipy_linkage_cut(fps, method='average', t=0.7, criterion='distance') -> np.ndarray:
|
||
if linkage is None:
|
||
raise RuntimeError("scipy 未安装,无法使用 linkage。")
|
||
dm = distance_matrix_from_fps(fps)
|
||
n = dm.shape[0]
|
||
iu = np.triu_indices(n, k=1)
|
||
condensed = dm[iu]
|
||
Z = linkage(condensed, method=method)
|
||
labels = fcluster(Z, t=t, criterion=criterion) - 1
|
||
return labels
|
||
|
||
|
||
# ---------- 代表分子 ----------
|
||
def pick_medoid_indices(idx: List[int], fps) -> int:
|
||
"""返回簇内 medoid 的下标(全对全,适合簇不大时)"""
|
||
if len(idx) == 1:
|
||
return idx[0]
|
||
sub = idx
|
||
n = len(sub)
|
||
sums = []
|
||
for i in range(n):
|
||
sims = [tanimoto(fps[sub[i]], fps[sub[j]]) for j in range(n)]
|
||
dsum = float(np.sum(1.0 - np.array(sims)))
|
||
sums.append(dsum)
|
||
return sub[int(np.argmin(sums))]
|
||
|
||
|
||
def pick_maxmin_indices(idx: List[int], fps, k=1) -> List[int]:
|
||
"""在指定索引集合里挑 k 个多样性最大的(MaxMin)"""
|
||
if not idx:
|
||
return []
|
||
chosen = [idx[0]]
|
||
while len(chosen) < min(k, len(idx)):
|
||
best_j, best_min_d = None, -1.0
|
||
for j in idx:
|
||
if j in chosen:
|
||
continue
|
||
mind = 1.0
|
||
for i in chosen:
|
||
d = 1.0 - tanimoto(fps[i], fps[j])
|
||
if d < mind:
|
||
mind = d
|
||
if mind > best_min_d:
|
||
best_min_d, best_j = mind, j
|
||
chosen.append(best_j)
|
||
return chosen
|
||
|
||
|
||
def select_representatives(labels: np.ndarray, fps, per_cluster: int = 1,
|
||
strategy: str = "medoid") -> List[int]:
|
||
reps = []
|
||
for c in np.unique(labels):
|
||
members = np.where(labels == c)[0].tolist()
|
||
if not members:
|
||
continue
|
||
if strategy == "medoid":
|
||
reps.append(pick_medoid_indices(members, fps))
|
||
elif strategy == "maxmin":
|
||
reps.extend(pick_maxmin_indices(members, fps, k=per_cluster))
|
||
else:
|
||
raise ValueError("strategy ∈ {'medoid','maxmin'}")
|
||
return sorted(set(reps))
|
||
|
||
|
||
# ---------- 评估指标 ----------
|
||
@dataclass
|
||
class ClusterStats:
|
||
n_samples: int
|
||
n_clusters: int
|
||
largest_cluster_size: int
|
||
largest_cluster_ratio: float
|
||
sizes: Dict[int, int] = field(default_factory=dict)
|
||
|
||
|
||
def cluster_stats(labels: np.ndarray) -> ClusterStats:
|
||
n = len(labels)
|
||
vals, counts = np.unique(labels, return_counts=True)
|
||
largest = int(counts.max())
|
||
return ClusterStats(
|
||
n_samples=int(n),
|
||
n_clusters=int(len(vals)),
|
||
largest_cluster_size=largest,
|
||
largest_cluster_ratio=float(largest / n),
|
||
sizes={int(v): int(c) for v, c in zip(vals, counts)}
|
||
)
|
||
|
||
|
||
# ---------- 对外 API ----------
|
||
@dataclass
|
||
class TanimotoClusteringAPI:
|
||
fp_cfg: FPConfig = field(default_factory=FPConfig)
|
||
params: dict = field(default_factory=dict)
|
||
|
||
def fit_from_smiles(self, smiles: List[str], method: str = "butina",
|
||
method_kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||
"""
|
||
返回:{ 'labels': np.ndarray, 'keep_idx': List[int], 'fps': List[FP], 'stats': ClusterStats }
|
||
"""
|
||
method_kwargs = method_kwargs or {}
|
||
mols, keep_idx = smiles_to_mols(smiles)
|
||
if len(mols) == 0:
|
||
raise ValueError("没有有效的 SMILES。")
|
||
fps = mols_to_ecfp(mols, self.fp_cfg)
|
||
|
||
if method == "butina":
|
||
labels = cluster_butina(fps, **method_kwargs)
|
||
elif method == "dbscan_threshold":
|
||
labels = cluster_dbscan_threshold(fps, **method_kwargs)
|
||
elif method == "agglomerative":
|
||
labels = cluster_agglomerative_precomputed(fps, **method_kwargs)
|
||
elif method == "scipy_linkage":
|
||
labels = cluster_scipy_linkage_cut(fps, **method_kwargs)
|
||
else:
|
||
raise ValueError("未知 method")
|
||
|
||
stats = cluster_stats(labels)
|
||
return {"labels": labels, "keep_idx": keep_idx, "fps": fps, "stats": stats}
|
||
|
||
# ChemPlot 可视化(可选)
|
||
def chemplot_visualize(self, smiles_valid: List[str], labels: np.ndarray,
|
||
target: Optional[pd.Series] = None,
|
||
target_type: Optional[str] = None,
|
||
random_state: int = 42):
|
||
from chemplot import Plotter
|
||
import matplotlib.pyplot as plt
|
||
cp = Plotter.from_smiles(smiles_valid, sim_type="structural",
|
||
target=target, target_type=target_type)
|
||
cp.umap(random_state=random_state)
|
||
cp._Plotter__df_2_components["clusters"] = labels
|
||
fig = cp.visualize_plot(clusters=True)
|
||
return cp, fig
|
||
|
||
# 网格探索(大样本优先用 Butina / 稀疏 DBSCAN)
|
||
def explore(self, smiles: List[str],
|
||
butina_cuts=(0.5, 0.6, 0.7, 0.75, 0.8),
|
||
dbscan_params=((0.6, 0.5, 5), (0.6, 0.4, 5), (0.7, 0.5, 5))
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
返回每个配置的 stats 列表,便于挑选“最大簇占比最小”的方案。
|
||
"""
|
||
results = []
|
||
mols, keep_idx = smiles_to_mols(smiles)
|
||
fps = mols_to_ecfp(mols, self.fp_cfg)
|
||
|
||
for cut in butina_cuts:
|
||
labels = cluster_butina(fps, sim_cutoff=cut)
|
||
stats = cluster_stats(labels)
|
||
results.append({"method": "butina", "params": {"sim_cutoff": cut}, "stats": stats})
|
||
|
||
for sim_cut, eps, min_samples in dbscan_params:
|
||
labels = cluster_dbscan_threshold(fps, sim_cutoff=sim_cut, eps=eps, min_samples=min_samples)
|
||
stats = cluster_stats(labels)
|
||
results.append({"method": "dbscan_threshold",
|
||
"params": {"sim_cutoff": sim_cut, "eps": eps, "min_samples": min_samples},
|
||
"stats": stats})
|
||
return results
|
||
|
||
|
||
# ---------- 自动搜索封装(对外 API 保持一致) ----------
|
||
@dataclass
|
||
class SearchConfig:
|
||
# 指纹配置
|
||
radii: List[int] = field(default_factory=lambda: [2, 3])
|
||
n_bits_list: List[int] = field(default_factory=lambda: [1024, 2048])
|
||
# 方法搜索空间
|
||
butina_cuts: List[float] = field(default_factory=lambda: [0.5, 0.6, 0.7, 0.75, 0.8])
|
||
dbscan_params: List[Tuple[float, float, int]] = field(default_factory=lambda: [(0.6, 0.5, 5), (0.6, 0.4, 5), (0.7, 0.5, 5)])
|
||
# 小样本时可选
|
||
try_agglomerative: bool = True
|
||
agglo_k_list: List[int] = field(default_factory=lambda: [5, 8, 10])
|
||
try_scipy: bool = False
|
||
scipy_t_list: List[float] = field(default_factory=lambda: [0.6, 0.7])
|
||
linkage: str = "average"
|
||
|
||
def _score(stats: ClusterStats) -> float:
|
||
# 越大越好:更均衡(最大簇比例小)+ 簇数量>1的奖励
|
||
balance = 1.0 - stats.largest_cluster_ratio
|
||
bonus = 0.1 if stats.n_clusters > 1 else 0.0
|
||
return balance + bonus
|
||
|
||
def search_best_config(smiles: List[str], cfg: Optional[SearchConfig] = None) -> Tuple[TanimotoClusteringAPI, ClusterStats, pd.DataFrame]:
|
||
"""
|
||
自动尝试不同 FP 半径/位数 + (Butina/DBSCAN/可选Agglo/SciPy),
|
||
以 1-最大簇比例 为主的评分,返回:最佳 API(已设置好 fp_cfg 与 method/参数)、其 stats、以及历史对比表。
|
||
"""
|
||
cfg = cfg or SearchConfig()
|
||
trials: List[Dict[str, Any]] = []
|
||
|
||
for r in cfg.radii:
|
||
for nb in cfg.n_bits_list:
|
||
api = TanimotoClusteringAPI(fp_cfg=FPConfig(radius=r, n_bits=nb))
|
||
# 1) Butina
|
||
for cut in cfg.butina_cuts:
|
||
res = api.fit_from_smiles(smiles, method="butina", method_kwargs={"sim_cutoff": cut})
|
||
sc = _score(res["stats"])
|
||
trials.append({"fp_radius": r, "fp_n_bits": nb, "method": "butina",
|
||
"params": {"sim_cutoff": cut}, "stats": res["stats"], "score": sc})
|
||
# 2) 稀疏 DBSCAN
|
||
for (sim_cut, eps, ms) in cfg.dbscan_params:
|
||
res = api.fit_from_smiles(smiles, method="dbscan_threshold",
|
||
method_kwargs={"sim_cutoff": sim_cut, "eps": eps, "min_samples": ms})
|
||
sc = _score(res["stats"])
|
||
trials.append({"fp_radius": r, "fp_n_bits": nb, "method": "dbscan_threshold",
|
||
"params": {"sim_cutoff": sim_cut, "eps": eps, "min_samples": ms},
|
||
"stats": res["stats"], "score": sc})
|
||
# 3) (可选)Agglomerative(小样本或探索时)
|
||
if cfg.try_agglomerative:
|
||
for k in cfg.agglo_k_list:
|
||
res = api.fit_from_smiles(smiles, method="agglomerative",
|
||
method_kwargs={"n_clusters": k, "linkage_method": cfg.linkage})
|
||
sc = _score(res["stats"])
|
||
trials.append({"fp_radius": r, "fp_n_bits": nb, "method": "agglomerative",
|
||
"params": {"n_clusters": k, "linkage_method": cfg.linkage},
|
||
"stats": res["stats"], "score": sc})
|
||
# 4) (可选)SciPy linkage(需小样本 & 安装 scipy)
|
||
if cfg.try_scipy and linkage is not None:
|
||
for t in cfg.scipy_t_list:
|
||
res = api.fit_from_smiles(smiles, method="scipy_linkage",
|
||
method_kwargs={"method": cfg.linkage, "t": t, "criterion": "distance"})
|
||
sc = _score(res["stats"])
|
||
trials.append({"fp_radius": r, "fp_n_bits": nb, "method": "scipy_linkage",
|
||
"params": {"t": t, "method": cfg.linkage, "criterion": "distance"},
|
||
"stats": res["stats"], "score": sc})
|
||
|
||
# 排序选最优
|
||
rows = []
|
||
for t in trials:
|
||
s: ClusterStats = t["stats"]
|
||
rows.append({
|
||
"fp_radius": t["fp_radius"],
|
||
"fp_n_bits": t["fp_n_bits"],
|
||
"method": t["method"],
|
||
"params": t["params"],
|
||
"n_samples": s.n_samples,
|
||
"n_clusters": s.n_clusters,
|
||
"largest_cluster_ratio": s.largest_cluster_ratio,
|
||
"sizes": s.sizes,
|
||
"score": t["score"],
|
||
})
|
||
hist = pd.DataFrame(rows).sort_values("score", ascending=False).reset_index(drop=True)
|
||
|
||
best = hist.iloc[0]
|
||
# 组装一个“已配置好的” API 返回
|
||
api_best = TanimotoClusteringAPI(fp_cfg=FPConfig(radius=int(best["fp_radius"]), n_bits=int(best["fp_n_bits"])))
|
||
# 触发一次拟合以便调用方马上拿 labels/keep_idx
|
||
res = api_best.fit_from_smiles(smiles, method=best["method"], method_kwargs=best["params"])
|
||
return api_best, res["stats"], hist
|