diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index 79dbfb2a022d..413894264ea6 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -426,12 +426,48 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, // BufferStore(wb_buffers[i], BufferLoad(ct_buffers[i], {Integer(0)}), wb_indices)); wb_regions.push_back(BufferRegion(wb_buffers[i], region)); } + + // Construct the predicate of the write-back block. It is the conjunction of + // - each predicate clause of the original block which contains spatial loop var, and + // - `t == 0` for each reduction thread dim when the write-back buffer is not local. PrimExpr wb_predicate = const_true(); - for (const ForNode* loop : reduction_loops) { - if (loop->thread_binding.defined()) { - wb_predicate = wb_predicate && (loop->loop_var == IntImm(loop->loop_var->dtype, 0)); + std::unordered_set reduction_loop_vars; + reduction_loop_vars.reserve(reduction_loops.size()); + for (const ForNode* reduction_loop : reduction_loops) { + reduction_loop_vars.insert(reduction_loop->loop_var.get()); + } + PostOrderVisit(realize->predicate, [&wb_predicate, &reduction_loop_vars](const ObjectRef& obj) { + if (const auto* and_node = obj.as()) { + Array sub_exprs = {and_node->a, and_node->b}; + for (PrimExpr sub_expr : sub_exprs) { + if (sub_expr->IsInstance()) { + continue; + } + bool is_reduction = [sub_expr, &reduction_loop_vars]() { + Array vars = UndefinedVars(sub_expr); + for (Var var : vars) { + if (reduction_loop_vars.find(var.get()) != reduction_loop_vars.end()) { + return true; + } + } + return false; + }(); + if (!is_reduction) { + wb_predicate = wb_predicate && sub_expr; + } + } + return true; + } + return false; + }); + if (wb_buffers[0].scope() != "local") { + for (const ForNode* loop : reduction_loops) { + if (loop->thread_binding.defined()) { + wb_predicate = wb_predicate && (loop->loop_var == IntImm(loop->loop_var->dtype, 0)); + } } } + stmts.push_back(BlockRealize( /*iter_values=*/std::move(bindings), /*predicate=*/wb_predicate, @@ -498,21 +534,45 @@ class CrossThreadReductionTransformer : public StmtMutator { } // Check if the input block needs thread broadcast rewrite. - // One block needs broadcast rewrite when there exists one or more thread - // vars which vars free variables to this block. + // One block needs broadcast rewrite when + // 1. it consumes a buffer produced by cross-thread reduction under + // the same kernel (i.e., same group of blockIdx), + // 2. it writes to non-local memory, + // 3. at least one of the reduction thread vars of the cross-thread reduction + // is free to this block (i.e., not bound to the block). std::vector> NeedCrossThreadBroadcast( const BlockRealizeNode* realize) { - std::unordered_map unbound_thread2range = - thread2range_; + Block block = realize->block; + + // If the block writes to local memory, no rewrite is needed. + for (BufferRegion write_region : block->writes) { + if (write_region->buffer.scope() == "local") { + return {}; + } + } + + // Find out the reduction threads for the read-buffers which are produced by + // cross-thread reduction. + std::unordered_map thread2range; + for (BufferRegion read_region : block->reads) { + auto buf_it = crt_buf2threads_.find(read_region->buffer.get()); + if (buf_it == crt_buf2threads_.end()) { + continue; + } + for (auto [scope, range] : buf_it->second) { + thread2range[scope] = range; + } + } + + // Erase those threads which are not free to this block. for (const ForNode* loop : loop_stack_) { if (loop->thread_binding.defined()) { ThreadScope scope = ThreadScope::Create(loop->thread_binding.value()->thread_tag); - unbound_thread2range.erase(scope); + thread2range.erase(scope); } } - std::vector> unbound_thread2range_list; - for (auto [scope, range] : unbound_thread2range) { + for (auto [scope, range] : thread2range) { unbound_thread2range_list.emplace_back(scope, range); } return unbound_thread2range_list; @@ -582,13 +642,28 @@ class CrossThreadReductionTransformer : public StmtMutator { std::tie(reducer, combiner_lhs, combiner_rhs) = GetReducerAndCombinerLhsRhs(NullOpt, init_values, updates); + // Condition 4. All reduction buffers should be all local or all non-local. + int is_local_buf = -1; Array reduction_buffers; reduction_buffers.reserve(updates.size()); for (const BufferStore& buf_store : updates) { reduction_buffers.push_back(buf_store->buffer); + if (buf_store->buffer.scope() == "local") { + CHECK_NE(is_local_buf, 0) + << "ValueError: Cross-thread reduction requires all reduction buffers to be all " + "local or all non-local. However, here some buffer is local while some buffer is " + "shared or global."; + is_local_buf = 1; + } else { + CHECK_NE(is_local_buf, 1) + << "ValueError: Cross-thread reduction requires all reduction buffers to be all " + "local or all non-local. However, here some buffer is local while some buffer is " + "shared or global."; + is_local_buf = 0; + } } - // Condition 4. The block should be the last block under the first reduction-related loop. + // Condition 5. The block should be the last block under the first reduction-related loop. bool visit = false; PreOrderVisit(GetRef(reduction_loops[0]), [block, &visit](const ObjectRef& obj) { if (const auto* realize = obj.as()) { @@ -631,8 +706,6 @@ class CrossThreadReductionTransformer : public StmtMutator { if (scope.rank == 1 && scope.dim_index >= 0) { is_thread_idx = true; ++thread_idx_depth; - thread2range_[scope] = Range::FromMinExtent(loop->min, loop->extent); - thread_loop_var2scope_[loop->loop_var.get()] = scope; } else if (scope.rank == 0) { is_block_idx = true; ++block_idx_depth; @@ -649,7 +722,7 @@ class CrossThreadReductionTransformer : public StmtMutator { --block_idx_depth; } if (is_block_idx || (is_thread_idx && thread_idx_depth == 0 && block_idx_depth == 0)) { - thread2range_.clear(); + crt_buf2threads_.clear(); } // Replace `result` with the pre-stored result if `loop` appears as a key in `loop2new_stmt_`. @@ -716,6 +789,21 @@ class CrossThreadReductionTransformer : public StmtMutator { loop2new_stmt_[reduction_loops[0]] = TransformReductionBlock(realize, it_buffers, ct_buffers, reduction_buffers, wb_indices, reducer, combiner_rhs, reduction_loops); + + // Step 5. Record the reduction thread dims for the write-back buffers. + // The information is used for consumer block broadcasting detection. + std::vector> reduction_threads; + reduction_threads.reserve(reduction_loops.size()); + for (const ForNode* loop : reduction_loops) { + if (loop->thread_binding.defined()) { + reduction_threads.emplace_back( + ThreadScope::Create(loop->thread_binding.value()->thread_tag), + Range::FromMinExtent(loop->min, loop->extent)); + } + } + for (const Buffer& reduction_buf : reduction_buffers) { + crt_buf2threads_[reduction_buf.get()] = reduction_threads; + } } Stmt MakeCrossThreadBroadcast( @@ -792,8 +880,8 @@ class CrossThreadReductionTransformer : public StmtMutator { int block_idx_depth = 0; int thread_idx_depth = 0; - std::unordered_map thread2range_; - std::unordered_map thread_loop_var2scope_; + std::unordered_map>> + crt_buf2threads_; }; PrimFunc LowerCrossThreadReduction(PrimFunc f) { diff --git a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py index 6162233b6583..f42f8ca85f5c 100644 --- a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py +++ b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py @@ -496,6 +496,64 @@ def lowered_single_reduction_loop_with_block_predicate( ) +@T.prim_func +def spatial_reduction_loop_predicate(A: T.Buffer((2, 32), "float32"), B: T.Buffer((2,), "float32")): + for i_0 in range(1): + for i_1 in T.thread_binding(16, thread="threadIdx.y"): + for k_0 in range(1): + for k_1 in T.thread_binding(64, thread="threadIdx.x"): + with T.block("block"): + vi = T.axis.spatial(2, i_0 * 16 + i_1) + vk = T.axis.reduce(32, k_0 * 64 + k_1) + T.where(i_0 * 16 + i_1 < 2 and k_0 * 64 + k_1 < 32) + T.reads(A[vi, vk]) + T.writes(B[vi]) + with T.init(): + B[vi] = T.float32(0) + B[vi] = B[vi] + A[vi, vk] + + +@T.prim_func +def lowered_reduction_spatial_loop_predicate( + A: T.Buffer((2, 32), "float32"), B: T.Buffer((2,), "float32") +): + cross_thread_B = T.alloc_buffer((1,), strides=(1,), scope="local") + in_thread_B = T.alloc_buffer((1,), strides=(1,), scope="local") + for i_0 in range(1): + for i_1 in T.thread_binding(16, thread="threadIdx.y"): + for k_1 in T.thread_binding(64, thread="threadIdx.x"): + with T.block("block_in_thread_init"): + T.reads() + T.writes(in_thread_B[0]) + in_thread_B[0] = T.float32(0) + for k_0 in range(1): + with T.block("block_in_thread"): + vi = T.axis.spatial(2, i_0 * 16 + i_1) + vk = T.axis.reduce(32, k_0 * 64 + k_1) + T.where(i_0 * 16 + i_1 < 2 and k_0 * 64 + k_1 < 32) + T.reads(A[vi, vk]) + T.writes(in_thread_B[0]) + in_thread_B[0] = in_thread_B[0] + A[vi, vk] + with T.block("block_cross_thread"): + T.reads(in_thread_B[0]) + T.writes(cross_thread_B[0]) + T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ) + T.tvm_thread_allreduce( + T.uint32(1), in_thread_B[0], T.bool(True), cross_thread_B[0], k_1 + ) + k_0 = T.int32() + with T.block("block_write_back"): + vi = T.axis.spatial(2, i_0 * 16 + i_1) + T.where(i_0 * 16 + i_1 < 2 and k_1 == 0) + T.reads(cross_thread_B[0]) + T.writes(B[vi]) + B[vi] = cross_thread_B[0] + + @T.prim_func def single_reduction_loop_with_tensorize( input_A: T.Buffer((1, 64, 7, 7, 32), "uint8"), @@ -1315,7 +1373,6 @@ def lowered_thread_broadcast_1(A: T.Buffer((256, 256), "float32"), B: T.Buffer(( ) with T.block("sum_write_back"): vi = T.axis.spatial(256, i) - T.where(k == 0) T.reads(cross_thread_temp_local[0]) T.writes(temp_local[vi]) temp_local[vi] = cross_thread_temp_local[0] @@ -1428,7 +1485,7 @@ def lowered_thread_broadcast_2(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int6 with T.block("NT_matmul_write_back"): v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) v1 = T.axis.spatial(n, ax0_ax1_fused % n) - T.where(ax0_fused == T.int64(0)) + T.where(T.int64(0) <= ax0_ax1_fused // n and ax0_ax1_fused // n < T.int64(32) and T.int64(0) <= ax0_ax1_fused % n and ax0_ax1_fused % n < n) T.reads(cross_thread_var_NT_matmul_intermediate_local[0]) T.writes(var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1]) var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1] = cross_thread_var_NT_matmul_intermediate_local[0] @@ -1442,6 +1499,72 @@ def lowered_thread_broadcast_2(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int6 var_compute_intermediate[T.int64(0), v0, T.int64(0), v1] = T.Cast("float32", T.min(T.max(var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1] * T.float16(0.088397790055248615), T.float16(-65504)), lv1582[T.int64(0), T.int64(0), T.int64(0), v1])) # fmt: on + +@T.prim_func +def no_thread_broadcast(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256, 256), "float32")): + temp_1_local = T.alloc_buffer((256,), scope="local") + temp_2_local = T.alloc_buffer((1,), scope="local") + for i in T.thread_binding(256, thread="blockIdx.x"): + for k in T.thread_binding(256, thread="threadIdx.x"): + with T.block("sum"): + vi, vk = T.axis.remap("SR", [i, k]) + T.reads(A[vi, vk]) + T.writes(temp_1_local[vi]) + with T.init(): + temp_1_local[vi] = T.float32(0) + temp_1_local[vi] = temp_1_local[vi] + A[vi, vk] + with T.block("add"): + vi = T.axis.spatial(256, i) + T.reads(temp_1_local[vi]) + T.writes(temp_2_local[0]) + temp_2_local[0] = temp_1_local[vi] + T.float32(1) + for j in T.thread_binding(256, thread="threadIdx.x"): + with T.block("sum"): + vi, vj = T.axis.remap("SR", [i, j]) + T.reads(temp_2_local[0]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] + temp_2_local[0] + + +@T.prim_func +def lowered_no_thread_broadcast( + A: T.Buffer((256, 256), "float32"), B: T.Buffer((256, 256), "float32") +): + temp_1_local = T.alloc_buffer((256,), scope="local") + temp_2_local = T.alloc_buffer((1,), scope="local") + cross_thread_temp_1_local = T.alloc_buffer((1,), strides=(1,), scope="local") + for i in T.thread_binding(256, thread="blockIdx.x"): + for k in T.thread_binding(256, thread="threadIdx.x"): + with T.block("sum_cross_thread"): + vi, vk = T.axis.remap("SR", [i, k]) + T.reads(A[vi, vk]) + T.writes(cross_thread_temp_1_local[0]) + T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ) + T.tvm_thread_allreduce( + T.uint32(1), A[vi, vk], T.bool(True), cross_thread_temp_1_local[0], k + ) + with T.block("sum_write_back"): + vi = T.axis.spatial(256, i) + T.reads(cross_thread_temp_1_local[0]) + T.writes(temp_1_local[vi]) + temp_1_local[vi] = cross_thread_temp_1_local[0] + with T.block("add"): + vi = T.axis.spatial(256, i) + T.reads(temp_1_local[vi]) + T.writes(temp_2_local[0]) + temp_2_local[0] = temp_1_local[vi] + T.float32(1) + for j in T.thread_binding(256, thread="threadIdx.x"): + with T.block("sum"): + vi, vj = T.axis.remap("SR", [i, j]) + T.reads(temp_2_local[0]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] + temp_2_local[0] + + # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg @@ -1472,6 +1595,10 @@ def test_single_reduction_loop_with_block_predicate(): ) +def test_spatial_reduction_loop_predicate(): + _check(spatial_reduction_loop_predicate, lowered_reduction_spatial_loop_predicate) + + def test_single_reduction_loop_with_tensorize(): _check( single_reduction_loop_with_tensorize, @@ -1534,6 +1661,10 @@ def test_thread_broadcast_rewrite_2(): _check(thread_broadcast_2, lowered_thread_broadcast_2) +def test_no_thread_broadcast_rewrite(): + _check(no_thread_broadcast, lowered_no_thread_broadcast) + + def test_lower_te(): a = te.placeholder((32, 2, 2)) k1 = te.reduce_axis((0, 2), "k1")