Skip to content

splits

splits

Deterministic train/test split helper used by dataset providers.

Functions

apply_split

apply_split(items: List[T], *, split: SplitName, seed: int, train_frac: float) -> List[T]

Return a deterministic slice of items according to split.

The underlying permutation is random.Random(seed).shuffle(items_copy). train is the first int(len(items) * train_frac) entries of the shuffle, test is the remainder, all is the whole shuffle.

Source code in src/openjarvis/evals/core/splits.py
def apply_split(
    items: List[T],
    *,
    split: SplitName,
    seed: int,
    train_frac: float,
) -> List[T]:
    """Return a deterministic slice of ``items`` according to ``split``.

    The underlying permutation is ``random.Random(seed).shuffle(items_copy)``.
    ``train`` is the first ``int(len(items) * train_frac)`` entries of the
    shuffle, ``test`` is the remainder, ``all`` is the whole shuffle.
    """
    if split not in ("train", "test", "all"):
        raise ValueError(f"split must be one of train/test/all, got {split!r}")
    if not 0.0 < train_frac < 1.0:
        raise ValueError(f"train_frac must be in (0, 1), got {train_frac}")
    shuffled = list(items)
    random.Random(seed).shuffle(shuffled)
    if split == "all":
        return shuffled
    cut = int(len(shuffled) * train_frac)
    if split == "train":
        return shuffled[:cut]
    return shuffled[cut:]