#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 结构感知 + 分布对齐 的 DrugBank 数据集划分脚本 - 读取 CSV(至少包含 SELFIES 或 SMILES 之一,建议还包含 qed) - 计算 SMILES(若仅有 SELFIES),构建 RDKit Mol - 计算 Bemis-Murcko Scaffold、分子量等基本属性 - 以 scaffold 为分组单元进行分层划分(train/val/test) - 对齐 QED/MW 分布(分箱 + 贪心平衡) - 输出三份 CSV:split_train.csv / split_val.csv / split_test.csv 结构感知 + 分布对齐 的 DrugBank 数据集划分脚本(修正版) - 关键修复:分配策略改为【容量优先 + 分布对齐】,避免 TRAIN=0 的偏置 """ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 结构感知 + 分布对齐 的 DrugBank 数据集划分脚本(修正版) - 关键修复:分配策略改为【容量优先 + 分布对齐】,避免 TRAIN=0 的偏置 """ import argparse from pathlib import Path import math import hashlib from collections import defaultdict, Counter import pandas as pd import numpy as np from rdkit import Chem from rdkit.Chem.Scaffolds import MurckoScaffold from rdkit.Chem import Descriptors try: import selfies as sf except Exception: sf = None def to_smiles_from_selfies(s): if pd.isna(s) or not isinstance(s, str) or not s.strip(): return None if sf is None: return None try: smi = sf.decoder(s) mol = Chem.MolFromSmiles(smi) if mol is None: return None return Chem.MolToSmiles(mol) except Exception: return None def canonicalize_smiles(s): if pd.isna(s) or not isinstance(s, str) or not s.strip(): return None mol = Chem.MolFromSmiles(s) if mol is None: return None return Chem.MolToSmiles(mol) def bemis_murcko_scaffold(smiles): if smiles is None: return None mol = Chem.MolFromSmiles(smiles) if mol is None: return None try: scaf = MurckoScaffold.GetScaffoldForMol(mol) if scaf is None: return None return Chem.MolToSmiles(scaf) except Exception: return None def calc_basic_props(smiles): mol = Chem.MolFromSmiles(smiles) if smiles else None if mol is None: return (np.nan, np.nan, np.nan, np.nan) mw = Descriptors.MolWt(mol) tpsa = Descriptors.TPSA(mol) logp = Descriptors.MolLogP(mol) heavy = mol.GetNumHeavyAtoms() return (mw, tpsa, logp, heavy) def hash_str(x): return int(hashlib.md5(x.encode('utf-8')).hexdigest(), 16) def bin_by_edges(x, edges): if isinstance(x, float) and math.isnan(x): return -1 for i in range(1, len(edges)): if x <= edges[i]: return i-1 return len(edges) - 2 def hist_distance_L2(h1, h2): keys = set(h1.keys()) | set(h2.keys()) return math.sqrt(sum((h1.get(k,0) - h2.get(k,0))**2 for k in keys)) def main(): ap = argparse.ArgumentParser() ap.add_argument("--in-csv", required=True) ap.add_argument("--out-dir", default=".") ap.add_argument("--seed", type=int, default=2025) ap.add_argument("--train-ratio", type=float, default=0.8) ap.add_argument("--val-ratio", type=float, default=0.1) ap.add_argument("--test-ratio", type=float, default=0.1) ap.add_argument("--n_qed_bins", type=int, default=5) ap.add_argument("--n_mw_bins", type=int, default=5) # 新增:控制是否按 scaffold 大小降序分配(推荐 True) ap.add_argument("--largest-first", action="store_true", default=True) args = ap.parse_args() np.random.seed(args.seed) out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) df = pd.read_csv(args.in_csv) def col_like(name_candidates): for nc in name_candidates: idx = [i for i,c in enumerate(df.columns) if c.lower()==nc.lower()] if idx: return df.columns[idx[0]] return None col_smiles = col_like(["smiles", "canonical_smiles"]) col_selfies = col_like(["selfies"]) col_id = col_like(["id","molid","drug_id"]) col_qed = col_like(["qed"]) if col_smiles is None: if col_selfies is None: raise SystemExit("需要 SMILES 或 SELFIES 至少一个列。") df["SMILES"] = df[col_selfies].map(to_smiles_from_selfies) col_smiles = "SMILES" else: df["SMILES"] = df[col_smiles] df["SMILES"] = df["SMILES"].map(canonicalize_smiles) df = df[~df["SMILES"].isna()].copy() df = df.drop_duplicates(subset=["SMILES"]) props = df["SMILES"].map(calc_basic_props) df["MW"] = [p[0] for p in props] df["TPSA"] = [p[1] for p in props] df["LogP"] = [p[2] for p in props] df["HeavyAtoms"] = [p[3] for p in props] df["Scaffold"] = df["SMILES"].map(bemis_murcko_scaffold) df["GroupKey"] = np.where(df["Scaffold"].isna(), df["SMILES"], df["Scaffold"]) if col_qed is None: df["qed"] = np.nan else: df["qed"] = pd.to_numeric(df[col_qed], errors="coerce") def compute_edges(series, n_bins): ser = series.dropna() if len(ser) < n_bins: return np.linspace(ser.min() if len(ser)>0 else 0, ser.max() if len(ser)>0 else 1, n_bins+1) qs = np.linspace(0,1,n_bins+1) return np.quantile(ser, qs) qed_edges = compute_edges(df["qed"], args.n_qed_bins) mw_edges = compute_edges(df["MW"], args.n_mw_bins) df["QED_bin"] = df["qed"].map(lambda x: bin_by_edges(x, qed_edges)) df["MW_bin"] = df["MW"].map(lambda x: bin_by_edges(x, mw_edges)) df["Strata"] = list(zip(df["QED_bin"], df["MW_bin"])) group2idx = defaultdict(list) for i, g in enumerate(df["GroupKey"]): group2idx[g].append(i) group_hist = {g: Counter(df.loc[idxs, "Strata"].tolist()) for g, idxs in group2idx.items()} global_hist = Counter(df["Strata"].tolist()) total = len(df) target_counts = { "train": int(round(total * args.train_ratio)), "val": int(round(total * args.val_ratio)), "test": int(round(total * args.test_ratio)), } diff = total - sum(target_counts.values()) if diff != 0: target_counts["train"] += diff global_ratio = { "train": args.train_ratio, "val": args.val_ratio, "test": args.test_ratio, } target_hist = {split: {k: v * r for k, v in global_hist.items()} for split, r in global_ratio.items()} splits = {"train": [], "val": [], "test": []} split_counts = {"train": 0, "val": 0, "test": 0} split_hist = {"train": Counter(), "val": Counter(), "test": Counter()} scaffolds = list(group2idx.keys()) rng = np.random.RandomState(args.seed) # 关键:按组大小降序(先放大的) if args.largest_first: scaffolds.sort(key=lambda g: len(group2idx[g]), reverse=True) else: rng.shuffle(scaffolds) def remaining_capacity(split): return target_counts[split] - split_counts[split] for g in scaffolds: idxs = group2idx[g] g_size = len(idxs) g_hist = group_hist[g] # 计算每个 split 的剩余容量 caps = {s: remaining_capacity(s) for s in ["train","val","test"]} # 优先考虑【仍有正容量】的 split positive_splits = [s for s,c in caps.items() if c > 0] # 如果都有容量<=0,说明目标配额已满;退化为“最小溢出优先” candidate_splits = positive_splits if positive_splits else ["train","val","test"] # 在候选中,综合“容量优先 + 分布对齐(L2更小)” best = None for s in candidate_splits: cap = caps[s] proj_hist = split_hist[s] + g_hist l2 = hist_distance_L2(proj_hist, target_hist[s]) # 排序键: # 1) -max(cap, 0): 正容量越大越好;都负时此项=0 # 2) abs(min(cap, 0)): 负容量越小(溢出越少)越好 # 3) l2: 分布差越小越好 key = (-max(cap, 0), abs(min(cap, 0)), l2) if (best is None) or (key < best[0]): best = (key, s) chosen = best[1] splits[chosen].append(g) split_counts[chosen] += g_size split_hist[chosen] += g_hist # 收集索引 idx_train, idx_val, idx_test = [], [], [] for g in splits["train"]: idx_train.extend(group2idx[g]) for g in splits["val"]: idx_val.extend(group2idx[g]) for g in splits["test"]: idx_test.extend(group2idx[g]) df_train = df.iloc[sorted(idx_train)].copy() df_val = df.iloc[sorted(idx_val)].copy() df_test = df.iloc[sorted(idx_test)].copy() def report(name, sub): print(f"\n== {name} ==") print(f"size = {len(sub)}") h = Counter(sub["Strata"].tolist()) for k, v in h.most_common(8): print(f" {k}: {v}") print(" QED mean/std:", np.nanmean(sub["qed"]), np.nanstd(sub["qed"])) print(" MW mean/std:", np.nanmean(sub["MW"]), np.nanstd(sub["MW"])) print(f"Total: {len(df)} (targets: {target_counts})") report("TRAIN", df_train) report("VAL", df_val) report("TEST", df_test) out_train = out_dir / "split_train.csv" out_val = out_dir / "split_val.csv" out_test = out_dir / "split_test.csv" df_train.to_csv(out_train, index=False) df_val.to_csv(out_val, index=False) df_test.to_csv(out_test, index=False) edges = pd.DataFrame({ "qed_edges": list(qed_edges) + [np.nan]*(max(len(mw_edges),len(qed_edges))-len(qed_edges)), "mw_edges": list(mw_edges) + [np.nan]*(max(len(mw_edges),len(qed_edges))-len(mw_edges)), }) edges.to_csv(out_dir / "bin_edges_qed_mw.csv", index=False) # 额外提示:打印容量偏差 print("\nFinal counts vs targets:") for s in ["train","val","test"]: print(f" {s}: {split_counts[s]} / {target_counts[s]} (delta={split_counts[s]-target_counts[s]})") print(f"\nSaved:\n {out_train}\n {out_val}\n {out_test}\n {out_dir/'bin_edges_qed_mw.csv'}") print("\n验证/测试阶段,仅使用 split_val / split_test 作为 feed-chemical。") if __name__ == "__main__": main()