feat(validation): add stratified sampling by ring size
This commit is contained in:
55
src/macro_lactone_toolkit/validation/sampling.py
Normal file
55
src/macro_lactone_toolkit/validation/sampling.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from macro_lactone_toolkit import MacroLactoneAnalyzer
|
||||||
|
|
||||||
|
|
||||||
|
def stratified_sample_by_ring_size(
|
||||||
|
df: pd.DataFrame,
|
||||||
|
sample_ratio: float,
|
||||||
|
smiles_col: str = "smiles",
|
||||||
|
random_state: int = 42,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Perform stratified sampling by ring size.
|
||||||
|
|
||||||
|
First classifies all molecules, then samples 10% from each ring size layer.
|
||||||
|
"""
|
||||||
|
analyzer = MacroLactoneAnalyzer()
|
||||||
|
|
||||||
|
# Classify all molecules
|
||||||
|
classifications = []
|
||||||
|
ring_sizes = []
|
||||||
|
|
||||||
|
for smiles in df[smiles_col]:
|
||||||
|
result = analyzer.classify_macrocycle(smiles)
|
||||||
|
classifications.append(result.classification)
|
||||||
|
ring_sizes.append(result.ring_size)
|
||||||
|
|
||||||
|
df = df.copy()
|
||||||
|
df["_classification"] = classifications
|
||||||
|
df["_ring_size"] = ring_sizes
|
||||||
|
|
||||||
|
# Group by ring size and sample from each group
|
||||||
|
sampled_groups = []
|
||||||
|
|
||||||
|
for ring_size in range(12, 21):
|
||||||
|
group = df[df["_ring_size"] == ring_size]
|
||||||
|
if len(group) > 0:
|
||||||
|
n_samples = max(1, int(len(group) * sample_ratio))
|
||||||
|
sampled = group.sample(n=min(n_samples, len(group)), random_state=random_state)
|
||||||
|
sampled_groups.append(sampled)
|
||||||
|
|
||||||
|
# Also sample from unknown ring size (None)
|
||||||
|
unknown_group = df[df["_ring_size"].isna()]
|
||||||
|
if len(unknown_group) > 0:
|
||||||
|
n_samples = max(1, int(len(unknown_group) * sample_ratio))
|
||||||
|
sampled = unknown_group.sample(n=min(n_samples, len(unknown_group)), random_state=random_state)
|
||||||
|
sampled_groups.append(sampled)
|
||||||
|
|
||||||
|
if not sampled_groups:
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
result = pd.concat(sampled_groups, ignore_index=True)
|
||||||
|
return result
|
||||||
23
tests/validation/test_sampling.py
Normal file
23
tests/validation/test_sampling.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from macro_lactone_toolkit.validation.sampling import stratified_sample_by_ring_size
|
||||||
|
|
||||||
|
|
||||||
|
def test_stratified_sample():
|
||||||
|
# Create test data with known ring sizes
|
||||||
|
data = {
|
||||||
|
"smiles": [
|
||||||
|
"O=C1CCCCCCCCCCCCCCO1", # 16-membered
|
||||||
|
"O=C1CCCCCCCCCCCCO1", # 14-membered
|
||||||
|
"O=C1CCCCCCCCCCCCCCCCO1", # 18-membered
|
||||||
|
],
|
||||||
|
"id": ["A", "B", "C"],
|
||||||
|
}
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
|
||||||
|
sampled = stratified_sample_by_ring_size(df, sample_ratio=0.5, random_state=42)
|
||||||
|
|
||||||
|
# Should get at least 1 from each ring size (50% of 1 = 1)
|
||||||
|
assert len(sampled) >= 1
|
||||||
|
assert len(sampled) <= 3
|
||||||
Reference in New Issue
Block a user