From 4df3068e8662e9ba0e80def514ff909f3b29601f Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 30 Jan 2023 15:03:18 -0500 Subject: [PATCH] [Fix][Arith] Analyzer simplification starts with canonical This PR updates the order of arithmetic analyzer simplification, by adding a stage of canonical simplification at the very beginning so that every simplification always starts with a canonical round. This is because the rewrite simplification may destroy some PrimExpr property that the canonical simplification can make use of. Therefore, adding the canonical one in the front can maximize the use of canonical simplification. --- src/arith/analyzer.cc | 4 ++ src/arith/canonical_simplify.cc | 12 ++++-- .../unittest/test_arith_canonical_simplify.py | 14 +++++++ tests/python/unittest/test_arith_intset.py | 11 +----- tests/python/unittest/test_arith_simplify.py | 38 +++++++++++++++++++ tests/python/unittest/test_tir_buffer.py | 2 +- .../unittest/test_tir_schedule_analysis.py | 6 +-- 7 files changed, 70 insertions(+), 17 deletions(-) create mode 100644 tests/python/unittest/test_arith_simplify.py diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 921f8ac7094b..4714cf1df59f 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -129,6 +129,10 @@ bool Analyzer::CanProve(const PrimExpr& expr) { PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) { PrimExpr res = expr; + // Always starts with a canonical simplification, as some structural property + // of an expression might be destroyed by rewrite simplification. + res = this->canonical_simplify(res); + for (int i = 0; i < steps; ++i) { if (tir::is_const_int(res)) { return res; diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 39d626aaf2b4..11fb041511f9 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -335,6 +335,8 @@ class SumExprNode : public CanonicalExprNode { * \return whether the cast can be safely pushed to children */ bool CanPushCastToChildren(DataType dtype, Analyzer* analyzer) const { + bool is_min_value = dtype.bits() == 64 ? base == std::numeric_limits::lowest() + : base == -(1LL << (dtype.bits() - 1)); // cast(dtype, arg_1 + arg_2 + ... arg_n) == // cast(dtype, arg_1) + ... + cast(dtype, arg_n) // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of @@ -351,7 +353,7 @@ class SumExprNode : public CanonicalExprNode { } } } - if (base > 0) { + if (base > 0 || is_min_value) { res = res + make_const(dtype, base); if (!CastIsSafe(dtype, res, analyzer)) { return false; @@ -366,7 +368,7 @@ class SumExprNode : public CanonicalExprNode { } } } - if (base < 0) { + if (base < 0 && !is_min_value) { res = res - make_const(dtype, -base); if (!CastIsSafe(dtype, res, analyzer)) { return false; @@ -497,6 +499,8 @@ class SumExprNode : public CanonicalExprNode { return args; } static PrimExpr Normalize_(DataType dtype, const std::vector& args, int64_t base) { + bool is_min_value = dtype.bits() == 64 ? base == std::numeric_limits::lowest() + : base == -(1LL << (dtype.bits() - 1)); // Positive scales first PrimExpr res = make_const(dtype, 0); for (size_t i = 0; i < args.size(); ++i) { @@ -504,7 +508,7 @@ class SumExprNode : public CanonicalExprNode { res = res + args[i]->Normalize(); } } - if (base > 0) { + if (base > 0 || is_min_value) { res = res + make_const(dtype, base); } // negative scales follows using sub. @@ -513,7 +517,7 @@ class SumExprNode : public CanonicalExprNode { res = res - args[i]->NormalizeWithScale(-1); } } - if (base < 0) { + if (base < 0 && !is_min_value) { res = res - make_const(dtype, -base); } return res; diff --git a/tests/python/unittest/test_arith_canonical_simplify.py b/tests/python/unittest/test_arith_canonical_simplify.py index 9db3035fd944..59143056000a 100644 --- a/tests/python/unittest/test_arith_canonical_simplify.py +++ b/tests/python/unittest/test_arith_canonical_simplify.py @@ -372,5 +372,19 @@ def test_simplify_cast(): ck.verify(res, 2) +def test_simplify_normalize_min_value_expr(): + ck = CanonicalChecker() + x = te.var("x", "int32") + + ck.verify(te.min_value("int32") - x == 0, x == te.min_value("int32")) + ck.verify(te.min_value("int32") + x == 0, False) + ck.verify(0 == te.min_value("int32") - x, x == te.min_value("int32")) + ck.verify(0 == te.min_value("int32") + x, False) + ck.verify(-x + te.min_value("int32") == 0, x == te.min_value("int32")) + ck.verify(x + te.min_value("int32") == 0, False) + ck.verify(0 == -x + te.min_value("int32"), x == te.min_value("int32")) + ck.verify(0 == x + te.min_value("int32"), False) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index 24228fb52703..da3fd94f8192 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -182,7 +182,6 @@ def check_region_bound(expect_region, var_dom, mode, predicate=None): expect_begin, expect_end = expect_desc[binding] result_begin = analyzer.simplify(intset.min_value, 3) result_end = analyzer.simplify(intset.max_value + 1, 3) - print(result_end) assert analyzer.can_prove_equal( result_begin - expect_begin, 0 ), f"{result_begin} vs {expect_begin}" @@ -306,10 +305,7 @@ def test_region_lower_bound_for_non_perfect_tile(): + h2: { (): ( tvm.tir.max(h3 * 8, 1), - tvm.tir.max(h3 * 8, 1) - - tvm.tir.max(h3 * 8, 214) - - tvm.tir.max(1 - h3 * 8, 0) - + 224, + tvm.tir.min(0, h3 * 8 - 214) + 224, ), ((h3, 0),): (1, 10), # h3 == 0: region is [1, 10) ((h3, 10),): (h3 * 8, h3 * 8 + 10), # 0 < h3 <= 26: region is [h3 * 8, h3 * 8 + 10) @@ -333,10 +329,7 @@ def test_region_lower_bound_for_non_perfect_tile(): + h1: { (): ( tvm.tir.max(h3 * 8, 1), - tvm.tir.max(h3 * 8, 1) - - tvm.tir.max(h3 * 8, 214) - - tvm.tir.max(1 - h3 * 8, 0) - + 224, + tvm.tir.min(0, h3 * 8 - 214) + 224, ), ((h3, 0),): (1, 10), ((h3, 10),): (h3 * 8, h3 * 8 + 10), diff --git a/tests/python/unittest/test_arith_simplify.py b/tests/python/unittest/test_arith_simplify.py new file mode 100644 index 000000000000..aa9d5179aa3f --- /dev/null +++ b/tests/python/unittest/test_arith_simplify.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm import tir + + +def test_simplify_reshape_flattened_index(): + ana = tvm.arith.Analyzer() + + i0 = tir.Var("i0", "int64") + i1 = tir.Var("i1", "int64") + ana.bind(i0, tvm.ir.Range(0, 8)) + ana.bind(i1, tvm.ir.Range(0, 3)) + + i_flattened = i0 * 3 + i1 + assert tvm.ir.structural_equal( + ana.simplify((i_flattened) // 12 * 12 + (i_flattened) % 12 // 4 * 4 + (i_flattened) % 4), + i_flattened, + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/unittest/test_tir_buffer.py b/tests/python/unittest/test_tir_buffer.py index 55c83167392f..95ad81db889b 100644 --- a/tests/python/unittest/test_tir_buffer.py +++ b/tests/python/unittest/test_tir_buffer.py @@ -150,7 +150,7 @@ def assert_simplified_equal(index_simplified, index_direct): index_simplified = A.offset_of( (idxd(idxm(k0, idxd(k1, s)), n), idxm(idxm(k0, idxd(k1, s)), n) + idxm(k0, k1)) ) - index_direct = A.offset_of((0, idxm(k0, k1) + idxm(k0, idxd(k1, s)))) + index_direct = A.offset_of((0, idxm(k0, idxd(k1, s)) + idxm(k0, k1))) assert_simplified_equal(index_simplified, index_direct) # Test Case3 index_simplified = A.offset_of( diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index 38bd4bba1418..349c4734c9ee 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -126,7 +126,7 @@ def test_suggest_index_map_winograd(): floordiv(i0, 2), floordiv(i1, 2), floormod(i0, 2), - floormod(((i1 * 4) + floordiv(i2, 32)), 8), + floormod(i1, 2) * 4 + floordiv(i2, 32), floormod(i2, 32), floordiv(i3, 32), floormod(i3, 32), @@ -137,8 +137,8 @@ def test_suggest_index_map_winograd(): expected_inverse_index_map = IndexMap.from_func( lambda i0, i1, i2, i3, i4, i5, i6: ( ((i0 * 2) + i2), - ((i1 * 2) + floordiv(((i3 * 32) + i4), 128)), - floormod(((i3 * 32) + i4), 128), + i1 * 2 + floordiv(i3, 4), + floormod(i3, 4) * 32 + i4, ((i5 * 32) + i6), ) )