From 34a3434fef6343040a3ec417bca21fcc50b0824c Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Thu, 9 Mar 2023 12:43:05 -0800 Subject: [PATCH] fix reverse compute inline --- src/tir/schedule/primitive/compute_inline.cc | 4 +++- .../test_tir_schedule_compute_inline.py | 24 +++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 99286c91b344..ad4aa9ef748e 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -844,9 +844,11 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block NotSingleReadWriteBuffer::GetSingleRead(self, consumer_block, scope_root_sref); // Step 2. Check completeness CheckCompleteBlock(self, consumer_block_sref, scope_root_sref); - // Step 3. Check if the consumer has a single complete producer + // Step 3. Check if the consumer has a single complete producer, and the producer is not an output + // block StmtSRef producer_block_sref = NonSingleProducerError::Check(self, consumer_block_sref, scope_root_sref); + CheckNotOutputBlock(self, producer_block_sref, scope_root_sref); // Step 4. Analyze the block body ReverseComputeInliner inliner(inlined_buffer, producer_block_sref->StmtAs(), consumer_block_realize, scope_root_sref, self->mod); diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index a4c7344909c5..42eb2b6be4e9 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -503,6 +503,21 @@ def matmul_relu(var_A: T.handle, var_B: T.handle, var_compute: T.handle) -> None compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0)) +@T.prim_func +def elementwise_output(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + @T.prim_func def inline_block_with_init( A: T.Buffer((1, 512, 7, 7), "float32"), @@ -1027,6 +1042,15 @@ def test_output_block(use_block_name): with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block) + sch = tir.Schedule(elementwise_output, debug_mask="all") + block = sch.get_block("B") + with pytest.raises(tvm.tir.ScheduleError): + sch.compute_inline(block) + + block = sch.get_block("C") + with pytest.raises(tvm.tir.ScheduleError): + sch.reverse_compute_inline(block) + def test_compute_inline_predicate(use_block_name): sch = tir.Schedule(elementwise_predicate, debug_mask="all")