diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 7b98216c1cba..d076035604e3 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -517,7 +517,7 @@ def param_groups(self): def accumulate_hp_grads_and_remove_lp(self, lp_param, group_idx, param_idx): assert self.immediate_grad_update - self._update_hp_grad(lp_param, group_idx, param_idx, clear_lp_grads=False) + self._update_hp_grad(lp_param, group_idx, param_idx, clear_lp_grads=True) def create_grad_acc_hooks(self): self.grad_accs = []