Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions transformer_engine/pytorch/csrc/comm_gemm_overlap.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,6 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
}
_ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options());

const char *env_p = std::getenv("NVTE_RS_STRIDED_ATOMIC");
const char *env_q = std::getenv("NVTE_UB_ATOMIC_GEMM_RS");
if (rank == 0 && env_p != nullptr && env_q != nullptr && env_q[0] == '1') {
if (env_p[0] == '1')
printf("!! Using reducescatter2_userbuff_strided_atomic\n");
else if (env_p[0] == '2')
printf("!! Using reducescatter2_userbuff_strided_multiatomic\n");
else
printf("!! Using reducescatter2_userbuff_strided\n");
}

at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
for (int i = 0; i < std::min(num_max_streams, num_splits); i++) {
cudaStream_t stream;
Expand Down
25 changes: 14 additions & 11 deletions transformer_engine/pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,12 @@ def __init__(
zero_centered_gamma: bool = False,
qkv_weight_interleaved: bool = True,
ub_tp_comm_overlap: bool = False,
ub_bulk_wgrad: bool = True,
ub_bulk_dgrad: bool = True,
ub_split_ag: bool = True,
ub_split_rs: bool = True,
ub_atomic_gemm_ag: bool = False,
ub_atomic_gemm_rs: bool = False,
bias: bool = True,
activation: str = 'gelu',
normalization: str = "LayerNorm",
Expand All @@ -274,21 +280,18 @@ def __init__(
self.window_size = window_size
self.window_size = check_set_window_size(self_attn_mask_type, self.window_size)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
ub_tp_comm_overlap = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_OVERLAP", "1")))
ub_bulk_wgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_WGRAD", "1")))
ub_bulk_dgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_DGRAD", "1")))
ub_split_ag = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_AG", "1")))
ub_split_rs = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_RS", "1")))
ub_atomic_gemm_rs = (ub_tp_comm_overlap
and bool(int(os.getenv("NVTE_UB_ATOMIC_GEMM_RS", "0"))))
ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad
ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad
ub_split_ag = ub_tp_comm_overlap and ub_split_ag
ub_split_rs = ub_tp_comm_overlap and ub_split_rs
ub_atomic_gemm_rs = ub_tp_comm_overlap and ub_atomic_gemm_rs
assert (
not (ub_split_rs and ub_atomic_gemm_rs)
), "Only one type of RS overlap NVTE_UB_SPLIT_RS/NVTE_UB_ATOMIC_GEMM_RS should be enabled."
ub_atomic_gemm_ag = (ub_tp_comm_overlap
and bool(int(os.getenv("NVTE_UB_ATOMIC_GEMM_AG", "0"))))
), "Only one type of RS overlap ub_split_rs/ub_atomic_gemm_rs should be enabled."
ub_atomic_gemm_ag = ub_tp_comm_overlap and ub_atomic_gemm_ag
assert (
not (ub_split_ag and ub_atomic_gemm_ag)
), "Only one type of AG overlap NVTE_UB_SPLIT_AG/NVTE_UB_ATOMIC_GEMM_AG should be enabled."
), "Only one type of AG overlap ub_split_ag/ub_atomic_gemm_ag should be enabled."

if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
warnings.warn(
Expand Down