Files
macrolactone-toolkit/tests/validation/test_sampling.py

24 lines
668 B
Python

import pandas as pd
import pytest
from macro_lactone_toolkit.validation.sampling import stratified_sample_by_ring_size
def test_stratified_sample():
# Create test data with known ring sizes
data = {
"smiles": [
"O=C1CCCCCCCCCCCCCCO1", # 16-membered
"O=C1CCCCCCCCCCCCO1", # 14-membered
"O=C1CCCCCCCCCCCCCCCCO1", # 18-membered
],
"id": ["A", "B", "C"],
}
df = pd.DataFrame(data)
sampled = stratified_sample_by_ring_size(df, sample_ratio=0.5, random_state=42)
# Should get at least 1 from each ring size (50% of 1 = 1)
assert len(sampled) >= 1
assert len(sampled) <= 3