diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index cc402017e6bc..79dbfb2a022d 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -25,12 +25,31 @@ #include #include +#include "../../runtime/thread_storage_scope.h" +#include "../../support/utils.h" #include "../schedule/analysis.h" #include "./ir_utils.h" namespace tvm { namespace tir { +using runtime::ThreadScope; +using support::StartsWith; + +// Implement a hash and equality function for ThreadScope so that +// ThreadScope can serve as map key class +struct ThreadScopeHash { + size_t operator()(const ThreadScope& scope) const { + return static_cast(scope.rank * 30 + scope.dim_index); + } +}; + +struct ThreadScopeEqual { + bool operator()(const ThreadScope& a, const ThreadScope& b) const { + return a.rank == b.rank && a.dim_index == b.dim_index; + } +}; + /*! * \brief Checks if a loop is bound to threadIdx.x/y/z * \brief loop The loop to be checked @@ -478,6 +497,27 @@ class CrossThreadReductionTransformer : public StmtMutator { return need ? reduction_loops : std::vector{}; } + // Check if the input block needs thread broadcast rewrite. + // One block needs broadcast rewrite when there exists one or more thread + // vars which vars free variables to this block. + std::vector> NeedCrossThreadBroadcast( + const BlockRealizeNode* realize) { + std::unordered_map unbound_thread2range = + thread2range_; + for (const ForNode* loop : loop_stack_) { + if (loop->thread_binding.defined()) { + ThreadScope scope = ThreadScope::Create(loop->thread_binding.value()->thread_tag); + unbound_thread2range.erase(scope); + } + } + + std::vector> unbound_thread2range_list; + for (auto [scope, range] : unbound_thread2range) { + unbound_thread2range_list.emplace_back(scope, range); + } + return unbound_thread2range_list; + } + /*! * \brief Given that the input block needs cross-thread reduction, check if cross-thread reduction * can be applied to the block (i.e., the block satisfies all necessary conditions of cross-thread @@ -578,9 +618,39 @@ class CrossThreadReductionTransformer : public StmtMutator { Stmt VisitStmt_(const ForNode* loop) final { loop_stack_.push_back(loop); loop_range_map_.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + + // Collect loop-thread information: + // - when encountering a threadIdx loop, we keep note of its domain and + // the "loop var -> thread scope" relation, in order to collect all existing + // threads within a thread block. + // - we are careful about thread block boundary for safety. + bool is_block_idx = false; + bool is_thread_idx = false; + if (loop->kind == ForKind::kThreadBinding) { + ThreadScope scope = ThreadScope::Create(loop->thread_binding.value()->thread_tag); + if (scope.rank == 1 && scope.dim_index >= 0) { + is_thread_idx = true; + ++thread_idx_depth; + thread2range_[scope] = Range::FromMinExtent(loop->min, loop->extent); + thread_loop_var2scope_[loop->loop_var.get()] = scope; + } else if (scope.rank == 0) { + is_block_idx = true; + ++block_idx_depth; + } + } + Stmt result = StmtMutator::VisitStmt_(loop); loop_stack_.pop_back(); loop_range_map_.erase(loop->loop_var); + if (is_thread_idx) { + --thread_idx_depth; + } + if (is_block_idx) { + --block_idx_depth; + } + if (is_block_idx || (is_thread_idx && thread_idx_depth == 0 && block_idx_depth == 0)) { + thread2range_.clear(); + } // Replace `result` with the pre-stored result if `loop` appears as a key in `loop2new_stmt_`. auto it = loop2new_stmt_.find(loop); @@ -613,14 +683,11 @@ class CrossThreadReductionTransformer : public StmtMutator { return std::move(new_block); } - Stmt VisitStmt_(const BlockRealizeNode* realize) final { + void MakeCrossThreadReduction(const BlockRealizeNode* realize, + const std::vector reduction_loops) { const BlockNode* block = realize->block.get(); - // Step 1. Check whether cross-thread reduction is needed. If no, skip this block. - std::vector reduction_loops = NeedCrossThreadReduction(realize); - if (reduction_loops.empty()) { - return StmtMutator::VisitStmt_(realize); - } - // Step 2. Check whether cross-thread reduction can be applied. If no, throw an exception on + + // Step 1. Check whether cross-thread reduction can be applied. If no, throw an exception on // which condition the block violates. int n_bound_reduction_loops = 0; CommReducer reducer{nullptr}; @@ -629,13 +696,13 @@ class CrossThreadReductionTransformer : public StmtMutator { Array wb_indices{nullptr}; std::tie(n_bound_reduction_loops, reducer, reduction_buffers, combiner_rhs, wb_indices) = CheckCanApplyCrossThreadReduction(block, reduction_loops); - // Step 3. Before doing the cross-thread reduction, in-thread reduction is needed when + // Step 2. Before doing the cross-thread reduction, in-thread reduction is needed when // - not all the reduction-related loops are bound to thread axes, or // - the block-realize has a non-constant-true predicate. bool need_in_thread_reduction = n_bound_reduction_loops < static_cast(reduction_loops.size()) || !is_one(realize->predicate); - // Step 4. Create intermediate buffers, storing them in `ct_buffers` and + // Step 3. Create intermediate buffers, storing them in `ct_buffers` and // `it_buffers`. Let the scope block allocate these new buffers. Array& new_buffers = block2new_buffers_[block_stack_.back()]; Array ct_buffers = MakeScratchpads(reduction_buffers, /*is_cross_thread_buffer=*/true); @@ -645,16 +712,76 @@ class CrossThreadReductionTransformer : public StmtMutator { it_buffers = MakeScratchpads(reduction_buffers, /*is_cross_thread_buffer=*/false); new_buffers.insert(new_buffers.end(), it_buffers.value().begin(), it_buffers.value().end()); } - // Step 5. Transform. + // Step 4. Transform. loop2new_stmt_[reduction_loops[0]] = TransformReductionBlock(realize, it_buffers, ct_buffers, reduction_buffers, wb_indices, reducer, combiner_rhs, reduction_loops); - // Step 6. Return an empty statement, because the transformation result will be inserted when - // returning to the first reduction-related loop. - return Stmt{nullptr}; + } + + Stmt MakeCrossThreadBroadcast( + const BlockRealizeNode* realize, + const std::vector>& unbound_thread2range) { + // Step 1. Generate loop var for each unbound thread. + // Update the block predicate with clauses of `thread_var == min`. + PrimExpr predicate = realize->predicate; + Array loop_vars; + loop_vars.reserve(unbound_thread2range.size()); + for (auto [scope, range] : unbound_thread2range) { + std::string dim_index(1, static_cast(scope.dim_index + 'x')); + Var loop_var("t" + dim_index, range->min->dtype); + loop_vars.push_back(loop_var); + predicate = (loop_var == range->min) && predicate; + } + + // Step 2. Update the BlockRealize with the new predicate. + ObjectPtr p_realize = make_object(*realize); + p_realize->predicate = std::move(predicate); + + // Step 3. Wrap the updated BlockRealize with the new loops. + Stmt body(p_realize); + for (int i = 0; i < static_cast(unbound_thread2range.size()); ++i) { + std::string dim_index(1, static_cast(unbound_thread2range[i].first.dim_index + 'x')); + body = For( + /*loop_var=*/loop_vars[i], // + /*min=*/unbound_thread2range[i].second->min, // + /*extent=*/unbound_thread2range[i].second->extent, // + /*kind=*/ForKind::kThreadBinding, // + /*body=*/body, // + /*thread_binding=*/ + IterVar(NullValue(), Var(""), IterVarType::kThreadIndex, + "threadIdx." + dim_index)); + } + return body; + } + + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + // Part 1. Check if the block needs cross-thread reduction rewrite. + std::vector reduction_loops = NeedCrossThreadReduction(realize); + if (!reduction_loops.empty()) { + // Return an empty statement, because the transformation result will + // be inserted when returning to the first reduction-related loop. + has_cross_thread_reduction_ = true; + MakeCrossThreadReduction(realize, reduction_loops); + return Stmt{nullptr}; + } + + if (!has_cross_thread_reduction_) { + return StmtMutator::VisitStmt_(realize); + } + + // Part 2. Check if the block needs all-thread broadcasting rewrite. + // We only check this when cross-thread reduction was detected. + std::vector> unbound_thread2range = + NeedCrossThreadBroadcast(realize); + if (!unbound_thread2range.empty()) { + return MakeCrossThreadBroadcast(realize, unbound_thread2range); + } + + return StmtMutator::VisitStmt_(realize); } private: + bool has_cross_thread_reduction_ = false; std::vector statement_stack_; std::vector loop_stack_; std::vector block_stack_; @@ -662,6 +789,11 @@ class CrossThreadReductionTransformer : public StmtMutator { std::unordered_map loop2new_stmt_; Map loop_range_map_; arith::Analyzer analyzer_; + + int block_idx_depth = 0; + int thread_idx_depth = 0; + std::unordered_map thread2range_; + std::unordered_map thread_loop_var2scope_; }; PrimFunc LowerCrossThreadReduction(PrimFunc f) { diff --git a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py index 8b5c21224148..2334fe535076 100644 --- a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py +++ b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py @@ -1274,6 +1274,172 @@ def lowered_layer_norm_tuple_sum( ] +@T.prim_func +def thread_broadcast_1(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256,), "float32")): + temp_local = T.alloc_buffer((256,), scope="local") + for i in T.thread_binding(256, thread="blockIdx.x"): + for k in T.thread_binding(256, thread="threadIdx.x"): + with T.block("sum"): + vi, vk = T.axis.remap("SR", [i, k]) + T.reads(A[vi, vk]) + T.writes(temp_local[vi]) + with T.init(): + temp_local[vi] = T.float32(0) + temp_local[vi] = temp_local[vi] + A[vi, vk] + with T.block("add"): + vi = T.axis.spatial(256, i) + T.reads(temp_local[vi]) + T.writes(B[vi]) + B[vi] = temp_local[vi] + T.float32(1) + + +@T.prim_func +def lowered_thread_broadcast_1(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256,), "float32")): + temp_local = T.alloc_buffer((256,), scope="local") + cross_thread_temp_local = T.alloc_buffer((1,), strides=(1,), scope="local") + for i in T.thread_binding(256, thread="blockIdx.x"): + for k in T.thread_binding(256, thread="threadIdx.x"): + with T.block("sum_cross_thread"): + vi, vk = T.axis.remap("SR", [i, k]) + T.reads(A[vi, vk]) + T.writes(cross_thread_temp_local[0]) + T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ) + T.tvm_thread_allreduce( + T.uint32(1), A[vi, vk], T.bool(True), cross_thread_temp_local[0], k + ) + with T.block("sum_write_back"): + vi = T.axis.spatial(256, i) + T.where(k == 0) + T.reads(cross_thread_temp_local[0]) + T.writes(temp_local[vi]) + temp_local[vi] = cross_thread_temp_local[0] + for tx in T.thread_binding(256, thread="threadIdx.x"): + with T.block("add"): + vi = T.axis.spatial(256, i) + T.where(tx == 0) + T.reads(temp_local[vi]) + T.writes(B[vi]) + B[vi] = temp_local[vi] + T.float32(1) + + +# fmt: off +@T.prim_func +def thread_broadcast_2(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16"), p_lv1606: T.handle, p_lv1582: T.handle, p_output0: T.handle): + n = T.int64() + lv1606 = T.match_buffer(p_lv1606, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") + lv1582 = T.match_buffer(p_lv1582, (T.int64(1), T.int64(1), T.int64(1), n), "float16") + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) + var_NT_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16", scope="local") + var_NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(256), T.int64(1), T.int64(32), T.int64(1), n), "float16", scope="local") + for ax0_ax1_fused in T.thread_binding(n * T.int64(32), thread="blockIdx.x"): + for ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("NT_matmul_rf_init"): + vax2_fused_1 = T.axis.spatial(T.int64(256), ax2_fused_1) + v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) + v1 = T.axis.spatial(n, ax0_ax1_fused % n) + T.reads() + T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]) + var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] = T.float16(0) + for ax2_fused_0 in range(T.int64(1)): + with T.block("NT_matmul_rf_update"): + vax2_fused_1 = T.axis.spatial(T.int64(256), ax2_fused_1) + v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) + v1 = T.axis.spatial(n, ax0_ax1_fused % n) + vax2_fused_0 = T.axis.reduce(T.int64(1), ax2_fused_0) + T.where(ax2_fused_0 * T.int64(256) + ax2_fused_1 < T.int64(128)) + T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1], lv1605[T.int64(0), v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1], lv1606[T.int64(0), v0, v1, vax2_fused_0 * T.int64(256) + vax2_fused_1]) + T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]) + var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] = var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] + lv1605[T.int64(0), v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1] * lv1606[T.int64(0), v0, v1, vax2_fused_0 * T.int64(256) + vax2_fused_1] + for ax1_ax2_fused in range(T.int64(1)): + for ax0_fused in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("NT_matmul"): + vax2_fused_1 = T.axis.reduce(T.int64(256), ax0_fused) + v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) + v1 = T.axis.spatial(n, ax0_ax1_fused % n) + T.where(T.int64(0) <= ax0_ax1_fused // n and ax0_ax1_fused // n < T.int64(32) and T.int64(0) <= ax0_ax1_fused % n and ax0_ax1_fused % n < n) + T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]) + T.writes(var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1]) + with T.init(): + var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1] = T.float16(0) + var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1] = var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1] + var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] + with T.block("compute"): + v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) + v1 = T.axis.spatial(n, ax0_ax1_fused % n) + T.where(T.int64(0) <= ax0_ax1_fused // n and ax0_ax1_fused // n < T.int64(32) and T.int64(0) <= ax0_ax1_fused % n and ax0_ax1_fused % n < n) + T.reads(var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1], lv1582[T.int64(0), T.int64(0), T.int64(0), v1]) + T.writes(var_compute_intermediate[T.int64(0), v0, T.int64(0), v1]) + var_compute_intermediate[T.int64(0), v0, T.int64(0), v1] = T.Cast("float32", T.min(T.max(var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1] * T.float16(0.088397790055248615), T.float16(-65504)), lv1582[T.int64(0), T.int64(0), T.int64(0), v1])) + + +@T.prim_func +def lowered_thread_broadcast_2(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16"), p_lv1606: T.handle, p_lv1582: T.handle, p_output0: T.handle): + n = T.int64() + lv1606 = T.match_buffer(p_lv1606, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") + lv1582 = T.match_buffer(p_lv1582, (T.int64(1), T.int64(1), T.int64(1), n), "float16") + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) + var_NT_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16", scope="local") + var_NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(256), T.int64(1), T.int64(32), T.int64(1), n), "float16", scope="local") + cross_thread_var_NT_matmul_intermediate_local = T.alloc_buffer((1,), "float16", strides=(1,), scope="local") + in_thread_var_NT_matmul_intermediate_local = T.alloc_buffer((1,), "float16", strides=(1,), scope="local") + for ax0_ax1_fused in T.thread_binding(n * T.int64(32), thread="blockIdx.x"): + for ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("NT_matmul_rf_init"): + vax2_fused_1 = T.axis.spatial(T.int64(256), ax2_fused_1) + v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) + v1 = T.axis.spatial(n, ax0_ax1_fused % n) + T.reads() + T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]) + var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] = T.float16(0) + for ax2_fused_0 in range(T.int64(1)): + with T.block("NT_matmul_rf_update"): + vax2_fused_1 = T.axis.spatial(T.int64(256), ax2_fused_1) + v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) + v1 = T.axis.spatial(n, ax0_ax1_fused % n) + vax2_fused_0 = T.axis.reduce(T.int64(1), ax2_fused_0) + T.where(ax2_fused_0 * T.int64(256) + ax2_fused_1 < T.int64(128)) + T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1], lv1605[T.int64(0), v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1], lv1606[T.int64(0), v0, v1, vax2_fused_0 * T.int64(256) + vax2_fused_1]) + T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]) + var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] = var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] + lv1605[T.int64(0), v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1] * lv1606[T.int64(0), v0, v1, vax2_fused_0 * T.int64(256) + vax2_fused_1] + for ax1_ax2_fused in range(T.int64(1)): + for ax0_fused in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("NT_matmul_in_thread_init"): + T.reads() + T.writes(in_thread_var_NT_matmul_intermediate_local[0]) + in_thread_var_NT_matmul_intermediate_local[0] = T.float16(0) + with T.block("NT_matmul_in_thread"): + vax2_fused_1 = T.axis.reduce(T.int64(256), ax0_fused) + v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) + v1 = T.axis.spatial(n, ax0_ax1_fused % n) + T.where(T.int64(0) <= ax0_ax1_fused // n and ax0_ax1_fused // n < T.int64(32) and T.int64(0) <= ax0_ax1_fused % n and ax0_ax1_fused % n < n) + T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]) + T.writes(in_thread_var_NT_matmul_intermediate_local[0]) + in_thread_var_NT_matmul_intermediate_local[0] = in_thread_var_NT_matmul_intermediate_local[0] + var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] + with T.block("NT_matmul_cross_thread"): + T.reads(in_thread_var_NT_matmul_intermediate_local[0]) + T.writes(cross_thread_var_NT_matmul_intermediate_local[0]) + T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float16(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))) + T.tvm_thread_allreduce(T.uint32(1), in_thread_var_NT_matmul_intermediate_local[0], T.bool(True), cross_thread_var_NT_matmul_intermediate_local[0], ax0_fused) + with T.block("NT_matmul_write_back"): + v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) + v1 = T.axis.spatial(n, ax0_ax1_fused % n) + T.where(ax0_fused == T.int64(0)) + T.reads(cross_thread_var_NT_matmul_intermediate_local[0]) + T.writes(var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1]) + var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1] = cross_thread_var_NT_matmul_intermediate_local[0] + for tx in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("compute"): + v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) + v1 = T.axis.spatial(n, ax0_ax1_fused % n) + T.where(tx == T.int64(0) and (T.int64(0) <= ax0_ax1_fused // n and ax0_ax1_fused // n < T.int64(32) and T.int64(0) <= ax0_ax1_fused % n and ax0_ax1_fused % n < n)) + T.reads(var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1], lv1582[T.int64(0), T.int64(0), T.int64(0), v1]) + T.writes(var_compute_intermediate[T.int64(0), v0, T.int64(0), v1]) + var_compute_intermediate[T.int64(0), v0, T.int64(0), v1] = T.Cast("float32", T.min(T.max(var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1] * T.float16(0.088397790055248615), T.float16(-65504)), lv1582[T.int64(0), T.int64(0), T.int64(0), v1])) +# fmt: on + # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg @@ -1358,6 +1524,14 @@ def test_argmin_split_init_update_reordered(): _check(argmin_split_init_update_reordered, lowered_argmin_split_init_update_reordered) +def test_thread_broadcast_rewrite_1(): + _check(thread_broadcast_1, lowered_thread_broadcast_1) + + +def test_thread_broadcast_rewrite_2(): + _check(thread_broadcast_2, lowered_thread_broadcast_2) + + def test_lower_te(): a = te.placeholder((32, 2, 2)) k1 = te.reduce_axis((0, 2), "k1")