diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index c1566936c531..97a34a6ede1f 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -333,8 +333,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { { PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {}); if (group_extent > 1) { - mask = mask & - (((1 << reduce_extent) - 1) << (reduce_extent * cast(mask_dtype, group_index))); + mask = mask & (make_const(mask_dtype, (1ll << reduce_extent) - 1) + << (reduce_extent * cast(mask_dtype, group_index))); } seq.emplace_back(BufferStore(mask_buffer, mask, zero_indices)); // Push the buffer description. Later this will have an @@ -392,7 +392,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // During the sub-warp reduction, values from inactive threads could be read, // which is an undefined behavior according to the cuda document. // - // In practise, the return value are usually 0, which does no harm to sum reduction. + // In practice, the return value are usually 0, which does no harm to sum reduction. // However, the result can be incorrect in max or prod reduction. // Therefore an additional range check has to be performed to ensure the correctness. if (offset * 2 > reduce_extent) { @@ -405,7 +405,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // Broadcast the reduction result from lane 0 to all other lanes. // This avoids to emit predicated stores, as all threads are - // uniformly writting the same result. + // uniformly writing the same result. // for (size_t i = 0; i < size; ++i) { Buffer buf = shared_bufs[i]; @@ -669,7 +669,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { return false; } - // whether reduce_extent and group_extent are vaild for warp reduction. + // whether reduce_extent and group_extent are valid for warp reduction. if (target_->kind->name == "rocm") { return reduce_extent == warp_size_; } else { // target_->kind->name == "cuda" diff --git a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py index f20d11ffb401..c9e6136ca8d7 100644 --- a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py +++ b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py @@ -235,5 +235,70 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer(128, "float32")): B[i] = reduce[0] +class TestMultiGroupMask(BaseCompare): + @T.prim_func + def before(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32,), "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + threadIdx_y = T.launch_thread("threadIdx.y", 32) + cross_thread_B = T.allocate([1], "float32", "local") + threadIdx_x = T.launch_thread("threadIdx.x", 32) + cross_thread_B_1 = T.Buffer((1,), data=cross_thread_B, scope="local") + with T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + A_1 = T.Buffer((1024,), data=A.data) + T.tvm_thread_allreduce( + T.uint32(1), + A_1[threadIdx_y * 32 + threadIdx_x], + T.bool(True), + cross_thread_B_1[0], + threadIdx_x, + ) + if threadIdx_x == 0: + B_1 = T.Buffer((32,), data=B.data) + B_1[threadIdx_y] = cross_thread_B_1[0] + + @T.prim_func + def expected(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32,), "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + threadIdx_y = T.launch_thread("threadIdx.y", 32) + red_buf0 = T.allocate([1], "float32", "local") + threadIdx_x = T.launch_thread("threadIdx.x", 32) + red_buf0_1 = T.Buffer((1,), data=red_buf0, scope="local") + with T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + mask = T.allocate([1], "uint32", "local") + t0 = T.allocate([1], "float32", "local") + A_1 = T.Buffer((1024,), data=A.data) + red_buf0_1[0] = A_1[threadIdx_y * 32 + threadIdx_x] + + mask_1 = T.Buffer((1,), "uint32", data=mask, scope="local") + mask_1[0] = T.bitwise_and( + T.tvm_warp_activemask(), + T.shift_left(T.uint32(4294967295), T.uint32(32) * T.Cast("uint32", threadIdx_y)), + ) + + t0_1 = T.Buffer((1,), data=t0, scope="local") + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 16, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 8, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 4, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 2, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 1, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + red_buf0_1[0] = T.tvm_warp_shuffle(mask_1[0], red_buf0_1[0], 32 * threadIdx_y, 32, 32) + if threadIdx_x == 0: + B_1 = T.Buffer((32,), data=B.data) + B_1[threadIdx_y] = red_buf0_1[0] + + if __name__ == "__main__": tvm.testing.main()