Skip to content

Fix ZeRO-1/2 CPU-offloaded gradient loss with multiple backward() per step#7981

Merged
delock merged 2 commits intodeepspeedai:masterfrom
roycho96:fix/zero2-offload-ga1-multi-backward
Apr 22, 2026
Merged

Fix ZeRO-1/2 CPU-offloaded gradient loss with multiple backward() per step#7981
delock merged 2 commits intodeepspeedai:masterfrom
roycho96:fix/zero2-offload-ga1-multi-backward

Conversation

@roycho96
Copy link
Copy Markdown
Contributor

Summary

ZeRO-1/2 + offload_optimizer + gradient_accumulation_steps=1 with multiple engine.backward() calls per optimizer step (via set_gradient_accumulation_boundary(), formalized in #7665) silently drops all but the last backward's gradient.

copy_grads_in_partition only called async_accumulate_grad_in_cpu_via_gpu under if gradient_accumulation_steps > 1, so with ga_steps=1 intermediate backwards' reduced grads were never stored. The boundary async_inplace_copy_grad_to_fp32_buffer_from_gpu then overwrote (not added) the fp32 buffer with the last chunk only.

ZeRO-3 + offload and non-offload ZeRO-1/2 are unaffected.

Fix

Replace the ga > 1 gate with one that fires exactly when a CPU accumulator is needed:

if self.micro_step_id > 0 or not self.is_gradient_accumulation_boundary:
    self.async_accumulate_grad_in_cpu_via_gpu(param)
  • ga_steps=1 + single backward() → skipped. No CPU buffer, no extra copy. Fast path preserved.
  • ga_steps=1 + multi-backward → accumulates correctly across calls.
  • ga_steps>1 → identical to prior behaviour.

Measurement

2x H100, 3-layer MLP, Adam, lr=1e-3, N=4 backwards/step, ga_steps=1
Max param diff vs no-offload reference:

fp32 bf16
Before 2.00e-03 (wrong, around 2 x lr)
After 7.45e-09 (noise) 0.00e+00

Tests

New tests/unit/v1/zero/test_zero2_offload_multi_backward.py, parametrized over ZeRO-1/2:
multi-backward offload matches no-offload / single-backward unchanged / multi-step state-leak guard / single-backward allocates no CPU buffer (perf guard) / ga_steps>1 + offload unchanged (#7967 regression guard).

… step

Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
@delock
Copy link
Copy Markdown
Collaborator

delock commented Apr 21, 2026

Hi @roycho96 can you fix formatting? Thanks!

@roycho96
Copy link
Copy Markdown
Contributor Author

Hi @roycho96 can you fix formatting? Thanks!

Done! Thank you for the review!

Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
@roycho96 roycho96 force-pushed the fix/zero2-offload-ga1-multi-backward branch from 95d73e2 to efd10ee Compare April 21, 2026 09:59
@delock delock merged commit aeb10bb into deepspeedai:master Apr 22, 2026
9 checks passed
@roycho96 roycho96 deleted the fix/zero2-offload-ga1-multi-backward branch April 22, 2026 06:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants