From d8f1dcd3bb3761aae0dbef7d4a5c81f7d3a51e1c Mon Sep 17 00:00:00 2001 From: hamlet Date: Mon, 15 Mar 2021 19:22:43 +0800 Subject: [PATCH 1/3] Fix zero stage2 cpu_offload when some model trainable parameters skipped in training, as in https://github.com/microsoft/DeepSpeed/issues/707 As some model trainable parameters skipped in training, their backward hooks in self.create_reduce_and_remove_grad_hooks() will not run, so they have no norm_for_param_grads --- deepspeed/runtime/zero/stage2.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zero/stage2.py b/deepspeed/runtime/zero/stage2.py index bdd1de4cbdda..2abe6e4fcc1d 100755 --- a/deepspeed/runtime/zero/stage2.py +++ b/deepspeed/runtime/zero/stage2.py @@ -878,8 +878,12 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): for p in params: if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): param_id = self.get_param_id(p) - param_norm = self.norm_for_param_grads[param_id] - total_norm += param_norm.item()**2 + # as some model have trainable parameters but skipped in training, + # their backward hooks in self.create_reduce_and_remove_grad_hooks() will not run, + # so they have no norm_for_param_grads + if param_id in self.norm_for_param_grads: + param_norm = self.norm_for_param_grads[param_id] + total_norm += param_norm.item()**2 # Sum across all model parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) From da595c2a16fb1e89b44c261ea3eb6b90beb5100e Mon Sep 17 00:00:00 2001 From: hamlet Date: Mon, 15 Mar 2021 19:36:53 +0800 Subject: [PATCH 2/3] Trim space --- deepspeed/runtime/zero/stage2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage2.py b/deepspeed/runtime/zero/stage2.py index 2abe6e4fcc1d..70cd2902b61b 100755 --- a/deepspeed/runtime/zero/stage2.py +++ b/deepspeed/runtime/zero/stage2.py @@ -881,7 +881,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): # as some model have trainable parameters but skipped in training, # their backward hooks in self.create_reduce_and_remove_grad_hooks() will not run, # so they have no norm_for_param_grads - if param_id in self.norm_for_param_grads: + if param_id in self.norm_for_param_grads: param_norm = self.norm_for_param_grads[param_id] total_norm += param_norm.item()**2 From e6a46c3717a816543655930efa8fff5297f25ea1 Mon Sep 17 00:00:00 2001 From: hamlet Date: Mon, 15 Mar 2021 19:38:47 +0800 Subject: [PATCH 3/3] Trim space --- deepspeed/runtime/zero/stage2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage2.py b/deepspeed/runtime/zero/stage2.py index 70cd2902b61b..2624b1122c75 100755 --- a/deepspeed/runtime/zero/stage2.py +++ b/deepspeed/runtime/zero/stage2.py @@ -879,7 +879,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): param_id = self.get_param_id(p) # as some model have trainable parameters but skipped in training, - # their backward hooks in self.create_reduce_and_remove_grad_hooks() will not run, + # their backward hooks in self.create_reduce_and_remove_grad_hooks() will not run, # so they have no norm_for_param_grads if param_id in self.norm_for_param_grads: param_norm = self.norm_for_param_grads[param_id]