karpathy/nanochat

Task slicing: stop not validated → len() misreports; iteration raises IndexError

Summary

  • ContextSmolTalk is a Task implementation that loads conversational training data from HuggingFace and supports dataset slicing via the base Task class (start, stop, step parameters).

  • BugSmolTalk (and all other Task subclasses) do not validate the stop parameter against the actual dataset size, allowing users to create Task instances with out-of-bounds stop values.

  • Actual vs. expected: When stop exceeds the dataset size, len() reports the invalid stop value, but accessing indices beyond the actual dataset raises an IndexError. Expected behavior would be to either validate the stop parameter during initialization or clamp it to the dataset size.

  • Impact: Training or evaluation jobs crash with IndexError when iterating over Task instances created with invalid stopvalues, wasting compute time and causing confusing failures that appear during data loading rather than at initialization.

Code with bug

tasks/smoltalk.py:

class SmolTalk(Task):
    """
    smol-smoltalk dataset.
    Train: ~460K rows
    Test:  ~24K rows
    """

    def __init__(self, split, **kwargs):
        super().__init__(**kwargs)
        # <-- BUG 🔴 No validation that stop <= dataset size

        assert split in ["train", "test"], "SmolTalk split must be train|test"

        self.ds = load_dataset(
            "HuggingFaceTB/smol-smoltalk",
            split=split,
        ).shuffle(seed=42)

        self.length = len(self.ds)

    def num_examples(self):
        return self.length  # <-- Returns actual dataset size

    def get_example(self, index):
        row = self.ds[index]  # <-- Will raise IndexError if index >= self.length
        # ... rest of method

tasks/common.py (base Task class):

class Task:
    def __init__(self, start=0, stop=None, step=1):
        # allows a lightweight logical view over a dataset
        assert start >= 0, f"Start must be non-negative, got {start}"
        assert stop is None or stop >= start, (
            f"Stop should be greater than or equal to start, "
            f"got {stop} and {start}"
        )
        # <-- BUG 🔴 No validation that stop is within dataset bounds
        assert step >= 1, f"Step must be strictly positive, got {step}"

        self.start = start
        self.stop = stop  # could be None here
        self.step = step

    def __len__(self):
        start = self.start
        stop = self.num_examples() if self.stop is None else self.stop  # <-- Uses stop directly if provided
        step = self.step

        span = stop - start
        num = (span + step - 1) // step  # ceil_div(span, step)

        assert num >= 0, f"Negative number of examples???: {num}"  # prevent footguns
        return num

    def __getitem__(self, index: int):
        assert isinstance(index, int), f"Index must be an integer, got {type(index)}"

        physical_index = self.start + index * self.step  # <-- Computes physical index
        conversation = self.get_example(physical_index)  # <-- Passes to subclass

        return conversation

Example

Repro with an out-of-bounds stop on the test split (24K rows):

# Test split has ~24,000 rows
ds = SmolTalk(split="test", stop=30_000)

print(len(ds))  # Output: 30000

try:
    _ = ds[24_000]  # physical_index = 24000
except IndexError as e:
    print(f"IndexError: {e}")

Output:


This shows len() reflects the invalid stop while item access beyond the actual dataset crashes.

Recommended fix

Add validation callable after the subclass loads its dataset in the base class, then invoke it from each subclass:

def __init__(self, start=0, stop=None, step=1):
    # allows a lightweight logical view over a dataset
    assert start >= 0, f"Start must be non-negative, got {start}"
    assert stop is None or stop >= start, (
        f"Stop should be greater than or equal to start, "
        f"got {stop} and {start}"
    )
    assert step >= 1, f"Step must be strictly positive, got {step}"

    self.start = start
    self.stop = stop  # could be None here
    self.step = step

    # Note: Cannot validate stop <= num_examples() here
    # because subclass hasn't loaded dataset yet


def _validate_slicing(self):  # <-- FIX 🟢 New method to call after dataset loads
    """Validate that slicing parameters are within dataset bounds."""
    num_ex = self.num_examples()

    if self.stop is not None and self.stop > num_ex:
        raise ValueError(
            f"Stop parameter ({self.stop}) exceeds dataset size ({num_ex}). "
            f"Please use stop <= {num_ex} or remove the stop parameter to use the full dataset."
        )

And in tasks/smoltalk.py:

class SmolTalk(Task):
    def __init__(self, split, **kwargs):
        super().__init__(**kwargs)

        assert split in ["train", "test"], "SmolTalk split must be train|test"

        self.ds = load_dataset(
            "HuggingFaceTB/smol-smoltalk",
            split=split,
        ).shuffle(seed=42)

        self.length = len(self.ds)

        self._validate_slicing()  # <-- FIX 🟢 Validate after dataset is loaded