diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 921f8ac7094b..4714cf1df59f 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -129,6 +129,10 @@ bool Analyzer::CanProve(const PrimExpr& expr) { PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) { PrimExpr res = expr; + // Always starts with a canonical simplification, as some structural property + // of an expression might be destroyed by rewrite simplification. + res = this->canonical_simplify(res); + for (int i = 0; i < steps; ++i) { if (tir::is_const_int(res)) { return res; diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 39d626aaf2b4..11fb041511f9 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -335,6 +335,8 @@ class SumExprNode : public CanonicalExprNode { * \return whether the cast can be safely pushed to children */ bool CanPushCastToChildren(DataType dtype, Analyzer* analyzer) const { + bool is_min_value = dtype.bits() == 64 ? base == std::numeric_limits::lowest() + : base == -(1LL << (dtype.bits() - 1)); // cast(dtype, arg_1 + arg_2 + ... arg_n) == // cast(dtype, arg_1) + ... + cast(dtype, arg_n) // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of @@ -351,7 +353,7 @@ class SumExprNode : public CanonicalExprNode { } } } - if (base > 0) { + if (base > 0 || is_min_value) { res = res + make_const(dtype, base); if (!CastIsSafe(dtype, res, analyzer)) { return false; @@ -366,7 +368,7 @@ class SumExprNode : public CanonicalExprNode { } } } - if (base < 0) { + if (base < 0 && !is_min_value) { res = res - make_const(dtype, -base); if (!CastIsSafe(dtype, res, analyzer)) { return false; @@ -497,6 +499,8 @@ class SumExprNode : public CanonicalExprNode { return args; } static PrimExpr Normalize_(DataType dtype, const std::vector& args, int64_t base) { + bool is_min_value = dtype.bits() == 64 ? base == std::numeric_limits::lowest() + : base == -(1LL << (dtype.bits() - 1)); // Positive scales first PrimExpr res = make_const(dtype, 0); for (size_t i = 0; i < args.size(); ++i) { @@ -504,7 +508,7 @@ class SumExprNode : public CanonicalExprNode { res = res + args[i]->Normalize(); } } - if (base > 0) { + if (base > 0 || is_min_value) { res = res + make_const(dtype, base); } // negative scales follows using sub. @@ -513,7 +517,7 @@ class SumExprNode : public CanonicalExprNode { res = res - args[i]->NormalizeWithScale(-1); } } - if (base < 0) { + if (base < 0 && !is_min_value) { res = res - make_const(dtype, -base); } return res; diff --git a/tests/python/unittest/test_arith_canonical_simplify.py b/tests/python/unittest/test_arith_canonical_simplify.py index 9db3035fd944..59143056000a 100644 --- a/tests/python/unittest/test_arith_canonical_simplify.py +++ b/tests/python/unittest/test_arith_canonical_simplify.py @@ -372,5 +372,19 @@ def test_simplify_cast(): ck.verify(res, 2) +def test_simplify_normalize_min_value_expr(): + ck = CanonicalChecker() + x = te.var("x", "int32") + + ck.verify(te.min_value("int32") - x == 0, x == te.min_value("int32")) + ck.verify(te.min_value("int32") + x == 0, False) + ck.verify(0 == te.min_value("int32") - x, x == te.min_value("int32")) + ck.verify(0 == te.min_value("int32") + x, False) + ck.verify(-x + te.min_value("int32") == 0, x == te.min_value("int32")) + ck.verify(x + te.min_value("int32") == 0, False) + ck.verify(0 == -x + te.min_value("int32"), x == te.min_value("int32")) + ck.verify(0 == x + te.min_value("int32"), False) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index 24228fb52703..da3fd94f8192 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -182,7 +182,6 @@ def check_region_bound(expect_region, var_dom, mode, predicate=None): expect_begin, expect_end = expect_desc[binding] result_begin = analyzer.simplify(intset.min_value, 3) result_end = analyzer.simplify(intset.max_value + 1, 3) - print(result_end) assert analyzer.can_prove_equal( result_begin - expect_begin, 0 ), f"{result_begin} vs {expect_begin}" @@ -306,10 +305,7 @@ def test_region_lower_bound_for_non_perfect_tile(): + h2: { (): ( tvm.tir.max(h3 * 8, 1), - tvm.tir.max(h3 * 8, 1) - - tvm.tir.max(h3 * 8, 214) - - tvm.tir.max(1 - h3 * 8, 0) - + 224, + tvm.tir.min(0, h3 * 8 - 214) + 224, ), ((h3, 0),): (1, 10), # h3 == 0: region is [1, 10) ((h3, 10),): (h3 * 8, h3 * 8 + 10), # 0 < h3 <= 26: region is [h3 * 8, h3 * 8 + 10) @@ -333,10 +329,7 @@ def test_region_lower_bound_for_non_perfect_tile(): + h1: { (): ( tvm.tir.max(h3 * 8, 1), - tvm.tir.max(h3 * 8, 1) - - tvm.tir.max(h3 * 8, 214) - - tvm.tir.max(1 - h3 * 8, 0) - + 224, + tvm.tir.min(0, h3 * 8 - 214) + 224, ), ((h3, 0),): (1, 10), ((h3, 10),): (h3 * 8, h3 * 8 + 10), diff --git a/tests/python/unittest/test_arith_simplify.py b/tests/python/unittest/test_arith_simplify.py new file mode 100644 index 000000000000..aa9d5179aa3f --- /dev/null +++ b/tests/python/unittest/test_arith_simplify.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm import tir + + +def test_simplify_reshape_flattened_index(): + ana = tvm.arith.Analyzer() + + i0 = tir.Var("i0", "int64") + i1 = tir.Var("i1", "int64") + ana.bind(i0, tvm.ir.Range(0, 8)) + ana.bind(i1, tvm.ir.Range(0, 3)) + + i_flattened = i0 * 3 + i1 + assert tvm.ir.structural_equal( + ana.simplify((i_flattened) // 12 * 12 + (i_flattened) % 12 // 4 * 4 + (i_flattened) % 4), + i_flattened, + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/unittest/test_tir_buffer.py b/tests/python/unittest/test_tir_buffer.py index 55c83167392f..95ad81db889b 100644 --- a/tests/python/unittest/test_tir_buffer.py +++ b/tests/python/unittest/test_tir_buffer.py @@ -150,7 +150,7 @@ def assert_simplified_equal(index_simplified, index_direct): index_simplified = A.offset_of( (idxd(idxm(k0, idxd(k1, s)), n), idxm(idxm(k0, idxd(k1, s)), n) + idxm(k0, k1)) ) - index_direct = A.offset_of((0, idxm(k0, k1) + idxm(k0, idxd(k1, s)))) + index_direct = A.offset_of((0, idxm(k0, idxd(k1, s)) + idxm(k0, k1))) assert_simplified_equal(index_simplified, index_direct) # Test Case3 index_simplified = A.offset_of( diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index 38bd4bba1418..349c4734c9ee 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -126,7 +126,7 @@ def test_suggest_index_map_winograd(): floordiv(i0, 2), floordiv(i1, 2), floormod(i0, 2), - floormod(((i1 * 4) + floordiv(i2, 32)), 8), + floormod(i1, 2) * 4 + floordiv(i2, 32), floormod(i2, 32), floordiv(i3, 32), floormod(i3, 32), @@ -137,8 +137,8 @@ def test_suggest_index_map_winograd(): expected_inverse_index_map = IndexMap.from_func( lambda i0, i1, i2, i3, i4, i5, i6: ( ((i0 * 2) + i2), - ((i1 * 2) + floordiv(((i3 * 32) + i4), 128)), - floormod(((i3 * 32) + i4), 128), + i1 * 2 + floordiv(i3, 4), + floormod(i3, 4) * 32 + i4, ((i5 * 32) + i6), ) )