Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/tir/transforms/lower_thread_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
{
PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
if (group_extent > 1) {
mask = mask &
(((1 << reduce_extent) - 1) << (reduce_extent * cast(mask_dtype, group_index)));
mask = mask & (make_const(mask_dtype, (1ll << reduce_extent) - 1)
<< (reduce_extent * cast(mask_dtype, group_index)));
}
seq.emplace_back(BufferStore(mask_buffer, mask, zero_indices));
// Push the buffer description. Later this will have an
Expand Down Expand Up @@ -392,7 +392,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// During the sub-warp reduction, values from inactive threads could be read,
// which is an undefined behavior according to the cuda document.
//
// In practise, the return value are usually 0, which does no harm to sum reduction.
// In practice, the return value are usually 0, which does no harm to sum reduction.
// However, the result can be incorrect in max or prod reduction.
// Therefore an additional range check has to be performed to ensure the correctness.
if (offset * 2 > reduce_extent) {
Expand All @@ -405,7 +405,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {

// Broadcast the reduction result from lane 0 to all other lanes.
// This avoids to emit predicated stores, as all threads are
// uniformly writting the same result.
// uniformly writing the same result.
//
for (size_t i = 0; i < size; ++i) {
Buffer buf = shared_bufs[i];
Expand Down Expand Up @@ -669,7 +669,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
return false;
}

// whether reduce_extent and group_extent are vaild for warp reduction.
// whether reduce_extent and group_extent are valid for warp reduction.
if (target_->kind->name == "rocm") {
return reduce_extent == warp_size_;
} else { // target_->kind->name == "cuda"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,5 +235,70 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer(128, "float32")):
B[i] = reduce[0]


class TestMultiGroupMask(BaseCompare):
@T.prim_func
def before(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32,), "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
threadIdx_y = T.launch_thread("threadIdx.y", 32)
cross_thread_B = T.allocate([1], "float32", "local")
threadIdx_x = T.launch_thread("threadIdx.x", 32)
cross_thread_B_1 = T.Buffer((1,), data=cross_thread_B, scope="local")
with T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
"reduce_scope",
T.reinterpret("handle", T.uint64(0)),
):
A_1 = T.Buffer((1024,), data=A.data)
T.tvm_thread_allreduce(
T.uint32(1),
A_1[threadIdx_y * 32 + threadIdx_x],
T.bool(True),
cross_thread_B_1[0],
threadIdx_x,
)
if threadIdx_x == 0:
B_1 = T.Buffer((32,), data=B.data)
B_1[threadIdx_y] = cross_thread_B_1[0]

@T.prim_func
def expected(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32,), "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
threadIdx_y = T.launch_thread("threadIdx.y", 32)
red_buf0 = T.allocate([1], "float32", "local")
threadIdx_x = T.launch_thread("threadIdx.x", 32)
red_buf0_1 = T.Buffer((1,), data=red_buf0, scope="local")
with T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
"reduce_scope",
T.reinterpret("handle", T.uint64(0)),
):
mask = T.allocate([1], "uint32", "local")
t0 = T.allocate([1], "float32", "local")
A_1 = T.Buffer((1024,), data=A.data)
red_buf0_1[0] = A_1[threadIdx_y * 32 + threadIdx_x]

mask_1 = T.Buffer((1,), "uint32", data=mask, scope="local")
mask_1[0] = T.bitwise_and(
T.tvm_warp_activemask(),
T.shift_left(T.uint32(4294967295), T.uint32(32) * T.Cast("uint32", threadIdx_y)),
)

t0_1 = T.Buffer((1,), data=t0, scope="local")
t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 16, 32, 32)
red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 8, 32, 32)
red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 4, 32, 32)
red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 2, 32, 32)
red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 1, 32, 32)
red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
red_buf0_1[0] = T.tvm_warp_shuffle(mask_1[0], red_buf0_1[0], 32 * threadIdx_y, 32, 32)
if threadIdx_x == 0:
B_1 = T.Buffer((32,), data=B.data)
B_1[threadIdx_y] = red_buf0_1[0]


if __name__ == "__main__":
tvm.testing.main()