From 6402e0312c3164576a2c268889a2b3bc01ab7740 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 8 May 2023 14:30:52 -0500 Subject: [PATCH] [TIR] Preserve object equality in Buffer::GetFlattenedBuffer If buffer is already flat, then `Buffer::GetFlattenedBuffer()` should return the same object. --- src/tir/ir/buffer.cc | 17 ++++++++++------- tests/python/unittest/test_tir_buffer.py | 23 +++++++++++++++++++++++ 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index c2e6fad42dce..d71187922874 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -387,13 +387,16 @@ Buffer Buffer::GetFlattenedBuffer() const { output_axis_separators.push_back(IntImm(dtype, i + 1)); } - Buffer output = *this; - auto writer = output.CopyOnWrite(); - writer->shape = output_shape; - writer->axis_separators = output_axis_separators; - writer->strides = {}; - - return output; + if (output_shape.size() == self->shape.size() && self->strides.empty()) { + return *this; + } else { + Buffer output = *this; + auto writer = output.CopyOnWrite(); + writer->shape = output_shape; + writer->axis_separators = output_axis_separators; + writer->strides = {}; + return output; + } } PrimExpr Buffer::vload(Array begin, DataType value_dtype) const { diff --git a/tests/python/unittest/test_tir_buffer.py b/tests/python/unittest/test_tir_buffer.py index e3b63d931506..78185510fbab 100644 --- a/tests/python/unittest/test_tir_buffer.py +++ b/tests/python/unittest/test_tir_buffer.py @@ -253,5 +253,28 @@ def check_auto_bind(): check_auto_bind() +def test_buffer_flatten(): + """A buffer should flatten to a 1-d shape""" + buf = tvm.tir.decl_buffer([16, 32]) + flat = buf.get_flattened_buffer() + assert buf.data.same_as(flat.data) + tvm.ir.assert_structural_equal(flat.shape, [16 * 32]) + + +def test_buffer_flatten_preserves_identity(): + """Flattening a 1-d buffer should return the original""" + buf = tvm.tir.decl_buffer([16]) + flat = buf.get_flattened_buffer() + assert buf.same_as(flat) + + +def test_buffer_flatten_uses_axis_separators(): + """Flattening to N-d physical buffers uses the axis separators""" + buf = tvm.tir.decl_buffer([4, 16, 32], axis_separators=[2]) + flat = buf.get_flattened_buffer() + tvm.ir.assert_structural_equal(flat.axis_separators, [1]) + tvm.ir.assert_structural_equal(flat.shape, [4 * 16, 32]) + + if __name__ == "__main__": tvm.testing.main()