From f8c213211f194c014ffd7c69b644ae1b2875a9e9 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Sun, 28 Apr 2024 04:47:45 +0000 Subject: [PATCH] make compute-ated block simple when the predicate could be merged as static loop domain --- src/tir/schedule/primitive/compute_at.cc | 4 + .../schedule/primitive/decompose_padding.cc | 9 --- .../test_tir_schedule_compute_at.py | 74 +++++++++++++++++++ 3 files changed, 78 insertions(+), 9 deletions(-) diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index fc388b004843..56d85318d7bc 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -224,6 +224,10 @@ struct BlockVarDomainInfo { analyzer->CanProveEqual(bound.max(), intersect.max())) { dom = bound; bound = arith::IntSet::Nothing(); + } else if (is_const_int(intersect.min()) && is_const_int(intersect.max())) { + // if the bound induce constant iter range, merge bound to loop domain + dom = intersect; + bound = arith::IntSet::Nothing(); } } }; diff --git a/src/tir/schedule/primitive/decompose_padding.cc b/src/tir/schedule/primitive/decompose_padding.cc index 50b978f0127b..299bc9a62d5a 100644 --- a/src/tir/schedule/primitive/decompose_padding.cc +++ b/src/tir/schedule/primitive/decompose_padding.cc @@ -393,15 +393,6 @@ class DecomposePaddingBlockReplacer : public StmtMutator { return std::move(new_loop); } - Stmt VisitStmt_(const SeqStmtNode* seq) final { - Array new_stmts; - new_stmts.reserve(seq->seq.size()); - for (const Stmt& old_stmt : seq->seq) { - new_stmts.push_back(VisitStmt(old_stmt)); - } - return SeqStmt::Flatten(new_stmts); - } - private: const ReplaceDesc& desc_; }; diff --git a/tests/python/tir-schedule/test_tir_schedule_compute_at.py b/tests/python/tir-schedule/test_tir_schedule_compute_at.py index 963d9586bcaa..2c44c9b29569 100644 --- a/tests/python/tir-schedule/test_tir_schedule_compute_at.py +++ b/tests/python/tir-schedule/test_tir_schedule_compute_at.py @@ -1915,5 +1915,79 @@ def expected(A: T.Buffer((32, 1, 128), "float32"), b: T.handle, c: T.handle): ) +def test_compute_at_sliced_concatenate(): + @T.prim_func + def before(): + X = T.alloc_buffer((1, 16, 28, 64), "float32") + Y = T.alloc_buffer((1, 32, 28, 64), "float32") + Z = T.alloc_buffer((1, 53, 28, 64), "float32") + Concat = T.alloc_buffer((1, 101, 28, 64), "float32") + Slice = T.alloc_buffer((1, 87, 28, 64), "float32") + for ax0, ax1, ax2, ax3 in T.grid(1, 16, 28, 64): + with T.block("compute"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + X[v_ax0, v_ax1, v_ax2, v_ax3] = 1.0 + for ax0, ax1, ax2, ax3 in T.grid(1, 101, 28, 64): + with T.block("T_concat"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + Concat[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else( + 85 <= v_ax1, + X[v_ax0, v_ax1 - 85, v_ax2, v_ax3], + T.if_then_else( + 53 <= v_ax1, + Y[v_ax0, v_ax1 - 53, v_ax2, v_ax3], + Z[v_ax0, v_ax1, v_ax2, v_ax3], + ), + ) + for ax0, ax1, ax2, ax3 in T.grid(1, 87, 28, 64): + with T.block("T_strided_slice"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + Slice[v_ax0, v_ax1, v_ax2, v_ax3] = Concat[v_ax0, v_ax1, v_ax2, v_ax3] + + @T.prim_func + def expect(): + X = T.alloc_buffer((1, 16, 28, 64)) + Y = T.alloc_buffer((1, 32, 28, 64)) + Z = T.alloc_buffer((1, 53, 28, 64)) + Concat = T.alloc_buffer((1, 101, 28, 64)) + Slice = T.alloc_buffer((1, 87, 28, 64)) + for ax0 in range(1): + for ax0_1, ax1, ax2 in T.grid(2, 28, 64): + with T.block("compute"): + v_ax0 = T.axis.spatial(1, 0) + v_ax1 = T.axis.spatial(16, ax0_1) + v_ax2, v_ax3 = T.axis.remap("SS", [ax1, ax2]) + X[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(1) + for ax0_1, ax1, ax2 in T.grid(87, 28, 64): + with T.block("T_concat"): + v_ax0 = T.axis.spatial(1, 0) + v_ax1 = T.axis.spatial(101, ax0_1) + v_ax2, v_ax3 = T.axis.remap("SS", [ax1, ax2]) + Concat[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else( + 85 <= v_ax1, + X[v_ax0, v_ax1 - 85, v_ax2, v_ax3], + T.if_then_else( + 53 <= v_ax1, + Y[v_ax0, v_ax1 - 53, v_ax2, v_ax3], + Z[v_ax0, v_ax1, v_ax2, v_ax3], + ), + ) + for ax1, ax2, ax3 in T.grid(87, 28, 64): + with T.block("T_strided_slice"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + Slice[v_ax0, v_ax1, v_ax2, v_ax3] = Concat[v_ax0, v_ax1, v_ax2, v_ax3] + + sch = tir.Schedule(before, debug_mask="all") + blk1 = sch.get_block("compute") + blk2 = sch.get_block("T_concat") + blk3 = sch.get_block("T_strided_slice") + loop = sch.get_loops(blk3)[0] + sch.compute_at(blk2, loop) + sch.compute_at(blk1, loop) + after = sch.mod["main"] + assert_structural_equal_ignore_global_symbol(expect, after) + verify_trace_roundtrip(sch=sch, mod=before) + + if __name__ == "__main__": tvm.testing.main()