From 05119963d31045b7567551cb108459aabdba9375 Mon Sep 17 00:00:00 2001 From: Tunji Ruwase Date: Wed, 24 Mar 2021 00:31:02 +0000 Subject: [PATCH] Avoid key error --- deepspeed/runtime/zero/stage2.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zero/stage2.py b/deepspeed/runtime/zero/stage2.py index e0ca4f025957..9a079e2594e0 100755 --- a/deepspeed/runtime/zero/stage2.py +++ b/deepspeed/runtime/zero/stage2.py @@ -883,8 +883,9 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): for p in params: if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): param_id = self.get_param_id(p) - param_norm = self.norm_for_param_grads[param_id] - total_norm += param_norm.item()**2 + if param_id in self.norm_for_param_grads: + param_norm = self.norm_for_param_grads[param_id] + total_norm += param_norm.item()**2 # Sum across all model parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])