From ce1465dae49ec696448633ecd99a4698fe30f52a Mon Sep 17 00:00:00 2001 From: Samyam Rajbhandari Date: Wed, 17 Mar 2021 18:20:55 -0700 Subject: [PATCH 1/4] Stage 3 Register only modules with parameters --- deepspeed/runtime/zero/stage3.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index ea4653578616..4c3c150997b7 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -98,7 +98,7 @@ def move_to_cpu(tensor_list): tensor.data = tensor.data.cpu() -def get_all_parameters(sub_module): +def get_all_parameters(sub_module, recurse=False): return itertools.chain(sub_module.named_parameters(recurse=False), sub_module.ds_external_parameters()) @@ -1066,6 +1066,8 @@ def _register_hooks_recursively(self, module, count=[0]): #print(f"{module.__class__} : {module.id}") for child in module.children(): + if len(self._get_all_parameters(child, recurse=True)) > 0: + continue count[0] = count[0] + 1 self._register_hooks_recursively(child, count=count) From dc05f87f0504491758a73c40bcad2b95375c7cd5 Mon Sep 17 00:00:00 2001 From: Samyam Date: Thu, 18 Mar 2021 01:57:10 +0000 Subject: [PATCH 2/4] Runs as before but it should prevent registering modules that do not have any parameters --- deepspeed/runtime/zero/stage3.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 4c3c150997b7..5fc3d6532449 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -99,7 +99,7 @@ def move_to_cpu(tensor_list): def get_all_parameters(sub_module, recurse=False): - return itertools.chain(sub_module.named_parameters(recurse=False), + return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters()) @@ -1066,10 +1066,9 @@ def _register_hooks_recursively(self, module, count=[0]): #print(f"{module.__class__} : {module.id}") for child in module.children(): - if len(self._get_all_parameters(child, recurse=True)) > 0: - continue - count[0] = count[0] + 1 - self._register_hooks_recursively(child, count=count) + if len(list(get_all_parameters(child, recurse=True))) > 0: + count[0] = count[0] + 1 + self._register_hooks_recursively(child, count=count) def _pre_forward_module_hook(module, *args): self.pre_sub_module_forward_function(module) From 0eb6809a04d0376a599ec6e690e7901c6f1c755a Mon Sep 17 00:00:00 2001 From: Samyam Date: Thu, 18 Mar 2021 02:17:09 +0000 Subject: [PATCH 3/4] Resetting step at the beginning of the forward --- deepspeed/runtime/zero/stage3.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 5fc3d6532449..860892e47956 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1037,13 +1037,19 @@ def setup_zero_stage3_hooks(self): self.hierarchy = 0 self._register_hooks_recursively(self.module) + #reset step at the beginning of forward + def _pre_forward_hook(module, *args): + self.param_coordinator.reset_step() + #reset step if in inference mode def _end_of_forward_hook(module, *args): if not torch._C.is_grad_enabled(): self.param_coordinator.reset_step() + #likely one of them should be enough but just to be safe self.module.register_forward_hook(_end_of_forward_hook) + self.module.register_forward_pre_hook(_pre_forward_hook) def persistent_parameters(self): persistent_params = [] From b7d04c93a7e433ee7f649a5656cf12af9d845323 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Wed, 7 Apr 2021 13:14:07 -0700 Subject: [PATCH 4/4] revert change --- deepspeed/runtime/zero/stage3.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 860892e47956..01be13a8d482 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1072,9 +1072,8 @@ def _register_hooks_recursively(self, module, count=[0]): #print(f"{module.__class__} : {module.id}") for child in module.children(): - if len(list(get_all_parameters(child, recurse=True))) > 0: - count[0] = count[0] + 1 - self._register_hooks_recursively(child, count=count) + count[0] = count[0] + 1 + self._register_hooks_recursively(child, count=count) def _pre_forward_module_hook(module, *args): self.pre_sub_module_forward_function(module)