diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 788f6fddfa50..044e5d6f6ca9 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -334,6 +334,35 @@ class RewriteSimplifier { * (n < 10) || (n < 5) => (n < 5) */ kApplyConstraintsToBooleanBranches = (1 << 2), + + /* Special handling for expressions `(A+B)*C < (A*B)*D` + * + * Expressions of the form `(A+B)*C < (A*B)*D` can occur occur + * when comparing the number of operations required for two + * different orderings in which matrix multiplications can be + * performed. Proving or disproving this conditional allows an + * optimal order of execution to be selected, even for dynamic + * argument shapes. + * + * The default behavior of `ConstIntBounds` assumes that each term + * in an expression is independent, and is insufficient to prove + * these inequalities. For example, the maximum value of `(A+B)*C + * - (A*B)*D` is determined by taking the maximum value of + * `(A+B)*C` and subtracting the minimum value of `(A*B)*D`. + * While this algorithm can be applied in all cases, the bound it + * provides is looser than strictly required. + * + * This extension adds a check for this case. When `A`, `B`, `C`, + * and `D` are all positive values, as is the case for tensor + * shapes, the inequality can be written as `1/A + 1/B < D/C`. If + * this inequality holds for the minimum values of `A`, `B`, and + * `D`, along with the maximum value of `C`, then the inequality + * holds for all values. + * + * This extension requires little to no performance overhead, and + * may be enabled by default in future releases. + */ + kComparisonOfProductAndSum = (1 << 3), }; /*! \brief Enable an optional extension or extensions diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py index 87801fd781b1..791fed27cb5e 100644 --- a/python/tvm/arith/__init__.py +++ b/python/tvm/arith/__init__.py @@ -24,7 +24,7 @@ estimate_region_strict_bound, estimate_region_upper_bound, ) -from .analyzer import ModularSet, ConstIntBound, Analyzer, ProofStrength +from .analyzer import ModularSet, ConstIntBound, Analyzer, ProofStrength, Extension from .bound import deduce_bound from .pattern import detect_linear_equation, detect_clip_bound, detect_common_subexpr from .int_solver import solve_linear_equations, solve_linear_inequalities diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index b2bad2ec0646..22555e0fb3a4 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=invalid-name """Arithmetic data structure and utility""" -from enum import IntEnum +import enum from typing import Union import tvm._ffi @@ -26,13 +26,26 @@ from . import _ffi_api -class ProofStrength(IntEnum): +class ProofStrength(enum.IntEnum): """Proof strength of the analysis""" DEFAULT = 0 SYMBOLIC_BOUND = 1 +class Extension(enum.Flag): + """Extensions enabled for RewriteSimplifier + + Values should match `RewriteSimplifier::Extensions` + """ + + NoExtensions = 0 + TransitivelyProveInequalities = 1 << 0 + ConvertBooleanToAndOfOrs = 1 << 1 + ApplyConstraintsToBooleanBranches = 1 << 2 + ComparisonOfProductAndSum = 1 << 3 + + @tvm._ffi.register_object("arith.ModularSet") class ModularSet(Object): """Represent range of (coeff * x + base) for x in Z""" @@ -107,6 +120,8 @@ def __init__(self): self._enter_constraint_context = _mod("enter_constraint_context") self._can_prove_equal = _mod("can_prove_equal") self._can_prove = _mod("can_prove") + self._get_enabled_extensions = _mod("get_enabled_extensions") + self._set_enabled_extensions = _mod("set_enabled_extensions") def const_int_bound(self, expr): """Find constant integer bound for expr. @@ -311,3 +326,22 @@ def can_prove_equal(self, lhs: "PrimExpr", rhs: "PrimExpr"): Whether we can prove that lhs == rhs """ return self._can_prove_equal(lhs, rhs) + + @property + def enabled_extensions(self) -> Extension: + """Return the currently enabled extensions""" + value = self._get_enabled_extensions() + return Extension(value) + + @enabled_extensions.setter + def enabled_extensions(self, flags: Union[int, Extension]): + """Enable extensions for the analyzer + + Parameters + ---------- + flags: Union[int,Extension] + + The extensions to enable. + """ + flags = Extension(flags).value + self._set_enabled_extensions(flags) diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 3e5b8834ebca..b0d240cc40a2 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -317,6 +317,16 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValu } else if (name == "can_prove_equal") { return PackedFunc( [self](TVMArgs args, TVMRetValue* ret) { *ret = self->CanProveEqual(args[0], args[1]); }); + } else if (name == "get_enabled_extensions") { + return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { + *ret = static_cast(self->rewrite_simplify.GetEnabledExtensions()); + }); + } else if (name == "set_enabled_extensions") { + return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { + std::int64_t flags = args[0]; + self->rewrite_simplify.SetEnabledExtensions( + static_cast(flags)); + }); } return PackedFunc(); }; diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 5eed998384e1..8d41f0f2c6e7 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -240,10 +240,6 @@ class ConstIntBoundAnalyzer::Impl ret.min_value = InfAwareAdd(a.min_value, b.min_value); ret.max_value = InfAwareAdd(a.max_value, b.max_value); - if (auto bound = BoundUsingReciprocal(GetRef(op))) { - ret = Intersect(ret, bound.value()); - } - return ret; } @@ -254,12 +250,6 @@ class ConstIntBoundAnalyzer::Impl ret.min_value = InfAwareAdd(a.min_value, -b.max_value); ret.max_value = InfAwareAdd(a.max_value, -b.min_value); - if (auto bound = BoundUsingReciprocal(GetRef(op))) { - ret = Intersect(ret, bound.value()); - } - if (auto bound = BoundUsingReciprocal(Sub(op->b, op->a))) { - ret = Intersect(ret, Negative(bound.value())); - } return ret; } @@ -775,164 +765,6 @@ class ConstIntBoundAnalyzer::Impl std::ceil(std::log2(arg_bounds.max_value))); } } - - std::optional BoundUsingReciprocal(PrimExpr expr) { - // Match expressions of the form `(A+B)*C - (A*B)*D`. Depending on - // previous simplifications, the exact form of the expression may vary. - auto opt_special_case = [&]() -> std::optional> { - PVar A, B, C, D; - - if (PMatchesOneOf{ - (A + B) * C - (A * B) * D, - (A + B) * C - (B * A) * D, - } - .Match(expr)) { - return std::tuple{VisitExpr(A.Eval()), VisitExpr(B.Eval()), VisitExpr(C.Eval()), - VisitExpr(D.Eval())}; - } else if (PMatchesOneOf{ - (A + B) * C - A * B, - (A + B) * C - B * A, - } - .Match(expr)) { - return std::tuple{VisitExpr(A.Eval()), VisitExpr(B.Eval()), VisitExpr(C.Eval()), - MakeBound(1, 1)}; - } else if (PMatchesOneOf{ - (A * B) * D - (A + B) * C, - (B * A) * D - (A + B) * C, - } - .Match(expr)) { - return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())), - Negative(VisitExpr(C.Eval())), Negative(VisitExpr(D.Eval()))}; - } else if (PMatchesOneOf{ - A * B - (A + B) * C, - B * A - (A + B) * C, - } - .Match(expr)) { - return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())), - Negative(VisitExpr(C.Eval())), MakeBound(-1, -1)}; - } else if (PMatchesOneOf{ - (A * B) * D + (A + B) * C, - (B * A) * D + (A + B) * C, - (A + B) * C + (A * B) * D, - (A + B) * C + (B * A) * D, - } - .Match(expr)) { - return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())), - VisitExpr(C.Eval()), Negative(VisitExpr(D.Eval()))}; - } else if (PMatchesOneOf{ - (A * B) + (A + B) * C, - (B * A) + (A + B) * C, - (A + B) * C + (A * B), - (A + B) * C + (B * A), - } - .Match(expr)) { - return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())), - VisitExpr(C.Eval()), MakeBound(-1, -1)}; - } else { - return std::nullopt; - } - }(); - - if (!opt_special_case.has_value()) { - return std::nullopt; - } - // Unpacking the tuple would be cleaner with a structured binding. - // However, until C++20, structured bindings cannot be captured for - // use in a lambda function. - auto A_bound = std::get<0>(*opt_special_case); - auto B_bound = std::get<1>(*opt_special_case); - auto C_bound = std::get<2>(*opt_special_case); - auto D_bound = std::get<3>(*opt_special_case); - - // If C and D have different signs, flip the signs of A/B/C so - // that C will match the sign of D. - if ((D_bound.max_value < 0 && C_bound.min_value > 0) || - (D_bound.min_value > 0 && C_bound.max_value < 0)) { - A_bound = Negative(A_bound); - B_bound = Negative(B_bound); - C_bound = Negative(C_bound); - } - - // If all terms are negative, then we'll be providing an upper bound - // rather than a lower bound. To avoid code duplication, flip all the - // signs here, find a lower bound, then flip the sign to produce the - // upper bound of the original expression. - bool all_terms_negative = (A_bound.max_value < 0 && B_bound.max_value < 0 && - C_bound.max_value < 0 && D_bound.max_value < 0); - if (all_terms_negative) { - A_bound = Negative(A_bound); - B_bound = Negative(B_bound); - C_bound = Negative(C_bound); - D_bound = Negative(D_bound); - } - - bool all_terms_positive = (A_bound.min_value > 0 && B_bound.min_value > 0 && - C_bound.min_value > 0 && D_bound.min_value > 0); - if (!all_terms_positive) { - return std::nullopt; - } - - // (A + B) * C - (A * B) * D - // (A*B*C*D) * ( (A+B)/(A*B*D) - 1/C ) - // (A*B*C*D) * ( (1/A + 1/B)/D - 1/C ) - // (A*B*C*D) * (1/(A*D) + 1/(B*D) - 1/C) - // - // The constant (A*B*C*D) is positive, and its minimum value is the - // product of the minimum values of A, B, C, and D. If the reciprocal - // term (1/(A*D) + 1/(B*D) - 1/C) is positive, then this constant can - // be used to provide a lower bound on the expression. - - bool reciprocal_term_is_positive = [&]() { - if (D_bound.max_value == ConstIntBound::kPosInf) { - // If D can grow without bound, the `1/(A*D)` and `1/(B*D)` - // terms will approach zero, at which point the `-1/C` term - // will determine the sign the sign. - return false; - } - - if (std::min(A_bound.max_value, B_bound.max_value) * D_bound.max_value <= C_bound.min_value) { - // 1/(A*D) + 1/(B*D) - 1/C is positive if 1/C < 1/(A*D) + 1/(B*D). - // Since each term is positive, this condition can hold if either - // A*D <= C or B*D <= C. - return true; - } - if (A_bound.max_value != ConstIntBound::kPosInf && - B_bound.max_value != ConstIntBound::kPosInf) { - // Even if neither term is sufficient on its own, if both A and B - // have known upper bounds, the inequality 1/C < 1/(A*D) + 1/(B*D) - // may still be provable. - // - // The maximum value of the LHS is found when C is minimized. The - // minimum value of the RHS is found when A, B, and D are - // maximized. If the condition holds in this case, then it holds - // in all cases. - // - // 1/C_min < 1/(A_max * D_max) + 1/(B_max*D_max) - // A_max*B_max*D_max < C_min*B_max + C_min*A_max - // A_max*B_max*D_max < C_min*(A_max + B_max) - // - if (A_bound.max_value * B_bound.max_value * D_bound.max_value < - C_bound.min_value * (A_bound.max_value + B_bound.max_value)) { - return true; - } - } - return false; - }(); - - if (!reciprocal_term_is_positive) { - return std::nullopt; - } - - auto ret = Everything(expr->dtype); - ret.min_value = A_bound.min_value * B_bound.min_value * C_bound.min_value * D_bound.min_value; - - // If we flipped the sign of the original expression, flip the sign of - // the resulting set of possible values. - if (all_terms_negative) { - ret = Negative(ret); - } - return ret; - } }; ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) const { diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index d063b872e938..e7e58a80fc08 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -156,10 +156,12 @@ CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, const PrimE }; output = CompareResult(output & TryCompareUsingConstIntBounds(x, y)); - if (is_finished()) return output; output = CompareResult(output & TryCompareUsingKnownInequalities(x, y)); + if (is_finished()) return output; + + output = CompareResult(output & TryComparisonOfProductAndSum(x, y)); return output; } @@ -175,6 +177,149 @@ CompareResult RewriteSimplifier::Impl::TryCompareUsingKnownInequalities(const Pr return analyzer_->transitive_comparisons.TryCompare(x, y, propagate_inequalities); } +CompareResult RewriteSimplifier::Impl::TryComparisonOfProductAndSum(const PrimExpr& x, + const PrimExpr& y) { + bool check_comparison_of_product_and_sum = enabled_extensions_ & kComparisonOfProductAndSum; + if (!check_comparison_of_product_and_sum) { + return CompareResult::kUnknown; + } + + auto opt_special_case = + [&]() -> std::optional> { + // Match expressions of the form `(A+B)*C - (A*B)*D`. Depending on + // previous simplifications, the exact form of the expression may vary. + PVar A, B, C, D; + + // diff is `(A+B)*C - (A*B)*D`. + PrimExpr diff = this->VisitExpr(x - y); + + if (PMatchesOneOf{ + (A + B) * C + (A * B) * D, + (A + B) * C + (B * A) * D, + (A * B) * D + (A + B) * C, + (B * A) * D + (A + B) * C, + } + .Match(diff)) { + return std::tuple{A.Eval(), B.Eval(), C.Eval(), -D.Eval()}; + } else if (PMatchesOneOf{ + (A + B) * C + (A * B), + (A + B) * C + (B * A), + (A * B) + (A + B) * C, + (B * A) + (A + B) * C, + } + .Match(diff)) { + return std::tuple{A.Eval(), B.Eval(), C.Eval(), Integer(-1)}; + } else { + return std::nullopt; + } + }(); + + if (!opt_special_case.has_value()) { + return CompareResult::kUnknown; + } + auto [A, B, C, D] = *opt_special_case; + + auto A_bound = analyzer_->const_int_bound(A); + auto B_bound = analyzer_->const_int_bound(B); + auto C_bound = analyzer_->const_int_bound(C); + auto D_bound = analyzer_->const_int_bound(D); + + auto negate = [](ConstIntBound bound) { + return ConstIntBound(-bound->max_value, -bound->min_value); + }; + auto is_negative = [](const ConstIntBound& bound) { return bound->max_value < 0; }; + auto is_positive = [](const ConstIntBound& bound) { return bound->min_value > 0; }; + + // If D is negative, then we'll be providing an upper bound for + // `(A*B)*D`, rather than a lower bound. To avoid code duplication, + // flip all the signs here, find a lower bound, then flip the sign + // to produce the upper bound of the original expression. + // + // Before: (A+B)*C < (A*B)*D + // After: (A*B)*(-D) < (A + B)*(-C) + bool is_upper_bound = is_negative(D_bound); + if (is_upper_bound) { + C_bound = negate(C_bound); + D_bound = negate(D_bound); + } + + // Before: (A+B)*C < (A*B)*D + // After: ((-A) + (-B))*(-C) < ((-A)*(-B))*D + if (is_negative(C_bound)) { + A_bound = negate(A_bound); + B_bound = negate(B_bound); + C_bound = negate(C_bound); + } + + bool all_terms_positive = (is_positive(A_bound) && is_positive(B_bound) && is_positive(C_bound) && + is_positive(D_bound)); + if (!all_terms_positive) { + return CompareResult::kUnknown; + } + + // (A + B) * C < (A * B) * D + // (A + B) * C / (A*B*C*D) < (A * B) * D / (A*B*C*D) + // 1/(A*D) + 1/(B*D) < 1/C + // (A*B*C*D) * ( (A+B)/(A*B*D) - 1/C ) + // (A*B*C*D) * ( (1/A + 1/B)/D - 1/C ) + // (A*B*C*D) * (1/(A*D) + 1/(B*D) - 1/C) + // + // The constant (A*B*C*D) is positive, and its minimum value is the + // product of the minimum values of A, B, C, and D. If the reciprocal + // term (1/(A*D) + 1/(B*D) - 1/C) is positive, then this constant can + // be used to provide a lower bound on the expression. + + bool reciprocal_term_is_positive = [&]() { + if (D_bound->max_value == ConstIntBound::kPosInf) { + // If D can grow without bound, the `1/(A*D)` and `1/(B*D)` + // terms will approach zero, at which point the `-1/C` term + // will determine the sign the sign. + return false; + } + + if (std::min(A_bound->max_value, B_bound->max_value) * D_bound->max_value <= + C_bound->min_value) { + // 1/(A*D) + 1/(B*D) - 1/C is positive if 1/C < 1/(A*D) + 1/(B*D). + // Since each term is positive, this condition can hold if either + // A*D <= C or B*D <= C. + return true; + } + if (A_bound->max_value != ConstIntBound::kPosInf && + B_bound->max_value != ConstIntBound::kPosInf) { + // Even if neither term is sufficient on its own, if both A and B + // have known upper bounds, the inequality 1/C < 1/(A*D) + 1/(B*D) + // may still be provable. + // + // The maximum value of the LHS is found when C is minimized. The + // minimum value of the RHS is found when A, B, and D are + // maximized. If the condition holds in this case, then it holds + // in all cases. + // + // 1/C_min < 1/(A_max * D_max) + 1/(B_max*D_max) + // A_max*B_max*D_max < C_min*B_max + C_min*A_max + // A_max*B_max*D_max < C_min*(A_max + B_max) + // + if (A_bound->max_value * B_bound->max_value * D_bound->max_value < + C_bound->min_value * (A_bound->max_value + B_bound->max_value)) { + return true; + } + } + return false; + }(); + + if (!reciprocal_term_is_positive) { + return CompareResult::kUnknown; + } + + if (is_upper_bound) { + // If we flipped the sign of the original expression, flip the sign of + // the resulting set of possible values. + return CompareResult::kLT; + } else { + return CompareResult::kGT; + } +} + // try to prove x equals val CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, int64_t val) { // NOTE on implementation: this function can be called many times and can be a bottleneck, diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index 7c4b0eab2224..e488024ec348 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -216,6 +216,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { private: CompareResult TryCompareUsingKnownInequalities(const PrimExpr& x, const PrimExpr& y); CompareResult TryCompareUsingConstIntBounds(const PrimExpr& x, const PrimExpr y); + CompareResult TryComparisonOfProductAndSum(const PrimExpr& x, const PrimExpr& y); // Whether x >= val bool CanProveGreaterEqual(const PrimExpr& x, int64_t val) { diff --git a/tests/python/arith/test_arith_const_int_bound.py b/tests/python/arith/test_arith_const_int_bound.py index c22e1dcb787c..e9b764c5f402 100644 --- a/tests/python/arith/test_arith_const_int_bound.py +++ b/tests/python/arith/test_arith_const_int_bound.py @@ -17,6 +17,8 @@ import contextlib +import pytest + import tvm import tvm.testing @@ -96,6 +98,7 @@ class TestAddSubBound(BaseCompare): ) +@pytest.mark.xfail(reason="Not currently supported") class TestBoundsUsingReciprocals(BaseCompare): """Special handling for differences of reciprocals diff --git a/tests/python/arith/test_arith_rewrite_simplify.py b/tests/python/arith/test_arith_rewrite_simplify.py index 5d2c3aa283cf..8645e5b26a28 100644 --- a/tests/python/arith/test_arith_rewrite_simplify.py +++ b/tests/python/arith/test_arith_rewrite_simplify.py @@ -51,15 +51,17 @@ def __name__(self): class BaseCompare: + extensions = tvm.arith.Extension.NoExtensions + def test_simplify(self, test_case): analyzer = tvm.arith.Analyzer() + analyzer.enabled_extensions = self.extensions if inspect.isclass(test_case.expected) and issubclass(test_case.expected, Exception): with pytest.raises(test_case.expected): with analyzer.constraint_scope(test_case.constraint): analyzer.rewrite_simplify(test_case.before) else: - with analyzer.constraint_scope(test_case.constraint): after = analyzer.rewrite_simplify(test_case.before) @@ -983,6 +985,15 @@ class TestComparisons(BaseCompare): TestCase(y * y >= 0, tvm.tir.const(1, "bool"), y <= 0), TestCase(x * 6 <= -3, tvm.tir.const(0, "bool"), x >= 0), TestCase(tmod(y - 1, 3) == 0, tmod(y + (-1), 3) == 0), + ) + + +class TestComparisonOfProductAndSum(BaseCompare): + extensions = tvm.arith.Extension.ComparisonOfProductAndSum + + x, y, z = te.var("x"), te.var("y"), te.var("z") + + test_case = tvm.testing.parameter( # Special inequality cases TestCase( x * y < (x + y) * 2048,