Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 104 additions & 16 deletions src/tir/transforms/lower_cross_thread_reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -426,12 +426,48 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, //
BufferStore(wb_buffers[i], BufferLoad(ct_buffers[i], {Integer(0)}), wb_indices));
wb_regions.push_back(BufferRegion(wb_buffers[i], region));
}

// Construct the predicate of the write-back block. It is the conjunction of
// - each predicate clause of the original block which contains spatial loop var, and
// - `t == 0` for each reduction thread dim when the write-back buffer is not local.
PrimExpr wb_predicate = const_true();
for (const ForNode* loop : reduction_loops) {
if (loop->thread_binding.defined()) {
wb_predicate = wb_predicate && (loop->loop_var == IntImm(loop->loop_var->dtype, 0));
std::unordered_set<const VarNode*> reduction_loop_vars;
reduction_loop_vars.reserve(reduction_loops.size());
for (const ForNode* reduction_loop : reduction_loops) {
reduction_loop_vars.insert(reduction_loop->loop_var.get());
}
PostOrderVisit(realize->predicate, [&wb_predicate, &reduction_loop_vars](const ObjectRef& obj) {
if (const auto* and_node = obj.as<AndNode>()) {
Array<PrimExpr> sub_exprs = {and_node->a, and_node->b};
for (PrimExpr sub_expr : sub_exprs) {
if (sub_expr->IsInstance<AndNode>()) {
continue;
}
bool is_reduction = [sub_expr, &reduction_loop_vars]() {
Array<Var> vars = UndefinedVars(sub_expr);
for (Var var : vars) {
if (reduction_loop_vars.find(var.get()) != reduction_loop_vars.end()) {
return true;
}
}
return false;
}();
if (!is_reduction) {
wb_predicate = wb_predicate && sub_expr;
}
}
return true;
}
return false;
});
if (wb_buffers[0].scope() != "local") {
for (const ForNode* loop : reduction_loops) {
if (loop->thread_binding.defined()) {
wb_predicate = wb_predicate && (loop->loop_var == IntImm(loop->loop_var->dtype, 0));
}
}
}

stmts.push_back(BlockRealize(
/*iter_values=*/std::move(bindings),
/*predicate=*/wb_predicate,
Expand Down Expand Up @@ -498,21 +534,45 @@ class CrossThreadReductionTransformer : public StmtMutator {
}

// Check if the input block needs thread broadcast rewrite.
// One block needs broadcast rewrite when there exists one or more thread
// vars which vars free variables to this block.
// One block needs broadcast rewrite when
// 1. it consumes a buffer produced by cross-thread reduction under
// the same kernel (i.e., same group of blockIdx),
// 2. it writes to non-local memory,
// 3. at least one of the reduction thread vars of the cross-thread reduction
// is free to this block (i.e., not bound to the block).
std::vector<std::pair<ThreadScope, Range>> NeedCrossThreadBroadcast(
const BlockRealizeNode* realize) {
std::unordered_map<ThreadScope, Range, ThreadScopeHash, ThreadScopeEqual> unbound_thread2range =
thread2range_;
Block block = realize->block;

// If the block writes to local memory, no rewrite is needed.
for (BufferRegion write_region : block->writes) {
if (write_region->buffer.scope() == "local") {
return {};
}
}

// Find out the reduction threads for the read-buffers which are produced by
// cross-thread reduction.
std::unordered_map<ThreadScope, Range, ThreadScopeHash, ThreadScopeEqual> thread2range;
for (BufferRegion read_region : block->reads) {
auto buf_it = crt_buf2threads_.find(read_region->buffer.get());
if (buf_it == crt_buf2threads_.end()) {
continue;
}
for (auto [scope, range] : buf_it->second) {
thread2range[scope] = range;
}
}

// Erase those threads which are not free to this block.
for (const ForNode* loop : loop_stack_) {
if (loop->thread_binding.defined()) {
ThreadScope scope = ThreadScope::Create(loop->thread_binding.value()->thread_tag);
unbound_thread2range.erase(scope);
thread2range.erase(scope);
}
}

std::vector<std::pair<ThreadScope, Range>> unbound_thread2range_list;
for (auto [scope, range] : unbound_thread2range) {
for (auto [scope, range] : thread2range) {
unbound_thread2range_list.emplace_back(scope, range);
}
return unbound_thread2range_list;
Expand Down Expand Up @@ -582,13 +642,28 @@ class CrossThreadReductionTransformer : public StmtMutator {
std::tie(reducer, combiner_lhs, combiner_rhs) =
GetReducerAndCombinerLhsRhs(NullOpt, init_values, updates);

// Condition 4. All reduction buffers should be all local or all non-local.
int is_local_buf = -1;
Array<Buffer> reduction_buffers;
reduction_buffers.reserve(updates.size());
for (const BufferStore& buf_store : updates) {
reduction_buffers.push_back(buf_store->buffer);
if (buf_store->buffer.scope() == "local") {
CHECK_NE(is_local_buf, 0)
<< "ValueError: Cross-thread reduction requires all reduction buffers to be all "
"local or all non-local. However, here some buffer is local while some buffer is "
"shared or global.";
is_local_buf = 1;
} else {
CHECK_NE(is_local_buf, 1)
<< "ValueError: Cross-thread reduction requires all reduction buffers to be all "
"local or all non-local. However, here some buffer is local while some buffer is "
"shared or global.";
is_local_buf = 0;
}
}

// Condition 4. The block should be the last block under the first reduction-related loop.
// Condition 5. The block should be the last block under the first reduction-related loop.
bool visit = false;
PreOrderVisit(GetRef<For>(reduction_loops[0]), [block, &visit](const ObjectRef& obj) {
if (const auto* realize = obj.as<BlockRealizeNode>()) {
Expand Down Expand Up @@ -631,8 +706,6 @@ class CrossThreadReductionTransformer : public StmtMutator {
if (scope.rank == 1 && scope.dim_index >= 0) {
is_thread_idx = true;
++thread_idx_depth;
thread2range_[scope] = Range::FromMinExtent(loop->min, loop->extent);
thread_loop_var2scope_[loop->loop_var.get()] = scope;
} else if (scope.rank == 0) {
is_block_idx = true;
++block_idx_depth;
Expand All @@ -649,7 +722,7 @@ class CrossThreadReductionTransformer : public StmtMutator {
--block_idx_depth;
}
if (is_block_idx || (is_thread_idx && thread_idx_depth == 0 && block_idx_depth == 0)) {
thread2range_.clear();
crt_buf2threads_.clear();
}

// Replace `result` with the pre-stored result if `loop` appears as a key in `loop2new_stmt_`.
Expand Down Expand Up @@ -716,6 +789,21 @@ class CrossThreadReductionTransformer : public StmtMutator {
loop2new_stmt_[reduction_loops[0]] =
TransformReductionBlock(realize, it_buffers, ct_buffers, reduction_buffers, wb_indices,
reducer, combiner_rhs, reduction_loops);

// Step 5. Record the reduction thread dims for the write-back buffers.
// The information is used for consumer block broadcasting detection.
std::vector<std::pair<ThreadScope, Range>> reduction_threads;
reduction_threads.reserve(reduction_loops.size());
for (const ForNode* loop : reduction_loops) {
if (loop->thread_binding.defined()) {
reduction_threads.emplace_back(
ThreadScope::Create(loop->thread_binding.value()->thread_tag),
Range::FromMinExtent(loop->min, loop->extent));
}
}
for (const Buffer& reduction_buf : reduction_buffers) {
crt_buf2threads_[reduction_buf.get()] = reduction_threads;
}
}

Stmt MakeCrossThreadBroadcast(
Expand Down Expand Up @@ -792,8 +880,8 @@ class CrossThreadReductionTransformer : public StmtMutator {

int block_idx_depth = 0;
int thread_idx_depth = 0;
std::unordered_map<ThreadScope, Range, ThreadScopeHash, ThreadScopeEqual> thread2range_;
std::unordered_map<const VarNode*, ThreadScope> thread_loop_var2scope_;
std::unordered_map<const BufferNode*, std::vector<std::pair<ThreadScope, Range>>>
crt_buf2threads_;
};

PrimFunc LowerCrossThreadReduction(PrimFunc f) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,64 @@ def lowered_single_reduction_loop_with_block_predicate(
)


@T.prim_func
def spatial_reduction_loop_predicate(A: T.Buffer((2, 32), "float32"), B: T.Buffer((2,), "float32")):
for i_0 in range(1):
for i_1 in T.thread_binding(16, thread="threadIdx.y"):
for k_0 in range(1):
for k_1 in T.thread_binding(64, thread="threadIdx.x"):
with T.block("block"):
vi = T.axis.spatial(2, i_0 * 16 + i_1)
vk = T.axis.reduce(32, k_0 * 64 + k_1)
T.where(i_0 * 16 + i_1 < 2 and k_0 * 64 + k_1 < 32)
T.reads(A[vi, vk])
T.writes(B[vi])
with T.init():
B[vi] = T.float32(0)
B[vi] = B[vi] + A[vi, vk]


@T.prim_func
def lowered_reduction_spatial_loop_predicate(
A: T.Buffer((2, 32), "float32"), B: T.Buffer((2,), "float32")
):
cross_thread_B = T.alloc_buffer((1,), strides=(1,), scope="local")
in_thread_B = T.alloc_buffer((1,), strides=(1,), scope="local")
for i_0 in range(1):
for i_1 in T.thread_binding(16, thread="threadIdx.y"):
for k_1 in T.thread_binding(64, thread="threadIdx.x"):
with T.block("block_in_thread_init"):
T.reads()
T.writes(in_thread_B[0])
in_thread_B[0] = T.float32(0)
for k_0 in range(1):
with T.block("block_in_thread"):
vi = T.axis.spatial(2, i_0 * 16 + i_1)
vk = T.axis.reduce(32, k_0 * 64 + k_1)
T.where(i_0 * 16 + i_1 < 2 and k_0 * 64 + k_1 < 32)
T.reads(A[vi, vk])
T.writes(in_thread_B[0])
in_thread_B[0] = in_thread_B[0] + A[vi, vk]
with T.block("block_cross_thread"):
T.reads(in_thread_B[0])
T.writes(cross_thread_B[0])
T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
"reduce_scope",
T.reinterpret("handle", T.uint64(0)),
)
T.tvm_thread_allreduce(
T.uint32(1), in_thread_B[0], T.bool(True), cross_thread_B[0], k_1
)
k_0 = T.int32()
with T.block("block_write_back"):
vi = T.axis.spatial(2, i_0 * 16 + i_1)
T.where(i_0 * 16 + i_1 < 2 and k_1 == 0)
T.reads(cross_thread_B[0])
T.writes(B[vi])
B[vi] = cross_thread_B[0]


@T.prim_func
def single_reduction_loop_with_tensorize(
input_A: T.Buffer((1, 64, 7, 7, 32), "uint8"),
Expand Down Expand Up @@ -1315,7 +1373,6 @@ def lowered_thread_broadcast_1(A: T.Buffer((256, 256), "float32"), B: T.Buffer((
)
with T.block("sum_write_back"):
vi = T.axis.spatial(256, i)
T.where(k == 0)
T.reads(cross_thread_temp_local[0])
T.writes(temp_local[vi])
temp_local[vi] = cross_thread_temp_local[0]
Expand Down Expand Up @@ -1428,7 +1485,7 @@ def lowered_thread_broadcast_2(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int6
with T.block("NT_matmul_write_back"):
v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n)
v1 = T.axis.spatial(n, ax0_ax1_fused % n)
T.where(ax0_fused == T.int64(0))
T.where(T.int64(0) <= ax0_ax1_fused // n and ax0_ax1_fused // n < T.int64(32) and T.int64(0) <= ax0_ax1_fused % n and ax0_ax1_fused % n < n)
T.reads(cross_thread_var_NT_matmul_intermediate_local[0])
T.writes(var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1])
var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1] = cross_thread_var_NT_matmul_intermediate_local[0]
Expand All @@ -1442,6 +1499,72 @@ def lowered_thread_broadcast_2(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int6
var_compute_intermediate[T.int64(0), v0, T.int64(0), v1] = T.Cast("float32", T.min(T.max(var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1] * T.float16(0.088397790055248615), T.float16(-65504)), lv1582[T.int64(0), T.int64(0), T.int64(0), v1]))
# fmt: on


@T.prim_func
def no_thread_broadcast(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256, 256), "float32")):
temp_1_local = T.alloc_buffer((256,), scope="local")
temp_2_local = T.alloc_buffer((1,), scope="local")
for i in T.thread_binding(256, thread="blockIdx.x"):
for k in T.thread_binding(256, thread="threadIdx.x"):
with T.block("sum"):
vi, vk = T.axis.remap("SR", [i, k])
T.reads(A[vi, vk])
T.writes(temp_1_local[vi])
with T.init():
temp_1_local[vi] = T.float32(0)
temp_1_local[vi] = temp_1_local[vi] + A[vi, vk]
with T.block("add"):
vi = T.axis.spatial(256, i)
T.reads(temp_1_local[vi])
T.writes(temp_2_local[0])
temp_2_local[0] = temp_1_local[vi] + T.float32(1)
for j in T.thread_binding(256, thread="threadIdx.x"):
with T.block("sum"):
vi, vj = T.axis.remap("SR", [i, j])
T.reads(temp_2_local[0])
T.writes(B[vi, vj])
B[vi, vj] = A[vi, vj] + temp_2_local[0]


@T.prim_func
def lowered_no_thread_broadcast(
A: T.Buffer((256, 256), "float32"), B: T.Buffer((256, 256), "float32")
):
temp_1_local = T.alloc_buffer((256,), scope="local")
temp_2_local = T.alloc_buffer((1,), scope="local")
cross_thread_temp_1_local = T.alloc_buffer((1,), strides=(1,), scope="local")
for i in T.thread_binding(256, thread="blockIdx.x"):
for k in T.thread_binding(256, thread="threadIdx.x"):
with T.block("sum_cross_thread"):
vi, vk = T.axis.remap("SR", [i, k])
T.reads(A[vi, vk])
T.writes(cross_thread_temp_1_local[0])
T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
"reduce_scope",
T.reinterpret("handle", T.uint64(0)),
)
T.tvm_thread_allreduce(
T.uint32(1), A[vi, vk], T.bool(True), cross_thread_temp_1_local[0], k
)
with T.block("sum_write_back"):
vi = T.axis.spatial(256, i)
T.reads(cross_thread_temp_1_local[0])
T.writes(temp_1_local[vi])
temp_1_local[vi] = cross_thread_temp_1_local[0]
with T.block("add"):
vi = T.axis.spatial(256, i)
T.reads(temp_1_local[vi])
T.writes(temp_2_local[0])
temp_2_local[0] = temp_1_local[vi] + T.float32(1)
for j in T.thread_binding(256, thread="threadIdx.x"):
with T.block("sum"):
vi, vj = T.axis.remap("SR", [i, j])
T.reads(temp_2_local[0])
T.writes(B[vi, vj])
B[vi, vj] = A[vi, vj] + temp_2_local[0]


# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg


Expand Down Expand Up @@ -1472,6 +1595,10 @@ def test_single_reduction_loop_with_block_predicate():
)


def test_spatial_reduction_loop_predicate():
_check(spatial_reduction_loop_predicate, lowered_reduction_spatial_loop_predicate)


def test_single_reduction_loop_with_tensorize():
_check(
single_reduction_loop_with_tensorize,
Expand Down Expand Up @@ -1534,6 +1661,10 @@ def test_thread_broadcast_rewrite_2():
_check(thread_broadcast_2, lowered_thread_broadcast_2)


def test_no_thread_broadcast_rewrite():
_check(no_thread_broadcast, lowered_no_thread_broadcast)


def test_lower_te():
a = te.placeholder((32, 2, 2))
k1 = te.reduce_axis((0, 2), "k1")
Expand Down