import pandas as pd import pytest from macro_lactone_toolkit.validation.sampling import stratified_sample_by_ring_size def test_stratified_sample(): # Create test data with known ring sizes data = { "smiles": [ "O=C1CCCCCCCCCCCCCCO1", # 16-membered "O=C1CCCCCCCCCCCCO1", # 14-membered "O=C1CCCCCCCCCCCCCCCCO1", # 18-membered ], "id": ["A", "B", "C"], } df = pd.DataFrame(data) sampled = stratified_sample_by_ring_size(df, sample_ratio=0.5, random_state=42) # Should get at least 1 from each ring size (50% of 1 = 1) assert len(sampled) >= 1 assert len(sampled) <= 3