diff --git a/deepspeed/runtime/zero/stage2.py b/deepspeed/runtime/zero/stage2.py index e0ca4f025957..9a079e2594e0 100755 --- a/deepspeed/runtime/zero/stage2.py +++ b/deepspeed/runtime/zero/stage2.py @@ -883,8 +883,9 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): for p in params: if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): param_id = self.get_param_id(p) - param_norm = self.norm_for_param_grads[param_id] - total_norm += param_norm.item()**2 + if param_id in self.norm_for_param_grads: + param_norm = self.norm_for_param_grads[param_id] + total_norm += param_norm.item()**2 # Sum across all model parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])