Fix ZeRO-1/2 CPU-offloaded gradient loss with multiple backward() per step#7981
Merged
delock merged 2 commits intodeepspeedai:masterfrom Apr 22, 2026
Merged
Conversation
… step Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
delock
approved these changes
Apr 21, 2026
Collaborator
|
Hi @roycho96 can you fix formatting? Thanks! |
Contributor
Author
Done! Thank you for the review! |
Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
95d73e2 to
efd10ee
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
ZeRO-1/2 +
offload_optimizer+gradient_accumulation_steps=1with multipleengine.backward()calls per optimizer step (viaset_gradient_accumulation_boundary(), formalized in #7665) silently drops all but the last backward's gradient.copy_grads_in_partitiononly calledasync_accumulate_grad_in_cpu_via_gpuunderif gradient_accumulation_steps > 1, so withga_steps=1intermediate backwards' reduced grads were never stored. The boundaryasync_inplace_copy_grad_to_fp32_buffer_from_gputhen overwrote (not added) the fp32 buffer with the last chunk only.ZeRO-3 + offload and non-offload ZeRO-1/2 are unaffected.
Fix
Replace the
ga > 1gate with one that fires exactly when a CPU accumulator is needed:ga_steps=1+ singlebackward()→ skipped. No CPU buffer, no extra copy. Fast path preserved.ga_steps=1+ multi-backward → accumulates correctly across calls.ga_steps>1→ identical to prior behaviour.Measurement
2x H100, 3-layer MLP, Adam, lr=1e-3, N=4 backwards/step, ga_steps=1
Max param diff vs no-offload reference:
Tests
New
tests/unit/v1/zero/test_zero2_offload_multi_backward.py, parametrized over ZeRO-1/2:multi-backward offload matches no-offload / single-backward unchanged / multi-step state-leak guard / single-backward allocates no CPU buffer (perf guard) /
ga_steps>1+ offload unchanged (#7967 regression guard).