From 1d96587091053af543831a078e919b5baf3b5b25 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Fri, 3 Nov 2023 22:32:55 +0800 Subject: [PATCH 1/2] [TIR] Fix pass RenewDefs error in gather/take case Pervious implementation of RenewDefs pass will fail in the case of the gather/take function. This is because the pass visit and renew the read/write regions twice. This PR fixes it and adds a regression test. --- src/tir/transforms/renew_defs.cc | 10 ++++++--- tests/python/unittest/test_tir_renew_defs.py | 22 +++++++++++++++++--- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/src/tir/transforms/renew_defs.cc b/src/tir/transforms/renew_defs.cc index fd2c27dcd154..8a122f892204 100644 --- a/src/tir/transforms/renew_defs.cc +++ b/src/tir/transforms/renew_defs.cc @@ -120,9 +120,11 @@ class RenewDefMutator : public StmtExprMutator { std::bind(&RenewDefMutator::VisitMatchBuffer, this, std::placeholders::_1)); // Step 3. Visit body - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - ICHECK(op); + Optional init = NullOpt; + if (op->init.defined()) { + init = this->VisitStmt(op->init.value()); + } + Stmt body = this->VisitStmt(op->body); // Step 4. Revisit access region Array reads = @@ -137,6 +139,8 @@ class RenewDefMutator : public StmtExprMutator { n->match_buffers = std::move(match_buffers); n->reads = std::move(reads); n->writes = std::move(writes); + n->body = std::move(body); + n->init = std::move(init); return Stmt(n); } diff --git a/tests/python/unittest/test_tir_renew_defs.py b/tests/python/unittest/test_tir_renew_defs.py index 3f286a241cfc..22f7b65ca17b 100644 --- a/tests/python/unittest/test_tir_renew_defs.py +++ b/tests/python/unittest/test_tir_renew_defs.py @@ -15,9 +15,6 @@ # specific language governing permissions and limitations # under the License. -import sys - -import pytest import tvm import tvm.testing from tvm.script import tir as T @@ -187,5 +184,24 @@ def main(a: T.handle, b: T.handle): assert f1.buffer_map[f1.params[1]].shape[0] != f2.buffer_map[f2.params[1]].shape[0] +def test_gather(): + @T.prim_func(private=True) + def take( + A: T.Buffer((4096, 4096), "float16"), + B: T.Buffer((1,), "int32"), + T_take: T.Buffer((1, 4096), "float16"), + ): + for ax0, ax1 in T.grid(1, 4096): + with T.block("T_take"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[B[v_ax0], v_ax1], B[v_ax0]) + T.writes(T_take[v_ax0, v_ax1]) + T_take[v_ax0, v_ax1] = A[B[v_ax0], v_ax1] + + f1 = take + f2 = tvm.tir.stmt_functor.renew_defs(take) + tvm.ir.assert_structural_equal(f1, f2) + + if __name__ == "__main__": tvm.testing.main() From d0ea27fa5d6e585f6e94342d187294277ee9caf8 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Fri, 3 Nov 2023 22:35:10 +0800 Subject: [PATCH 2/2] [Unity] Fix FuseTIR pass for gather/take cases The current implementation of FuseTIR pass does not handle the buffer access region of the blocks, which may fail when the function is in gather or take pattern. This PR fixes the issue. --- src/relax/transform/fuse_tir.cc | 28 ++++-- tests/python/relax/test_transform_fuse_tir.py | 90 ++++++++++++++++++- 2 files changed, 112 insertions(+), 6 deletions(-) diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 2fb3f1d8cee4..df3c85c05ce1 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -202,21 +202,27 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { auto f_mutate_match_buffers = [this](const MatchBufferRegion& match_buffer) { const Buffer& src_buffer = SubstituteBuffer(match_buffer->source->buffer); const Buffer& tgt_buffer = SubstituteAllocatedBuffer(match_buffer->buffer); + Region region = MutateRegion(match_buffer->source->region); if (src_buffer.same_as(match_buffer->source->buffer) && - tgt_buffer.same_as(match_buffer->buffer)) { + tgt_buffer.same_as(match_buffer->buffer) && + region.same_as(match_buffer->source->region)) { return match_buffer; } else { auto n = make_object(*match_buffer.get()); n->buffer = tgt_buffer; - n->source = BufferRegion(src_buffer, match_buffer->source->region); + n->source = BufferRegion(src_buffer, region); return MatchBufferRegion(n); } }; auto f_mutate_read_write_region = [this](const BufferRegion& buffer_region) { - auto it = buffer_remap_.find(buffer_region->buffer); - return it == buffer_remap_.end() ? buffer_region - : BufferRegion((*it).second, buffer_region->region); + const Buffer& buffer = SubstituteBuffer(buffer_region->buffer); + const Region& region = MutateRegion(buffer_region->region); + if (buffer.same_as(buffer_region->buffer) && region.same_as(buffer_region->region)) { + return buffer_region; + } else { + return BufferRegion(buffer, region); + } }; // Step 1. Mutate `match_buffers`. @@ -285,6 +291,18 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { return buffer; } } + + inline Region MutateRegion(const Region& region) { + return MutateArray(region, [this](const Range& range) { + const PrimExpr& min = this->VisitExpr(range->min); + const PrimExpr& extent = this->VisitExpr(range->extent); + if (min.same_as(range->min) && extent.same_as(range->extent)) { + return range; + } else { + return Range::FromMinExtent(min, extent); + } + }); + } }; /*! \brief A mutator which detect block name duplication and deduplicate the names. */ diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index 556b673e61e9..dc2421d64afe 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -1538,7 +1538,6 @@ def sum_1d( Y: T.Buffer([T.int64(1)], "float32"), num_elements: T.int64, ): - X = T.match_buffer(X_handle, [num_elements], "float32") for i in range(num_elements): @@ -1603,5 +1602,94 @@ def main( _check(Before, Expected) +def test_gather(): + @I.ir_module + class Before: + @T.prim_func(private=True) + def add( + A: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), + Out: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), + ): + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + Out[vi, vj] = A[vi, vj] + T.float16(1.0) + + @T.prim_func(private=True) + def take( + A: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), + B: T.Buffer((T.int64(1),), "int32"), + T_take: T.Buffer((T.int64(1), T.int64(4096)), "float16"), + ): + for ax0, ax1 in T.grid(T.int64(1), T.int64(4096)): + with T.block("T_take"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T_take[v_ax0, v_ax1] = A[B[v_ax0], v_ax1] + + @R.function + def main( + input_ids: R.Tensor((1,), dtype="int32"), + input_embeds: R.Tensor((4096, 4096), dtype="float16"), + ) -> R.Tensor((1, 4096), dtype="float16"): + cls = Before + with R.dataflow(): + gv: R.Tensor((1, 4096), dtype="float16") = cls.fused_func(input_ids, input_embeds) + R.output(gv) + return gv + + @R.function(private=True) + def fused_func( + input_ids: R.Tensor((1,), dtype="int32"), + input_embeds: R.Tensor((4096, 4096), dtype="float16"), + ) -> R.Tensor((1, 4096), dtype="float16"): + R.func_attr({"Primitive": 1}) + cls = Before + with R.dataflow(): + lv = R.call_tir( + cls.add, (input_embeds,), out_sinfo=R.Tensor((4096, 4096), dtype="float16") + ) + gv = R.call_tir( + cls.take, (lv, input_ids), out_sinfo=R.Tensor((1, 4096), dtype="float16") + ) + R.output(gv) + return gv + + @I.ir_module + class After: + @T.prim_func(private=True) + def fused_func( + input_ids: T.Buffer((T.int64(1),), "int32"), + input_embeds: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), + T_take: T.Buffer((T.int64(1), T.int64(4096)), "float16"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + Out_handle_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + Out_handle_intermediate[vi, vj] = input_embeds[vi, vj] + T.float16(1) + for ax0, ax1 in T.grid(T.int64(1), T.int64(4096)): + with T.block("T_take"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T_take[v_ax0, v_ax1] = Out_handle_intermediate[input_ids[v_ax0], v_ax1] + + @R.function + def main( + input_ids: R.Tensor((1,), dtype="int32"), + input_embeds: R.Tensor((4096, 4096), dtype="float16"), + ) -> R.Tensor((1, 4096), dtype="float16"): + cls = After + with R.dataflow(): + gv = R.call_tir( + cls.fused_func, + (input_ids, input_embeds), + out_sinfo=R.Tensor((1, 4096), dtype="float16"), + ) + R.output(gv) + return gv + + _check(Before, After) + + if __name__ == "__main__": tvm.testing.main()