From b04fcf14fc915edebca080c8bbc7b883a009a1f6 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 27 Aug 2024 23:05:44 -0400 Subject: [PATCH] [Fix][TIR] LowerThreadAllreduce warp reduction mask The warp reduction implemented by "shuffle down" primitive takes a mask denoting the active threads within the warp that participate in this shuffle. Previously we compute the mask, while in practice we find that it results in "CUDA illegal instruction" error on NVIDIA H100 GPU when the mask is set, and the issue is gone if we do not update the mask. Therefore, this PR updates the allreduce lowering to remove the mask update. Confirmed the correctness on the following devices: * NVIDIA H100, * NVIDIA RTX 4090, * AMD Radeon 7900 XTX, * Apple M2 Ultra. --- src/tir/transforms/lower_thread_allreduce.cc | 7 ------- .../test_tir_transform_lower_thread_all_reduce.py | 15 ++++----------- 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 37d8f67580fe..dde33fa2678d 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -294,10 +294,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {}); if (reduce_extent <= warp_size_) { - if (group_extent > 1 && reduce_extent < warp_size_) { - mask = mask & - (((1 << reduce_extent) - 1) << (reduce_extent * cast(mask_dtype, group_index))); - } std::tie(reduce_results, new_alloc_bufs) = MakeWarpAllreduce( values, types, combiner, reduce_index, reduce_extent, group_index, mask, NullOpt, &seq); @@ -352,9 +348,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { values[i] = BufferLoad(/*buffer=*/staging_shared_bufs[i], /*indices=*/{group_index * n_warps + reduce_index}); } - if (n_warps < warp_size_) { - mask = mask & (((1 << n_warps) - 1) << (group_index * n_warps)); - } std::tie(reduce_results, local_bufs) = MakeWarpAllreduce( values, types, combiner, reduce_index, n_warps, group_index, mask, /*predicate=*/reduce_index < make_const(reduce_index->dtype, n_warps), &seq); diff --git a/tests/python/tir-transform/test_tir_transform_lower_thread_all_reduce.py b/tests/python/tir-transform/test_tir_transform_lower_thread_all_reduce.py index d8c9568da90e..18d6339349ff 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_thread_all_reduce.py +++ b/tests/python/tir-transform/test_tir_transform_lower_thread_all_reduce.py @@ -342,10 +342,7 @@ def expected(A: T.Buffer((32, 8), "float32"), B: T.Buffer((32,), "float32")): t0 = T.decl_buffer([1], "float32", scope="local") A_1 = T.Buffer((256,), data=A.data) red_buf0_1[0] = A_1[threadIdx_y * 8 + threadIdx_x] - mask[0] = T.bitwise_and( - T.tvm_warp_activemask(), - T.shift_left(T.uint32(255), T.uint32(8) * T.Cast("uint32", threadIdx_y)), - ) + mask[0] = T.tvm_warp_activemask() t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0_1[0], 4, 32, 32) red_buf0_1[0] = red_buf0_1[0] + t0[0] t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0_1[0], 2, 32, 32) @@ -421,7 +418,7 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32")) T.tvm_storage_sync("shared") if threadIdx_x < 4: red_buf0[0] = red_buf_staging[threadIdx_x] - mask[0] = T.bitwise_and(T.tvm_warp_activemask(), T.uint32(15)) + mask[0] = T.tvm_warp_activemask() t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 2, 32, 32) red_buf0[0] = red_buf0[0] + t0[0] t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 1, 32, 32) @@ -573,9 +570,7 @@ def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")): T.tvm_storage_sync("shared") if threadIdx_x < 4: red_buf0[0] = red_buf_staging[threadIdx_y * 4 + threadIdx_x] - mask[0] = T.bitwise_and( - T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(15, threadIdx_y * 4)) - ) + mask[0] = T.tvm_warp_activemask() t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 2, 32, 32) red_buf0[0] = red_buf0[0] + t0[0] t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 1, 32, 32) @@ -657,9 +652,7 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")): T.tvm_storage_sync("shared") if threadIdx_x < 16: red_buf0[0] = red_buf_staging[threadIdx_y * 16 + threadIdx_x] - mask[0] = T.bitwise_and( - T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(65535, threadIdx_y * 16)) - ) + mask[0] = T.tvm_warp_activemask() t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 8, 32, 32) red_buf0[0] = red_buf0[0] + t0[0] t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 4, 32, 32)