diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 8ce502523159..5eed998384e1 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -26,6 +26,7 @@ #include #include +#include #include "constraint_extract.h" #include "int_operator.h" @@ -81,6 +82,16 @@ struct ConstIntBoundAnalyzer::Entry { bool operator==(const Entry& other) const { return min_value == other.min_value && max_value == other.max_value; } + + friend std::ostream& operator<<(std::ostream& os, const Entry& entry) { + os << "Entry["; + PrintBoundValue(os, entry.min_value); + os << ", "; + PrintBoundValue(os, entry.max_value); + os << "]"; + + return os; + } }; class ConstIntBoundAnalyzer::Impl @@ -228,6 +239,11 @@ class ConstIntBoundAnalyzer::Impl Entry ret; ret.min_value = InfAwareAdd(a.min_value, b.min_value); ret.max_value = InfAwareAdd(a.max_value, b.max_value); + + if (auto bound = BoundUsingReciprocal(GetRef(op))) { + ret = Intersect(ret, bound.value()); + } + return ret; } @@ -237,6 +253,13 @@ class ConstIntBoundAnalyzer::Impl Entry ret; ret.min_value = InfAwareAdd(a.min_value, -b.max_value); ret.max_value = InfAwareAdd(a.max_value, -b.min_value); + + if (auto bound = BoundUsingReciprocal(GetRef(op))) { + ret = Intersect(ret, bound.value()); + } + if (auto bound = BoundUsingReciprocal(Sub(op->b, op->a))) { + ret = Intersect(ret, Negative(bound.value())); + } return ret; } @@ -628,6 +651,25 @@ class ConstIntBoundAnalyzer::Impl ret.max_value = std::min(a.max_value, b.max_value); return ret; } + /*! + * \brief Flip the sign of a set. + * \param entry The set of values + */ + static Entry Negative(Entry entry) { + Entry ret; + if (entry.max_value == kPosInf) { + ret.min_value = kNegInf; + } else { + ret.min_value = -entry.max_value; + } + if (entry.min_value == kNegInf) { + ret.max_value = kPosInf; + } else { + ret.max_value = -entry.min_value; + } + + return ret; + } /*! * \brief return everything dtype can represent. * \param dtype The data type. @@ -733,6 +775,164 @@ class ConstIntBoundAnalyzer::Impl std::ceil(std::log2(arg_bounds.max_value))); } } + + std::optional BoundUsingReciprocal(PrimExpr expr) { + // Match expressions of the form `(A+B)*C - (A*B)*D`. Depending on + // previous simplifications, the exact form of the expression may vary. + auto opt_special_case = [&]() -> std::optional> { + PVar A, B, C, D; + + if (PMatchesOneOf{ + (A + B) * C - (A * B) * D, + (A + B) * C - (B * A) * D, + } + .Match(expr)) { + return std::tuple{VisitExpr(A.Eval()), VisitExpr(B.Eval()), VisitExpr(C.Eval()), + VisitExpr(D.Eval())}; + } else if (PMatchesOneOf{ + (A + B) * C - A * B, + (A + B) * C - B * A, + } + .Match(expr)) { + return std::tuple{VisitExpr(A.Eval()), VisitExpr(B.Eval()), VisitExpr(C.Eval()), + MakeBound(1, 1)}; + } else if (PMatchesOneOf{ + (A * B) * D - (A + B) * C, + (B * A) * D - (A + B) * C, + } + .Match(expr)) { + return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())), + Negative(VisitExpr(C.Eval())), Negative(VisitExpr(D.Eval()))}; + } else if (PMatchesOneOf{ + A * B - (A + B) * C, + B * A - (A + B) * C, + } + .Match(expr)) { + return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())), + Negative(VisitExpr(C.Eval())), MakeBound(-1, -1)}; + } else if (PMatchesOneOf{ + (A * B) * D + (A + B) * C, + (B * A) * D + (A + B) * C, + (A + B) * C + (A * B) * D, + (A + B) * C + (B * A) * D, + } + .Match(expr)) { + return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())), + VisitExpr(C.Eval()), Negative(VisitExpr(D.Eval()))}; + } else if (PMatchesOneOf{ + (A * B) + (A + B) * C, + (B * A) + (A + B) * C, + (A + B) * C + (A * B), + (A + B) * C + (B * A), + } + .Match(expr)) { + return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())), + VisitExpr(C.Eval()), MakeBound(-1, -1)}; + } else { + return std::nullopt; + } + }(); + + if (!opt_special_case.has_value()) { + return std::nullopt; + } + // Unpacking the tuple would be cleaner with a structured binding. + // However, until C++20, structured bindings cannot be captured for + // use in a lambda function. + auto A_bound = std::get<0>(*opt_special_case); + auto B_bound = std::get<1>(*opt_special_case); + auto C_bound = std::get<2>(*opt_special_case); + auto D_bound = std::get<3>(*opt_special_case); + + // If C and D have different signs, flip the signs of A/B/C so + // that C will match the sign of D. + if ((D_bound.max_value < 0 && C_bound.min_value > 0) || + (D_bound.min_value > 0 && C_bound.max_value < 0)) { + A_bound = Negative(A_bound); + B_bound = Negative(B_bound); + C_bound = Negative(C_bound); + } + + // If all terms are negative, then we'll be providing an upper bound + // rather than a lower bound. To avoid code duplication, flip all the + // signs here, find a lower bound, then flip the sign to produce the + // upper bound of the original expression. + bool all_terms_negative = (A_bound.max_value < 0 && B_bound.max_value < 0 && + C_bound.max_value < 0 && D_bound.max_value < 0); + if (all_terms_negative) { + A_bound = Negative(A_bound); + B_bound = Negative(B_bound); + C_bound = Negative(C_bound); + D_bound = Negative(D_bound); + } + + bool all_terms_positive = (A_bound.min_value > 0 && B_bound.min_value > 0 && + C_bound.min_value > 0 && D_bound.min_value > 0); + if (!all_terms_positive) { + return std::nullopt; + } + + // (A + B) * C - (A * B) * D + // (A*B*C*D) * ( (A+B)/(A*B*D) - 1/C ) + // (A*B*C*D) * ( (1/A + 1/B)/D - 1/C ) + // (A*B*C*D) * (1/(A*D) + 1/(B*D) - 1/C) + // + // The constant (A*B*C*D) is positive, and its minimum value is the + // product of the minimum values of A, B, C, and D. If the reciprocal + // term (1/(A*D) + 1/(B*D) - 1/C) is positive, then this constant can + // be used to provide a lower bound on the expression. + + bool reciprocal_term_is_positive = [&]() { + if (D_bound.max_value == ConstIntBound::kPosInf) { + // If D can grow without bound, the `1/(A*D)` and `1/(B*D)` + // terms will approach zero, at which point the `-1/C` term + // will determine the sign the sign. + return false; + } + + if (std::min(A_bound.max_value, B_bound.max_value) * D_bound.max_value <= C_bound.min_value) { + // 1/(A*D) + 1/(B*D) - 1/C is positive if 1/C < 1/(A*D) + 1/(B*D). + // Since each term is positive, this condition can hold if either + // A*D <= C or B*D <= C. + return true; + } + if (A_bound.max_value != ConstIntBound::kPosInf && + B_bound.max_value != ConstIntBound::kPosInf) { + // Even if neither term is sufficient on its own, if both A and B + // have known upper bounds, the inequality 1/C < 1/(A*D) + 1/(B*D) + // may still be provable. + // + // The maximum value of the LHS is found when C is minimized. The + // minimum value of the RHS is found when A, B, and D are + // maximized. If the condition holds in this case, then it holds + // in all cases. + // + // 1/C_min < 1/(A_max * D_max) + 1/(B_max*D_max) + // A_max*B_max*D_max < C_min*B_max + C_min*A_max + // A_max*B_max*D_max < C_min*(A_max + B_max) + // + if (A_bound.max_value * B_bound.max_value * D_bound.max_value < + C_bound.min_value * (A_bound.max_value + B_bound.max_value)) { + return true; + } + } + return false; + }(); + + if (!reciprocal_term_is_positive) { + return std::nullopt; + } + + auto ret = Everything(expr->dtype); + ret.min_value = A_bound.min_value * B_bound.min_value * C_bound.min_value * D_bound.min_value; + + // If we flipped the sign of the original expression, flip the sign of + // the resulting set of possible values. + if (all_terms_negative) { + ret = Negative(ret); + } + return ret; + } }; ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) const { diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 0eaaff5ba838..d063b872e938 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -1768,6 +1768,17 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) { if (merge_constants) { return RecursiveRewrite(merge_constants.value()); } + + auto common_factor = [&]() -> int64_t { + auto modular_a = analyzer_->modular_set(ret->a); + auto modular_b = analyzer_->modular_set(ret->b); + auto gcd_lhs = ZeroAwareGCD(modular_a->base, modular_a->coeff); + auto gcd_rhs = ZeroAwareGCD(modular_b->base, modular_b->coeff); + return ZeroAwareGCD(gcd_lhs, gcd_rhs); + }(); + if (common_factor > 1) { + return RecursiveRewrite(floordiv(ret->a, common_factor) < floordiv(ret->b, common_factor)); + } } return std::move(ret); } diff --git a/tests/python/arith/test_arith_const_int_bound.py b/tests/python/arith/test_arith_const_int_bound.py index 5667c79aaced..c22e1dcb787c 100644 --- a/tests/python/arith/test_arith_const_int_bound.py +++ b/tests/python/arith/test_arith_const_int_bound.py @@ -15,373 +15,283 @@ # specific language governing permissions and limitations # under the License. +import contextlib + import tvm import tvm.testing from tvm import te +from tvm.arith import ConstIntBound +NEG_INF = ConstIntBound.NEG_INF +POS_INF = ConstIntBound.POS_INF -def test_dtype_bound(): - analyzer = tvm.arith.Analyzer() - x = te.var("x", dtype="int64") - bd = analyzer.const_int_bound(x) - assert bd.min_value == bd.NEG_INF - assert bd.max_value == bd.POS_INF +class TestCase: + def __init__(self, expr, expected_bounds, known_bounds=None, constraint=None): + self.expr = expr + self.expected_bounds = expected_bounds + if known_bounds is None: + self.known_bounds = {} + else: + self.known_bounds = known_bounds - x = te.var("x", dtype="int8") - bd = analyzer.const_int_bound(x) - assert bd.min_value == -128 - assert bd.max_value == 127 + self.constraint = constraint - x = te.var("x", dtype="uint8") - bd = analyzer.const_int_bound(x) - assert bd.min_value == 0 - assert bd.max_value == 255 + @property + def __name__(self): + return str(self.expr) -def test_cast_bound(): - analyzer = tvm.arith.Analyzer() - x = te.var("x", dtype="int8") - tmod = tvm.tir.truncmod - bd = analyzer.const_int_bound(tmod(x, 3).astype("uint32")) - assert bd.min_value == 0 - assert bd.max_value == 2 - - bd = analyzer.const_int_bound(tmod(x, 3).astype("float32").astype("int32")) - assert bd.min_value == -2 - assert bd.max_value == 2 - - -def test_add_sub_bound(): - analyzer = tvm.arith.Analyzer() - x, y = te.var("x", "int64"), te.var("y", "int64") - bd = analyzer.const_int_bound(x + y) - assert bd.min_value == bd.NEG_INF - assert bd.max_value == bd.POS_INF - - analyzer.update(x, tvm.arith.ConstIntBound(0, 4)) - analyzer.update(y, tvm.arith.ConstIntBound(1, 10)) - bd = analyzer.const_int_bound(x + y) - assert bd.min_value == 1 - assert bd.max_value == 14 - - bd = analyzer.const_int_bound(x - y) - assert bd.min_value == -10 - assert bd.max_value == 3 - - analyzer.update(x, tvm.arith.ConstIntBound(0, bd.POS_INF), override=True) - bd = analyzer.const_int_bound(x - y) - assert bd.min_value == -10 - assert bd.max_value == bd.POS_INF - - bd = analyzer.const_int_bound(1 - x) - assert bd.min_value == bd.NEG_INF - assert bd.max_value == 1 - - ## constants with negative or positive max(int64) occassionally show up - ## in models, this is to ensure we can handle those cases - analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.NEG_INF), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True) - bd = analyzer.const_int_bound(x + y) - assert bd.min_value == bd.NEG_INF - assert bd.max_value == bd.POS_INF - - analyzer.update(x, tvm.arith.ConstIntBound(bd.POS_INF, bd.POS_INF), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True) - bd = analyzer.const_int_bound(x + y) - assert bd.min_value == bd.NEG_INF - assert bd.max_value == bd.POS_INF - - -def test_mul_bound(): - analyzer = tvm.arith.Analyzer() - x, y = te.var("x"), te.var("y") +class BaseCompare: + def test_const_bounds(self, test_case): + analyzer = tvm.arith.Analyzer() - analyzer.update(x, tvm.arith.ConstIntBound(-2, 4)) - analyzer.update(y, tvm.arith.ConstIntBound(4, 10)) - bd = analyzer.const_int_bound(x * y + 20) - assert bd.min_value == 0 - assert bd.max_value == 60 + for var, bounds in test_case.known_bounds.items(): + analyzer.update(var, ConstIntBound(*bounds)) - analyzer.update(x, tvm.arith.ConstIntBound(-3, 4), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(-8, 2), override=True) - bd = analyzer.const_int_bound(x * y) - assert bd.min_value == -32 - assert bd.max_value == 24 + with contextlib.ExitStack() as stack: + if test_case.constraint is not None: + stack.enter_context(analyzer.constraint_scope(test_case.constraint)) - analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, 4), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(-8, 2), override=True) - bd = analyzer.const_int_bound(x * y) - assert bd.min_value == bd.NEG_INF - assert bd.max_value == bd.POS_INF + bounds = analyzer.const_int_bound(test_case.expr) + if test_case.expected_bounds[0] is None: + assert bounds.max_value == test_case.expected_bounds[1] + elif test_case.expected_bounds[1] is None: + assert bounds.min_value == test_case.expected_bounds[0] + else: + assert (bounds.min_value, bounds.max_value) == test_case.expected_bounds -def test_truncdiv_bound(): - analyzer = tvm.arith.Analyzer() - x, y = te.var("x"), te.var("y") - tdiv = tvm.tir.truncdiv - analyzer.update(x, tvm.arith.ConstIntBound(-9, 4)) - analyzer.update(y, tvm.arith.ConstIntBound(4, 10)) - bd = analyzer.const_int_bound(tdiv(x, y)) - assert bd.min_value == -2 +class TestDataType(BaseCompare): + test_case = tvm.testing.parameter( + TestCase(te.var("x", dtype="int64"), (NEG_INF, POS_INF)), + TestCase(te.var("x", dtype="int8"), (-128, 127)), + TestCase(te.var("x", dtype="uint8"), (0, 255)), + TestCase(te.size_var("x", dtype="int32"), (0, POS_INF)), + ) - analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(-2, 0), override=True) - bd = analyzer.const_int_bound(tdiv(x, y)) - assert bd.min_value == -4 - assert bd.max_value == 9 - analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, 4), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(-2, 1), override=True) - bd = analyzer.const_int_bound(tdiv(x, y)) - assert bd.min_value == bd.NEG_INF - assert bd.max_value == bd.POS_INF +class TestCastBound(BaseCompare): + x = te.var("x", dtype="int8") + tmod = tvm.tir.truncmod - analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(-4, 12), override=True) - bd = analyzer.const_int_bound(tdiv(x, y)) - assert bd.min_value == -9 - assert bd.max_value == 9 + test_case = tvm.testing.parameter( + TestCase(tmod(x, 3).astype("uint32"), (0, 2)), + TestCase(tmod(x, 3).astype("float32").astype("int32"), (-2, 2)), + ) -def test_truncmod_bound(): - analyzer = tvm.arith.Analyzer() - x, y = te.var("x"), te.var("y") +class TestAddSubBound(BaseCompare): + x = te.var("x", "int64") + y = te.var("y", "int64") - tmod = tvm.tir.truncmod + test_case = tvm.testing.parameter( + TestCase(x + y, (NEG_INF, POS_INF)), + TestCase(x + y, (1, 14), known_bounds={x: (0, 4), y: (1, 10)}), + TestCase(x - y, (-10, 3), known_bounds={x: (0, 4), y: (1, 10)}), + TestCase(x - y, (-10, POS_INF), known_bounds={x: (0, POS_INF), y: (1, 10)}), + TestCase(1 - x, (NEG_INF, 1), known_bounds={x: (0, POS_INF), y: (1, 10)}), + ) - analyzer.update(x, tvm.arith.ConstIntBound(-9, 4)) - analyzer.update(y, tvm.arith.ConstIntBound(4, 10)) - bd = analyzer.const_int_bound(tmod(x, y)) - assert bd.min_value == -9 - assert bd.max_value == 4 - analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True) - bd = analyzer.const_int_bound(tmod(x, y)) - assert bd.min_value == -9 - assert bd.max_value == 9 +class TestBoundsUsingReciprocals(BaseCompare): + """Special handling for differences of reciprocals - analyzer.update(x, tvm.arith.ConstIntBound(1, bd.POS_INF), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True) - bd = analyzer.const_int_bound(tmod(x, y)) - assert bd.min_value == 0 - assert bd.max_value == 9 + These terms can appear when comparing the number of operations for + different orderings of matrix multiplications, with A, B, and C + known to be positive values. + In these cases, comparing `(A+B)*C < A*B` is equivalent to + `1/A + 1/B < 1/C`. Working in terms of the reciprocals + allows the ConstIntBound analyzer to provide a tighter + bound for these differences than would otherwise be + available. -def test_floordiv_bound(): - analyzer = tvm.arith.Analyzer() - x, y = te.var("x"), te.var("y") - fld = tvm.te.floordiv - analyzer.update(x, tvm.arith.ConstIntBound(-9, 4)) - analyzer.update(y, tvm.arith.ConstIntBound(4, 10)) - bd = analyzer.const_int_bound(fld(x, y)) - assert bd.min_value == -9 // 4 - - analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(-2, 0), override=True) - bd = analyzer.const_int_bound(fld(x, y)) - assert bd.min_value == -4 - assert bd.max_value == 9 - - analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, 4), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(-2, 1), override=True) - bd = analyzer.const_int_bound(fld(x, y)) - assert bd.min_value == bd.NEG_INF - assert bd.max_value == bd.POS_INF - - analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(-4, 12), override=True) - bd = analyzer.const_int_bound(fld(x, y)) - assert bd.min_value == -9 - assert bd.max_value == 9 - - # Test handling unsigned integers well - x, y = te.var("x", dtype="uint32"), te.var("y", dtype="uint32") - analyzer.update(x, tvm.arith.ConstIntBound(1, 4), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(0, 12), override=True) - bd = analyzer.const_int_bound(fld(x, y)) - assert bd.min_value == 0 - assert bd.max_value == 4 - - -def test_floormod_bound(): - analyzer = tvm.arith.Analyzer() - x, y = te.var("x"), te.var("y") - flm = tvm.te.floormod + For `(A+B)*C - A*B`, the normal bottom-up integer bounds are unable to + provide the bounds required to provide these inequalities, because they + treat the terms as uncorrelated. That is, they assume that `(A+B)*C` may + achieve its minimum while `A*B` simultaneously achieves its maximum. + """ - analyzer.update(x, tvm.arith.ConstIntBound(-9, 4)) - analyzer.update(y, tvm.arith.ConstIntBound(4, 10)) - bd = analyzer.const_int_bound(flm(x, y)) - assert bd.min_value == 0 - assert bd.max_value == 9 + A, B, C = [te.var(letter, "int64") for letter in "ABC"] - analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True) - bd = analyzer.const_int_bound(flm(x, y)) - assert bd.min_value == 0 - assert bd.max_value == 9 + symmetric_bounds = {A: (1, 4095), B: (1, 4095), C: (2048, 2048)} + asymmetric_bounds = {A: (1, 1024), B: (1, POS_INF), C: (2048, 2048)} - analyzer.update(x, tvm.arith.ConstIntBound(1, bd.POS_INF), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True) - bd = analyzer.const_int_bound(flm(x, y)) - assert bd.min_value == 0 - assert bd.max_value == 9 + test_case = tvm.testing.parameter( + TestCase((A + B) * C - A * B, (2048, None), known_bounds=symmetric_bounds), + TestCase((A + B) * C - B * A, (2048, None), known_bounds=symmetric_bounds), + TestCase(A * B - (A + B) * C, (None, -2048), known_bounds=symmetric_bounds), + TestCase(B * A - (A + B) * C, (None, -2048), known_bounds=symmetric_bounds), + TestCase((A + B) * C - A * B, (2048, None), known_bounds=asymmetric_bounds), + TestCase((A + B) * C - B * A, (2048, None), known_bounds=asymmetric_bounds), + TestCase(A * B - (A + B) * C, (None, -2048), known_bounds=asymmetric_bounds), + TestCase(B * A - (A + B) * C, (None, -2048), known_bounds=asymmetric_bounds), + ) -def test_min_max_bound(): - analyzer = tvm.arith.Analyzer() +class TestMulBound(BaseCompare): x, y = te.var("x"), te.var("y") - analyzer.update(x, tvm.arith.ConstIntBound(-9, 11)) - analyzer.update(y, tvm.arith.ConstIntBound(4, 10)) - bd = analyzer.const_int_bound(tvm.te.min(x, y)) - assert bd.min_value == -9 - assert bd.max_value == 10 + test_case = tvm.testing.parameter( + TestCase(x * y + 20, (0, 60), {x: (-2, 4), y: (4, 10)}), + TestCase(x * y, (-32, 24), {x: (-3, 4), y: (-8, 2)}), + TestCase(x * y, (NEG_INF, POS_INF), {x: (NEG_INF, 4), y: (-8, 2)}), + ) + - analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True) - bd = analyzer.const_int_bound(tvm.te.min(x, y)) - assert bd.min_value == bd.NEG_INF - assert bd.max_value == 10 +class TestTruncDivBound(BaseCompare): + x, y = te.var("x"), te.var("y") - bd = analyzer.const_int_bound(tvm.te.max(x, y)) - assert bd.min_value == 4 - assert bd.max_value == bd.POS_INF + expr = tvm.tir.truncdiv(x, y) - analyzer.update(x, tvm.arith.ConstIntBound(1, bd.POS_INF), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True) - bd = analyzer.const_int_bound(tvm.te.max(x, y)) - assert bd.min_value == 4 - assert bd.max_value == bd.POS_INF + test_case = tvm.testing.parameter( + TestCase(expr, (-2, None), {x: (-9, 4), y: (4, 10)}), + TestCase(expr, (-4, 9), {x: (-9, 4), y: (-2, 0)}), + TestCase(expr, (NEG_INF, POS_INF), {x: (NEG_INF, 4), y: (-2, 1)}), + TestCase(expr, (-9, 9), {x: (-9, 4), y: (-4, 12)}), + ) -def test_select_bound(): - analyzer = tvm.arith.Analyzer() +class TestTruncModBound(BaseCompare): x, y = te.var("x"), te.var("y") - analyzer.update(x, tvm.arith.ConstIntBound(-9, 11)) - analyzer.update(y, tvm.arith.ConstIntBound(4, 10)) + expr = tvm.tir.truncmod(x, y) + + test_case = tvm.testing.parameter( + TestCase(expr, (-9, 4), {x: (-9, 4), y: (4, 10)}), + TestCase(expr, (-9, 9), {x: (NEG_INF, POS_INF), y: (4, 10)}), + TestCase(expr, (0, 9), {x: (1, POS_INF), y: (4, 10)}), + ) - bd = analyzer.const_int_bound(tvm.tir.Select(x > 1, (y < 0).astype("int32"), y + 1)) - assert bd.min_value == 0 - assert bd.max_value == 11 +class TestFloorDivBound(BaseCompare): + x, y = te.var("x"), te.var("y") + ux = te.var("x", dtype="uint32") + uy = te.var("y", dtype="uint32") + + test_case = tvm.testing.parameter( + TestCase(x // y, (-9 // 4, None), {x: (-9, 4), y: (4, 10)}), + TestCase(x // y, (-4, 9), {x: (-9, 4), y: (-2, 0)}), + TestCase(x // y, (NEG_INF, POS_INF), {x: (NEG_INF, 4), y: (-2, 1)}), + TestCase(x // y, (-9, 9), {x: (-9, 4), y: (-4, 12)}), + TestCase(ux // uy, (0, 4), {ux: (1, 4), uy: (0, 12)}), + ) -def test_shift_and_bound(): - analyzer = tvm.arith.Analyzer() + +class TestFloorModBound(BaseCompare): x, y = te.var("x"), te.var("y") - analyzer.update(x, tvm.arith.ConstIntBound(-9, 11)) - analyzer.update(y, tvm.arith.ConstIntBound(2, 10)) + test_case = tvm.testing.parameter( + TestCase(x % y, (0, 9), {x: (-9, 4), y: (4, 10)}), + TestCase(x % y, (0, 9), {x: (NEG_INF, POS_INF), y: (4, 10)}), + TestCase(x % y, (0, 9), {x: (1, POS_INF), y: (4, 10)}), + ) - bd = analyzer.const_int_bound(x >> y) - assert bd.min_value == -3 - assert bd.max_value == 2 - bd = analyzer.const_int_bound(x & y) - assert bd.min_value == 0 - assert bd.max_value == 10 +class TestMinMaxBound(BaseCompare): + x, y = te.var("x"), te.var("y") - analyzer.update(x, tvm.arith.ConstIntBound(10, 11), override=True) - bd = analyzer.const_int_bound(x & y) - assert bd.min_value == 0 - assert bd.max_value == 10 + test_case = tvm.testing.parameter( + TestCase(tvm.te.min(x, y), (-9, 10), {x: (-9, 11), y: (4, 10)}), + TestCase(tvm.te.min(x, y), (NEG_INF, 10), {x: (NEG_INF, POS_INF), y: (4, 10)}), + TestCase(tvm.te.max(x, y), (4, POS_INF), {x: (NEG_INF, POS_INF), y: (4, 10)}), + TestCase(tvm.te.max(x, y), (4, POS_INF), {x: (1, POS_INF), y: (4, 10)}), + ) -def test_mix_index_bound(): - analyzer = tvm.arith.Analyzer() +class TestSelectBound(BaseCompare): x, y = te.var("x"), te.var("y") - tdiv = tvm.tir.truncdiv - tmod = tvm.tir.truncmod - analyzer.update(x, tvm.arith.ConstIntBound(0, 24 - 1)) - analyzer.update(y, tvm.arith.ConstIntBound(0, 3 - 1)) - bd = analyzer.const_int_bound(tmod(x, 8) + tdiv(x, 8) * 8) - assert bd.min_value == 0 - assert bd.max_value == 24 - 1 + test_case = tvm.testing.parameter( + TestCase( + tvm.tir.Select(x > 1, (y < 0).astype("int32"), y + 1), + (0, 11), + {x: (-9, 11), y: (4, 10)}, + ), + ) + + +class TestShiftAndBound(BaseCompare): + x, y = te.var("x"), te.var("y") - bd = analyzer.const_int_bound(y + x * 3) - assert bd.min_value == 0 - assert bd.max_value == 24 * 3 - 1 + test_case = tvm.testing.parameter( + TestCase(x >> y, (-3, 2), {x: (-9, 11), y: (2, 10)}), + TestCase(x & y, (0, 10), {x: (-9, 11), y: (2, 10)}), + TestCase(x & y, (0, 10), {x: (10, 11), y: (2, 10)}), + ) - bd = analyzer.const_int_bound(tmod(x, 7) + tdiv(x, 7) * 7) - assert bd.min_value == 0 - assert bd.max_value == (23 // 7) * 7 + 6 +class TestMixIndexBound(BaseCompare): + x, y = te.var("x"), te.var("y") + tdiv = tvm.tir.truncdiv + tmod = tvm.tir.truncmod -def test_size_var_bound(): - analyzer = tvm.arith.Analyzer() - x = te.size_var("x") - bd = analyzer.const_int_bound(x) - assert bd.min_value == 0 - assert bd.max_value == bd.POS_INF + test_case = tvm.testing.parameter( + TestCase(tmod(x, 8) + tdiv(x, 8) * 8, (0, 24 - 1), {x: (0, 24 - 1), y: (0, 3 - 1)}), + TestCase(y + x * 3, (0, 24 * 3 - 1), {x: (0, 24 - 1), y: (0, 3 - 1)}), + TestCase( + tmod(x, 7) + tdiv(x, 7) * 7, (0, (23 // 7) * 7 + 6), {x: (0, 24 - 1), y: (0, 3 - 1)} + ), + ) -def test_let_bound(): - analyzer = tvm.arith.Analyzer() +class TestLetBound(BaseCompare): x = te.var("x") - bd = analyzer.const_int_bound(tvm.tir.Let(x, 1, x + 1)) - assert bd.min_value == 2 - assert bd.max_value == 2 + test_case = tvm.testing.parameter( + TestCase(tvm.tir.Let(x, 1, x + 1), (2, 2)), + ) -def test_floormod_negative_divisor(): - analyzer = tvm.arith.Analyzer() +class TestFloorModNegativeDivisor(BaseCompare): flm, fld = tvm.te.floormod, tvm.te.floordiv a, b = te.var("a"), te.var("b") - analyzer.update(a, tvm.arith.ConstIntBound(0, 6)) - analyzer.update(b, tvm.arith.ConstIntBound(-5, 7)) - bd = analyzer.const_int_bound(flm(a, b)) - assert bd.min_value == -4 - assert bd.max_value == 6 + test_case = tvm.testing.parameter( + TestCase(a % b, (-4, 6), {a: (0, 6), b: (-5, 7)}), + ) + + +class TestDivModAssumeNoZeroDivisor(BaseCompare): + """Divmod non negative expression makes assumption that divide by + zero won't occur this assumption is important to get best result + from symbolic shape programs + """ -def test_divmod_assume_no_zero_divsor(): - # Divmod non negative expression makes assumption that divide by zero won't occur - # this assumption is important to get best result from symbolic shape programs - analyzer = tvm.arith.Analyzer() - flm, fld = tvm.te.floormod, tvm.te.floordiv a, b = te.var("a"), te.var("b") - analyzer.update(a, tvm.arith.ConstIntBound(0, 6)) - analyzer.update(b, tvm.arith.ConstIntBound(0, tvm.arith.ConstIntBound.POS_INF)) - bd = analyzer.const_int_bound(fld(a, b)) - assert bd.min_value == 0 - assert bd.max_value == 6 - bd = analyzer.const_int_bound(flm(a, b)) - assert bd.min_value == 0 - assert bd.max_value == 6 + test_case = tvm.testing.parameter( + TestCase(a // b, (0, 6), {a: (0, 6), b: (0, POS_INF)}), + TestCase(a % b, (0, 6), {a: (0, 6), b: (0, POS_INF)}), + ) -def test_multiple_condition(): - analyzer = tvm.arith.Analyzer() - flm, fld = tvm.te.floormod, tvm.te.floordiv +class TestMultipleCondition(BaseCompare): a = te.var("a") - analyzer.update(a, tvm.arith.ConstIntBound(0, 128)) - with analyzer.constraint_scope(tvm.tir.all(1 <= flm(a, 58), flm(a, 58) < 57)): - bound = analyzer.const_int_bound(flm(a, 58) - 1) - assert bound.min_value == 0 + test_case = tvm.testing.parameter( + TestCase( + a % 58 - 1, + (0, None), + known_bounds={a: (0, 128)}, + constraint=tvm.tir.all(1 <= a % 58, a % 58 < 57), + ), + ) -def test_broadcast_bound(): - analyzer = tvm.arith.Analyzer() +class TestBroadcastBound(BaseCompare): a = te.var("a") - analyzer.update(a, tvm.arith.ConstIntBound(0, 128)) - bound = analyzer.const_int_bound(tvm.tir.Broadcast(a, 4)) - assert bound.min_value == 0 - assert bound.max_value == 128 + test_case = tvm.testing.parameter( + TestCase(tvm.tir.Broadcast(a, 4), (0, 128), {a: (0, 128)}), + ) -def test_ramp_bound(): - analyzer = tvm.arith.Analyzer() +class TestRampBound(BaseCompare): a = te.var("a") - analyzer.update(a, tvm.arith.ConstIntBound(0, 128)) - bound = analyzer.const_int_bound(tvm.tir.Ramp(a, 2, 4) + 2) - assert bound.min_value == 2 - assert bound.max_value == 128 + 2 * 3 + 2 + test_case = tvm.testing.parameter( + TestCase(tvm.tir.Ramp(a, 2, 4) + 2, (2, 128 + 2 * 3 + 2), {a: (0, 128)}), + ) if __name__ == "__main__": diff --git a/tests/python/arith/test_arith_rewrite_simplify.py b/tests/python/arith/test_arith_rewrite_simplify.py index 6433dc2dece9..5d2c3aa283cf 100644 --- a/tests/python/arith/test_arith_rewrite_simplify.py +++ b/tests/python/arith/test_arith_rewrite_simplify.py @@ -983,6 +983,30 @@ class TestComparisons(BaseCompare): TestCase(y * y >= 0, tvm.tir.const(1, "bool"), y <= 0), TestCase(x * 6 <= -3, tvm.tir.const(0, "bool"), x >= 0), TestCase(tmod(y - 1, 3) == 0, tmod(y + (-1), 3) == 0), + # Special inequality cases + TestCase( + x * y < (x + y) * 2048, + tvm.tir.const(1, "bool"), + [x > 0, y > 0, x < 2048], + ), + TestCase( + x * y < (x + y) * 2048, + tvm.tir.const(1, "bool"), + [x > 0, y > 0, x < 4096, y < 4096], + ), + TestCase( + # Both sides are divisible by 8192 + x * y * 8192 < (y + x) * 16777216, + tvm.tir.const(1, "bool"), + [x > 0, y > 0, x < 4096, y < 4096], + ), + TestCase( + # The two sides have co-prime factors, but the bounds are + # still sufficient to prove the inequality. + x * y * 59 < (y + x) * 176128, + tvm.tir.const(1, "bool"), + [x > 0, y > 0, x < 4096, y < 4096], + ), )