From 7ff7811c0360cbab0b44bc41106d62091180dcc7 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 28 Sep 2022 09:08:31 -0500 Subject: [PATCH] [TE] Raise error for non-bijective transformation This is a fix for a bug introduced in https://github.com/apache/tvm/pull/12904. Prior to then, an exception was raised when the transformation wouldn't be bijective over the transformed buffer's shape. The PR replaced the bijective check done as part of `DetectIterMap` with a check done on the returned `padding_predicate`. However, this check was not equivalent, and some transformations could erroneously apply, rather than raising an exception as being non-bijective. This commit re-enables the bijectivity check in `DetectIterMap`, and adds a test case for this behavior. --- src/arith/iter_affine_map.cc | 4 ++- src/tir/ir/index_map.cc | 33 +++++++++++-------- .../python/unittest/test_transform_layout.py | 13 ++++++++ 3 files changed, 36 insertions(+), 14 deletions(-) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 182eada24d96..d41db2ff135e 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -1739,7 +1739,9 @@ class IterMapToExprNormalizer : public ExprMutator { bool IterMapRewriter::CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs) { const auto* clhs = lhs.as(); const auto* crhs = rhs.as(); - if (clhs && crhs) { + if (crhs && crhs->value == 0) { + return false; + } else if (clhs && crhs) { return clhs->value % crhs->value == 0; } diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index 2c5349ab9941..44c35e63ad99 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -53,19 +53,21 @@ IndexMap IndexMap::FromFunc(int ndim, runtime::TypedPackedFunc(A return IndexMap(initial_indices, func(initial_indices), std::move(inverse_index_map)); } -std::pair IndexMap::NonSurjectiveInverse(Array initial_ranges) const { - if ((*this)->inverse_index_map.defined()) { +std::pair IndexMapInverseImpl(const IndexMap& self, + const Array& initial_ranges, + arith::IterMapLevel check_level) { + if (self->inverse_index_map.defined()) { // return the pre-defined inverse index map if exists. In this // case, the user-defined inverse is assumed to be correct and // bijective. PrimExpr padding_predicate = Bool(false); - return {Downcast((*this)->inverse_index_map.value()), padding_predicate}; + return {Downcast(self->inverse_index_map.value()), padding_predicate}; } // Dummy variables to represent the inverse's inputs. Array output_vars; - for (size_t i = 0; i < (*this)->final_indices.size(); i++) { - PrimExpr index = (*this)->final_indices[i]; + for (size_t i = 0; i < self->final_indices.size(); i++) { + PrimExpr index = self->final_indices[i]; // TODO(Lunderberg): Better names for these variables. A variable // that is passed through unmodified (`index` is an element of // `initial_indices`) should use that input index's name. A pair @@ -79,16 +81,16 @@ std::pair IndexMap::NonSurjectiveInverse(Array initia // Dummy ranges for the extent of each input. Map input_iters; - ICHECK_EQ((*this)->initial_indices.size(), initial_ranges.size()); + ICHECK_EQ(self->initial_indices.size(), initial_ranges.size()); for (size_t i = 0; i < initial_ranges.size(); i++) { - input_iters.Set((*this)->initial_indices[i], initial_ranges[i]); + input_iters.Set(self->initial_indices[i], initial_ranges[i]); } // Unpack the output indices into linear combinations of the initial // indices. arith::Analyzer analyzer; - auto padded_iter_map = DetectIterMap((*this)->final_indices, input_iters, /* predicate = */ 1, - /*check_level=*/arith::IterMapLevel::NoCheck, &analyzer, + auto padded_iter_map = DetectIterMap(self->final_indices, input_iters, /* predicate = */ 1, + /*check_level=*/check_level, &analyzer, /*simplify_trivial_iterators=*/false); CHECK(padded_iter_map->errors.empty()) << "Could not parse mapping as sum of iterators. " << "Error: " << padded_iter_map->errors[0]; @@ -100,8 +102,8 @@ std::pair IndexMap::NonSurjectiveInverse(Array initia // Unpack the map to an array, maintaining the same parameter order. Array inverse_exprs; - for (int i = 0, n = (*this)->initial_indices.size(); i < n; ++i) { - Var index = (*this)->initial_indices[i]; + for (int i = 0, n = self->initial_indices.size(); i < n; ++i) { + Var index = self->initial_indices[i]; PrimExpr expr; if (is_one(initial_ranges[i]->extent) && !inverse_exprs_map.count(index)) { expr = initial_ranges[i]->min; @@ -116,7 +118,7 @@ std::pair IndexMap::NonSurjectiveInverse(Array initia padding_predicate = Substitute(padding_predicate, inverse_exprs_map); { - auto output_ranges = (*this)->MapRanges(initial_ranges); + auto output_ranges = self->MapRanges(initial_ranges); ICHECK_EQ(output_ranges.size(), output_vars.size()); arith::Analyzer analyzer; @@ -131,8 +133,13 @@ std::pair IndexMap::NonSurjectiveInverse(Array initia return {IndexMap(output_vars, inverse_exprs), padding_predicate}; } +std::pair IndexMap::NonSurjectiveInverse(Array initial_ranges) const { + return IndexMapInverseImpl(*this, initial_ranges, arith::IterMapLevel::NoCheck); +} + IndexMap IndexMap::Inverse(Array initial_ranges) const { - auto [inverse, padding_predicate] = NonSurjectiveInverse(std::move(initial_ranges)); + auto [inverse, padding_predicate] = + IndexMapInverseImpl(*this, initial_ranges, arith::IterMapLevel::Bijective); arith::Analyzer analyzer; CHECK(analyzer.CanProve(!padding_predicate)) << "Bijective inverse should not contain padding, but inverse of " << *this << " over range " diff --git a/tests/python/unittest/test_transform_layout.py b/tests/python/unittest/test_transform_layout.py index 18b37741765f..375fe4a24d57 100755 --- a/tests/python/unittest/test_transform_layout.py +++ b/tests/python/unittest/test_transform_layout.py @@ -575,5 +575,18 @@ def test_size_one_buffer(shape, transform): s[B].transform_layout(transform) +def test_non_divisible_transform_raises_error(): + A = te.placeholder([1, 3, 8, 8]) + B = te.compute(A.shape, lambda *indices: A[indices]) + s = te.create_schedule(B.op) + + transform = lambda n, c, h, w: [n, c // 4, h, w, c % 4] + # Error occurs here, because the transformation would introduce + # padding. Padded transforms are supported in TIR-based + # schedules. + with pytest.raises(tvm.TVMError): + s[B].transform_layout(transform) + + if __name__ == "__main__": tvm.testing.main()