diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 2fb3f1d8cee4..df3c85c05ce1 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -202,21 +202,27 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { auto f_mutate_match_buffers = [this](const MatchBufferRegion& match_buffer) { const Buffer& src_buffer = SubstituteBuffer(match_buffer->source->buffer); const Buffer& tgt_buffer = SubstituteAllocatedBuffer(match_buffer->buffer); + Region region = MutateRegion(match_buffer->source->region); if (src_buffer.same_as(match_buffer->source->buffer) && - tgt_buffer.same_as(match_buffer->buffer)) { + tgt_buffer.same_as(match_buffer->buffer) && + region.same_as(match_buffer->source->region)) { return match_buffer; } else { auto n = make_object(*match_buffer.get()); n->buffer = tgt_buffer; - n->source = BufferRegion(src_buffer, match_buffer->source->region); + n->source = BufferRegion(src_buffer, region); return MatchBufferRegion(n); } }; auto f_mutate_read_write_region = [this](const BufferRegion& buffer_region) { - auto it = buffer_remap_.find(buffer_region->buffer); - return it == buffer_remap_.end() ? buffer_region - : BufferRegion((*it).second, buffer_region->region); + const Buffer& buffer = SubstituteBuffer(buffer_region->buffer); + const Region& region = MutateRegion(buffer_region->region); + if (buffer.same_as(buffer_region->buffer) && region.same_as(buffer_region->region)) { + return buffer_region; + } else { + return BufferRegion(buffer, region); + } }; // Step 1. Mutate `match_buffers`. @@ -285,6 +291,18 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { return buffer; } } + + inline Region MutateRegion(const Region& region) { + return MutateArray(region, [this](const Range& range) { + const PrimExpr& min = this->VisitExpr(range->min); + const PrimExpr& extent = this->VisitExpr(range->extent); + if (min.same_as(range->min) && extent.same_as(range->extent)) { + return range; + } else { + return Range::FromMinExtent(min, extent); + } + }); + } }; /*! \brief A mutator which detect block name duplication and deduplicate the names. */ diff --git a/src/tir/transforms/renew_defs.cc b/src/tir/transforms/renew_defs.cc index fd2c27dcd154..8a122f892204 100644 --- a/src/tir/transforms/renew_defs.cc +++ b/src/tir/transforms/renew_defs.cc @@ -120,9 +120,11 @@ class RenewDefMutator : public StmtExprMutator { std::bind(&RenewDefMutator::VisitMatchBuffer, this, std::placeholders::_1)); // Step 3. Visit body - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - ICHECK(op); + Optional init = NullOpt; + if (op->init.defined()) { + init = this->VisitStmt(op->init.value()); + } + Stmt body = this->VisitStmt(op->body); // Step 4. Revisit access region Array reads = @@ -137,6 +139,8 @@ class RenewDefMutator : public StmtExprMutator { n->match_buffers = std::move(match_buffers); n->reads = std::move(reads); n->writes = std::move(writes); + n->body = std::move(body); + n->init = std::move(init); return Stmt(n); } diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index 556b673e61e9..dc2421d64afe 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -1538,7 +1538,6 @@ def sum_1d( Y: T.Buffer([T.int64(1)], "float32"), num_elements: T.int64, ): - X = T.match_buffer(X_handle, [num_elements], "float32") for i in range(num_elements): @@ -1603,5 +1602,94 @@ def main( _check(Before, Expected) +def test_gather(): + @I.ir_module + class Before: + @T.prim_func(private=True) + def add( + A: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), + Out: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), + ): + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + Out[vi, vj] = A[vi, vj] + T.float16(1.0) + + @T.prim_func(private=True) + def take( + A: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), + B: T.Buffer((T.int64(1),), "int32"), + T_take: T.Buffer((T.int64(1), T.int64(4096)), "float16"), + ): + for ax0, ax1 in T.grid(T.int64(1), T.int64(4096)): + with T.block("T_take"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T_take[v_ax0, v_ax1] = A[B[v_ax0], v_ax1] + + @R.function + def main( + input_ids: R.Tensor((1,), dtype="int32"), + input_embeds: R.Tensor((4096, 4096), dtype="float16"), + ) -> R.Tensor((1, 4096), dtype="float16"): + cls = Before + with R.dataflow(): + gv: R.Tensor((1, 4096), dtype="float16") = cls.fused_func(input_ids, input_embeds) + R.output(gv) + return gv + + @R.function(private=True) + def fused_func( + input_ids: R.Tensor((1,), dtype="int32"), + input_embeds: R.Tensor((4096, 4096), dtype="float16"), + ) -> R.Tensor((1, 4096), dtype="float16"): + R.func_attr({"Primitive": 1}) + cls = Before + with R.dataflow(): + lv = R.call_tir( + cls.add, (input_embeds,), out_sinfo=R.Tensor((4096, 4096), dtype="float16") + ) + gv = R.call_tir( + cls.take, (lv, input_ids), out_sinfo=R.Tensor((1, 4096), dtype="float16") + ) + R.output(gv) + return gv + + @I.ir_module + class After: + @T.prim_func(private=True) + def fused_func( + input_ids: T.Buffer((T.int64(1),), "int32"), + input_embeds: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), + T_take: T.Buffer((T.int64(1), T.int64(4096)), "float16"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + Out_handle_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + Out_handle_intermediate[vi, vj] = input_embeds[vi, vj] + T.float16(1) + for ax0, ax1 in T.grid(T.int64(1), T.int64(4096)): + with T.block("T_take"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T_take[v_ax0, v_ax1] = Out_handle_intermediate[input_ids[v_ax0], v_ax1] + + @R.function + def main( + input_ids: R.Tensor((1,), dtype="int32"), + input_embeds: R.Tensor((4096, 4096), dtype="float16"), + ) -> R.Tensor((1, 4096), dtype="float16"): + cls = After + with R.dataflow(): + gv = R.call_tir( + cls.fused_func, + (input_ids, input_embeds), + out_sinfo=R.Tensor((1, 4096), dtype="float16"), + ) + R.output(gv) + return gv + + _check(Before, After) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_tir_renew_defs.py b/tests/python/unittest/test_tir_renew_defs.py index 3f286a241cfc..22f7b65ca17b 100644 --- a/tests/python/unittest/test_tir_renew_defs.py +++ b/tests/python/unittest/test_tir_renew_defs.py @@ -15,9 +15,6 @@ # specific language governing permissions and limitations # under the License. -import sys - -import pytest import tvm import tvm.testing from tvm.script import tir as T @@ -187,5 +184,24 @@ def main(a: T.handle, b: T.handle): assert f1.buffer_map[f1.params[1]].shape[0] != f2.buffer_map[f2.params[1]].shape[0] +def test_gather(): + @T.prim_func(private=True) + def take( + A: T.Buffer((4096, 4096), "float16"), + B: T.Buffer((1,), "int32"), + T_take: T.Buffer((1, 4096), "float16"), + ): + for ax0, ax1 in T.grid(1, 4096): + with T.block("T_take"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[B[v_ax0], v_ax1], B[v_ax0]) + T.writes(T_take[v_ax0, v_ax1]) + T_take[v_ax0, v_ax1] = A[B[v_ax0], v_ax1] + + f1 = take + f2 = tvm.tir.stmt_functor.renew_defs(take) + tvm.ir.assert_structural_equal(f1, f2) + + if __name__ == "__main__": tvm.testing.main()