From 2b071b86813ac0a80855ffe9c2a4ac24d3671780 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 20 Feb 2025 12:49:41 +0000 Subject: [PATCH 1/5] Use dummy wgrads for lower memory consumption Signed-off-by: Kirthi Shankar Sivamani Signed-off-by: Vasudevan Rengasamy --- transformer_engine/pytorch/module/base.py | 12 ++++++++++++ .../pytorch/module/layernorm_linear.py | 18 ++++++++---------- transformer_engine/pytorch/module/linear.py | 18 ++++++++---------- 3 files changed, 28 insertions(+), 20 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index cdb75aa1b6..0020922ce3 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -43,6 +43,7 @@ _2X_ACC_DGRAD = True _2X_ACC_WGRAD = True _multi_stream_cublas_workspace = [] +_dummy_wgrads = {} _cublas_workspace = None _ub_communicators = None _NUM_MAX_UB_STREAMS = 3 @@ -78,6 +79,17 @@ def get_multi_stream_cublas_workspace() -> List[torch.Tensor]: return _multi_stream_cublas_workspace +def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor: + """Returns a dummy tensor of given shape.""" + assert len(shape) == 2 + global _dummy_wgrads + if (shape[0], shape[1], dtype) not in _dummy_wgrads: + _dummy_wgrads[(shape[0], shape[1], dtype)] = torch.empty(shape, dtype=dtype, device="cuda") + if zero: + _dummy_wgrads[(shape[0], shape[1], dtype)].fill_(0) + return _dummy_wgrads[(shape[0], shape[1], dtype)] + + def initialize_ub( shape: list, tp_size: int, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 5fb986bdc3..6c86a0c05c 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -19,6 +19,7 @@ get_workspace, get_ub, TransformerEngineBaseModule, + get_dummy_wgrad, _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD, @@ -796,18 +797,15 @@ def backward( if ctx.fuse_wgrad_accumulation and hasattr(origin_weight, "grad_added_to_main_grad"): origin_weight.grad_added_to_main_grad = True if getattr(origin_weight, "zero_out_wgrad", False): - wgrad = torch.zeros( - origin_weight.main_grad.shape, - dtype=origin_weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, + zero=True, ) else: - wgrad = torch.empty( - origin_weight.main_grad.shape, - dtype=origin_weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, ) elif ctx.fuse_wgrad_accumulation: wgrad = None diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b0e60fbe5d..ca9dd29043 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -16,6 +16,7 @@ get_workspace, get_ub, TransformerEngineBaseModule, + get_dummy_wgrad, _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD, @@ -688,18 +689,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ): weight.grad_added_to_main_grad = True if getattr(weight, "zero_out_wgrad", False): - wgrad = torch.zeros( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, + zero=True, ) else: - wgrad = torch.empty( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, ) elif ctx.fuse_wgrad_accumulation: wgrad = None From c014f89f961cdcbf87166f974f8aa7f5bee784ac Mon Sep 17 00:00:00 2001 From: Vasudevan Rengasamy Date: Tue, 25 Feb 2025 13:50:00 -0800 Subject: [PATCH 2/5] Bug fix to avoid sharing gradients. Signed-off-by: Vasudevan Rengasamy --- transformer_engine/pytorch/module/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 0020922ce3..035449eb9a 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -84,10 +84,10 @@ def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor assert len(shape) == 2 global _dummy_wgrads if (shape[0], shape[1], dtype) not in _dummy_wgrads: - _dummy_wgrads[(shape[0], shape[1], dtype)] = torch.empty(shape, dtype=dtype, device="cuda") + _dummy_wgrads[(shape[0], shape[1], dtype)] = torch.empty(shape, dtype=dtype, device="cuda", requires_grad=False,) if zero: _dummy_wgrads[(shape[0], shape[1], dtype)].fill_(0) - return _dummy_wgrads[(shape[0], shape[1], dtype)] + return _dummy_wgrads[(shape[0], shape[1], dtype)].detach() def initialize_ub( From eb8433d30df753a3e05d2ad30dbb1d4bf9124a12 Mon Sep 17 00:00:00 2001 From: Vasudevan Rengasamy Date: Fri, 14 Mar 2025 10:44:50 -0700 Subject: [PATCH 3/5] Disable automatic use of batch_p2p_comm for CP2 Signed-off-by: Vasudevan Rengasamy --- transformer_engine/pytorch/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 6440c628cd..0d442435bf 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -616,7 +616,7 @@ def forward( rank = get_distributed_rank(cp_group) send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] recv_src = cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] - batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) + batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type @@ -1564,7 +1564,7 @@ def backward(ctx, dout): rank = get_distributed_rank(ctx.cp_group) send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] - batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) + batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = ( restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) From 77b0955a2716827d89445135a9642327eb6300fa Mon Sep 17 00:00:00 2001 From: Vasudevan Rengasamy Date: Tue, 1 Apr 2025 19:54:29 -0700 Subject: [PATCH 4/5] Change weight to origin_weight for LN_LINEAR Signed-off-by: Vasudevan Rengasamy --- transformer_engine/pytorch/module/layernorm_linear.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 6c86a0c05c..f49bad48c3 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -798,14 +798,14 @@ def backward( origin_weight.grad_added_to_main_grad = True if getattr(origin_weight, "zero_out_wgrad", False): wgrad = get_dummy_wgrad( - list(weight.main_grad.shape), - weight.dtype, + list(origin_weight.main_grad.shape), + origin_weight.dtype, zero=True, ) else: wgrad = get_dummy_wgrad( - list(weight.main_grad.shape), - weight.dtype, + list(origin_weight.main_grad.shape), + origin_weight.dtype, ) elif ctx.fuse_wgrad_accumulation: wgrad = None From e5cacd71f8955585c7aebcb4776e1dfb6504b2a2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 7 Apr 2025 17:10:13 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vasudevan Rengasamy --- transformer_engine/pytorch/module/base.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 035449eb9a..31a464caad 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -84,7 +84,12 @@ def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor assert len(shape) == 2 global _dummy_wgrads if (shape[0], shape[1], dtype) not in _dummy_wgrads: - _dummy_wgrads[(shape[0], shape[1], dtype)] = torch.empty(shape, dtype=dtype, device="cuda", requires_grad=False,) + _dummy_wgrads[(shape[0], shape[1], dtype)] = torch.empty( + shape, + dtype=dtype, + device="cuda", + requires_grad=False, + ) if zero: _dummy_wgrads[(shape[0], shape[1], dtype)].fill_(0) return _dummy_wgrads[(shape[0], shape[1], dtype)].detach()