diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index 8d30bf4a8c..827dec5010 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -95,17 +95,6 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { } _ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options()); - const char *env_p = std::getenv("NVTE_RS_STRIDED_ATOMIC"); - const char *env_q = std::getenv("NVTE_UB_ATOMIC_GEMM_RS"); - if (rank == 0 && env_p != nullptr && env_q != nullptr && env_q[0] == '1') { - if (env_p[0] == '1') - printf("!! Using reducescatter2_userbuff_strided_atomic\n"); - else if (env_p[0] == '2') - printf("!! Using reducescatter2_userbuff_strided_multiatomic\n"); - else - printf("!! Using reducescatter2_userbuff_strided\n"); - } - at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { cudaStream_t stream; diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index dd21be36c8..e0d45a094e 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -257,6 +257,12 @@ def __init__( zero_centered_gamma: bool = False, qkv_weight_interleaved: bool = True, ub_tp_comm_overlap: bool = False, + ub_bulk_wgrad: bool = True, + ub_bulk_dgrad: bool = True, + ub_split_ag: bool = True, + ub_split_rs: bool = True, + ub_atomic_gemm_ag: bool = False, + ub_atomic_gemm_rs: bool = False, bias: bool = True, activation: str = 'gelu', normalization: str = "LayerNorm", @@ -274,21 +280,18 @@ def __init__( self.window_size = window_size self.window_size = check_set_window_size(self_attn_mask_type, self.window_size) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype - ub_tp_comm_overlap = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_OVERLAP", "1"))) - ub_bulk_wgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_WGRAD", "1"))) - ub_bulk_dgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_DGRAD", "1"))) - ub_split_ag = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_AG", "1"))) - ub_split_rs = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_RS", "1"))) - ub_atomic_gemm_rs = (ub_tp_comm_overlap - and bool(int(os.getenv("NVTE_UB_ATOMIC_GEMM_RS", "0")))) + ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad + ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad + ub_split_ag = ub_tp_comm_overlap and ub_split_ag + ub_split_rs = ub_tp_comm_overlap and ub_split_rs + ub_atomic_gemm_rs = ub_tp_comm_overlap and ub_atomic_gemm_rs assert ( not (ub_split_rs and ub_atomic_gemm_rs) - ), "Only one type of RS overlap NVTE_UB_SPLIT_RS/NVTE_UB_ATOMIC_GEMM_RS should be enabled." - ub_atomic_gemm_ag = (ub_tp_comm_overlap - and bool(int(os.getenv("NVTE_UB_ATOMIC_GEMM_AG", "0")))) + ), "Only one type of RS overlap ub_split_rs/ub_atomic_gemm_rs should be enabled." + ub_atomic_gemm_ag = ub_tp_comm_overlap and ub_atomic_gemm_ag assert ( not (ub_split_ag and ub_atomic_gemm_ag) - ), "Only one type of AG overlap NVTE_UB_SPLIT_AG/NVTE_UB_ATOMIC_GEMM_AG should be enabled." + ), "Only one type of AG overlap ub_split_ag/ub_atomic_gemm_ag should be enabled." if ub_atomic_gemm_rs or ub_atomic_gemm_ag: warnings.warn(