From c7cb4acbf4fd0730fc531680dc400f302db52f51 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 14 Jul 2023 19:34:01 -0400 Subject: [PATCH] [Fix][TIR] LowerThreadAllreduce with correct thread mask This PR fixes a bug in the LowerThreadAllreduce pass. Prior to this PR, in multi-group settings, the thread mask is not correctly set: when the reduction extent is 32, the thread mask will always be 0. This bug was not spotted because even when the mask is 0, the CUDA program still gives correct result. But in any way, having the zero mask is dangerous and should be fixed. --- src/tir/transforms/lower_thread_allreduce.cc | 10 +-- ...t_tir_transform_lower_thread_all_reduce.py | 65 +++++++++++++++++++ 2 files changed, 70 insertions(+), 5 deletions(-) diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index c1566936c531..97a34a6ede1f 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -333,8 +333,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { { PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {}); if (group_extent > 1) { - mask = mask & - (((1 << reduce_extent) - 1) << (reduce_extent * cast(mask_dtype, group_index))); + mask = mask & (make_const(mask_dtype, (1ll << reduce_extent) - 1) + << (reduce_extent * cast(mask_dtype, group_index))); } seq.emplace_back(BufferStore(mask_buffer, mask, zero_indices)); // Push the buffer description. Later this will have an @@ -392,7 +392,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // During the sub-warp reduction, values from inactive threads could be read, // which is an undefined behavior according to the cuda document. // - // In practise, the return value are usually 0, which does no harm to sum reduction. + // In practice, the return value are usually 0, which does no harm to sum reduction. // However, the result can be incorrect in max or prod reduction. // Therefore an additional range check has to be performed to ensure the correctness. if (offset * 2 > reduce_extent) { @@ -405,7 +405,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // Broadcast the reduction result from lane 0 to all other lanes. // This avoids to emit predicated stores, as all threads are - // uniformly writting the same result. + // uniformly writing the same result. // for (size_t i = 0; i < size; ++i) { Buffer buf = shared_bufs[i]; @@ -669,7 +669,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { return false; } - // whether reduce_extent and group_extent are vaild for warp reduction. + // whether reduce_extent and group_extent are valid for warp reduction. if (target_->kind->name == "rocm") { return reduce_extent == warp_size_; } else { // target_->kind->name == "cuda" diff --git a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py index f20d11ffb401..c9e6136ca8d7 100644 --- a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py +++ b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py @@ -235,5 +235,70 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer(128, "float32")): B[i] = reduce[0] +class TestMultiGroupMask(BaseCompare): + @T.prim_func + def before(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32,), "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + threadIdx_y = T.launch_thread("threadIdx.y", 32) + cross_thread_B = T.allocate([1], "float32", "local") + threadIdx_x = T.launch_thread("threadIdx.x", 32) + cross_thread_B_1 = T.Buffer((1,), data=cross_thread_B, scope="local") + with T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + A_1 = T.Buffer((1024,), data=A.data) + T.tvm_thread_allreduce( + T.uint32(1), + A_1[threadIdx_y * 32 + threadIdx_x], + T.bool(True), + cross_thread_B_1[0], + threadIdx_x, + ) + if threadIdx_x == 0: + B_1 = T.Buffer((32,), data=B.data) + B_1[threadIdx_y] = cross_thread_B_1[0] + + @T.prim_func + def expected(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32,), "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + threadIdx_y = T.launch_thread("threadIdx.y", 32) + red_buf0 = T.allocate([1], "float32", "local") + threadIdx_x = T.launch_thread("threadIdx.x", 32) + red_buf0_1 = T.Buffer((1,), data=red_buf0, scope="local") + with T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + mask = T.allocate([1], "uint32", "local") + t0 = T.allocate([1], "float32", "local") + A_1 = T.Buffer((1024,), data=A.data) + red_buf0_1[0] = A_1[threadIdx_y * 32 + threadIdx_x] + + mask_1 = T.Buffer((1,), "uint32", data=mask, scope="local") + mask_1[0] = T.bitwise_and( + T.tvm_warp_activemask(), + T.shift_left(T.uint32(4294967295), T.uint32(32) * T.Cast("uint32", threadIdx_y)), + ) + + t0_1 = T.Buffer((1,), data=t0, scope="local") + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 16, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 8, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 4, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 2, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 1, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + red_buf0_1[0] = T.tvm_warp_shuffle(mask_1[0], red_buf0_1[0], 32 * threadIdx_y, 32, 32) + if threadIdx_x == 0: + B_1 = T.Buffer((32,), data=B.data) + B_1[threadIdx_y] = red_buf0_1[0] + + if __name__ == "__main__": tvm.testing.main()