diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 6440c628cd..0d442435bf 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -616,7 +616,7 @@ def forward( rank = get_distributed_rank(cp_group) send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] recv_src = cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] - batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) + batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type @@ -1564,7 +1564,7 @@ def backward(ctx, dout): rank = get_distributed_rank(ctx.cp_group) send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] - batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) + batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = ( restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index cdb75aa1b6..31a464caad 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -43,6 +43,7 @@ _2X_ACC_DGRAD = True _2X_ACC_WGRAD = True _multi_stream_cublas_workspace = [] +_dummy_wgrads = {} _cublas_workspace = None _ub_communicators = None _NUM_MAX_UB_STREAMS = 3 @@ -78,6 +79,22 @@ def get_multi_stream_cublas_workspace() -> List[torch.Tensor]: return _multi_stream_cublas_workspace +def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor: + """Returns a dummy tensor of given shape.""" + assert len(shape) == 2 + global _dummy_wgrads + if (shape[0], shape[1], dtype) not in _dummy_wgrads: + _dummy_wgrads[(shape[0], shape[1], dtype)] = torch.empty( + shape, + dtype=dtype, + device="cuda", + requires_grad=False, + ) + if zero: + _dummy_wgrads[(shape[0], shape[1], dtype)].fill_(0) + return _dummy_wgrads[(shape[0], shape[1], dtype)].detach() + + def initialize_ub( shape: list, tp_size: int, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 5fb986bdc3..f49bad48c3 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -19,6 +19,7 @@ get_workspace, get_ub, TransformerEngineBaseModule, + get_dummy_wgrad, _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD, @@ -796,18 +797,15 @@ def backward( if ctx.fuse_wgrad_accumulation and hasattr(origin_weight, "grad_added_to_main_grad"): origin_weight.grad_added_to_main_grad = True if getattr(origin_weight, "zero_out_wgrad", False): - wgrad = torch.zeros( - origin_weight.main_grad.shape, - dtype=origin_weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(origin_weight.main_grad.shape), + origin_weight.dtype, + zero=True, ) else: - wgrad = torch.empty( - origin_weight.main_grad.shape, - dtype=origin_weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(origin_weight.main_grad.shape), + origin_weight.dtype, ) elif ctx.fuse_wgrad_accumulation: wgrad = None diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b0e60fbe5d..ca9dd29043 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -16,6 +16,7 @@ get_workspace, get_ub, TransformerEngineBaseModule, + get_dummy_wgrad, _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD, @@ -688,18 +689,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ): weight.grad_added_to_main_grad = True if getattr(weight, "zero_out_wgrad", False): - wgrad = torch.zeros( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, + zero=True, ) else: - wgrad = torch.empty( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, ) elif ctx.fuse_wgrad_accumulation: wgrad = None