From 2ffaecbbf32773c907a27b718141bee4257a1ac7 Mon Sep 17 00:00:00 2001 From: Samyam Rajbhandari Date: Wed, 10 Mar 2021 14:53:14 -0800 Subject: [PATCH 1/9] Fix mis-aligned-grad When a parameter is not divisible by world size, the partitioned gradients are mis-aligned due to incorrect padding handling. This PR should fix for that. --- deepspeed/runtime/zero/partition_parameters.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 05825fc90688..4216e5606a76 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -807,8 +807,9 @@ def _partition_gradient(self, param, partition_buffer=None, accumulate=False): if start < param.ds_numel: elements = min(param.ds_numel - start, partition_size) - dest_tensor = partition_buffer.view(-1).narrow(0, 0, elements) + dest_tensor_full_buffer = partition_buffer.view(-1).narrow(0, 0, partition_size) + dest_tensor = dest_tensor_full_buffer.narrow(0, 0, elements) src_tensor = param.grad.view(-1).narrow(0, start, elements) # just copy the grad partition to the buffer @@ -841,7 +842,7 @@ def _partition_gradient(self, param, partition_buffer=None, accumulate=False): # elements)) #print("after partition gradients") - param.grad.data = dest_tensor.data + param.grad.data = dest_tensor_full_buffer.data see_memory_usage("After partitioning gradients", force=False) From ac9266eef286235eb593ed0110fb7231dbd2c936 Mon Sep 17 00:00:00 2001 From: Samyam Date: Wed, 10 Mar 2021 23:11:49 +0000 Subject: [PATCH 2/9] Formatting fix --- deepspeed/runtime/zero/partition_parameters.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 4216e5606a76..e6cb9199899a 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -807,7 +807,10 @@ def _partition_gradient(self, param, partition_buffer=None, accumulate=False): if start < param.ds_numel: elements = min(param.ds_numel - start, partition_size) - dest_tensor_full_buffer = partition_buffer.view(-1).narrow(0, 0, partition_size) + dest_tensor_full_buffer = partition_buffer.view(-1).narrow( + 0, + 0, + partition_size) dest_tensor = dest_tensor_full_buffer.narrow(0, 0, elements) src_tensor = param.grad.view(-1).narrow(0, start, elements) From bf53561cad23f47ffed58887bb17d0aed8acc0b4 Mon Sep 17 00:00:00 2001 From: Samyam Date: Wed, 10 Mar 2021 23:26:46 +0000 Subject: [PATCH 3/9] Adding static_scale test back for Z3, and also changing hidden size to be not divisile by world_size --- tests/unit/test_fp16.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_fp16.py b/tests/unit/test_fp16.py index 5012614f97b0..038ccacc471f 100755 --- a/tests/unit/test_fp16.py +++ b/tests/unit/test_fp16.py @@ -347,9 +347,6 @@ def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload): if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: pytest.skip("cpu-adam is not compatible") - if zero_stage == 3: - pytest.skip("skip for now") - config_dict = { "train_batch_size": 4, "steps_per_print": 1, @@ -372,7 +369,8 @@ def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload): @distributed_test(world_size=2) def _test_zero_static_scale(args, zero_stage): - hidden_dim = 10 + #making hidden size not divisible by DP for covering this scenario + hidden_dim = 9 model = SimpleModel(hidden_dim) model, optim, _, _ = deepspeed.initialize(args=args, From 9cd813d252e963bd44915b0e5e98fa5a83973dda Mon Sep 17 00:00:00 2001 From: Samyam Date: Thu, 11 Mar 2021 02:08:58 +0000 Subject: [PATCH 4/9] also removing alignment from flat fp16 buffers --- deepspeed/runtime/zero/stage3.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index d2c197fa93c8..10208c7b5b49 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -961,10 +961,9 @@ def _create_fp16_partitions_with_defragmentation(self): #create flat buffer in CPU and move to GPU self.fp16_partitioned_groups_flat.append( - flatten_dense_tensors_aligned( - self.fp16_partitioned_groups[i], - dist.get_world_size(group=self.dp_process_group)).cuda( - torch.cuda.current_device())) + flatten_dense_tensors_aligned(self.fp16_partitioned_groups[i], + 1).cuda( + torch.cuda.current_device())) see_memory_usage( f"After flattening and moving param group {i} to GPU", force=False) From 5692c62b60e81e7212c5f7d848e7a0fd9d744eea Mon Sep 17 00:00:00 2001 From: Samyam Date: Thu, 11 Mar 2021 02:21:52 +0000 Subject: [PATCH 5/9] Testing for hidden dim alignment --- tests/unit/test_fp16.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_fp16.py b/tests/unit/test_fp16.py index 038ccacc471f..dbd40c322be9 100755 --- a/tests/unit/test_fp16.py +++ b/tests/unit/test_fp16.py @@ -368,9 +368,9 @@ def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload): args = args_from_dict(tmpdir, config_dict) @distributed_test(world_size=2) - def _test_zero_static_scale(args, zero_stage): + def _test_zero_static_scale(args, zero_stage, hidden_dim): #making hidden size not divisible by DP for covering this scenario - hidden_dim = 9 + hidden_dim = hidden_dim model = SimpleModel(hidden_dim) model, optim, _, _ = deepspeed.initialize(args=args, @@ -391,7 +391,10 @@ def _test_zero_static_scale(args, zero_stage): model.backward(loss) model.step() - _test_zero_static_scale(args=args, zero_stage=zero_stage) + #test when hidden_dim is not aligned with world size + _test_zero_static_scale(args=args, zero_stage=zero_stage, hidden_dim=9) + #test when hidden_dim is aligned with world size + _test_zero_static_scale(args=args, zero_stage=zero_stage, hidden_dim=10) def test_zero_static_scale_deprecated_format(tmpdir): From f8881306449b1cdbf6036d0746ba757417f488ad Mon Sep 17 00:00:00 2001 From: Samyam Date: Thu, 11 Mar 2021 19:35:15 +0000 Subject: [PATCH 6/9] inference hook fix --- deepspeed/runtime/zero/stage3.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 10208c7b5b49..93c243bfda51 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1035,6 +1035,17 @@ def setup_zero_stage3_hooks(self): self.hierarchy = 0 self._register_hooks_recursively(self.module) + #reset step if in inference mode + def _end_of_forward_hook(module, *args): + + if not torch._C.is_grad_enabled(): + self.param_coordinator.reset_step() + print_rank_0(f"In inference mode", force=True) + else: + print_rank_0(f"Not in inference mode", force=True) + + self.module.register_forward_hook(_end_of_forward_hook) + def persistent_parameters(self): persistent_params = [] total_persistent_parameters = 0 From 06a75680a9b289c24d19c69715fc95ab7905a4b3 Mon Sep 17 00:00:00 2001 From: Samyam Rajbhandari Date: Thu, 11 Mar 2021 16:37:26 -0800 Subject: [PATCH 7/9] Update stage3.py --- deepspeed/runtime/zero/stage3.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 93c243bfda51..4ba0567a5a7c 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1040,10 +1040,7 @@ def _end_of_forward_hook(module, *args): if not torch._C.is_grad_enabled(): self.param_coordinator.reset_step() - print_rank_0(f"In inference mode", force=True) - else: - print_rank_0(f"Not in inference mode", force=True) - + self.module.register_forward_hook(_end_of_forward_hook) def persistent_parameters(self): From e1122046d1780fbf4f517f94488bfb6e8896ed3b Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Fri, 12 Mar 2021 09:19:34 -0800 Subject: [PATCH 8/9] formatting --- deepspeed/runtime/zero/stage3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 4ba0567a5a7c..1861bdadb738 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1040,7 +1040,7 @@ def _end_of_forward_hook(module, *args): if not torch._C.is_grad_enabled(): self.param_coordinator.reset_step() - + self.module.register_forward_hook(_end_of_forward_hook) def persistent_parameters(self): From d4ef4ff0b8af19257bb484fc3a23645950980d85 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Fri, 12 Mar 2021 17:53:04 -0800 Subject: [PATCH 9/9] [bug-fix] move params to gpu if offload params is turned off --- deepspeed/runtime/zero/stage3.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index be07e00e249d..99b4916aef3c 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -975,10 +975,12 @@ def _create_fp16_partitions_with_defragmentation(self): flat_offset, total_elements) self.fp16_partitioned_groups_flat.append(fp16_partitioned_group_flat) - self._move_to_flat_buffer(self.fp16_partitioned_groups[i], - self.fp16_partitioned_groups_flat[i]) flat_offset += total_elements + # move param to flat buffer for both param offload on/off + self._move_to_flat_buffer(self.fp16_partitioned_groups[i], + self.fp16_partitioned_groups_flat[i]) + see_memory_usage(f"After Flattening param group {i}", force=False) def _create_fp32_partitions(self):