Files
embedding_atlas/script/split_drugbank.py
2025-09-22 20:06:39 +08:00

305 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
结构感知 + 分布对齐 的 DrugBank 数据集划分脚本
- 读取 CSV至少包含 SELFIES 或 SMILES 之一,建议还包含 qed
- 计算 SMILES若仅有 SELFIES构建 RDKit Mol
- 计算 Bemis-Murcko Scaffold、分子量等基本属性
- 以 scaffold 为分组单元进行分层划分train/val/test
- 对齐 QED/MW 分布(分箱 + 贪心平衡)
- 输出三份 CSVsplit_train.csv / split_val.csv / split_test.csv
结构感知 + 分布对齐 的 DrugBank 数据集划分脚本(修正版)
- 关键修复:分配策略改为【容量优先 + 分布对齐】,避免 TRAIN=0 的偏置
"""
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
结构感知 + 分布对齐 的 DrugBank 数据集划分脚本(修正版)
- 关键修复:分配策略改为【容量优先 + 分布对齐】,避免 TRAIN=0 的偏置
"""
import argparse
from pathlib import Path
import math
import hashlib
from collections import defaultdict, Counter
import pandas as pd
import numpy as np
from rdkit import Chem
from rdkit.Chem.Scaffolds import MurckoScaffold
from rdkit.Chem import Descriptors
try:
import selfies as sf
except Exception:
sf = None
def to_smiles_from_selfies(s):
if pd.isna(s) or not isinstance(s, str) or not s.strip():
return None
if sf is None:
return None
try:
smi = sf.decoder(s)
mol = Chem.MolFromSmiles(smi)
if mol is None:
return None
return Chem.MolToSmiles(mol)
except Exception:
return None
def canonicalize_smiles(s):
if pd.isna(s) or not isinstance(s, str) or not s.strip():
return None
mol = Chem.MolFromSmiles(s)
if mol is None:
return None
return Chem.MolToSmiles(mol)
def bemis_murcko_scaffold(smiles):
if smiles is None:
return None
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
try:
scaf = MurckoScaffold.GetScaffoldForMol(mol)
if scaf is None:
return None
return Chem.MolToSmiles(scaf)
except Exception:
return None
def calc_basic_props(smiles):
mol = Chem.MolFromSmiles(smiles) if smiles else None
if mol is None:
return (np.nan, np.nan, np.nan, np.nan)
mw = Descriptors.MolWt(mol)
tpsa = Descriptors.TPSA(mol)
logp = Descriptors.MolLogP(mol)
heavy = mol.GetNumHeavyAtoms()
return (mw, tpsa, logp, heavy)
def hash_str(x):
return int(hashlib.md5(x.encode('utf-8')).hexdigest(), 16)
def bin_by_edges(x, edges):
if isinstance(x, float) and math.isnan(x):
return -1
for i in range(1, len(edges)):
if x <= edges[i]:
return i-1
return len(edges) - 2
def hist_distance_L2(h1, h2):
keys = set(h1.keys()) | set(h2.keys())
return math.sqrt(sum((h1.get(k,0) - h2.get(k,0))**2 for k in keys))
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--in-csv", required=True)
ap.add_argument("--out-dir", default=".")
ap.add_argument("--seed", type=int, default=2025)
ap.add_argument("--train-ratio", type=float, default=0.8)
ap.add_argument("--val-ratio", type=float, default=0.1)
ap.add_argument("--test-ratio", type=float, default=0.1)
ap.add_argument("--n_qed_bins", type=int, default=5)
ap.add_argument("--n_mw_bins", type=int, default=5)
# 新增:控制是否按 scaffold 大小降序分配(推荐 True
ap.add_argument("--largest-first", action="store_true", default=True)
args = ap.parse_args()
np.random.seed(args.seed)
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
df = pd.read_csv(args.in_csv)
def col_like(name_candidates):
for nc in name_candidates:
idx = [i for i,c in enumerate(df.columns) if c.lower()==nc.lower()]
if idx:
return df.columns[idx[0]]
return None
col_smiles = col_like(["smiles", "canonical_smiles"])
col_selfies = col_like(["selfies"])
col_id = col_like(["id","molid","drug_id"])
col_qed = col_like(["qed"])
if col_smiles is None:
if col_selfies is None:
raise SystemExit("需要 SMILES 或 SELFIES 至少一个列。")
df["SMILES"] = df[col_selfies].map(to_smiles_from_selfies)
col_smiles = "SMILES"
else:
df["SMILES"] = df[col_smiles]
df["SMILES"] = df["SMILES"].map(canonicalize_smiles)
df = df[~df["SMILES"].isna()].copy()
df = df.drop_duplicates(subset=["SMILES"])
props = df["SMILES"].map(calc_basic_props)
df["MW"] = [p[0] for p in props]
df["TPSA"] = [p[1] for p in props]
df["LogP"] = [p[2] for p in props]
df["HeavyAtoms"] = [p[3] for p in props]
df["Scaffold"] = df["SMILES"].map(bemis_murcko_scaffold)
df["GroupKey"] = np.where(df["Scaffold"].isna(), df["SMILES"], df["Scaffold"])
if col_qed is None:
df["qed"] = np.nan
else:
df["qed"] = pd.to_numeric(df[col_qed], errors="coerce")
def compute_edges(series, n_bins):
ser = series.dropna()
if len(ser) < n_bins:
return np.linspace(ser.min() if len(ser)>0 else 0,
ser.max() if len(ser)>0 else 1,
n_bins+1)
qs = np.linspace(0,1,n_bins+1)
return np.quantile(ser, qs)
qed_edges = compute_edges(df["qed"], args.n_qed_bins)
mw_edges = compute_edges(df["MW"], args.n_mw_bins)
df["QED_bin"] = df["qed"].map(lambda x: bin_by_edges(x, qed_edges))
df["MW_bin"] = df["MW"].map(lambda x: bin_by_edges(x, mw_edges))
df["Strata"] = list(zip(df["QED_bin"], df["MW_bin"]))
group2idx = defaultdict(list)
for i, g in enumerate(df["GroupKey"]):
group2idx[g].append(i)
group_hist = {g: Counter(df.loc[idxs, "Strata"].tolist())
for g, idxs in group2idx.items()}
global_hist = Counter(df["Strata"].tolist())
total = len(df)
target_counts = {
"train": int(round(total * args.train_ratio)),
"val": int(round(total * args.val_ratio)),
"test": int(round(total * args.test_ratio)),
}
diff = total - sum(target_counts.values())
if diff != 0:
target_counts["train"] += diff
global_ratio = {
"train": args.train_ratio,
"val": args.val_ratio,
"test": args.test_ratio,
}
target_hist = {split: {k: v * r for k, v in global_hist.items()}
for split, r in global_ratio.items()}
splits = {"train": [], "val": [], "test": []}
split_counts = {"train": 0, "val": 0, "test": 0}
split_hist = {"train": Counter(), "val": Counter(), "test": Counter()}
scaffolds = list(group2idx.keys())
rng = np.random.RandomState(args.seed)
# 关键:按组大小降序(先放大的)
if args.largest_first:
scaffolds.sort(key=lambda g: len(group2idx[g]), reverse=True)
else:
rng.shuffle(scaffolds)
def remaining_capacity(split):
return target_counts[split] - split_counts[split]
for g in scaffolds:
idxs = group2idx[g]
g_size = len(idxs)
g_hist = group_hist[g]
# 计算每个 split 的剩余容量
caps = {s: remaining_capacity(s) for s in ["train","val","test"]}
# 优先考虑【仍有正容量】的 split
positive_splits = [s for s,c in caps.items() if c > 0]
# 如果都有容量<=0说明目标配额已满退化为“最小溢出优先”
candidate_splits = positive_splits if positive_splits else ["train","val","test"]
# 在候选中,综合“容量优先 + 分布对齐L2更小
best = None
for s in candidate_splits:
cap = caps[s]
proj_hist = split_hist[s] + g_hist
l2 = hist_distance_L2(proj_hist, target_hist[s])
# 排序键:
# 1) -max(cap, 0): 正容量越大越好;都负时此项=0
# 2) abs(min(cap, 0)): 负容量越小(溢出越少)越好
# 3) l2: 分布差越小越好
key = (-max(cap, 0), abs(min(cap, 0)), l2)
if (best is None) or (key < best[0]):
best = (key, s)
chosen = best[1]
splits[chosen].append(g)
split_counts[chosen] += g_size
split_hist[chosen] += g_hist
# 收集索引
idx_train, idx_val, idx_test = [], [], []
for g in splits["train"]:
idx_train.extend(group2idx[g])
for g in splits["val"]:
idx_val.extend(group2idx[g])
for g in splits["test"]:
idx_test.extend(group2idx[g])
df_train = df.iloc[sorted(idx_train)].copy()
df_val = df.iloc[sorted(idx_val)].copy()
df_test = df.iloc[sorted(idx_test)].copy()
def report(name, sub):
print(f"\n== {name} ==")
print(f"size = {len(sub)}")
h = Counter(sub["Strata"].tolist())
for k, v in h.most_common(8):
print(f" {k}: {v}")
print(" QED mean/std:", np.nanmean(sub["qed"]), np.nanstd(sub["qed"]))
print(" MW mean/std:", np.nanmean(sub["MW"]), np.nanstd(sub["MW"]))
print(f"Total: {len(df)} (targets: {target_counts})")
report("TRAIN", df_train)
report("VAL", df_val)
report("TEST", df_test)
out_train = out_dir / "split_train.csv"
out_val = out_dir / "split_val.csv"
out_test = out_dir / "split_test.csv"
df_train.to_csv(out_train, index=False)
df_val.to_csv(out_val, index=False)
df_test.to_csv(out_test, index=False)
edges = pd.DataFrame({
"qed_edges": list(qed_edges) + [np.nan]*(max(len(mw_edges),len(qed_edges))-len(qed_edges)),
"mw_edges": list(mw_edges) + [np.nan]*(max(len(mw_edges),len(qed_edges))-len(mw_edges)),
})
edges.to_csv(out_dir / "bin_edges_qed_mw.csv", index=False)
# 额外提示:打印容量偏差
print("\nFinal counts vs targets:")
for s in ["train","val","test"]:
print(f" {s}: {split_counts[s]} / {target_counts[s]} (delta={split_counts[s]-target_counts[s]})")
print(f"\nSaved:\n {out_train}\n {out_val}\n {out_test}\n {out_dir/'bin_edges_qed_mw.csv'}")
print("\n验证/测试阶段,仅使用 split_val / split_test 作为 feed-chemical。")
if __name__ == "__main__":
main()