Skip to content

[Checkpointing] Persist DataLoaderShard epoch counter across save/load_state#4012

Open
s-zx wants to merge 1 commit intohuggingface:mainfrom
s-zx:fix-3996-dataloader-shuffle-iteration
Open

[Checkpointing] Persist DataLoaderShard epoch counter across save/load_state#4012
s-zx wants to merge 1 commit intohuggingface:mainfrom
s-zx:fix-3996-dataloader-shuffle-iteration

Conversation

@s-zx
Copy link
Copy Markdown

@s-zx s-zx commented Apr 17, 2026

What does this PR do?

Fixes #3996.

accelerator.save_state / accelerator.load_state previously round-tripped
samplers and (when use_stateful_dataloader=True) the dataloader state_dict,
but did not persist the per-epoch iteration counter on
DataLoaderShard / DataLoaderDispatcher.

That counter is what seeds SeedableRandomSampler on the next epoch (via
set_epoch()sampler.set_epoch(epoch) → deterministic shuffle seed), so a
resumed run replayed the epoch-0 shuffle order instead of the correct epoch's
permutation. This reproduces exactly the "shuffle sequence replays from
iteration 0" behavior reported in the issue with use_seedable_sampler=True.

Fix

Extend save_accelerator_state / load_accelerator_state to write a small
dl_iteration.bin sibling file per dataloader when the dataloader exposes an
iteration attribute, and restore it on load via set_epoch so samplers and
IterableDataset.set_epoch hooks stay in sync.

  • Save path is unconditional on hasattr(dataloader, "iteration") — so it also
    covers the non-stateful path (use_stateful_dataloader=False).
  • Load path checks input_iteration_file.exists(), keeping backward
    compatibility with checkpoints produced by older accelerate versions.
  • Restoration uses set_epoch(iteration) rather than direct assignment so
    downstream samplers / datasets receive the epoch change.

Tests

Added unit tests to tests/test_state_checkpointing.py:

  • test_dataloader_iteration_counter_is_persisted — round-trips iteration through save/load.
  • test_load_state_tolerates_missing_iteration_file — backward compatibility.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@muellerzr @BenjaminBossan

…d_state

Fixes huggingface#3996.

accelerator.save_state / load_state previously round-tripped samplers and
(optionally) stateful-dataloader state_dicts, but did not preserve the
per-epoch `iteration` counter on DataLoaderShard / DataLoaderDispatcher.
That counter is what seeds SeedableRandomSampler for the next epoch
(via set_epoch() -> sampler.set_epoch(epoch) -> deterministic shuffle
seed), so a resumed run replayed the epoch-0 shuffle order instead of
continuing with the correct epoch's permutation.

Extend save_accelerator_state / load_accelerator_state to write a small
`dl_iteration.bin` sibling file per dataloader when the dataloader
exposes an `iteration` attribute, and restore it on load via set_epoch
so samplers and dataset hooks that listen for epoch changes stay in sync.
The load path checks file existence, keeping backward compatibility with
checkpoints produced by older versions of accelerate.

Also adds unit tests covering the round-trip, missing-file compatibility,
and multi-dataloader filename handling.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

DataLoader shuffle sequence replays from epoch 0 after resuming from a checkpoint

2 participants