From 6332750eb6caeb5d2ec2e3ba01d8a1e973ffb650 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 7 Apr 2023 08:43:15 -0500 Subject: [PATCH] [Arith][Bugfix] Simplify "x - 1 < y" into "x <= y" This simplification was introduced in https://github.com/apache/tvm/pull/13217, and was erroneously removed in https://github.com/apache/tvm/pull/13933. This commit re-enables this simplification, and adds unit tests to prevent any future regression. --- src/arith/rewrite_simplify.cc | 54 ++++++++++++++----- .../unittest/test_arith_rewrite_simplify.py | 21 +++++++- 2 files changed, 59 insertions(+), 16 deletions(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 40a5977ec54c..e44ef31da1dd 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -29,6 +29,7 @@ #include #include +#include #include #include "../target/datatype/registry.h" @@ -120,6 +121,23 @@ PrimExpr NormalizeBooleanOperators(PrimExpr expr) { } } +std::tuple ExtractConstantOffset(const PrimExpr& expr) { + PVar x; + PVar c1; + + // Any (c1+x) terms are normalized into (x+c1), so we don't need to + // check for it. + if ((x + c1).Match(expr)) { + return {x.Eval(), c1.Eval()->value}; + } else if ((x - c1).Match(expr)) { + return {x.Eval(), -c1.Eval()->value}; + } else if ((c1 - x).Match(expr)) { + return {x.Eval(), c1.Eval()->value}; + } else { + return {expr, 0}; + } +} + CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, const PrimExpr& y) { CompareResult output = CompareResult::kUnknown; @@ -1664,20 +1682,28 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) { TVM_TRY_RECURSIVE_REWRITE(x < c1 + y, x - y < c1); TVM_TRY_RECURSIVE_REWRITE(c1 + y < x, c1 < x - y); - if ((x + c1 < y + c2).Match(ret)) { - int64_t diff = c2.Eval()->value - c1.Eval()->value; - PrimExpr out = [&]() { - if (diff == 0) { - return (x < y).Eval(); - } else if (diff == 1) { - return (x <= y).Eval(); - } else if (diff < 0) { - return (x + (-diff) < y).Eval(); - } else { - return (x < y + diff).Eval(); - } - }(); - return RecursiveRewrite(out); + auto merge_constants = [&]() -> Optional { + auto [lhs, lhs_offset] = ExtractConstantOffset(ret->a); + auto [rhs, rhs_offset] = ExtractConstantOffset(ret->b); + if (lhs_offset == 0 && rhs_offset == 0) { + return NullOpt; + } + + int64_t diff = rhs_offset - lhs_offset; + if (diff == 0) { + return lhs < rhs; + } else if (diff == 1) { + return lhs <= rhs; + } else if (diff < 0 && rhs_offset != 0) { + return lhs + make_const(lhs.dtype(), -diff) < rhs; + } else if (diff > 0 && lhs_offset != 0) { + return lhs < rhs + make_const(rhs.dtype(), diff); + } + + return NullOpt; + }(); + if (merge_constants) { + return RecursiveRewrite(merge_constants.value()); } } return std::move(ret); diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 119c767ed408..7ecc34c385b6 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -750,6 +750,23 @@ class TestComparisons(BaseCompare): TestCase((x - 10).equal(0), x.equal(10)), TestCase((10 - x).equal(0), x.equal(10)), TestCase((x * y).equal(0), tvm.tir.Or(x.equal(0), y.equal(0))), + # Write LT as LE for integer arguments, if possible + TestCase(x - 1 < y, x <= y), + TestCase(x + (-1) < y, x <= y), + TestCase(x < y - (-1), x <= y), + TestCase(x < y + 1, x <= y), + TestCase(x + 2 < y + 3, x <= y), + TestCase(x - 3 < y - 2, x <= y), + TestCase(x - 3 < y + (-2), x <= y), + TestCase(x + (-3) < y - 2, x <= y), + # Merge constants on the LHS/RHS of a LT expression. + TestCase(x + 10 < y + 10, x < y), + TestCase(x + 5 < y + 10, x < y + 5), + TestCase(x + 10 < y + 5, x + 5 < y), + TestCase(x - 5 < y - 10, x + 5 < y), + TestCase(x - 10 < y - 5, x < y + 5), + TestCase(x < y - 10, x + 10 < y), + TestCase(x - 10 < y, x < y + 10), # cmp bound TestCase(x + y < x + z, y < z), TestCase(x + y < z + x, y < z), @@ -815,7 +832,7 @@ class TestComparisons(BaseCompare): TestCase(tdiv(x, 4) * 4 < x - y, tvm.tir.LT(y, tmod(x, 4))), TestCase(tdiv(x + 2, 4) * 4 >= x, tvm.tir.LE(tmod(x + 2, 4), 2)), TestCase(tdiv(x + 2, 4) * 4 >= x + y, tvm.tir.LE(tmod(x + 2, 4) + y, 2)), - TestCase(tdiv(x + 2, 4) * 4 >= x - y, tvm.tir.LE(tmod(x + 2, 4) + (-2), y)), + TestCase(tdiv(x + 2, 4) * 4 >= x - y, tvm.tir.LE(tmod(x + 2, 4), y + 2)), # floor div TestCase(fld(x, 2) < 3, x < 6), TestCase(3 < fld(x, 2), tvm.tir.LT(7, x)), @@ -833,7 +850,7 @@ class TestComparisons(BaseCompare): TestCase(fld(x, 4) * 4 < x - y, tvm.tir.LT(y, flm(x, 4))), TestCase(fld(x + 2, 4) * 4 >= x, tvm.tir.LE(flm(x + 2, 4), 2)), TestCase(fld(x + 2, 4) * 4 >= x + y, tvm.tir.LE(flm(x + 2, 4) + y, 2)), - TestCase(fld(x + 2, 4) * 4 >= x - y, tvm.tir.LE(flm(x + 2, 4) + (-2), y)), + TestCase(fld(x + 2, 4) * 4 >= x - y, tvm.tir.LE(flm(x + 2, 4), y + 2)), # End DivMod Rules # merging flm/fld into known value TestCase(tir.all(fld(x, 8) == 3, flm(x, 8) == 4), x == 28),