From 1e36e52112d145aa3e08b0e65af00496e7e08549 Mon Sep 17 00:00:00 2001 From: lingyuzeng Date: Thu, 19 Mar 2026 10:28:38 +0800 Subject: [PATCH] feat(validation): add stratified sampling by ring size --- .../validation/sampling.py | 55 +++++++++++++++++++ tests/validation/test_sampling.py | 23 ++++++++ 2 files changed, 78 insertions(+) create mode 100644 src/macro_lactone_toolkit/validation/sampling.py create mode 100644 tests/validation/test_sampling.py diff --git a/src/macro_lactone_toolkit/validation/sampling.py b/src/macro_lactone_toolkit/validation/sampling.py new file mode 100644 index 0000000..e932a04 --- /dev/null +++ b/src/macro_lactone_toolkit/validation/sampling.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import pandas as pd + +from macro_lactone_toolkit import MacroLactoneAnalyzer + + +def stratified_sample_by_ring_size( + df: pd.DataFrame, + sample_ratio: float, + smiles_col: str = "smiles", + random_state: int = 42, +) -> pd.DataFrame: + """ + Perform stratified sampling by ring size. + + First classifies all molecules, then samples 10% from each ring size layer. + """ + analyzer = MacroLactoneAnalyzer() + + # Classify all molecules + classifications = [] + ring_sizes = [] + + for smiles in df[smiles_col]: + result = analyzer.classify_macrocycle(smiles) + classifications.append(result.classification) + ring_sizes.append(result.ring_size) + + df = df.copy() + df["_classification"] = classifications + df["_ring_size"] = ring_sizes + + # Group by ring size and sample from each group + sampled_groups = [] + + for ring_size in range(12, 21): + group = df[df["_ring_size"] == ring_size] + if len(group) > 0: + n_samples = max(1, int(len(group) * sample_ratio)) + sampled = group.sample(n=min(n_samples, len(group)), random_state=random_state) + sampled_groups.append(sampled) + + # Also sample from unknown ring size (None) + unknown_group = df[df["_ring_size"].isna()] + if len(unknown_group) > 0: + n_samples = max(1, int(len(unknown_group) * sample_ratio)) + sampled = unknown_group.sample(n=min(n_samples, len(unknown_group)), random_state=random_state) + sampled_groups.append(sampled) + + if not sampled_groups: + return pd.DataFrame() + + result = pd.concat(sampled_groups, ignore_index=True) + return result diff --git a/tests/validation/test_sampling.py b/tests/validation/test_sampling.py new file mode 100644 index 0000000..35133b6 --- /dev/null +++ b/tests/validation/test_sampling.py @@ -0,0 +1,23 @@ +import pandas as pd +import pytest + +from macro_lactone_toolkit.validation.sampling import stratified_sample_by_ring_size + + +def test_stratified_sample(): + # Create test data with known ring sizes + data = { + "smiles": [ + "O=C1CCCCCCCCCCCCCCO1", # 16-membered + "O=C1CCCCCCCCCCCCO1", # 14-membered + "O=C1CCCCCCCCCCCCCCCCO1", # 18-membered + ], + "id": ["A", "B", "C"], + } + df = pd.DataFrame(data) + + sampled = stratified_sample_by_ring_size(df, sample_ratio=0.5, random_state=42) + + # Should get at least 1 from each ring size (50% of 1 = 1) + assert len(sampled) >= 1 + assert len(sampled) <= 3