From e193901a0a949c9b44601b97a9bb1752e26a8522 Mon Sep 17 00:00:00 2001 From: Samyam Date: Fri, 19 Mar 2021 20:15:27 +0000 Subject: [PATCH 1/2] Fix for fragmented linear inputs in ZeRO 3 Linear layers where reshape is needed instead of view. Should not affect performance for cases which only requires view since reshape will just do view when possible --- deepspeed/runtime/zero/linear.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/zero/linear.py b/deepspeed/runtime/zero/linear.py index f29fcda2bb19..23f97d5a542a 100644 --- a/deepspeed/runtime/zero/linear.py +++ b/deepspeed/runtime/zero/linear.py @@ -77,10 +77,10 @@ def backward(ctx, grad_output): #print("Computing grad weight") dim = grad_output.dim() if dim > 2: - grad_weight = grad_output.view(-1, - grad_output.shape[-1]).t().matmul( - input.view(-1, - input.shape[-1])) + grad_weight = grad_output.reshape(-1, + grad_output.shape[-1]).t().matmul( + input.reshape(-1, + input.shape[-1])) else: grad_weight = grad_output.t().matmul(input) #print(f"Computed grad weight grad_weight {grad_weight.shape}") From 79b81eda864072452395cb8896f0b309051bf048 Mon Sep 17 00:00:00 2001 From: Samyam Date: Fri, 19 Mar 2021 20:34:48 +0000 Subject: [PATCH 2/2] Enable memory efficient linear when ZeRO 3 model is initialized in Stage 3 initialize instead of using deepspeed.init --- deepspeed/runtime/zero/partition_parameters.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index e6cb9199899a..e546dff70445 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -190,6 +190,9 @@ def _init_subclass(cls, **kwargs): torch.empty = empty_cuda_tensor if self.mem_efficient_linear: + print_rank_0( + f"Your linear layers are being patched with more memory efficient version. This will persit unless manually reset.", + force=True) self.linear_bk = torch.nn.functional.linear torch.nn.functional.linear = LinearFunctionForZeroStage3.apply @@ -210,8 +213,9 @@ def _disable_class(cls): torch.Tensor.__new__ = torch.Tensor.__old_new__ torch.empty = _orig_torch_empty - if self.mem_efficient_linear: - torch.nn.functional.linear = self.linear_bk + #un doing it here will undo it during training + #if self.mem_efficient_linear: + # torch.nn.functional.linear = self.linear_bk # Now that we cleaned up the metaclass injection, raise the exception. if exc_type is not None: @@ -354,6 +358,13 @@ def get_model(): self._convert_to_deepspeed_param(param) param.partition() + if mem_efficient_linear: + print_rank_0( + f"Your linear layers are being patched with more memory efficient version. This will persit unless manually turned reset.", + force=True) + self.linear_bk = torch.nn.functional.linear + torch.nn.functional.linear = LinearFunctionForZeroStage3.apply + def _post_init_method(self, module): #see_memory_usage(f"Before converting parmas in {module.__class__.__name__}", force=False) print_rank_0(f'Converting Params in {module.__class__.__name__}', force=False)