From 2471b8fd05be9044a49dd502a26654ae2be10dc9 Mon Sep 17 00:00:00 2001 From: Jaemin Choi Date: Wed, 31 Jan 2024 11:11:12 -0800 Subject: [PATCH 1/4] Pass knobs for TP comm overlap instead of env vars Signed-off-by: Jaemin Choi --- transformer_engine/pytorch/transformer.py | 37 ++++++++++++++++------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index dd21be36c8..f220f68bb1 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -193,6 +193,18 @@ class TransformerLayer(torch.nn.Module): `set_tensor_parallel_group(tp_group)` method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives. + ub_tp_comm_overlap : bool, default = `False` + if set to `True`, enables overlap of TP communication with computation. + ub_bulk_wgrad : bool, default = `True` + ub_bulk_dgrad : bool, default = `True` + ub_split_ag : bool, default = `True` + enables split-pipelined overlap of allgather with computation. + ub_split_rs : bool, default = `True` + enables split-pipelined overlap of reduce-scatter with computation. + ub_atomic_gemm_ag: bool, default = `False` + if set to `True`, enables atomic overlap of allgather with computation. + ub_atomic_gemm_rs: bool, default = `False` + if set to `True`, enables atomic overlap of reduce-scatter with computation. Optimization parameters ----------------------- @@ -257,6 +269,12 @@ def __init__( zero_centered_gamma: bool = False, qkv_weight_interleaved: bool = True, ub_tp_comm_overlap: bool = False, + ub_bulk_wgrad: bool = True, + ub_bulk_dgrad: bool = True, + ub_split_ag: bool = True, + ub_split_rs: bool = True, + ub_atomic_gemm_ag: bool = False, + ub_atomic_gemm_rs: bool = False, bias: bool = True, activation: str = 'gelu', normalization: str = "LayerNorm", @@ -274,21 +292,18 @@ def __init__( self.window_size = window_size self.window_size = check_set_window_size(self_attn_mask_type, self.window_size) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype - ub_tp_comm_overlap = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_OVERLAP", "1"))) - ub_bulk_wgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_WGRAD", "1"))) - ub_bulk_dgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_DGRAD", "1"))) - ub_split_ag = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_AG", "1"))) - ub_split_rs = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_RS", "1"))) - ub_atomic_gemm_rs = (ub_tp_comm_overlap - and bool(int(os.getenv("NVTE_UB_ATOMIC_GEMM_RS", "0")))) + ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad + ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad + ub_split_ag = ub_tp_comm_overlap and ub_split_ag + ub_split_rs = ub_tp_comm_overlap and ub_split_rs + ub_atomic_gemm_rs = ub_tp_comm_overlap and ub_atomic_gemm_rs assert ( not (ub_split_rs and ub_atomic_gemm_rs) - ), "Only one type of RS overlap NVTE_UB_SPLIT_RS/NVTE_UB_ATOMIC_GEMM_RS should be enabled." - ub_atomic_gemm_ag = (ub_tp_comm_overlap - and bool(int(os.getenv("NVTE_UB_ATOMIC_GEMM_AG", "0")))) + ), "Only one type of RS overlap ub_split_rs/ub_atomic_gemm_rs should be enabled." + ub_atomic_gemm_ag = ub_tp_comm_overlap and ub_atomic_gemm_ag assert ( not (ub_split_ag and ub_atomic_gemm_ag) - ), "Only one type of AG overlap NVTE_UB_SPLIT_AG/NVTE_UB_ATOMIC_GEMM_AG should be enabled." + ), "Only one type of AG overlap ub_split_ag/ub_atomic_gemm_ag should be enabled." if ub_atomic_gemm_rs or ub_atomic_gemm_ag: warnings.warn( From 9f5b9e87363618cd8476a3ad68419ffabc9d46ce Mon Sep 17 00:00:00 2001 From: Jaemin Choi Date: Wed, 31 Jan 2024 12:36:03 -0800 Subject: [PATCH 2/4] Comment out debugging print Signed-off-by: Jaemin Choi --- transformer_engine/pytorch/csrc/comm_gemm_overlap.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index 8d30bf4a8c..2d6fe54b7b 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -95,6 +95,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { } _ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options()); + // FIXME: We no longer rely on NVTE_UB_ATOMIC_GEMM_RS, so we cannot determine + // if atomic overlap is the chosen algorithm here. + /* const char *env_p = std::getenv("NVTE_RS_STRIDED_ATOMIC"); const char *env_q = std::getenv("NVTE_UB_ATOMIC_GEMM_RS"); if (rank == 0 && env_p != nullptr && env_q != nullptr && env_q[0] == '1') { @@ -105,6 +108,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { else printf("!! Using reducescatter2_userbuff_strided\n"); } + */ at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { From 51fcc8460b2af03b8bc1fe85fa093152cda0261a Mon Sep 17 00:00:00 2001 From: Jaemin Choi Date: Wed, 31 Jan 2024 14:22:13 -0800 Subject: [PATCH 3/4] Remove docstring Signed-off-by: Jaemin Choi --- transformer_engine/pytorch/transformer.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index f220f68bb1..e0d45a094e 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -193,18 +193,6 @@ class TransformerLayer(torch.nn.Module): `set_tensor_parallel_group(tp_group)` method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives. - ub_tp_comm_overlap : bool, default = `False` - if set to `True`, enables overlap of TP communication with computation. - ub_bulk_wgrad : bool, default = `True` - ub_bulk_dgrad : bool, default = `True` - ub_split_ag : bool, default = `True` - enables split-pipelined overlap of allgather with computation. - ub_split_rs : bool, default = `True` - enables split-pipelined overlap of reduce-scatter with computation. - ub_atomic_gemm_ag: bool, default = `False` - if set to `True`, enables atomic overlap of allgather with computation. - ub_atomic_gemm_rs: bool, default = `False` - if set to `True`, enables atomic overlap of reduce-scatter with computation. Optimization parameters ----------------------- From 7c989aed8cbc0c59801c0f911faf4506ea169d46 Mon Sep 17 00:00:00 2001 From: Jaemin Choi Date: Fri, 9 Feb 2024 16:55:10 -0800 Subject: [PATCH 4/4] Remove debugging output Signed-off-by: Jaemin Choi --- .../pytorch/csrc/comm_gemm_overlap.h | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index 2d6fe54b7b..827dec5010 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -95,21 +95,6 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { } _ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options()); - // FIXME: We no longer rely on NVTE_UB_ATOMIC_GEMM_RS, so we cannot determine - // if atomic overlap is the chosen algorithm here. - /* - const char *env_p = std::getenv("NVTE_RS_STRIDED_ATOMIC"); - const char *env_q = std::getenv("NVTE_UB_ATOMIC_GEMM_RS"); - if (rank == 0 && env_p != nullptr && env_q != nullptr && env_q[0] == '1') { - if (env_p[0] == '1') - printf("!! Using reducescatter2_userbuff_strided_atomic\n"); - else if (env_p[0] == '2') - printf("!! Using reducescatter2_userbuff_strided_multiatomic\n"); - else - printf("!! Using reducescatter2_userbuff_strided\n"); - } - */ - at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { cudaStream_t stream;