重构项目结构并更新README.md
1. 重构目录结构: - 创建src/visualization模块用于存放可视化相关功能 - 移动script/visualize_csv_comparison.py到src/visualization/comparison.py - 创建src/visualization/__init__.py导出主要函数 - 整理script目录,按功能分类存放脚本文件 2. 更新README.md: - 添加CSV文件比较可视化部分 - 提供Python API和命令行使用方法说明 - 描述功能特点和使用示例 3. 更新模块引用: - 修正comparison.py中的模块引用路径 - 更新命令行帮助信息中的使用示例
This commit is contained in:
186
script/data_processing/add_ecfp4_tanimoto.py
Normal file
186
script/data_processing/add_ecfp4_tanimoto.py
Normal file
@@ -0,0 +1,186 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Augment a CSV with ECFP4 binary fingerprints and Tanimoto neighbor summaries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import pathlib
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterable, List, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from rdkit import Chem, DataStructs
|
||||
from rdkit.Chem import rdFingerprintGenerator
|
||||
|
||||
|
||||
@dataclass
|
||||
class ECFP4Generator:
|
||||
"""Generate ECFP4 (Morgan radius 2) fingerprints as RDKit bit vectors."""
|
||||
|
||||
n_bits: int = 2048
|
||||
radius: int = 2
|
||||
include_chirality: bool = True
|
||||
generator: rdFingerprintGenerator.MorganGenerator = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.generator = rdFingerprintGenerator.GetMorganGenerator(
|
||||
radius=self.radius,
|
||||
fpSize=self.n_bits,
|
||||
includeChirality=self.include_chirality,
|
||||
)
|
||||
|
||||
def fingerprint(self, smiles: str) -> Optional[DataStructs.ExplicitBitVect]:
|
||||
if not smiles:
|
||||
return None
|
||||
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
if mol is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return self.generator.GetFingerprint(mol)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def to_binary_string(self, fp: DataStructs.ExplicitBitVect) -> str:
|
||||
arr = np.zeros((self.n_bits,), dtype=np.uint8)
|
||||
DataStructs.ConvertToNumpyArray(fp, arr)
|
||||
return ''.join(arr.astype(str))
|
||||
|
||||
|
||||
def tanimoto_top_k(
|
||||
fingerprints: Sequence[Optional[DataStructs.ExplicitBitVect]],
|
||||
ids: Sequence[str],
|
||||
top_k: int = 5,
|
||||
) -> List[str]:
|
||||
"""Return semicolon-delimited Tanimoto summaries for each fingerprint."""
|
||||
valid_indices = [idx for idx, fp in enumerate(fingerprints) if fp is not None]
|
||||
valid_fps = [fingerprints[idx] for idx in valid_indices]
|
||||
|
||||
index_lookup = {pos: original for pos, original in enumerate(valid_indices)}
|
||||
summaries = [''] * len(fingerprints)
|
||||
|
||||
if not valid_fps:
|
||||
return summaries
|
||||
|
||||
for pos, fp in enumerate(valid_fps):
|
||||
sims = DataStructs.BulkTanimotoSimilarity(fp, valid_fps)
|
||||
ranked: List[Tuple[int, float]] = []
|
||||
for other_pos, score in enumerate(sims):
|
||||
if other_pos == pos:
|
||||
continue
|
||||
ranked.append((other_pos, score))
|
||||
|
||||
ranked.sort(key=lambda item: item[1], reverse=True)
|
||||
top_entries = []
|
||||
for other_pos, score in ranked[:top_k]:
|
||||
original_idx = index_lookup[other_pos]
|
||||
top_entries.append(f"{ids[original_idx]}:{score:.4f}")
|
||||
|
||||
summaries[index_lookup[pos]] = ';'.join(top_entries)
|
||||
|
||||
return summaries
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("input_csv", type=pathlib.Path, help="Source CSV with SMILES")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=pathlib.Path,
|
||||
default=None,
|
||||
help="Destination file (default: <input>_with_ecfp4.parquet)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--smiles-column",
|
||||
default="smiles",
|
||||
help="Name of the column containing SMILES (default: smiles)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--id-column",
|
||||
default="generated_id",
|
||||
help="Column used to label Tanimoto neighbors (default: generated_id)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of nearest neighbors to report in Tanimoto summaries (default: 5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--format",
|
||||
choices=("parquet", "csv", "auto"),
|
||||
default="parquet",
|
||||
help="Output format; defaults to parquet unless overridden or inferred from --output",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
if not args.input_csv.exists():
|
||||
raise FileNotFoundError(f"Input file not found: {args.input_csv}")
|
||||
|
||||
def resolve_output_path() -> pathlib.Path:
|
||||
if args.output is not None:
|
||||
return args.output
|
||||
|
||||
suffix = ".parquet" if args.format in ("parquet", "auto") else ".csv"
|
||||
return args.input_csv.with_name(f"{args.input_csv.stem}_with_ecfp4{suffix}")
|
||||
|
||||
def resolve_format(path: pathlib.Path) -> str:
|
||||
suffix = path.suffix.lower()
|
||||
if suffix in {".parquet", ".pq"}:
|
||||
return "parquet"
|
||||
if suffix == ".csv":
|
||||
return "csv"
|
||||
if args.format == "parquet":
|
||||
return "parquet"
|
||||
if args.format == "csv":
|
||||
return "csv"
|
||||
raise ValueError(
|
||||
"无法根据输出文件推断格式,请为 --output 指定 .parquet/.csv 后缀或使用 --format",
|
||||
)
|
||||
|
||||
output_path = resolve_output_path()
|
||||
output_format = resolve_format(output_path)
|
||||
|
||||
if args.input_csv.resolve() == output_path.resolve():
|
||||
raise ValueError("Output path must differ from input path to avoid overwriting input.")
|
||||
|
||||
df = pd.read_csv(args.input_csv)
|
||||
if args.smiles_column not in df.columns:
|
||||
raise ValueError(f"Column '{args.smiles_column}' not found in input data")
|
||||
|
||||
smiles_series = df[args.smiles_column].fillna('')
|
||||
|
||||
if args.id_column in df.columns:
|
||||
ids = df[args.id_column].astype(str).tolist()
|
||||
else:
|
||||
ids = [f"D{idx:06d}" for idx in range(1, len(df) + 1)]
|
||||
df[args.id_column] = ids
|
||||
|
||||
generator = ECFP4Generator()
|
||||
fingerprints: List[Optional[DataStructs.ExplicitBitVect]] = []
|
||||
binary_repr: List[str] = []
|
||||
|
||||
for smiles in smiles_series:
|
||||
fp = generator.fingerprint(smiles)
|
||||
fingerprints.append(fp)
|
||||
binary_repr.append(generator.to_binary_string(fp) if fp is not None else '')
|
||||
|
||||
df['ecfp4_binary'] = binary_repr
|
||||
df['tanimoto_top_neighbors'] = tanimoto_top_k(fingerprints, ids, top_k=args.top_k)
|
||||
|
||||
if output_format == "parquet":
|
||||
df.to_parquet(output_path, index=False)
|
||||
else:
|
||||
df.to_csv(output_path, index=False)
|
||||
|
||||
print(f"Wrote augmented data with {len(df)} rows to {output_path} ({output_format})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
127
script/data_processing/add_macrocycle_columns.py
Normal file
127
script/data_processing/add_macrocycle_columns.py
Normal file
@@ -0,0 +1,127 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Add generated IDs and macrolactone ring size annotations to a CSV file."""
|
||||
|
||||
import argparse
|
||||
import pathlib
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import pandas as pd
|
||||
from rdkit import Chem
|
||||
|
||||
|
||||
@dataclass
|
||||
class MacrolactoneRingDetector:
|
||||
"""Detect macrolactone rings in SMILES strings via SMARTS patterns."""
|
||||
|
||||
min_size: int = 12
|
||||
max_size: int = 20
|
||||
patterns: Dict[int, Optional[Chem.Mol]] = field(init=False, default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.patterns = {
|
||||
size: Chem.MolFromSmarts(f"[r{size}]([#8][#6](=[#8]))")
|
||||
for size in range(self.min_size, self.max_size + 1)
|
||||
}
|
||||
|
||||
def ring_sizes(self, smiles: str) -> List[int]:
|
||||
"""Return a sorted list of macrolactone ring sizes present in the SMILES."""
|
||||
if not smiles:
|
||||
return []
|
||||
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
if mol is None:
|
||||
return []
|
||||
|
||||
ring_atoms = mol.GetRingInfo().AtomRings()
|
||||
if not ring_atoms:
|
||||
return []
|
||||
|
||||
matched_rings: Set[Tuple[int, ...]] = set()
|
||||
matched_sizes: Set[int] = set()
|
||||
|
||||
for size, query in self.patterns.items():
|
||||
if query is None:
|
||||
continue
|
||||
|
||||
for match in mol.GetSubstructMatches(query, uniquify=True):
|
||||
ring = self._pick_ring(match, ring_atoms, size)
|
||||
if ring and ring not in matched_rings:
|
||||
matched_rings.add(ring)
|
||||
matched_sizes.add(size)
|
||||
|
||||
sizes = sorted(matched_sizes)
|
||||
|
||||
if len(sizes) > 1:
|
||||
warnings.warn(
|
||||
"Multiple macrolactone ring sizes detected",
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
print(f"Multiple ring sizes {sizes} for SMILES: {smiles}")
|
||||
|
||||
return sizes
|
||||
|
||||
@staticmethod
|
||||
def _pick_ring(
|
||||
match: Tuple[int, ...], rings: Iterable[Tuple[int, ...]], expected_size: int
|
||||
) -> Optional[Tuple[int, ...]]:
|
||||
ring_atom = match[0]
|
||||
for ring in rings:
|
||||
if len(ring) == expected_size and ring_atom in ring:
|
||||
return tuple(sorted(ring))
|
||||
return None
|
||||
|
||||
|
||||
def add_columns(df: pd.DataFrame, detector: MacrolactoneRingDetector) -> pd.DataFrame:
|
||||
result = df.copy()
|
||||
result["generated_id"] = [f"D{index:06d}" for index in range(1, len(result) + 1)]
|
||||
|
||||
smiles_series = result["smiles"].fillna("") if "smiles" in result.columns else pd.Series(
|
||||
[""] * len(result), index=result.index
|
||||
)
|
||||
|
||||
def format_sizes(smiles: str) -> str:
|
||||
sizes = detector.ring_sizes(smiles)
|
||||
return ";".join(str(size) for size in sizes) if sizes else ""
|
||||
|
||||
result["macrocycle_ring_sizes"] = smiles_series.apply(format_sizes)
|
||||
return result
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("input_csv", type=pathlib.Path, help="Path to the source CSV file")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=pathlib.Path,
|
||||
default=None,
|
||||
help="Destination for the augmented CSV (default: <input>_with_macrocycles.csv)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
if not args.input_csv.exists():
|
||||
raise FileNotFoundError(f"Input file not found: {args.input_csv}")
|
||||
|
||||
output_path = args.output or args.input_csv.with_name(
|
||||
f"{args.input_csv.stem}_with_macrocycles.csv"
|
||||
)
|
||||
|
||||
if args.input_csv.resolve() == output_path.resolve():
|
||||
raise ValueError("Output path must differ from input path to avoid overwriting.")
|
||||
|
||||
df = pd.read_csv(args.input_csv)
|
||||
detector = MacrolactoneRingDetector()
|
||||
augmented = add_columns(df, detector)
|
||||
augmented.to_csv(output_path, index=False)
|
||||
|
||||
print(f"Wrote augmented data with {len(augmented)} rows to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
70
script/data_processing/merge_splits.py
Normal file
70
script/data_processing/merge_splits.py
Normal file
@@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Merge split CSVs into a single file with split source labels."""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
# Mapping of split CSV filenames to numeric labels for the source column
|
||||
SPLIT_LABELS = {
|
||||
"split_test.csv": 1,
|
||||
"split_train.csv": 2,
|
||||
"split_val.csv": 3,
|
||||
}
|
||||
|
||||
DEFAULT_COLUMN_NAME = "split_source"
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
repo_root = Path(__file__).resolve().parent.parent
|
||||
parser = argparse.ArgumentParser(
|
||||
description=
|
||||
"Combine split_*.csv files from splits_v2 and label their origin with integers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-dir",
|
||||
type=Path,
|
||||
default=repo_root / "splits_v2",
|
||||
help="Directory containing split_*.csv files (default: %(default)s)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
default=repo_root / "data" / "merged_splits.csv",
|
||||
help="Destination CSV path (default: %(default)s)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--column-name",
|
||||
default=DEFAULT_COLUMN_NAME,
|
||||
help="Name for the source column (default: %(default)s)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
if not args.input_dir.is_dir():
|
||||
raise SystemExit(f"Input directory not found: {args.input_dir}")
|
||||
|
||||
frames = []
|
||||
for filename, label in SPLIT_LABELS.items():
|
||||
csv_path = args.input_dir / filename
|
||||
if not csv_path.is_file():
|
||||
raise SystemExit(f"Missing expected split file: {csv_path}")
|
||||
df = pd.read_csv(csv_path)
|
||||
df[args.column_name] = label
|
||||
frames.append(df)
|
||||
|
||||
if not frames:
|
||||
raise SystemExit("No split CSV files were loaded.")
|
||||
|
||||
merged = pd.concat(frames, ignore_index=True)
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
merged.to_csv(args.output, index=False)
|
||||
print(f"Merged {len(frames)} files with {len(merged)} rows into {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
304
script/data_processing/split_drugbank.py
Normal file
304
script/data_processing/split_drugbank.py
Normal file
@@ -0,0 +1,304 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
结构感知 + 分布对齐 的 DrugBank 数据集划分脚本
|
||||
- 读取 CSV(至少包含 SELFIES 或 SMILES 之一,建议还包含 qed)
|
||||
- 计算 SMILES(若仅有 SELFIES),构建 RDKit Mol
|
||||
- 计算 Bemis-Murcko Scaffold、分子量等基本属性
|
||||
- 以 scaffold 为分组单元进行分层划分(train/val/test)
|
||||
- 对齐 QED/MW 分布(分箱 + 贪心平衡)
|
||||
- 输出三份 CSV:split_train.csv / split_val.csv / split_test.csv
|
||||
|
||||
结构感知 + 分布对齐 的 DrugBank 数据集划分脚本(修正版)
|
||||
- 关键修复:分配策略改为【容量优先 + 分布对齐】,避免 TRAIN=0 的偏置
|
||||
"""
|
||||
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
结构感知 + 分布对齐 的 DrugBank 数据集划分脚本(修正版)
|
||||
- 关键修复:分配策略改为【容量优先 + 分布对齐】,避免 TRAIN=0 的偏置
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import math
|
||||
import hashlib
|
||||
from collections import defaultdict, Counter
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from rdkit import Chem
|
||||
from rdkit.Chem.Scaffolds import MurckoScaffold
|
||||
from rdkit.Chem import Descriptors
|
||||
|
||||
try:
|
||||
import selfies as sf
|
||||
except Exception:
|
||||
sf = None
|
||||
|
||||
def to_smiles_from_selfies(s):
|
||||
if pd.isna(s) or not isinstance(s, str) or not s.strip():
|
||||
return None
|
||||
if sf is None:
|
||||
return None
|
||||
try:
|
||||
smi = sf.decoder(s)
|
||||
mol = Chem.MolFromSmiles(smi)
|
||||
if mol is None:
|
||||
return None
|
||||
return Chem.MolToSmiles(mol)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def canonicalize_smiles(s):
|
||||
if pd.isna(s) or not isinstance(s, str) or not s.strip():
|
||||
return None
|
||||
mol = Chem.MolFromSmiles(s)
|
||||
if mol is None:
|
||||
return None
|
||||
return Chem.MolToSmiles(mol)
|
||||
|
||||
def bemis_murcko_scaffold(smiles):
|
||||
if smiles is None:
|
||||
return None
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
if mol is None:
|
||||
return None
|
||||
try:
|
||||
scaf = MurckoScaffold.GetScaffoldForMol(mol)
|
||||
if scaf is None:
|
||||
return None
|
||||
return Chem.MolToSmiles(scaf)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def calc_basic_props(smiles):
|
||||
mol = Chem.MolFromSmiles(smiles) if smiles else None
|
||||
if mol is None:
|
||||
return (np.nan, np.nan, np.nan, np.nan)
|
||||
mw = Descriptors.MolWt(mol)
|
||||
tpsa = Descriptors.TPSA(mol)
|
||||
logp = Descriptors.MolLogP(mol)
|
||||
heavy = mol.GetNumHeavyAtoms()
|
||||
return (mw, tpsa, logp, heavy)
|
||||
|
||||
def hash_str(x):
|
||||
return int(hashlib.md5(x.encode('utf-8')).hexdigest(), 16)
|
||||
|
||||
def bin_by_edges(x, edges):
|
||||
if isinstance(x, float) and math.isnan(x):
|
||||
return -1
|
||||
for i in range(1, len(edges)):
|
||||
if x <= edges[i]:
|
||||
return i-1
|
||||
return len(edges) - 2
|
||||
|
||||
def hist_distance_L2(h1, h2):
|
||||
keys = set(h1.keys()) | set(h2.keys())
|
||||
return math.sqrt(sum((h1.get(k,0) - h2.get(k,0))**2 for k in keys))
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--in-csv", required=True)
|
||||
ap.add_argument("--out-dir", default=".")
|
||||
ap.add_argument("--seed", type=int, default=2025)
|
||||
ap.add_argument("--train-ratio", type=float, default=0.8)
|
||||
ap.add_argument("--val-ratio", type=float, default=0.1)
|
||||
ap.add_argument("--test-ratio", type=float, default=0.1)
|
||||
ap.add_argument("--n_qed_bins", type=int, default=5)
|
||||
ap.add_argument("--n_mw_bins", type=int, default=5)
|
||||
# 新增:控制是否按 scaffold 大小降序分配(推荐 True)
|
||||
ap.add_argument("--largest-first", action="store_true", default=True)
|
||||
args = ap.parse_args()
|
||||
|
||||
np.random.seed(args.seed)
|
||||
|
||||
out_dir = Path(args.out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
df = pd.read_csv(args.in_csv)
|
||||
|
||||
def col_like(name_candidates):
|
||||
for nc in name_candidates:
|
||||
idx = [i for i,c in enumerate(df.columns) if c.lower()==nc.lower()]
|
||||
if idx:
|
||||
return df.columns[idx[0]]
|
||||
return None
|
||||
|
||||
col_smiles = col_like(["smiles", "canonical_smiles"])
|
||||
col_selfies = col_like(["selfies"])
|
||||
col_id = col_like(["id","molid","drug_id"])
|
||||
col_qed = col_like(["qed"])
|
||||
|
||||
if col_smiles is None:
|
||||
if col_selfies is None:
|
||||
raise SystemExit("需要 SMILES 或 SELFIES 至少一个列。")
|
||||
df["SMILES"] = df[col_selfies].map(to_smiles_from_selfies)
|
||||
col_smiles = "SMILES"
|
||||
else:
|
||||
df["SMILES"] = df[col_smiles]
|
||||
|
||||
df["SMILES"] = df["SMILES"].map(canonicalize_smiles)
|
||||
df = df[~df["SMILES"].isna()].copy()
|
||||
df = df.drop_duplicates(subset=["SMILES"])
|
||||
|
||||
props = df["SMILES"].map(calc_basic_props)
|
||||
df["MW"] = [p[0] for p in props]
|
||||
df["TPSA"] = [p[1] for p in props]
|
||||
df["LogP"] = [p[2] for p in props]
|
||||
df["HeavyAtoms"] = [p[3] for p in props]
|
||||
|
||||
df["Scaffold"] = df["SMILES"].map(bemis_murcko_scaffold)
|
||||
df["GroupKey"] = np.where(df["Scaffold"].isna(), df["SMILES"], df["Scaffold"])
|
||||
|
||||
if col_qed is None:
|
||||
df["qed"] = np.nan
|
||||
else:
|
||||
df["qed"] = pd.to_numeric(df[col_qed], errors="coerce")
|
||||
|
||||
def compute_edges(series, n_bins):
|
||||
ser = series.dropna()
|
||||
if len(ser) < n_bins:
|
||||
return np.linspace(ser.min() if len(ser)>0 else 0,
|
||||
ser.max() if len(ser)>0 else 1,
|
||||
n_bins+1)
|
||||
qs = np.linspace(0,1,n_bins+1)
|
||||
return np.quantile(ser, qs)
|
||||
|
||||
qed_edges = compute_edges(df["qed"], args.n_qed_bins)
|
||||
mw_edges = compute_edges(df["MW"], args.n_mw_bins)
|
||||
|
||||
df["QED_bin"] = df["qed"].map(lambda x: bin_by_edges(x, qed_edges))
|
||||
df["MW_bin"] = df["MW"].map(lambda x: bin_by_edges(x, mw_edges))
|
||||
df["Strata"] = list(zip(df["QED_bin"], df["MW_bin"]))
|
||||
|
||||
group2idx = defaultdict(list)
|
||||
for i, g in enumerate(df["GroupKey"]):
|
||||
group2idx[g].append(i)
|
||||
|
||||
group_hist = {g: Counter(df.loc[idxs, "Strata"].tolist())
|
||||
for g, idxs in group2idx.items()}
|
||||
global_hist = Counter(df["Strata"].tolist())
|
||||
|
||||
total = len(df)
|
||||
target_counts = {
|
||||
"train": int(round(total * args.train_ratio)),
|
||||
"val": int(round(total * args.val_ratio)),
|
||||
"test": int(round(total * args.test_ratio)),
|
||||
}
|
||||
diff = total - sum(target_counts.values())
|
||||
if diff != 0:
|
||||
target_counts["train"] += diff
|
||||
|
||||
global_ratio = {
|
||||
"train": args.train_ratio,
|
||||
"val": args.val_ratio,
|
||||
"test": args.test_ratio,
|
||||
}
|
||||
target_hist = {split: {k: v * r for k, v in global_hist.items()}
|
||||
for split, r in global_ratio.items()}
|
||||
|
||||
splits = {"train": [], "val": [], "test": []}
|
||||
split_counts = {"train": 0, "val": 0, "test": 0}
|
||||
split_hist = {"train": Counter(), "val": Counter(), "test": Counter()}
|
||||
|
||||
scaffolds = list(group2idx.keys())
|
||||
rng = np.random.RandomState(args.seed)
|
||||
|
||||
# 关键:按组大小降序(先放大的)
|
||||
if args.largest_first:
|
||||
scaffolds.sort(key=lambda g: len(group2idx[g]), reverse=True)
|
||||
else:
|
||||
rng.shuffle(scaffolds)
|
||||
|
||||
def remaining_capacity(split):
|
||||
return target_counts[split] - split_counts[split]
|
||||
|
||||
for g in scaffolds:
|
||||
idxs = group2idx[g]
|
||||
g_size = len(idxs)
|
||||
g_hist = group_hist[g]
|
||||
|
||||
# 计算每个 split 的剩余容量
|
||||
caps = {s: remaining_capacity(s) for s in ["train","val","test"]}
|
||||
|
||||
# 优先考虑【仍有正容量】的 split
|
||||
positive_splits = [s for s,c in caps.items() if c > 0]
|
||||
|
||||
# 如果都有容量<=0,说明目标配额已满;退化为“最小溢出优先”
|
||||
candidate_splits = positive_splits if positive_splits else ["train","val","test"]
|
||||
|
||||
# 在候选中,综合“容量优先 + 分布对齐(L2更小)”
|
||||
best = None
|
||||
for s in candidate_splits:
|
||||
cap = caps[s]
|
||||
proj_hist = split_hist[s] + g_hist
|
||||
l2 = hist_distance_L2(proj_hist, target_hist[s])
|
||||
|
||||
# 排序键:
|
||||
# 1) -max(cap, 0): 正容量越大越好;都负时此项=0
|
||||
# 2) abs(min(cap, 0)): 负容量越小(溢出越少)越好
|
||||
# 3) l2: 分布差越小越好
|
||||
key = (-max(cap, 0), abs(min(cap, 0)), l2)
|
||||
|
||||
if (best is None) or (key < best[0]):
|
||||
best = (key, s)
|
||||
|
||||
chosen = best[1]
|
||||
|
||||
splits[chosen].append(g)
|
||||
split_counts[chosen] += g_size
|
||||
split_hist[chosen] += g_hist
|
||||
|
||||
# 收集索引
|
||||
idx_train, idx_val, idx_test = [], [], []
|
||||
for g in splits["train"]:
|
||||
idx_train.extend(group2idx[g])
|
||||
for g in splits["val"]:
|
||||
idx_val.extend(group2idx[g])
|
||||
for g in splits["test"]:
|
||||
idx_test.extend(group2idx[g])
|
||||
|
||||
df_train = df.iloc[sorted(idx_train)].copy()
|
||||
df_val = df.iloc[sorted(idx_val)].copy()
|
||||
df_test = df.iloc[sorted(idx_test)].copy()
|
||||
|
||||
def report(name, sub):
|
||||
print(f"\n== {name} ==")
|
||||
print(f"size = {len(sub)}")
|
||||
h = Counter(sub["Strata"].tolist())
|
||||
for k, v in h.most_common(8):
|
||||
print(f" {k}: {v}")
|
||||
print(" QED mean/std:", np.nanmean(sub["qed"]), np.nanstd(sub["qed"]))
|
||||
print(" MW mean/std:", np.nanmean(sub["MW"]), np.nanstd(sub["MW"]))
|
||||
|
||||
print(f"Total: {len(df)} (targets: {target_counts})")
|
||||
report("TRAIN", df_train)
|
||||
report("VAL", df_val)
|
||||
report("TEST", df_test)
|
||||
|
||||
out_train = out_dir / "split_train.csv"
|
||||
out_val = out_dir / "split_val.csv"
|
||||
out_test = out_dir / "split_test.csv"
|
||||
df_train.to_csv(out_train, index=False)
|
||||
df_val.to_csv(out_val, index=False)
|
||||
df_test.to_csv(out_test, index=False)
|
||||
|
||||
edges = pd.DataFrame({
|
||||
"qed_edges": list(qed_edges) + [np.nan]*(max(len(mw_edges),len(qed_edges))-len(qed_edges)),
|
||||
"mw_edges": list(mw_edges) + [np.nan]*(max(len(mw_edges),len(qed_edges))-len(mw_edges)),
|
||||
})
|
||||
edges.to_csv(out_dir / "bin_edges_qed_mw.csv", index=False)
|
||||
|
||||
# 额外提示:打印容量偏差
|
||||
print("\nFinal counts vs targets:")
|
||||
for s in ["train","val","test"]:
|
||||
print(f" {s}: {split_counts[s]} / {target_counts[s]} (delta={split_counts[s]-target_counts[s]})")
|
||||
|
||||
print(f"\nSaved:\n {out_train}\n {out_val}\n {out_test}\n {out_dir/'bin_edges_qed_mw.csv'}")
|
||||
print("\n验证/测试阶段,仅使用 split_val / split_test 作为 feed-chemical。")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user