聚类方法,聚类后选择打分最高那个分子,并对 karamadock 的结果求交集
This commit is contained in:
1001
result/top_molecules/fgbar_karma_score_aligned_top1000.csv
Normal file
1001
result/top_molecules/fgbar_karma_score_aligned_top1000.csv
Normal file
File diff suppressed because it is too large
Load Diff
1001
result/top_molecules/fgbar_vina_score_top1000.csv
Normal file
1001
result/top_molecules/fgbar_vina_score_top1000.csv
Normal file
File diff suppressed because it is too large
Load Diff
1001
result/top_molecules/trpe_karma_score_aligned_top1000.csv
Normal file
1001
result/top_molecules/trpe_karma_score_aligned_top1000.csv
Normal file
File diff suppressed because it is too large
Load Diff
1001
result/top_molecules/trpe_vina_score_top1000.csv
Normal file
1001
result/top_molecules/trpe_vina_score_top1000.csv
Normal file
File diff suppressed because it is too large
Load Diff
@@ -8,6 +8,40 @@ python scripts/cluster_granularity_scan.py \
|
||||
--smiles-col smiles \
|
||||
--radius 3 \
|
||||
--n-bits 1024
|
||||
|
||||
Method Params #Clusters AvgSize AvgIntraSim
|
||||
Butina {'cutoff': 0.4} 8960 2.10 0.706
|
||||
Butina {'cutoff': 0.5} 6720 2.80 0.625
|
||||
Butina {'cutoff': 0.6} 4648 4.04 0.548
|
||||
Butina {'cutoff': 0.7} 2783 6.75 0.463
|
||||
Butina {'cutoff': 0.8} 958 19.61 0.333
|
||||
Hierarchical {'threshold': 0.3} 12235 1.54 0.814
|
||||
Hierarchical {'threshold': 0.4} 9603 1.96 0.739
|
||||
Hierarchical {'threshold': 0.5} 7300 2.57 0.664
|
||||
DBSCAN {'eps': 0.2} 2050 3.18 0.106
|
||||
DBSCAN {'eps': 0.3} 2275 4.61 0.113
|
||||
DBSCAN {'eps': 0.4} 2014 6.65 0.113
|
||||
KMeans {'k': 10} 10 1878.70 0.204
|
||||
KMeans {'k': 20} 20 939.35 0.200
|
||||
KMeans {'k': 50} 50 375.74 0.233
|
||||
| 列名 | 含义 |
|
||||
| --------------- | --------------------------------- |
|
||||
| **#Clusters** | 聚类后得到的簇数量(独立 cluster 数) |
|
||||
| **AvgSize** | 每个簇平均包含的分子个数 = 样本总数 / 簇数 |
|
||||
| **AvgIntraSim** | 每个簇内部分子两两之间的平均相似度(越接近 1 代表簇内部更相似) |
|
||||
现在的数据:
|
||||
|
||||
Butina 在 cutoff=0.4 时 AvgIntraSim=0.706(簇内结构还算比较接近,但簇数非常多)。
|
||||
|
||||
Hierarchical 阈值 0.3 时 AvgIntraSim=0.814(更紧密,但簇数更多)。
|
||||
|
||||
DBSCAN 和 KMeans 的簇内相似度都低,说明它们在 Tanimoto 上可能不太适合你这个任务。
|
||||
|
||||
聚类:用 Butina cutoff ≈ 0.6–0.7 或 Hierarchical 阈值 ≈ 0.5–0.6(保持簇内差异可控,簇数不要太多)。
|
||||
|
||||
选代表:每个簇取 1 个中心分子(簇内与其他成员平均相似度最高的那个)。
|
||||
|
||||
如果仍想增强多样性,可以在代表集中再跑一次 MaxMin picking。
|
||||
"""
|
||||
import sys, os
|
||||
from pathlib import Path
|
||||
|
||||
@@ -11,20 +11,20 @@ print("Running analysis examples...")
|
||||
|
||||
# Example 1: Basic usage
|
||||
print("\nExample 1: Basic usage")
|
||||
main_api(['qed_values_fgbar.csv', 'qed_values_trpe.csv'], ['fgbar', 'trpe'])
|
||||
main_api(['finally_data/qed_values_poses_fgbar_all.csv', 'finally_data/qed_values_poses_trpe_all.csv'], ['fgbar', 'trpe'])
|
||||
|
||||
# Example 2: With custom reference scores
|
||||
print("\nExample 2: With custom reference scores")
|
||||
main_api(['qed_values_fgbar.csv', 'qed_values_trpe.csv'], ['fgbar', 'trpe'],
|
||||
main_api(['finally_data/qed_values_poses_fgbar_all.csv', 'finally_data/qed_values_poses_trpe_all.csv'], ['fgbar', 'trpe'],
|
||||
reference_scores={'fgbar': {'9NY': -5.268}, 'trpe': {'0GA': -6.531}})
|
||||
|
||||
# Example 3: With specific conformation rank
|
||||
print("\nExample 3: With specific conformation rank")
|
||||
main_api(['qed_values_fgbar.csv', 'qed_values_trpe.csv'], ['fgbar', 'trpe'], rank=0)
|
||||
main_api(['finally_data/qed_values_poses_fgbar_all.csv', 'finally_data/qed_values_poses_trpe_all.csv'], ['fgbar', 'trpe'], rank=0)
|
||||
|
||||
# Example 4: With both custom reference scores and specific conformation rank
|
||||
print("\nExample 4: With both custom reference scores and specific conformation rank")
|
||||
main_api(['qed_values_fgbar.csv', 'qed_values_trpe.csv'], ['fgbar', 'trpe'],
|
||||
main_api(['finally_data/qed_values_poses_fgbar_all.csv', 'finally_data/qed_values_poses_trpe_all.csv'], ['fgbar', 'trpe'],
|
||||
reference_scores={'fgbar': {'9NY': -5.268}, 'trpe': {'0GA': -6.531}}, rank=0)
|
||||
|
||||
print("\nAnalysis complete! Check the generated PNG files.")
|
||||
59
scripts/extract_and_intersect.py
Normal file
59
scripts/extract_and_intersect.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
def process_cluster_file(cluster_file, score_file, output_file):
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(cluster_file):
|
||||
raise FileNotFoundError(f"聚类文件不存在: {cluster_file}")
|
||||
if not os.path.exists(score_file):
|
||||
raise FileNotFoundError(f"评分文件不存在: {score_file}")
|
||||
|
||||
# 读取聚类结果文件
|
||||
cluster_df = pd.read_csv(cluster_file)
|
||||
|
||||
# 提取filename列的stem属性
|
||||
cluster_df['filename_stem'] = cluster_df['filename'].apply(
|
||||
lambda x: Path(x).stem.split('_out')[0]
|
||||
)
|
||||
|
||||
# 读取score文件
|
||||
score_df = pd.read_csv(score_file)
|
||||
|
||||
# 获取两个文件的交集
|
||||
intersection = pd.merge(
|
||||
cluster_df,
|
||||
score_df,
|
||||
left_on='filename_stem',
|
||||
right_on='pdb_id',
|
||||
how='inner'
|
||||
)
|
||||
|
||||
# 保存结果
|
||||
intersection.to_csv(output_file, index=False)
|
||||
|
||||
return len(intersection)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 使用绝对路径确保文件位置正确
|
||||
base_dir = "/Users/lingyuzeng/Downloads/211.69.141.180/202508021824/vina"
|
||||
|
||||
# 处理fgbar数据
|
||||
fgbar_count = process_cluster_file(
|
||||
f"{base_dir}/scripts/finally_data/cluster_best/fgbar_cluster_best_vina_butina_butina.csv",
|
||||
f"{base_dir}/result/karamadock/FgBar1_score.csv",
|
||||
f"{base_dir}/scripts/finally_data/cluster_best/fgbar_intersection.csv"
|
||||
)
|
||||
|
||||
# 处理trpe数据
|
||||
trpe_count = process_cluster_file(
|
||||
f"{base_dir}/scripts/finally_data/cluster_best/trpe_cluster_best_vina_butina_butina.csv",
|
||||
f"{base_dir}/result/karamadock/TrpE_score.csv",
|
||||
f"{base_dir}/scripts/finally_data/cluster_best/trpe_intersection.csv"
|
||||
)
|
||||
|
||||
print(f"fgbar交集数量: {fgbar_count}")
|
||||
print(f"trpe交集数量: {trpe_count}")
|
||||
|
||||
# 验证输出文件是否生成
|
||||
print("脚本执行完成")
|
||||
64
scripts/extract_top_molecules.py
Normal file
64
scripts/extract_top_molecules.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import pandas as pd
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
|
||||
def parse_vina_scores(vina_scores_str):
|
||||
"""解析vina_scores字符串为浮点数列表"""
|
||||
try:
|
||||
scores = ast.literal_eval(vina_scores_str)
|
||||
if isinstance(scores, list) and len(scores) > 0:
|
||||
return scores[0] # 取第一个值作为vina_score
|
||||
return None
|
||||
except:
|
||||
return None
|
||||
|
||||
def extract_top_molecules(file_path, output_dir, dataset_name):
|
||||
"""从CSV文件中提取karma_score_aligned和vina_score前1000的分子"""
|
||||
# 读取数据
|
||||
df = pd.read_csv(file_path)
|
||||
|
||||
# 解析vina_scores列
|
||||
df['vina_score'] = df['vina_scores'].apply(parse_vina_scores)
|
||||
|
||||
# 按karma_score_aligned排序并提取前1000
|
||||
df_karma_top = df.sort_values('karma_score_aligned', ascending=False).head(1000)
|
||||
|
||||
# 按vina_score排序并提取前1000
|
||||
df_vina_top = df.sort_values('vina_score', ascending=False).head(1000)
|
||||
|
||||
# 保存结果
|
||||
karma_output_file = os.path.join(output_dir, f"{dataset_name}_karma_score_aligned_top1000.csv")
|
||||
vina_output_file = os.path.join(output_dir, f"{dataset_name}_vina_score_top1000.csv")
|
||||
|
||||
df_karma_top.to_csv(karma_output_file, index=False)
|
||||
df_vina_top.to_csv(vina_output_file, index=False)
|
||||
|
||||
print(f"{dataset_name} - karma_score_aligned前1000分子保存到: {karma_output_file}")
|
||||
print(f"{dataset_name} - vina_score前1000分子保存到: {vina_output_file}")
|
||||
print(f"{dataset_name} - karma_score_aligned前1000分子数量: {len(df_karma_top)}")
|
||||
print(f"{dataset_name} - vina_score前1000分子数量: {len(df_vina_top)}")
|
||||
|
||||
return df_karma_top, df_vina_top
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='从CSV文件中提取karma_score_aligned和vina_score前1000的分子')
|
||||
parser.add_argument('--input', nargs='+', required=True,
|
||||
help='输入CSV文件路径列表')
|
||||
parser.add_argument('--dataset-names', nargs='+', required=True,
|
||||
help='数据集名称列表,与输入文件一一对应')
|
||||
parser.add_argument('--output', required=True,
|
||||
help='输出目录')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 确保输出目录存在
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
|
||||
# 处理每个文件
|
||||
for file_path, dataset_name in zip(args.input, args.dataset_names):
|
||||
print(f"Processing {dataset_name}...")
|
||||
extract_top_molecules(file_path, args.output, dataset_name)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
4777
scripts/finally_data/cluster_best/fgbar_intersection.csv
Normal file
4777
scripts/finally_data/cluster_best/fgbar_intersection.csv
Normal file
File diff suppressed because it is too large
Load Diff
8693
scripts/finally_data/cluster_best/trpe_intersection.csv
Normal file
8693
scripts/finally_data/cluster_best/trpe_intersection.csv
Normal file
File diff suppressed because it is too large
Load Diff
27
utils/chem_cluster/__init__.py
Normal file
27
utils/chem_cluster/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# utils/chem_cluster/__init__.py
|
||||
|
||||
from .tanimoto_cluster_api import (
|
||||
TanimotoClusteringAPI as TanimotoClusterer,
|
||||
SearchConfig,
|
||||
search_best_config,
|
||||
ClusterStats,
|
||||
FPConfig,
|
||||
cluster_butina,
|
||||
cluster_dbscan_threshold,
|
||||
cluster_agglomerative_precomputed,
|
||||
cluster_scipy_linkage_cut,
|
||||
select_representatives,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TanimotoClusterer",
|
||||
"SearchConfig",
|
||||
"search_best_config",
|
||||
"ClusterStats",
|
||||
"FPConfig",
|
||||
"cluster_butina",
|
||||
"cluster_dbscan_threshold",
|
||||
"cluster_agglomerative_precomputed",
|
||||
"cluster_scipy_linkage_cut",
|
||||
"select_representatives",
|
||||
]
|
||||
384
utils/chem_cluster/tanimoto_cluster_api.py
Normal file
384
utils/chem_cluster/tanimoto_cluster_api.py
Normal file
@@ -0,0 +1,384 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
import math
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from rdkit import Chem, RDLogger
|
||||
from rdkit.Chem import rdMolDescriptors
|
||||
from rdkit import DataStructs
|
||||
from rdkit.ML.Cluster import Butina
|
||||
|
||||
from sklearn.cluster import DBSCAN, AgglomerativeClustering
|
||||
|
||||
try:
|
||||
from scipy.cluster.hierarchy import linkage, fcluster
|
||||
except Exception:
|
||||
linkage = None
|
||||
fcluster = None
|
||||
|
||||
# 静音 RDKit
|
||||
RDLogger.DisableLog('rdApp.*')
|
||||
|
||||
|
||||
# ---------- 指纹 & 相似度工具 ----------
|
||||
@dataclass
|
||||
class FPConfig:
|
||||
radius: int = 2 # ECFP 半径(2=ECFP4)
|
||||
n_bits: int = 2048 # 位数(1024 或 2048 常用)
|
||||
use_count: bool = False # True: 计数 FP;False: bit 向量(推荐用于 Tanimoto)
|
||||
|
||||
|
||||
def smiles_to_mols(smiles: List[str]) -> Tuple[List[Chem.Mol], List[int]]:
|
||||
mols, idx = [], []
|
||||
for i, smi in enumerate(smiles):
|
||||
m = Chem.MolFromSmiles(str(smi))
|
||||
if m is not None:
|
||||
mols.append(m)
|
||||
idx.append(i)
|
||||
return mols, idx
|
||||
|
||||
|
||||
def mols_to_ecfp(mols: List[Chem.Mol], cfg: FPConfig):
|
||||
fps = []
|
||||
if cfg.use_count:
|
||||
gen = rdMolDescriptors.GetMorganFingerprint
|
||||
for m in mols:
|
||||
fps.append(gen(m, cfg.radius))
|
||||
else:
|
||||
for m in mols:
|
||||
fps.append(rdMolDescriptors.GetMorganFingerprintAsBitVect(m, cfg.radius, nBits=cfg.n_bits))
|
||||
return fps
|
||||
|
||||
|
||||
def tanimoto(a, b) -> float:
|
||||
return DataStructs.TanimotoSimilarity(a, b)
|
||||
|
||||
|
||||
def bulk_tanimoto_to_many(fp, fps) -> List[float]:
|
||||
# 对大量 pair 更快
|
||||
return DataStructs.BulkTanimotoSimilarity(fp, fps)
|
||||
|
||||
|
||||
# ---------- 稳定可扩展的聚类(首选) ----------
|
||||
def cluster_butina(fps, sim_cutoff: float = 0.6) -> np.ndarray:
|
||||
"""
|
||||
RDKit Butina 聚类
|
||||
fps: RDKit Fingerprint list
|
||||
sim_cutoff: 相似度阈值(如 0.6~0.8)
|
||||
返回:簇标签 (0..K-1)
|
||||
"""
|
||||
n = len(fps)
|
||||
# 构建完整的压缩距离矩阵
|
||||
dists = []
|
||||
for i in range(1, n):
|
||||
sims = DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i])
|
||||
dists.extend([1.0 - s for s in sims]) # 全部保存
|
||||
|
||||
# 调用 Butina 聚类
|
||||
cs = Butina.ClusterData(dists, n, 1.0 - sim_cutoff, isDistData=True)
|
||||
|
||||
# 转换成标签数组
|
||||
labels = np.full(n, -1, dtype=int)
|
||||
for cid, members in enumerate(cs):
|
||||
for m in members:
|
||||
labels[m] = cid
|
||||
|
||||
# 把孤立点分配为独立簇(便于后续处理)
|
||||
if np.any(labels == -1):
|
||||
max_id = labels.max() if labels.max() >= 0 else -1
|
||||
for i in np.where(labels == -1)[0]:
|
||||
max_id += 1
|
||||
labels[i] = max_id
|
||||
|
||||
return labels
|
||||
|
||||
|
||||
def cluster_dbscan_threshold(fps, sim_cutoff: float = 0.6, eps: float = 0.5, min_samples: int = 5) -> np.ndarray:
|
||||
"""
|
||||
用相似度阈值构稀疏邻接,再交给 DBSCAN(metric='precomputed')。
|
||||
"""
|
||||
n = len(fps)
|
||||
dm = np.full((n, n), 1.0, dtype=np.float32)
|
||||
np.fill_diagonal(dm, 0.0)
|
||||
for i in range(n):
|
||||
sims = bulk_tanimoto_to_many(fps[i], fps)
|
||||
d = 1.0 - np.array(sims, dtype=np.float32)
|
||||
near = d <= (1.0 - sim_cutoff)
|
||||
dm[i, near] = d[near]
|
||||
clt = DBSCAN(metric='precomputed', eps=eps, min_samples=min_samples, n_jobs=-1)
|
||||
labels = clt.fit_predict(dm)
|
||||
# 保持 DBSCAN 的 -1 为噪声;如需改为独立簇可在外面处理
|
||||
return labels
|
||||
|
||||
|
||||
# ---------- 小数据可选:完整距离矩阵 ----------
|
||||
def distance_matrix_from_fps(fps) -> np.ndarray:
|
||||
n = len(fps)
|
||||
dm = np.zeros((n, n), dtype=np.float32)
|
||||
for i in range(n):
|
||||
sims = bulk_tanimoto_to_many(fps[i], fps[i+1:])
|
||||
if len(sims):
|
||||
d = 1.0 - np.array(sims, dtype=np.float32)
|
||||
dm[i, i+1:] = d
|
||||
dm[i+1:, i] = d
|
||||
return dm
|
||||
|
||||
|
||||
def cluster_agglomerative_precomputed(fps, n_clusters=5, linkage_method='average') -> np.ndarray:
|
||||
dm = distance_matrix_from_fps(fps)
|
||||
clt = AgglomerativeClustering(n_clusters=n_clusters, metric='precomputed', linkage=linkage_method)
|
||||
return clt.fit_predict(dm)
|
||||
|
||||
|
||||
def cluster_scipy_linkage_cut(fps, method='average', t=0.7, criterion='distance') -> np.ndarray:
|
||||
if linkage is None:
|
||||
raise RuntimeError("scipy 未安装,无法使用 linkage。")
|
||||
dm = distance_matrix_from_fps(fps)
|
||||
n = dm.shape[0]
|
||||
iu = np.triu_indices(n, k=1)
|
||||
condensed = dm[iu]
|
||||
Z = linkage(condensed, method=method)
|
||||
labels = fcluster(Z, t=t, criterion=criterion) - 1
|
||||
return labels
|
||||
|
||||
|
||||
# ---------- 代表分子 ----------
|
||||
def pick_medoid_indices(idx: List[int], fps) -> int:
|
||||
"""返回簇内 medoid 的下标(全对全,适合簇不大时)"""
|
||||
if len(idx) == 1:
|
||||
return idx[0]
|
||||
sub = idx
|
||||
n = len(sub)
|
||||
sums = []
|
||||
for i in range(n):
|
||||
sims = [tanimoto(fps[sub[i]], fps[sub[j]]) for j in range(n)]
|
||||
dsum = float(np.sum(1.0 - np.array(sims)))
|
||||
sums.append(dsum)
|
||||
return sub[int(np.argmin(sums))]
|
||||
|
||||
|
||||
def pick_maxmin_indices(idx: List[int], fps, k=1) -> List[int]:
|
||||
"""在指定索引集合里挑 k 个多样性最大的(MaxMin)"""
|
||||
if not idx:
|
||||
return []
|
||||
chosen = [idx[0]]
|
||||
while len(chosen) < min(k, len(idx)):
|
||||
best_j, best_min_d = None, -1.0
|
||||
for j in idx:
|
||||
if j in chosen:
|
||||
continue
|
||||
mind = 1.0
|
||||
for i in chosen:
|
||||
d = 1.0 - tanimoto(fps[i], fps[j])
|
||||
if d < mind:
|
||||
mind = d
|
||||
if mind > best_min_d:
|
||||
best_min_d, best_j = mind, j
|
||||
chosen.append(best_j)
|
||||
return chosen
|
||||
|
||||
|
||||
def select_representatives(labels: np.ndarray, fps, per_cluster: int = 1,
|
||||
strategy: str = "medoid") -> List[int]:
|
||||
reps = []
|
||||
for c in np.unique(labels):
|
||||
members = np.where(labels == c)[0].tolist()
|
||||
if not members:
|
||||
continue
|
||||
if strategy == "medoid":
|
||||
reps.append(pick_medoid_indices(members, fps))
|
||||
elif strategy == "maxmin":
|
||||
reps.extend(pick_maxmin_indices(members, fps, k=per_cluster))
|
||||
else:
|
||||
raise ValueError("strategy ∈ {'medoid','maxmin'}")
|
||||
return sorted(set(reps))
|
||||
|
||||
|
||||
# ---------- 评估指标 ----------
|
||||
@dataclass
|
||||
class ClusterStats:
|
||||
n_samples: int
|
||||
n_clusters: int
|
||||
largest_cluster_size: int
|
||||
largest_cluster_ratio: float
|
||||
sizes: Dict[int, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
def cluster_stats(labels: np.ndarray) -> ClusterStats:
|
||||
n = len(labels)
|
||||
vals, counts = np.unique(labels, return_counts=True)
|
||||
largest = int(counts.max())
|
||||
return ClusterStats(
|
||||
n_samples=int(n),
|
||||
n_clusters=int(len(vals)),
|
||||
largest_cluster_size=largest,
|
||||
largest_cluster_ratio=float(largest / n),
|
||||
sizes={int(v): int(c) for v, c in zip(vals, counts)}
|
||||
)
|
||||
|
||||
|
||||
# ---------- 对外 API ----------
|
||||
@dataclass
|
||||
class TanimotoClusteringAPI:
|
||||
fp_cfg: FPConfig = field(default_factory=FPConfig)
|
||||
params: dict = field(default_factory=dict)
|
||||
|
||||
def fit_from_smiles(self, smiles: List[str], method: str = "butina",
|
||||
method_kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
返回:{ 'labels': np.ndarray, 'keep_idx': List[int], 'fps': List[FP], 'stats': ClusterStats }
|
||||
"""
|
||||
method_kwargs = method_kwargs or {}
|
||||
mols, keep_idx = smiles_to_mols(smiles)
|
||||
if len(mols) == 0:
|
||||
raise ValueError("没有有效的 SMILES。")
|
||||
fps = mols_to_ecfp(mols, self.fp_cfg)
|
||||
|
||||
if method == "butina":
|
||||
labels = cluster_butina(fps, **method_kwargs)
|
||||
elif method == "dbscan_threshold":
|
||||
labels = cluster_dbscan_threshold(fps, **method_kwargs)
|
||||
elif method == "agglomerative":
|
||||
labels = cluster_agglomerative_precomputed(fps, **method_kwargs)
|
||||
elif method == "scipy_linkage":
|
||||
labels = cluster_scipy_linkage_cut(fps, **method_kwargs)
|
||||
else:
|
||||
raise ValueError("未知 method")
|
||||
|
||||
stats = cluster_stats(labels)
|
||||
return {"labels": labels, "keep_idx": keep_idx, "fps": fps, "stats": stats}
|
||||
|
||||
# ChemPlot 可视化(可选)
|
||||
def chemplot_visualize(self, smiles_valid: List[str], labels: np.ndarray,
|
||||
target: Optional[pd.Series] = None,
|
||||
target_type: Optional[str] = None,
|
||||
random_state: int = 42):
|
||||
from chemplot import Plotter
|
||||
import matplotlib.pyplot as plt
|
||||
cp = Plotter.from_smiles(smiles_valid, sim_type="structural",
|
||||
target=target, target_type=target_type)
|
||||
cp.umap(random_state=random_state)
|
||||
cp._Plotter__df_2_components["clusters"] = labels
|
||||
fig = cp.visualize_plot(clusters=True)
|
||||
return cp, fig
|
||||
|
||||
# 网格探索(大样本优先用 Butina / 稀疏 DBSCAN)
|
||||
def explore(self, smiles: List[str],
|
||||
butina_cuts=(0.5, 0.6, 0.7, 0.75, 0.8),
|
||||
dbscan_params=((0.6, 0.5, 5), (0.6, 0.4, 5), (0.7, 0.5, 5))
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
返回每个配置的 stats 列表,便于挑选“最大簇占比最小”的方案。
|
||||
"""
|
||||
results = []
|
||||
mols, keep_idx = smiles_to_mols(smiles)
|
||||
fps = mols_to_ecfp(mols, self.fp_cfg)
|
||||
|
||||
for cut in butina_cuts:
|
||||
labels = cluster_butina(fps, sim_cutoff=cut)
|
||||
stats = cluster_stats(labels)
|
||||
results.append({"method": "butina", "params": {"sim_cutoff": cut}, "stats": stats})
|
||||
|
||||
for sim_cut, eps, min_samples in dbscan_params:
|
||||
labels = cluster_dbscan_threshold(fps, sim_cutoff=sim_cut, eps=eps, min_samples=min_samples)
|
||||
stats = cluster_stats(labels)
|
||||
results.append({"method": "dbscan_threshold",
|
||||
"params": {"sim_cutoff": sim_cut, "eps": eps, "min_samples": min_samples},
|
||||
"stats": stats})
|
||||
return results
|
||||
|
||||
|
||||
# ---------- 自动搜索封装(对外 API 保持一致) ----------
|
||||
@dataclass
|
||||
class SearchConfig:
|
||||
# 指纹配置
|
||||
radii: List[int] = field(default_factory=lambda: [2, 3])
|
||||
n_bits_list: List[int] = field(default_factory=lambda: [1024, 2048])
|
||||
# 方法搜索空间
|
||||
butina_cuts: List[float] = field(default_factory=lambda: [0.5, 0.6, 0.7, 0.75, 0.8])
|
||||
dbscan_params: List[Tuple[float, float, int]] = field(default_factory=lambda: [(0.6, 0.5, 5), (0.6, 0.4, 5), (0.7, 0.5, 5)])
|
||||
# 小样本时可选
|
||||
try_agglomerative: bool = True
|
||||
agglo_k_list: List[int] = field(default_factory=lambda: [5, 8, 10])
|
||||
try_scipy: bool = False
|
||||
scipy_t_list: List[float] = field(default_factory=lambda: [0.6, 0.7])
|
||||
linkage: str = "average"
|
||||
|
||||
def _score(stats: ClusterStats) -> float:
|
||||
# 越大越好:更均衡(最大簇比例小)+ 簇数量>1的奖励
|
||||
balance = 1.0 - stats.largest_cluster_ratio
|
||||
bonus = 0.1 if stats.n_clusters > 1 else 0.0
|
||||
return balance + bonus
|
||||
|
||||
def search_best_config(smiles: List[str], cfg: Optional[SearchConfig] = None) -> Tuple[TanimotoClusteringAPI, ClusterStats, pd.DataFrame]:
|
||||
"""
|
||||
自动尝试不同 FP 半径/位数 + (Butina/DBSCAN/可选Agglo/SciPy),
|
||||
以 1-最大簇比例 为主的评分,返回:最佳 API(已设置好 fp_cfg 与 method/参数)、其 stats、以及历史对比表。
|
||||
"""
|
||||
cfg = cfg or SearchConfig()
|
||||
trials: List[Dict[str, Any]] = []
|
||||
|
||||
for r in cfg.radii:
|
||||
for nb in cfg.n_bits_list:
|
||||
api = TanimotoClusteringAPI(fp_cfg=FPConfig(radius=r, n_bits=nb))
|
||||
# 1) Butina
|
||||
for cut in cfg.butina_cuts:
|
||||
res = api.fit_from_smiles(smiles, method="butina", method_kwargs={"sim_cutoff": cut})
|
||||
sc = _score(res["stats"])
|
||||
trials.append({"fp_radius": r, "fp_n_bits": nb, "method": "butina",
|
||||
"params": {"sim_cutoff": cut}, "stats": res["stats"], "score": sc})
|
||||
# 2) 稀疏 DBSCAN
|
||||
for (sim_cut, eps, ms) in cfg.dbscan_params:
|
||||
res = api.fit_from_smiles(smiles, method="dbscan_threshold",
|
||||
method_kwargs={"sim_cutoff": sim_cut, "eps": eps, "min_samples": ms})
|
||||
sc = _score(res["stats"])
|
||||
trials.append({"fp_radius": r, "fp_n_bits": nb, "method": "dbscan_threshold",
|
||||
"params": {"sim_cutoff": sim_cut, "eps": eps, "min_samples": ms},
|
||||
"stats": res["stats"], "score": sc})
|
||||
# 3) (可选)Agglomerative(小样本或探索时)
|
||||
if cfg.try_agglomerative:
|
||||
for k in cfg.agglo_k_list:
|
||||
res = api.fit_from_smiles(smiles, method="agglomerative",
|
||||
method_kwargs={"n_clusters": k, "linkage_method": cfg.linkage})
|
||||
sc = _score(res["stats"])
|
||||
trials.append({"fp_radius": r, "fp_n_bits": nb, "method": "agglomerative",
|
||||
"params": {"n_clusters": k, "linkage_method": cfg.linkage},
|
||||
"stats": res["stats"], "score": sc})
|
||||
# 4) (可选)SciPy linkage(需小样本 & 安装 scipy)
|
||||
if cfg.try_scipy and linkage is not None:
|
||||
for t in cfg.scipy_t_list:
|
||||
res = api.fit_from_smiles(smiles, method="scipy_linkage",
|
||||
method_kwargs={"method": cfg.linkage, "t": t, "criterion": "distance"})
|
||||
sc = _score(res["stats"])
|
||||
trials.append({"fp_radius": r, "fp_n_bits": nb, "method": "scipy_linkage",
|
||||
"params": {"t": t, "method": cfg.linkage, "criterion": "distance"},
|
||||
"stats": res["stats"], "score": sc})
|
||||
|
||||
# 排序选最优
|
||||
rows = []
|
||||
for t in trials:
|
||||
s: ClusterStats = t["stats"]
|
||||
rows.append({
|
||||
"fp_radius": t["fp_radius"],
|
||||
"fp_n_bits": t["fp_n_bits"],
|
||||
"method": t["method"],
|
||||
"params": t["params"],
|
||||
"n_samples": s.n_samples,
|
||||
"n_clusters": s.n_clusters,
|
||||
"largest_cluster_ratio": s.largest_cluster_ratio,
|
||||
"sizes": s.sizes,
|
||||
"score": t["score"],
|
||||
})
|
||||
hist = pd.DataFrame(rows).sort_values("score", ascending=False).reset_index(drop=True)
|
||||
|
||||
best = hist.iloc[0]
|
||||
# 组装一个“已配置好的” API 返回
|
||||
api_best = TanimotoClusteringAPI(fp_cfg=FPConfig(radius=int(best["fp_radius"]), n_bits=int(best["fp_n_bits"])))
|
||||
# 触发一次拟合以便调用方马上拿 labels/keep_idx
|
||||
res = api_best.fit_from_smiles(smiles, method=best["method"], method_kwargs=best["params"])
|
||||
return api_best, res["stats"], hist
|
||||
Reference in New Issue
Block a user