From 460736cf3a5414b2da49cbf30d9500e54fb9ebf1 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 1 Jul 2023 21:20:11 -0400 Subject: [PATCH] [TIR] Support cross-threaad reduction lowering with thread-broadcasting rewrite This PR enhances the LowerCrossThreadReduction pass with the thread-broadcasting block rewrite. Specifically, previously whenever a TIR block has thread-broadcast behavior (i.e., there exists some thread var which is free for the block), we never insert a predicate for the block and therefore the generated final code has race condition, which sometimes lead to wrong computation results. This PR enhances the pass by collecting thread var information along transformation, and rewrite the thread-broadcast TIR block with additional predicate clauses which bound the thread vars and effectively state that "only execute the block when `thread_var == 0`". Therefore, the race condition issue in such blocks is resolved. --- .../lower_cross_thread_reduction.cc | 158 ++++++++++++++-- ..._transform_lower_cross_thread_reduction.py | 174 ++++++++++++++++++ 2 files changed, 319 insertions(+), 13 deletions(-) diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index cc402017e6bc..79dbfb2a022d 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -25,12 +25,31 @@ #include #include +#include "../../runtime/thread_storage_scope.h" +#include "../../support/utils.h" #include "../schedule/analysis.h" #include "./ir_utils.h" namespace tvm { namespace tir { +using runtime::ThreadScope; +using support::StartsWith; + +// Implement a hash and equality function for ThreadScope so that +// ThreadScope can serve as map key class +struct ThreadScopeHash { + size_t operator()(const ThreadScope& scope) const { + return static_cast(scope.rank * 30 + scope.dim_index); + } +}; + +struct ThreadScopeEqual { + bool operator()(const ThreadScope& a, const ThreadScope& b) const { + return a.rank == b.rank && a.dim_index == b.dim_index; + } +}; + /*! * \brief Checks if a loop is bound to threadIdx.x/y/z * \brief loop The loop to be checked @@ -478,6 +497,27 @@ class CrossThreadReductionTransformer : public StmtMutator { return need ? reduction_loops : std::vector{}; } + // Check if the input block needs thread broadcast rewrite. + // One block needs broadcast rewrite when there exists one or more thread + // vars which vars free variables to this block. + std::vector> NeedCrossThreadBroadcast( + const BlockRealizeNode* realize) { + std::unordered_map unbound_thread2range = + thread2range_; + for (const ForNode* loop : loop_stack_) { + if (loop->thread_binding.defined()) { + ThreadScope scope = ThreadScope::Create(loop->thread_binding.value()->thread_tag); + unbound_thread2range.erase(scope); + } + } + + std::vector> unbound_thread2range_list; + for (auto [scope, range] : unbound_thread2range) { + unbound_thread2range_list.emplace_back(scope, range); + } + return unbound_thread2range_list; + } + /*! * \brief Given that the input block needs cross-thread reduction, check if cross-thread reduction * can be applied to the block (i.e., the block satisfies all necessary conditions of cross-thread @@ -578,9 +618,39 @@ class CrossThreadReductionTransformer : public StmtMutator { Stmt VisitStmt_(const ForNode* loop) final { loop_stack_.push_back(loop); loop_range_map_.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + + // Collect loop-thread information: + // - when encountering a threadIdx loop, we keep note of its domain and + // the "loop var -> thread scope" relation, in order to collect all existing + // threads within a thread block. + // - we are careful about thread block boundary for safety. + bool is_block_idx = false; + bool is_thread_idx = false; + if (loop->kind == ForKind::kThreadBinding) { + ThreadScope scope = ThreadScope::Create(loop->thread_binding.value()->thread_tag); + if (scope.rank == 1 && scope.dim_index >= 0) { + is_thread_idx = true; + ++thread_idx_depth; + thread2range_[scope] = Range::FromMinExtent(loop->min, loop->extent); + thread_loop_var2scope_[loop->loop_var.get()] = scope; + } else if (scope.rank == 0) { + is_block_idx = true; + ++block_idx_depth; + } + } + Stmt result = StmtMutator::VisitStmt_(loop); loop_stack_.pop_back(); loop_range_map_.erase(loop->loop_var); + if (is_thread_idx) { + --thread_idx_depth; + } + if (is_block_idx) { + --block_idx_depth; + } + if (is_block_idx || (is_thread_idx && thread_idx_depth == 0 && block_idx_depth == 0)) { + thread2range_.clear(); + } // Replace `result` with the pre-stored result if `loop` appears as a key in `loop2new_stmt_`. auto it = loop2new_stmt_.find(loop); @@ -613,14 +683,11 @@ class CrossThreadReductionTransformer : public StmtMutator { return std::move(new_block); } - Stmt VisitStmt_(const BlockRealizeNode* realize) final { + void MakeCrossThreadReduction(const BlockRealizeNode* realize, + const std::vector reduction_loops) { const BlockNode* block = realize->block.get(); - // Step 1. Check whether cross-thread reduction is needed. If no, skip this block. - std::vector reduction_loops = NeedCrossThreadReduction(realize); - if (reduction_loops.empty()) { - return StmtMutator::VisitStmt_(realize); - } - // Step 2. Check whether cross-thread reduction can be applied. If no, throw an exception on + + // Step 1. Check whether cross-thread reduction can be applied. If no, throw an exception on // which condition the block violates. int n_bound_reduction_loops = 0; CommReducer reducer{nullptr}; @@ -629,13 +696,13 @@ class CrossThreadReductionTransformer : public StmtMutator { Array wb_indices{nullptr}; std::tie(n_bound_reduction_loops, reducer, reduction_buffers, combiner_rhs, wb_indices) = CheckCanApplyCrossThreadReduction(block, reduction_loops); - // Step 3. Before doing the cross-thread reduction, in-thread reduction is needed when + // Step 2. Before doing the cross-thread reduction, in-thread reduction is needed when // - not all the reduction-related loops are bound to thread axes, or // - the block-realize has a non-constant-true predicate. bool need_in_thread_reduction = n_bound_reduction_loops < static_cast(reduction_loops.size()) || !is_one(realize->predicate); - // Step 4. Create intermediate buffers, storing them in `ct_buffers` and + // Step 3. Create intermediate buffers, storing them in `ct_buffers` and // `it_buffers`. Let the scope block allocate these new buffers. Array& new_buffers = block2new_buffers_[block_stack_.back()]; Array ct_buffers = MakeScratchpads(reduction_buffers, /*is_cross_thread_buffer=*/true); @@ -645,16 +712,76 @@ class CrossThreadReductionTransformer : public StmtMutator { it_buffers = MakeScratchpads(reduction_buffers, /*is_cross_thread_buffer=*/false); new_buffers.insert(new_buffers.end(), it_buffers.value().begin(), it_buffers.value().end()); } - // Step 5. Transform. + // Step 4. Transform. loop2new_stmt_[reduction_loops[0]] = TransformReductionBlock(realize, it_buffers, ct_buffers, reduction_buffers, wb_indices, reducer, combiner_rhs, reduction_loops); - // Step 6. Return an empty statement, because the transformation result will be inserted when - // returning to the first reduction-related loop. - return Stmt{nullptr}; + } + + Stmt MakeCrossThreadBroadcast( + const BlockRealizeNode* realize, + const std::vector>& unbound_thread2range) { + // Step 1. Generate loop var for each unbound thread. + // Update the block predicate with clauses of `thread_var == min`. + PrimExpr predicate = realize->predicate; + Array loop_vars; + loop_vars.reserve(unbound_thread2range.size()); + for (auto [scope, range] : unbound_thread2range) { + std::string dim_index(1, static_cast(scope.dim_index + 'x')); + Var loop_var("t" + dim_index, range->min->dtype); + loop_vars.push_back(loop_var); + predicate = (loop_var == range->min) && predicate; + } + + // Step 2. Update the BlockRealize with the new predicate. + ObjectPtr p_realize = make_object(*realize); + p_realize->predicate = std::move(predicate); + + // Step 3. Wrap the updated BlockRealize with the new loops. + Stmt body(p_realize); + for (int i = 0; i < static_cast(unbound_thread2range.size()); ++i) { + std::string dim_index(1, static_cast(unbound_thread2range[i].first.dim_index + 'x')); + body = For( + /*loop_var=*/loop_vars[i], // + /*min=*/unbound_thread2range[i].second->min, // + /*extent=*/unbound_thread2range[i].second->extent, // + /*kind=*/ForKind::kThreadBinding, // + /*body=*/body, // + /*thread_binding=*/ + IterVar(NullValue(), Var(""), IterVarType::kThreadIndex, + "threadIdx." + dim_index)); + } + return body; + } + + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + // Part 1. Check if the block needs cross-thread reduction rewrite. + std::vector reduction_loops = NeedCrossThreadReduction(realize); + if (!reduction_loops.empty()) { + // Return an empty statement, because the transformation result will + // be inserted when returning to the first reduction-related loop. + has_cross_thread_reduction_ = true; + MakeCrossThreadReduction(realize, reduction_loops); + return Stmt{nullptr}; + } + + if (!has_cross_thread_reduction_) { + return StmtMutator::VisitStmt_(realize); + } + + // Part 2. Check if the block needs all-thread broadcasting rewrite. + // We only check this when cross-thread reduction was detected. + std::vector> unbound_thread2range = + NeedCrossThreadBroadcast(realize); + if (!unbound_thread2range.empty()) { + return MakeCrossThreadBroadcast(realize, unbound_thread2range); + } + + return StmtMutator::VisitStmt_(realize); } private: + bool has_cross_thread_reduction_ = false; std::vector statement_stack_; std::vector loop_stack_; std::vector block_stack_; @@ -662,6 +789,11 @@ class CrossThreadReductionTransformer : public StmtMutator { std::unordered_map loop2new_stmt_; Map loop_range_map_; arith::Analyzer analyzer_; + + int block_idx_depth = 0; + int thread_idx_depth = 0; + std::unordered_map thread2range_; + std::unordered_map thread_loop_var2scope_; }; PrimFunc LowerCrossThreadReduction(PrimFunc f) { diff --git a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py index 8b5c21224148..2334fe535076 100644 --- a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py +++ b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py @@ -1274,6 +1274,172 @@ def lowered_layer_norm_tuple_sum( ] +@T.prim_func +def thread_broadcast_1(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256,), "float32")): + temp_local = T.alloc_buffer((256,), scope="local") + for i in T.thread_binding(256, thread="blockIdx.x"): + for k in T.thread_binding(256, thread="threadIdx.x"): + with T.block("sum"): + vi, vk = T.axis.remap("SR", [i, k]) + T.reads(A[vi, vk]) + T.writes(temp_local[vi]) + with T.init(): + temp_local[vi] = T.float32(0) + temp_local[vi] = temp_local[vi] + A[vi, vk] + with T.block("add"): + vi = T.axis.spatial(256, i) + T.reads(temp_local[vi]) + T.writes(B[vi]) + B[vi] = temp_local[vi] + T.float32(1) + + +@T.prim_func +def lowered_thread_broadcast_1(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256,), "float32")): + temp_local = T.alloc_buffer((256,), scope="local") + cross_thread_temp_local = T.alloc_buffer((1,), strides=(1,), scope="local") + for i in T.thread_binding(256, thread="blockIdx.x"): + for k in T.thread_binding(256, thread="threadIdx.x"): + with T.block("sum_cross_thread"): + vi, vk = T.axis.remap("SR", [i, k]) + T.reads(A[vi, vk]) + T.writes(cross_thread_temp_local[0]) + T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ) + T.tvm_thread_allreduce( + T.uint32(1), A[vi, vk], T.bool(True), cross_thread_temp_local[0], k + ) + with T.block("sum_write_back"): + vi = T.axis.spatial(256, i) + T.where(k == 0) + T.reads(cross_thread_temp_local[0]) + T.writes(temp_local[vi]) + temp_local[vi] = cross_thread_temp_local[0] + for tx in T.thread_binding(256, thread="threadIdx.x"): + with T.block("add"): + vi = T.axis.spatial(256, i) + T.where(tx == 0) + T.reads(temp_local[vi]) + T.writes(B[vi]) + B[vi] = temp_local[vi] + T.float32(1) + + +# fmt: off +@T.prim_func +def thread_broadcast_2(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16"), p_lv1606: T.handle, p_lv1582: T.handle, p_output0: T.handle): + n = T.int64() + lv1606 = T.match_buffer(p_lv1606, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") + lv1582 = T.match_buffer(p_lv1582, (T.int64(1), T.int64(1), T.int64(1), n), "float16") + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) + var_NT_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16", scope="local") + var_NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(256), T.int64(1), T.int64(32), T.int64(1), n), "float16", scope="local") + for ax0_ax1_fused in T.thread_binding(n * T.int64(32), thread="blockIdx.x"): + for ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("NT_matmul_rf_init"): + vax2_fused_1 = T.axis.spatial(T.int64(256), ax2_fused_1) + v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) + v1 = T.axis.spatial(n, ax0_ax1_fused % n) + T.reads() + T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]) + var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] = T.float16(0) + for ax2_fused_0 in range(T.int64(1)): + with T.block("NT_matmul_rf_update"): + vax2_fused_1 = T.axis.spatial(T.int64(256), ax2_fused_1) + v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) + v1 = T.axis.spatial(n, ax0_ax1_fused % n) + vax2_fused_0 = T.axis.reduce(T.int64(1), ax2_fused_0) + T.where(ax2_fused_0 * T.int64(256) + ax2_fused_1 < T.int64(128)) + T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1], lv1605[T.int64(0), v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1], lv1606[T.int64(0), v0, v1, vax2_fused_0 * T.int64(256) + vax2_fused_1]) + T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]) + var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] = var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] + lv1605[T.int64(0), v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1] * lv1606[T.int64(0), v0, v1, vax2_fused_0 * T.int64(256) + vax2_fused_1] + for ax1_ax2_fused in range(T.int64(1)): + for ax0_fused in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("NT_matmul"): + vax2_fused_1 = T.axis.reduce(T.int64(256), ax0_fused) + v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) + v1 = T.axis.spatial(n, ax0_ax1_fused % n) + T.where(T.int64(0) <= ax0_ax1_fused // n and ax0_ax1_fused // n < T.int64(32) and T.int64(0) <= ax0_ax1_fused % n and ax0_ax1_fused % n < n) + T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]) + T.writes(var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1]) + with T.init(): + var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1] = T.float16(0) + var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1] = var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1] + var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] + with T.block("compute"): + v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) + v1 = T.axis.spatial(n, ax0_ax1_fused % n) + T.where(T.int64(0) <= ax0_ax1_fused // n and ax0_ax1_fused // n < T.int64(32) and T.int64(0) <= ax0_ax1_fused % n and ax0_ax1_fused % n < n) + T.reads(var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1], lv1582[T.int64(0), T.int64(0), T.int64(0), v1]) + T.writes(var_compute_intermediate[T.int64(0), v0, T.int64(0), v1]) + var_compute_intermediate[T.int64(0), v0, T.int64(0), v1] = T.Cast("float32", T.min(T.max(var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1] * T.float16(0.088397790055248615), T.float16(-65504)), lv1582[T.int64(0), T.int64(0), T.int64(0), v1])) + + +@T.prim_func +def lowered_thread_broadcast_2(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16"), p_lv1606: T.handle, p_lv1582: T.handle, p_output0: T.handle): + n = T.int64() + lv1606 = T.match_buffer(p_lv1606, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") + lv1582 = T.match_buffer(p_lv1582, (T.int64(1), T.int64(1), T.int64(1), n), "float16") + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) + var_NT_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16", scope="local") + var_NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(256), T.int64(1), T.int64(32), T.int64(1), n), "float16", scope="local") + cross_thread_var_NT_matmul_intermediate_local = T.alloc_buffer((1,), "float16", strides=(1,), scope="local") + in_thread_var_NT_matmul_intermediate_local = T.alloc_buffer((1,), "float16", strides=(1,), scope="local") + for ax0_ax1_fused in T.thread_binding(n * T.int64(32), thread="blockIdx.x"): + for ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("NT_matmul_rf_init"): + vax2_fused_1 = T.axis.spatial(T.int64(256), ax2_fused_1) + v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) + v1 = T.axis.spatial(n, ax0_ax1_fused % n) + T.reads() + T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]) + var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] = T.float16(0) + for ax2_fused_0 in range(T.int64(1)): + with T.block("NT_matmul_rf_update"): + vax2_fused_1 = T.axis.spatial(T.int64(256), ax2_fused_1) + v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) + v1 = T.axis.spatial(n, ax0_ax1_fused % n) + vax2_fused_0 = T.axis.reduce(T.int64(1), ax2_fused_0) + T.where(ax2_fused_0 * T.int64(256) + ax2_fused_1 < T.int64(128)) + T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1], lv1605[T.int64(0), v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1], lv1606[T.int64(0), v0, v1, vax2_fused_0 * T.int64(256) + vax2_fused_1]) + T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]) + var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] = var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] + lv1605[T.int64(0), v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1] * lv1606[T.int64(0), v0, v1, vax2_fused_0 * T.int64(256) + vax2_fused_1] + for ax1_ax2_fused in range(T.int64(1)): + for ax0_fused in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("NT_matmul_in_thread_init"): + T.reads() + T.writes(in_thread_var_NT_matmul_intermediate_local[0]) + in_thread_var_NT_matmul_intermediate_local[0] = T.float16(0) + with T.block("NT_matmul_in_thread"): + vax2_fused_1 = T.axis.reduce(T.int64(256), ax0_fused) + v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) + v1 = T.axis.spatial(n, ax0_ax1_fused % n) + T.where(T.int64(0) <= ax0_ax1_fused // n and ax0_ax1_fused // n < T.int64(32) and T.int64(0) <= ax0_ax1_fused % n and ax0_ax1_fused % n < n) + T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]) + T.writes(in_thread_var_NT_matmul_intermediate_local[0]) + in_thread_var_NT_matmul_intermediate_local[0] = in_thread_var_NT_matmul_intermediate_local[0] + var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] + with T.block("NT_matmul_cross_thread"): + T.reads(in_thread_var_NT_matmul_intermediate_local[0]) + T.writes(cross_thread_var_NT_matmul_intermediate_local[0]) + T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float16(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))) + T.tvm_thread_allreduce(T.uint32(1), in_thread_var_NT_matmul_intermediate_local[0], T.bool(True), cross_thread_var_NT_matmul_intermediate_local[0], ax0_fused) + with T.block("NT_matmul_write_back"): + v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) + v1 = T.axis.spatial(n, ax0_ax1_fused % n) + T.where(ax0_fused == T.int64(0)) + T.reads(cross_thread_var_NT_matmul_intermediate_local[0]) + T.writes(var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1]) + var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1] = cross_thread_var_NT_matmul_intermediate_local[0] + for tx in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("compute"): + v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) + v1 = T.axis.spatial(n, ax0_ax1_fused % n) + T.where(tx == T.int64(0) and (T.int64(0) <= ax0_ax1_fused // n and ax0_ax1_fused // n < T.int64(32) and T.int64(0) <= ax0_ax1_fused % n and ax0_ax1_fused % n < n)) + T.reads(var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1], lv1582[T.int64(0), T.int64(0), T.int64(0), v1]) + T.writes(var_compute_intermediate[T.int64(0), v0, T.int64(0), v1]) + var_compute_intermediate[T.int64(0), v0, T.int64(0), v1] = T.Cast("float32", T.min(T.max(var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1] * T.float16(0.088397790055248615), T.float16(-65504)), lv1582[T.int64(0), T.int64(0), T.int64(0), v1])) +# fmt: on + # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg @@ -1358,6 +1524,14 @@ def test_argmin_split_init_update_reordered(): _check(argmin_split_init_update_reordered, lowered_argmin_split_init_update_reordered) +def test_thread_broadcast_rewrite_1(): + _check(thread_broadcast_1, lowered_thread_broadcast_1) + + +def test_thread_broadcast_rewrite_2(): + _check(thread_broadcast_2, lowered_thread_broadcast_2) + + def test_lower_te(): a = te.placeholder((32, 2, 2)) k1 = te.reduce_axis((0, 2), "k1")