From ad6c808f85d0f81bef2754bda63363640b29bab4 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 30 Jul 2024 10:42:47 -0500 Subject: [PATCH 1/3] [TIR] Validate tir::Buffer axis_separators on construction Prior to this commit, the `axis_separators` field of a TIR buffer wasn't validated until the `tir.FlattenBuffer` legalization pass. Delaying the error until this point makes it difficult to determine where it invalid `axis_separators` were initially defined. This commit updates the `tir::Buffer` constructor to validate the `axis_separators` field immediately, allowing these invalid values to be caught on construction. Closes https://github.com/apache/tvm/issues/17215 --- src/tir/ir/buffer.cc | 46 ++++++++++++++++-------- tests/python/tir-base/test_tir_buffer.py | 12 +++++-- 2 files changed, 40 insertions(+), 18 deletions(-) diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 025605333138..a75e9a16dd6f 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -334,24 +334,38 @@ inline Array BufferOffset(const BufferNode* n, Array index, return offsets; } -Buffer Buffer::GetFlattenedBuffer() const { - auto self = operator->(); - +static void ValidateAxisSeparators(const Array& axis_separators, size_t buffer_dim) { // These checks ensure that all output axes contain at least one // input axis. - for (size_t i = 0; (i + 1) < self->axis_separators.size(); i++) { - auto sep = self->axis_separators[i]->value; - auto next_sep = self->axis_separators[i + 1]->value; - ICHECK_LT(sep, next_sep) << "Axis separators must be in strictly increasing order."; - } - if (self->axis_separators.size()) { - auto first_sep = self->axis_separators[0]->value; - ICHECK_GT(first_sep, 0) << "First axis separator must be strictly greater than 0, " - << "so that first output axis contains at least one input axis"; - auto last_sep = self->axis_separators[self->axis_separators.size() - 1]->value; - ICHECK_LT(last_sep, self->shape.size()) - << "Last output axis must contain at least one input axis."; + for (size_t i = 0; (i + 1) < axis_separators.size(); i++) { + auto sep = axis_separators[i]->value; + auto next_sep = axis_separators[i + 1]->value; + CHECK_LT(sep, next_sep) << "ValueError: " + << "Axis separators must be in strictly increasing order, " + << "but axis_separators[" << i << "] = " << sep + << " is greater than or equal to axis_separators[" << (i + 1) + << "] = " << next_sep << "."; + } + if (axis_separators.size()) { + auto first_sep = axis_separators[0]->value; + CHECK_GT(first_sep, 0) << "ValueError: " + << "First axis separator must be strictly greater than 0, " + << "so that first output axis contains at least one input axis. " + << "However, the axis_separators[0] = " << first_sep; + auto last_sep = axis_separators[axis_separators.size() - 1]->value; + CHECK_LT(last_sep, buffer_dim) + << "ValueError: " + << "Last output axis must contain at least one input axis. " + << "However, the axis_separators[" << (axis_separators.size() - 1) << "] = " << last_sep + << " does not leave any input axes between it and the buffer's dimensionality " + << buffer_dim; } +} + +Buffer Buffer::GetFlattenedBuffer() const { + auto self = operator->(); + + ValidateAxisSeparators(self->axis_separators, self->shape.size()); Array output_shape; if (self->strides.size()) { @@ -565,6 +579,8 @@ Buffer::Buffer(Var data, DataType dtype, Array shape, Array ICHECK(data->type_annotation.as()->element_type.as()) << "Variable " << data->name_hint << " does not point to a primitive."; + ValidateAxisSeparators(axis_separators, shape.size()); + auto n = make_object(); n->data = std::move(data); n->dtype = dtype; diff --git a/tests/python/tir-base/test_tir_buffer.py b/tests/python/tir-base/test_tir_buffer.py index 1ab7662b0b6b..7bfd2ae4a7c9 100644 --- a/tests/python/tir-base/test_tir_buffer.py +++ b/tests/python/tir-base/test_tir_buffer.py @@ -109,9 +109,10 @@ def test_buffer_index_merge_mult_mod(): A_stride = tvm.tir.decl_buffer((m, n), "float32", strides=(s, 1)) def assert_simplified_equal(index_simplified, index_direct): - tvm.ir.assert_structural_equal( - index_simplified, index_direct - ), "index_simplified=%s, index_direct=%s" % (index_simplified, index_direct) + ( + tvm.ir.assert_structural_equal(index_simplified, index_direct), + "index_simplified=%s, index_direct=%s" % (index_simplified, index_direct), + ) idxd = tvm.tir.indexdiv idxm = tvm.tir.indexmod @@ -276,5 +277,10 @@ def test_buffer_flatten_uses_axis_separators(): tvm.ir.assert_structural_equal(flat.shape, [4 * 16, 32]) +def test_invalid_axis_separators_raises_exception(): + with pytest.raises(ValueError): + tvm.tir.decl_buffer([1], axis_separators=[1]) + + if __name__ == "__main__": tvm.testing.main() From 1fc972763ae0897cbb94d529cf07e564a386d4b6 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 30 Jul 2024 13:51:17 -0500 Subject: [PATCH 2/3] Update metaschedule primitive to only set axis_separators of alloc --- .../primitive/layout_transformation.cc | 1 - .../test_tir_schedule_set_axis_separator.py | 18 ++++++++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index f1e9106a635b..927733975d80 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -1485,7 +1485,6 @@ class BufferAxisSeparatorMutator : private ReplaceBufferMutator { if (it != buffer_var_map_.end()) { const Buffer& new_source_buffer = it->second; Buffer new_target_buffer = match_buffer->buffer; - new_target_buffer.CopyOnWrite()->axis_separators = new_source_buffer->axis_separators; if (new_target_buffer->shape.size() != new_source_buffer->shape.size()) { LOG(WARNING) << "Target buffer in match_buffer doesn't have the same dimensionality as its source " diff --git a/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py b/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py index 76a6ade42f50..069a71c4b871 100644 --- a/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py +++ b/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py @@ -89,17 +89,31 @@ def element_wise_subregion_match(A: T.Buffer((128, 128), "float32"), C: T.Buffer @T.prim_func def element_wise_subregion_match_set_axis_separator(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")) -> None: + # The `set_axis_separator` scheduling primitive updates the + # backing allocation. B = T.alloc_buffer([128, 128], dtype="float32", axis_separators=[1]) for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) - B_subregion0 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[1]) + # Buffer views do *NOT* receive any `axis_separator` + # annotations. Since the dimensions (and even + # dimensionality) of a view may be different than that of + # the backing allocation, only the backing allocation can + # define how logical dimensions are mapped into physical + # dimensions. + # + # When lowering, buffer views are resolved prior to buffer + # flattening. Since the view no longer exists when the + # buffer is flattened, this ensures that all flattening of + # a buffer uses the `axis_separator` field of the backing + # allocation. + B_subregion0 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1) B_subregion0[()] = A[vi, vj] * T.float32(2) for i, j in T.grid(128, 128): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j]) - B_subregion1 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[1]) + B_subregion1 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1) C[vi, vj] = B_subregion1[()] + T.float32(1) From e99278ab437ae30a01426392edb9114c9a4a8fc5 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 31 Jul 2024 08:39:15 -0500 Subject: [PATCH 3/3] Allow axis separators to be increasing, rather than strictly increasing --- src/tir/ir/buffer.cc | 19 +++++++++---------- .../primitive/layout_transformation.cc | 14 ++++++++++---- tests/python/tir-base/test_tir_buffer.py | 2 +- .../test_tir_schedule_set_axis_separator.py | 18 ++---------------- 4 files changed, 22 insertions(+), 31 deletions(-) diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index a75e9a16dd6f..b7c4eb1d42ec 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -340,25 +340,24 @@ static void ValidateAxisSeparators(const Array& axis_separators, size_t for (size_t i = 0; (i + 1) < axis_separators.size(); i++) { auto sep = axis_separators[i]->value; auto next_sep = axis_separators[i + 1]->value; - CHECK_LT(sep, next_sep) << "ValueError: " - << "Axis separators must be in strictly increasing order, " + CHECK_LE(sep, next_sep) << "ValueError: " + << "Axis separators must be in increasing order, " << "but axis_separators[" << i << "] = " << sep << " is greater than or equal to axis_separators[" << (i + 1) << "] = " << next_sep << "."; } if (axis_separators.size()) { auto first_sep = axis_separators[0]->value; - CHECK_GT(first_sep, 0) << "ValueError: " - << "First axis separator must be strictly greater than 0, " - << "so that first output axis contains at least one input axis. " + CHECK_GE(first_sep, 0) << "ValueError: " + << "All axis separators must be non-negative. " << "However, the axis_separators[0] = " << first_sep; auto last_sep = axis_separators[axis_separators.size() - 1]->value; - CHECK_LT(last_sep, buffer_dim) + CHECK_LE(last_sep, buffer_dim) << "ValueError: " - << "Last output axis must contain at least one input axis. " - << "However, the axis_separators[" << (axis_separators.size() - 1) << "] = " << last_sep - << " does not leave any input axes between it and the buffer's dimensionality " - << buffer_dim; + << "All axis separators must be within the range " + << "0 <= sep <= buffer_dim. " + << "However, the last axis_separators[" << (axis_separators.size() - 1) + << "] = " << last_sep << " is greater than the buffer's dimensionality of " << buffer_dim; } } diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 927733975d80..8b95e0dc622f 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -1485,10 +1485,16 @@ class BufferAxisSeparatorMutator : private ReplaceBufferMutator { if (it != buffer_var_map_.end()) { const Buffer& new_source_buffer = it->second; Buffer new_target_buffer = match_buffer->buffer; - if (new_target_buffer->shape.size() != new_source_buffer->shape.size()) { - LOG(WARNING) - << "Target buffer in match_buffer doesn't have the same dimensionality as its source " - "buffer. `axis_separators` for the target buffer might be incorrect."; + + if (new_target_buffer->shape.size() == new_source_buffer->shape.size()) { + new_target_buffer.CopyOnWrite()->axis_separators = new_source_buffer->axis_separators; + } else { + new_target_buffer.CopyOnWrite()->axis_separators = + Array(new_source_buffer->axis_separators.size(), IntImm(DataType::Int(32), 0)); + LOG(WARNING) << "Buffer view " << new_target_buffer + << " has different dimensionality than backing buffer " << new_source_buffer + << ". The `axis_separators` for " << new_target_buffer << "." + << "`axis_separators` for the view might be incorrect."; } buffer_var_map_[new_target_buffer->data.get()] = new_target_buffer; return MatchBufferRegion(new_target_buffer, diff --git a/tests/python/tir-base/test_tir_buffer.py b/tests/python/tir-base/test_tir_buffer.py index 7bfd2ae4a7c9..b4b773197b14 100644 --- a/tests/python/tir-base/test_tir_buffer.py +++ b/tests/python/tir-base/test_tir_buffer.py @@ -279,7 +279,7 @@ def test_buffer_flatten_uses_axis_separators(): def test_invalid_axis_separators_raises_exception(): with pytest.raises(ValueError): - tvm.tir.decl_buffer([1], axis_separators=[1]) + tvm.tir.decl_buffer([1], axis_separators=[1, 2]) if __name__ == "__main__": diff --git a/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py b/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py index 069a71c4b871..788e17e77146 100644 --- a/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py +++ b/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py @@ -89,31 +89,17 @@ def element_wise_subregion_match(A: T.Buffer((128, 128), "float32"), C: T.Buffer @T.prim_func def element_wise_subregion_match_set_axis_separator(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")) -> None: - # The `set_axis_separator` scheduling primitive updates the - # backing allocation. B = T.alloc_buffer([128, 128], dtype="float32", axis_separators=[1]) for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) - # Buffer views do *NOT* receive any `axis_separator` - # annotations. Since the dimensions (and even - # dimensionality) of a view may be different than that of - # the backing allocation, only the backing allocation can - # define how logical dimensions are mapped into physical - # dimensions. - # - # When lowering, buffer views are resolved prior to buffer - # flattening. Since the view no longer exists when the - # buffer is flattened, this ensures that all flattening of - # a buffer uses the `axis_separator` field of the backing - # allocation. - B_subregion0 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1) + B_subregion0 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[0]) B_subregion0[()] = A[vi, vj] * T.float32(2) for i, j in T.grid(128, 128): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j]) - B_subregion1 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1) + B_subregion1 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[0]) C[vi, vj] = B_subregion1[()] + T.float32(1)