重构项目结构并更新README.md
1. 重构目录结构: - 创建src/visualization模块用于存放可视化相关功能 - 移动script/visualize_csv_comparison.py到src/visualization/comparison.py - 创建src/visualization/__init__.py导出主要函数 - 整理script目录,按功能分类存放脚本文件 2. 更新README.md: - 添加CSV文件比较可视化部分 - 提供Python API和命令行使用方法说明 - 描述功能特点和使用示例 3. 更新模块引用: - 修正comparison.py中的模块引用路径 - 更新命令行帮助信息中的使用示例
This commit is contained in:
304
script/data_processing/split_drugbank.py
Normal file
304
script/data_processing/split_drugbank.py
Normal file
@@ -0,0 +1,304 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
结构感知 + 分布对齐 的 DrugBank 数据集划分脚本
|
||||
- 读取 CSV(至少包含 SELFIES 或 SMILES 之一,建议还包含 qed)
|
||||
- 计算 SMILES(若仅有 SELFIES),构建 RDKit Mol
|
||||
- 计算 Bemis-Murcko Scaffold、分子量等基本属性
|
||||
- 以 scaffold 为分组单元进行分层划分(train/val/test)
|
||||
- 对齐 QED/MW 分布(分箱 + 贪心平衡)
|
||||
- 输出三份 CSV:split_train.csv / split_val.csv / split_test.csv
|
||||
|
||||
结构感知 + 分布对齐 的 DrugBank 数据集划分脚本(修正版)
|
||||
- 关键修复:分配策略改为【容量优先 + 分布对齐】,避免 TRAIN=0 的偏置
|
||||
"""
|
||||
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
结构感知 + 分布对齐 的 DrugBank 数据集划分脚本(修正版)
|
||||
- 关键修复:分配策略改为【容量优先 + 分布对齐】,避免 TRAIN=0 的偏置
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import math
|
||||
import hashlib
|
||||
from collections import defaultdict, Counter
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from rdkit import Chem
|
||||
from rdkit.Chem.Scaffolds import MurckoScaffold
|
||||
from rdkit.Chem import Descriptors
|
||||
|
||||
try:
|
||||
import selfies as sf
|
||||
except Exception:
|
||||
sf = None
|
||||
|
||||
def to_smiles_from_selfies(s):
|
||||
if pd.isna(s) or not isinstance(s, str) or not s.strip():
|
||||
return None
|
||||
if sf is None:
|
||||
return None
|
||||
try:
|
||||
smi = sf.decoder(s)
|
||||
mol = Chem.MolFromSmiles(smi)
|
||||
if mol is None:
|
||||
return None
|
||||
return Chem.MolToSmiles(mol)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def canonicalize_smiles(s):
|
||||
if pd.isna(s) or not isinstance(s, str) or not s.strip():
|
||||
return None
|
||||
mol = Chem.MolFromSmiles(s)
|
||||
if mol is None:
|
||||
return None
|
||||
return Chem.MolToSmiles(mol)
|
||||
|
||||
def bemis_murcko_scaffold(smiles):
|
||||
if smiles is None:
|
||||
return None
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
if mol is None:
|
||||
return None
|
||||
try:
|
||||
scaf = MurckoScaffold.GetScaffoldForMol(mol)
|
||||
if scaf is None:
|
||||
return None
|
||||
return Chem.MolToSmiles(scaf)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def calc_basic_props(smiles):
|
||||
mol = Chem.MolFromSmiles(smiles) if smiles else None
|
||||
if mol is None:
|
||||
return (np.nan, np.nan, np.nan, np.nan)
|
||||
mw = Descriptors.MolWt(mol)
|
||||
tpsa = Descriptors.TPSA(mol)
|
||||
logp = Descriptors.MolLogP(mol)
|
||||
heavy = mol.GetNumHeavyAtoms()
|
||||
return (mw, tpsa, logp, heavy)
|
||||
|
||||
def hash_str(x):
|
||||
return int(hashlib.md5(x.encode('utf-8')).hexdigest(), 16)
|
||||
|
||||
def bin_by_edges(x, edges):
|
||||
if isinstance(x, float) and math.isnan(x):
|
||||
return -1
|
||||
for i in range(1, len(edges)):
|
||||
if x <= edges[i]:
|
||||
return i-1
|
||||
return len(edges) - 2
|
||||
|
||||
def hist_distance_L2(h1, h2):
|
||||
keys = set(h1.keys()) | set(h2.keys())
|
||||
return math.sqrt(sum((h1.get(k,0) - h2.get(k,0))**2 for k in keys))
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--in-csv", required=True)
|
||||
ap.add_argument("--out-dir", default=".")
|
||||
ap.add_argument("--seed", type=int, default=2025)
|
||||
ap.add_argument("--train-ratio", type=float, default=0.8)
|
||||
ap.add_argument("--val-ratio", type=float, default=0.1)
|
||||
ap.add_argument("--test-ratio", type=float, default=0.1)
|
||||
ap.add_argument("--n_qed_bins", type=int, default=5)
|
||||
ap.add_argument("--n_mw_bins", type=int, default=5)
|
||||
# 新增:控制是否按 scaffold 大小降序分配(推荐 True)
|
||||
ap.add_argument("--largest-first", action="store_true", default=True)
|
||||
args = ap.parse_args()
|
||||
|
||||
np.random.seed(args.seed)
|
||||
|
||||
out_dir = Path(args.out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
df = pd.read_csv(args.in_csv)
|
||||
|
||||
def col_like(name_candidates):
|
||||
for nc in name_candidates:
|
||||
idx = [i for i,c in enumerate(df.columns) if c.lower()==nc.lower()]
|
||||
if idx:
|
||||
return df.columns[idx[0]]
|
||||
return None
|
||||
|
||||
col_smiles = col_like(["smiles", "canonical_smiles"])
|
||||
col_selfies = col_like(["selfies"])
|
||||
col_id = col_like(["id","molid","drug_id"])
|
||||
col_qed = col_like(["qed"])
|
||||
|
||||
if col_smiles is None:
|
||||
if col_selfies is None:
|
||||
raise SystemExit("需要 SMILES 或 SELFIES 至少一个列。")
|
||||
df["SMILES"] = df[col_selfies].map(to_smiles_from_selfies)
|
||||
col_smiles = "SMILES"
|
||||
else:
|
||||
df["SMILES"] = df[col_smiles]
|
||||
|
||||
df["SMILES"] = df["SMILES"].map(canonicalize_smiles)
|
||||
df = df[~df["SMILES"].isna()].copy()
|
||||
df = df.drop_duplicates(subset=["SMILES"])
|
||||
|
||||
props = df["SMILES"].map(calc_basic_props)
|
||||
df["MW"] = [p[0] for p in props]
|
||||
df["TPSA"] = [p[1] for p in props]
|
||||
df["LogP"] = [p[2] for p in props]
|
||||
df["HeavyAtoms"] = [p[3] for p in props]
|
||||
|
||||
df["Scaffold"] = df["SMILES"].map(bemis_murcko_scaffold)
|
||||
df["GroupKey"] = np.where(df["Scaffold"].isna(), df["SMILES"], df["Scaffold"])
|
||||
|
||||
if col_qed is None:
|
||||
df["qed"] = np.nan
|
||||
else:
|
||||
df["qed"] = pd.to_numeric(df[col_qed], errors="coerce")
|
||||
|
||||
def compute_edges(series, n_bins):
|
||||
ser = series.dropna()
|
||||
if len(ser) < n_bins:
|
||||
return np.linspace(ser.min() if len(ser)>0 else 0,
|
||||
ser.max() if len(ser)>0 else 1,
|
||||
n_bins+1)
|
||||
qs = np.linspace(0,1,n_bins+1)
|
||||
return np.quantile(ser, qs)
|
||||
|
||||
qed_edges = compute_edges(df["qed"], args.n_qed_bins)
|
||||
mw_edges = compute_edges(df["MW"], args.n_mw_bins)
|
||||
|
||||
df["QED_bin"] = df["qed"].map(lambda x: bin_by_edges(x, qed_edges))
|
||||
df["MW_bin"] = df["MW"].map(lambda x: bin_by_edges(x, mw_edges))
|
||||
df["Strata"] = list(zip(df["QED_bin"], df["MW_bin"]))
|
||||
|
||||
group2idx = defaultdict(list)
|
||||
for i, g in enumerate(df["GroupKey"]):
|
||||
group2idx[g].append(i)
|
||||
|
||||
group_hist = {g: Counter(df.loc[idxs, "Strata"].tolist())
|
||||
for g, idxs in group2idx.items()}
|
||||
global_hist = Counter(df["Strata"].tolist())
|
||||
|
||||
total = len(df)
|
||||
target_counts = {
|
||||
"train": int(round(total * args.train_ratio)),
|
||||
"val": int(round(total * args.val_ratio)),
|
||||
"test": int(round(total * args.test_ratio)),
|
||||
}
|
||||
diff = total - sum(target_counts.values())
|
||||
if diff != 0:
|
||||
target_counts["train"] += diff
|
||||
|
||||
global_ratio = {
|
||||
"train": args.train_ratio,
|
||||
"val": args.val_ratio,
|
||||
"test": args.test_ratio,
|
||||
}
|
||||
target_hist = {split: {k: v * r for k, v in global_hist.items()}
|
||||
for split, r in global_ratio.items()}
|
||||
|
||||
splits = {"train": [], "val": [], "test": []}
|
||||
split_counts = {"train": 0, "val": 0, "test": 0}
|
||||
split_hist = {"train": Counter(), "val": Counter(), "test": Counter()}
|
||||
|
||||
scaffolds = list(group2idx.keys())
|
||||
rng = np.random.RandomState(args.seed)
|
||||
|
||||
# 关键:按组大小降序(先放大的)
|
||||
if args.largest_first:
|
||||
scaffolds.sort(key=lambda g: len(group2idx[g]), reverse=True)
|
||||
else:
|
||||
rng.shuffle(scaffolds)
|
||||
|
||||
def remaining_capacity(split):
|
||||
return target_counts[split] - split_counts[split]
|
||||
|
||||
for g in scaffolds:
|
||||
idxs = group2idx[g]
|
||||
g_size = len(idxs)
|
||||
g_hist = group_hist[g]
|
||||
|
||||
# 计算每个 split 的剩余容量
|
||||
caps = {s: remaining_capacity(s) for s in ["train","val","test"]}
|
||||
|
||||
# 优先考虑【仍有正容量】的 split
|
||||
positive_splits = [s for s,c in caps.items() if c > 0]
|
||||
|
||||
# 如果都有容量<=0,说明目标配额已满;退化为“最小溢出优先”
|
||||
candidate_splits = positive_splits if positive_splits else ["train","val","test"]
|
||||
|
||||
# 在候选中,综合“容量优先 + 分布对齐(L2更小)”
|
||||
best = None
|
||||
for s in candidate_splits:
|
||||
cap = caps[s]
|
||||
proj_hist = split_hist[s] + g_hist
|
||||
l2 = hist_distance_L2(proj_hist, target_hist[s])
|
||||
|
||||
# 排序键:
|
||||
# 1) -max(cap, 0): 正容量越大越好;都负时此项=0
|
||||
# 2) abs(min(cap, 0)): 负容量越小(溢出越少)越好
|
||||
# 3) l2: 分布差越小越好
|
||||
key = (-max(cap, 0), abs(min(cap, 0)), l2)
|
||||
|
||||
if (best is None) or (key < best[0]):
|
||||
best = (key, s)
|
||||
|
||||
chosen = best[1]
|
||||
|
||||
splits[chosen].append(g)
|
||||
split_counts[chosen] += g_size
|
||||
split_hist[chosen] += g_hist
|
||||
|
||||
# 收集索引
|
||||
idx_train, idx_val, idx_test = [], [], []
|
||||
for g in splits["train"]:
|
||||
idx_train.extend(group2idx[g])
|
||||
for g in splits["val"]:
|
||||
idx_val.extend(group2idx[g])
|
||||
for g in splits["test"]:
|
||||
idx_test.extend(group2idx[g])
|
||||
|
||||
df_train = df.iloc[sorted(idx_train)].copy()
|
||||
df_val = df.iloc[sorted(idx_val)].copy()
|
||||
df_test = df.iloc[sorted(idx_test)].copy()
|
||||
|
||||
def report(name, sub):
|
||||
print(f"\n== {name} ==")
|
||||
print(f"size = {len(sub)}")
|
||||
h = Counter(sub["Strata"].tolist())
|
||||
for k, v in h.most_common(8):
|
||||
print(f" {k}: {v}")
|
||||
print(" QED mean/std:", np.nanmean(sub["qed"]), np.nanstd(sub["qed"]))
|
||||
print(" MW mean/std:", np.nanmean(sub["MW"]), np.nanstd(sub["MW"]))
|
||||
|
||||
print(f"Total: {len(df)} (targets: {target_counts})")
|
||||
report("TRAIN", df_train)
|
||||
report("VAL", df_val)
|
||||
report("TEST", df_test)
|
||||
|
||||
out_train = out_dir / "split_train.csv"
|
||||
out_val = out_dir / "split_val.csv"
|
||||
out_test = out_dir / "split_test.csv"
|
||||
df_train.to_csv(out_train, index=False)
|
||||
df_val.to_csv(out_val, index=False)
|
||||
df_test.to_csv(out_test, index=False)
|
||||
|
||||
edges = pd.DataFrame({
|
||||
"qed_edges": list(qed_edges) + [np.nan]*(max(len(mw_edges),len(qed_edges))-len(qed_edges)),
|
||||
"mw_edges": list(mw_edges) + [np.nan]*(max(len(mw_edges),len(qed_edges))-len(mw_edges)),
|
||||
})
|
||||
edges.to_csv(out_dir / "bin_edges_qed_mw.csv", index=False)
|
||||
|
||||
# 额外提示:打印容量偏差
|
||||
print("\nFinal counts vs targets:")
|
||||
for s in ["train","val","test"]:
|
||||
print(f" {s}: {split_counts[s]} / {target_counts[s]} (delta={split_counts[s]-target_counts[s]})")
|
||||
|
||||
print(f"\nSaved:\n {out_train}\n {out_val}\n {out_test}\n {out_dir/'bin_edges_qed_mw.csv'}")
|
||||
print("\n验证/测试阶段,仅使用 split_val / split_test 作为 feed-chemical。")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user