From 6cd514f60a64d301b46ad5520d1426921bbf93a6 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 18 Jan 2024 20:33:42 +0000 Subject: [PATCH 1/4] [Arith] Provide tighter ConstIntBounds for special cases Expressions of the form `(A+B)*C < (A*B)*D` can occur occur when comparing the number of operations required for two different orderings in which matrix multiplications can be performed. Proving or disproving this conditional allows an optimal order of execution to be selected, even for dynamic argument shapes. The default behavior of `ConstIntBounds` assumes that each term in an expression is independent. For example, the maximum value of `(A+B)*C - (A*B)*D` is determined by taking the maximum value of `(A+B)*C` and subtracting the minimum value of `(A*B)*D`. This algorithm can be applied in all cases, but can provide a bound that is looser than strictly required. This commit adds a check for this case in `ConstIntBounds`, to provide a tighter bound of possible values. When `A`, `B`, `C`, and `D` are all positive values, as is the case for tensor shapes, the inequality can be written as `1/A + 1/B < D/C`. If this inequality holds for the minimum values of `A`, `B`, and `D`, along with the maximum value of `C`, then it holds for all values. --- src/arith/const_int_bound.cc | 200 ++++++++++++++++++ src/arith/rewrite_simplify.cc | 11 + .../arith/test_arith_const_int_bound.py | 93 ++++++++ .../arith/test_arith_rewrite_simplify.py | 24 +++ 4 files changed, 328 insertions(+) diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 8ce502523159..5eed998384e1 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -26,6 +26,7 @@ #include #include +#include #include "constraint_extract.h" #include "int_operator.h" @@ -81,6 +82,16 @@ struct ConstIntBoundAnalyzer::Entry { bool operator==(const Entry& other) const { return min_value == other.min_value && max_value == other.max_value; } + + friend std::ostream& operator<<(std::ostream& os, const Entry& entry) { + os << "Entry["; + PrintBoundValue(os, entry.min_value); + os << ", "; + PrintBoundValue(os, entry.max_value); + os << "]"; + + return os; + } }; class ConstIntBoundAnalyzer::Impl @@ -228,6 +239,11 @@ class ConstIntBoundAnalyzer::Impl Entry ret; ret.min_value = InfAwareAdd(a.min_value, b.min_value); ret.max_value = InfAwareAdd(a.max_value, b.max_value); + + if (auto bound = BoundUsingReciprocal(GetRef(op))) { + ret = Intersect(ret, bound.value()); + } + return ret; } @@ -237,6 +253,13 @@ class ConstIntBoundAnalyzer::Impl Entry ret; ret.min_value = InfAwareAdd(a.min_value, -b.max_value); ret.max_value = InfAwareAdd(a.max_value, -b.min_value); + + if (auto bound = BoundUsingReciprocal(GetRef(op))) { + ret = Intersect(ret, bound.value()); + } + if (auto bound = BoundUsingReciprocal(Sub(op->b, op->a))) { + ret = Intersect(ret, Negative(bound.value())); + } return ret; } @@ -628,6 +651,25 @@ class ConstIntBoundAnalyzer::Impl ret.max_value = std::min(a.max_value, b.max_value); return ret; } + /*! + * \brief Flip the sign of a set. + * \param entry The set of values + */ + static Entry Negative(Entry entry) { + Entry ret; + if (entry.max_value == kPosInf) { + ret.min_value = kNegInf; + } else { + ret.min_value = -entry.max_value; + } + if (entry.min_value == kNegInf) { + ret.max_value = kPosInf; + } else { + ret.max_value = -entry.min_value; + } + + return ret; + } /*! * \brief return everything dtype can represent. * \param dtype The data type. @@ -733,6 +775,164 @@ class ConstIntBoundAnalyzer::Impl std::ceil(std::log2(arg_bounds.max_value))); } } + + std::optional BoundUsingReciprocal(PrimExpr expr) { + // Match expressions of the form `(A+B)*C - (A*B)*D`. Depending on + // previous simplifications, the exact form of the expression may vary. + auto opt_special_case = [&]() -> std::optional> { + PVar A, B, C, D; + + if (PMatchesOneOf{ + (A + B) * C - (A * B) * D, + (A + B) * C - (B * A) * D, + } + .Match(expr)) { + return std::tuple{VisitExpr(A.Eval()), VisitExpr(B.Eval()), VisitExpr(C.Eval()), + VisitExpr(D.Eval())}; + } else if (PMatchesOneOf{ + (A + B) * C - A * B, + (A + B) * C - B * A, + } + .Match(expr)) { + return std::tuple{VisitExpr(A.Eval()), VisitExpr(B.Eval()), VisitExpr(C.Eval()), + MakeBound(1, 1)}; + } else if (PMatchesOneOf{ + (A * B) * D - (A + B) * C, + (B * A) * D - (A + B) * C, + } + .Match(expr)) { + return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())), + Negative(VisitExpr(C.Eval())), Negative(VisitExpr(D.Eval()))}; + } else if (PMatchesOneOf{ + A * B - (A + B) * C, + B * A - (A + B) * C, + } + .Match(expr)) { + return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())), + Negative(VisitExpr(C.Eval())), MakeBound(-1, -1)}; + } else if (PMatchesOneOf{ + (A * B) * D + (A + B) * C, + (B * A) * D + (A + B) * C, + (A + B) * C + (A * B) * D, + (A + B) * C + (B * A) * D, + } + .Match(expr)) { + return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())), + VisitExpr(C.Eval()), Negative(VisitExpr(D.Eval()))}; + } else if (PMatchesOneOf{ + (A * B) + (A + B) * C, + (B * A) + (A + B) * C, + (A + B) * C + (A * B), + (A + B) * C + (B * A), + } + .Match(expr)) { + return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())), + VisitExpr(C.Eval()), MakeBound(-1, -1)}; + } else { + return std::nullopt; + } + }(); + + if (!opt_special_case.has_value()) { + return std::nullopt; + } + // Unpacking the tuple would be cleaner with a structured binding. + // However, until C++20, structured bindings cannot be captured for + // use in a lambda function. + auto A_bound = std::get<0>(*opt_special_case); + auto B_bound = std::get<1>(*opt_special_case); + auto C_bound = std::get<2>(*opt_special_case); + auto D_bound = std::get<3>(*opt_special_case); + + // If C and D have different signs, flip the signs of A/B/C so + // that C will match the sign of D. + if ((D_bound.max_value < 0 && C_bound.min_value > 0) || + (D_bound.min_value > 0 && C_bound.max_value < 0)) { + A_bound = Negative(A_bound); + B_bound = Negative(B_bound); + C_bound = Negative(C_bound); + } + + // If all terms are negative, then we'll be providing an upper bound + // rather than a lower bound. To avoid code duplication, flip all the + // signs here, find a lower bound, then flip the sign to produce the + // upper bound of the original expression. + bool all_terms_negative = (A_bound.max_value < 0 && B_bound.max_value < 0 && + C_bound.max_value < 0 && D_bound.max_value < 0); + if (all_terms_negative) { + A_bound = Negative(A_bound); + B_bound = Negative(B_bound); + C_bound = Negative(C_bound); + D_bound = Negative(D_bound); + } + + bool all_terms_positive = (A_bound.min_value > 0 && B_bound.min_value > 0 && + C_bound.min_value > 0 && D_bound.min_value > 0); + if (!all_terms_positive) { + return std::nullopt; + } + + // (A + B) * C - (A * B) * D + // (A*B*C*D) * ( (A+B)/(A*B*D) - 1/C ) + // (A*B*C*D) * ( (1/A + 1/B)/D - 1/C ) + // (A*B*C*D) * (1/(A*D) + 1/(B*D) - 1/C) + // + // The constant (A*B*C*D) is positive, and its minimum value is the + // product of the minimum values of A, B, C, and D. If the reciprocal + // term (1/(A*D) + 1/(B*D) - 1/C) is positive, then this constant can + // be used to provide a lower bound on the expression. + + bool reciprocal_term_is_positive = [&]() { + if (D_bound.max_value == ConstIntBound::kPosInf) { + // If D can grow without bound, the `1/(A*D)` and `1/(B*D)` + // terms will approach zero, at which point the `-1/C` term + // will determine the sign the sign. + return false; + } + + if (std::min(A_bound.max_value, B_bound.max_value) * D_bound.max_value <= C_bound.min_value) { + // 1/(A*D) + 1/(B*D) - 1/C is positive if 1/C < 1/(A*D) + 1/(B*D). + // Since each term is positive, this condition can hold if either + // A*D <= C or B*D <= C. + return true; + } + if (A_bound.max_value != ConstIntBound::kPosInf && + B_bound.max_value != ConstIntBound::kPosInf) { + // Even if neither term is sufficient on its own, if both A and B + // have known upper bounds, the inequality 1/C < 1/(A*D) + 1/(B*D) + // may still be provable. + // + // The maximum value of the LHS is found when C is minimized. The + // minimum value of the RHS is found when A, B, and D are + // maximized. If the condition holds in this case, then it holds + // in all cases. + // + // 1/C_min < 1/(A_max * D_max) + 1/(B_max*D_max) + // A_max*B_max*D_max < C_min*B_max + C_min*A_max + // A_max*B_max*D_max < C_min*(A_max + B_max) + // + if (A_bound.max_value * B_bound.max_value * D_bound.max_value < + C_bound.min_value * (A_bound.max_value + B_bound.max_value)) { + return true; + } + } + return false; + }(); + + if (!reciprocal_term_is_positive) { + return std::nullopt; + } + + auto ret = Everything(expr->dtype); + ret.min_value = A_bound.min_value * B_bound.min_value * C_bound.min_value * D_bound.min_value; + + // If we flipped the sign of the original expression, flip the sign of + // the resulting set of possible values. + if (all_terms_negative) { + ret = Negative(ret); + } + return ret; + } }; ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) const { diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 0eaaff5ba838..d063b872e938 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -1768,6 +1768,17 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) { if (merge_constants) { return RecursiveRewrite(merge_constants.value()); } + + auto common_factor = [&]() -> int64_t { + auto modular_a = analyzer_->modular_set(ret->a); + auto modular_b = analyzer_->modular_set(ret->b); + auto gcd_lhs = ZeroAwareGCD(modular_a->base, modular_a->coeff); + auto gcd_rhs = ZeroAwareGCD(modular_b->base, modular_b->coeff); + return ZeroAwareGCD(gcd_lhs, gcd_rhs); + }(); + if (common_factor > 1) { + return RecursiveRewrite(floordiv(ret->a, common_factor) < floordiv(ret->b, common_factor)); + } } return std::move(ret); } diff --git a/tests/python/arith/test_arith_const_int_bound.py b/tests/python/arith/test_arith_const_int_bound.py index 5667c79aaced..044e4d9a2151 100644 --- a/tests/python/arith/test_arith_const_int_bound.py +++ b/tests/python/arith/test_arith_const_int_bound.py @@ -19,6 +19,7 @@ import tvm.testing from tvm import te +from tvm.arith import ConstIntBound def test_dtype_bound(): @@ -94,6 +95,98 @@ def test_add_sub_bound(): assert bd.max_value == bd.POS_INF +def test_lower_bound_using_difference_of_reciprocals(): + """Special handling for differences of reciprocals + + These terms can appear when comparing the number of operations for + different orderings of matrix multiplications, with A, B, and C + known to be positive values. + + In these cases, comparing `(A+B)*C < A*B` is equivalent to + `1/A + 1/B < 1/C`. Working in terms of the reciprocals + allows the ConstIntBound analyzer to provide a tighter + bound for these differences than would otherwise be + available. + + For `(A+B)*C - A*B`, the normal bottom-up integer bounds are unable to + provide the bounds required to provide these inequalities, because they + treat the terms as uncorrelated. That is, they assume that `(A+B)*C` may + achieve its minimum while `A*B` simultaneously achieves its maximum. + + """ + analyzer = tvm.arith.Analyzer() + A, B, C = [te.var(letter, "int64") for letter in "ABC"] + + analyzer.update(A, ConstIntBound(1, 4095)) + analyzer.update(B, ConstIntBound(1, 4095)) + analyzer.update(C, ConstIntBound(2048, 2048)) + + bd = analyzer.const_int_bound((A + B) * C - A * B) + assert bd.min_value == 2048 + + bd = analyzer.const_int_bound((A + B) * C - B * A) + assert bd.min_value == 2048 + + +def test_lower_bound_using_difference_of_reciprocals_with_dominant_term(): + """Like `test_lower_bound_using_difference_of_reciprocal`, with single term known + + If a single term is enough to know the sign of `1/A + 1/B - 1/C`, + then we can still provide a bound. + """ + analyzer = tvm.arith.Analyzer() + A, B, C = [te.var(letter, "int64") for letter in "ABC"] + + analyzer.update(A, ConstIntBound(1, 1024)) + analyzer.update(B, ConstIntBound(1, ConstIntBound.POS_INF)) + analyzer.update(C, ConstIntBound(2048, 2048)) + + bd = analyzer.const_int_bound((A + B) * C - A * B) + assert bd.min_value == 2048 + + bd = analyzer.const_int_bound((B + A) * C - A * B) + assert bd.min_value == 2048 + + +def test_upper_bound_using_difference_of_reciprocals(): + """Upper bound for known negative terms + + Like `test_lower_bound_using_difference_of_reciprocals`, but with the terms + reversed. + """ + analyzer = tvm.arith.Analyzer() + A, B, C = [te.var(letter, "int64") for letter in "ABC"] + + analyzer.update(A, ConstIntBound(1, 4095)) + analyzer.update(B, ConstIntBound(1, 4095)) + analyzer.update(C, ConstIntBound(2048, 2048)) + + bd = analyzer.const_int_bound(A * B - (A + B) * C) + assert bd.max_value == -2048 + bd = analyzer.const_int_bound(B * A - (A + B) * C) + assert bd.max_value == -2048 + + +def test_upper_bound_using_difference_of_reciprocals_with_dominant_term(): + """Upper bound for known negative terms + + Like `test_lower_bound_using_difference_of_reciprocals_with_dominant_term`, + but with the terms reversed. + """ + analyzer = tvm.arith.Analyzer() + A, B, C = [te.var(letter, "int64") for letter in "ABC"] + + analyzer.update(A, ConstIntBound(1, 1024)) + analyzer.update(B, ConstIntBound(1, ConstIntBound.POS_INF)) + analyzer.update(C, ConstIntBound(2048, 2048)) + + bd = analyzer.const_int_bound(A * B - (A + B) * C) + assert bd.max_value == -2048 + + bd = analyzer.const_int_bound(A * B - (B + A) * C) + assert bd.max_value == -2048 + + def test_mul_bound(): analyzer = tvm.arith.Analyzer() x, y = te.var("x"), te.var("y") diff --git a/tests/python/arith/test_arith_rewrite_simplify.py b/tests/python/arith/test_arith_rewrite_simplify.py index 6433dc2dece9..5d2c3aa283cf 100644 --- a/tests/python/arith/test_arith_rewrite_simplify.py +++ b/tests/python/arith/test_arith_rewrite_simplify.py @@ -983,6 +983,30 @@ class TestComparisons(BaseCompare): TestCase(y * y >= 0, tvm.tir.const(1, "bool"), y <= 0), TestCase(x * 6 <= -3, tvm.tir.const(0, "bool"), x >= 0), TestCase(tmod(y - 1, 3) == 0, tmod(y + (-1), 3) == 0), + # Special inequality cases + TestCase( + x * y < (x + y) * 2048, + tvm.tir.const(1, "bool"), + [x > 0, y > 0, x < 2048], + ), + TestCase( + x * y < (x + y) * 2048, + tvm.tir.const(1, "bool"), + [x > 0, y > 0, x < 4096, y < 4096], + ), + TestCase( + # Both sides are divisible by 8192 + x * y * 8192 < (y + x) * 16777216, + tvm.tir.const(1, "bool"), + [x > 0, y > 0, x < 4096, y < 4096], + ), + TestCase( + # The two sides have co-prime factors, but the bounds are + # still sufficient to prove the inequality. + x * y * 59 < (y + x) * 176128, + tvm.tir.const(1, "bool"), + [x > 0, y > 0, x < 4096, y < 4096], + ), ) From b2aa44ff72de0b53eb7fa9715690266fe74df344 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 16 Feb 2024 21:20:41 -0600 Subject: [PATCH 2/4] Parametrize ConstIntBound tests --- .../arith/test_arith_const_int_bound.py | 549 ++++++------------ 1 file changed, 183 insertions(+), 366 deletions(-) diff --git a/tests/python/arith/test_arith_const_int_bound.py b/tests/python/arith/test_arith_const_int_bound.py index 044e4d9a2151..c22e1dcb787c 100644 --- a/tests/python/arith/test_arith_const_int_bound.py +++ b/tests/python/arith/test_arith_const_int_bound.py @@ -15,87 +15,88 @@ # specific language governing permissions and limitations # under the License. +import contextlib + import tvm import tvm.testing from tvm import te from tvm.arith import ConstIntBound +NEG_INF = ConstIntBound.NEG_INF +POS_INF = ConstIntBound.POS_INF -def test_dtype_bound(): - analyzer = tvm.arith.Analyzer() - x = te.var("x", dtype="int64") - bd = analyzer.const_int_bound(x) - assert bd.min_value == bd.NEG_INF - assert bd.max_value == bd.POS_INF +class TestCase: + def __init__(self, expr, expected_bounds, known_bounds=None, constraint=None): + self.expr = expr + self.expected_bounds = expected_bounds + if known_bounds is None: + self.known_bounds = {} + else: + self.known_bounds = known_bounds - x = te.var("x", dtype="int8") - bd = analyzer.const_int_bound(x) - assert bd.min_value == -128 - assert bd.max_value == 127 + self.constraint = constraint + + @property + def __name__(self): + return str(self.expr) + + +class BaseCompare: + def test_const_bounds(self, test_case): + analyzer = tvm.arith.Analyzer() + + for var, bounds in test_case.known_bounds.items(): + analyzer.update(var, ConstIntBound(*bounds)) + + with contextlib.ExitStack() as stack: + if test_case.constraint is not None: + stack.enter_context(analyzer.constraint_scope(test_case.constraint)) + + bounds = analyzer.const_int_bound(test_case.expr) + + if test_case.expected_bounds[0] is None: + assert bounds.max_value == test_case.expected_bounds[1] + elif test_case.expected_bounds[1] is None: + assert bounds.min_value == test_case.expected_bounds[0] + else: + assert (bounds.min_value, bounds.max_value) == test_case.expected_bounds - x = te.var("x", dtype="uint8") - bd = analyzer.const_int_bound(x) - assert bd.min_value == 0 - assert bd.max_value == 255 +class TestDataType(BaseCompare): + test_case = tvm.testing.parameter( + TestCase(te.var("x", dtype="int64"), (NEG_INF, POS_INF)), + TestCase(te.var("x", dtype="int8"), (-128, 127)), + TestCase(te.var("x", dtype="uint8"), (0, 255)), + TestCase(te.size_var("x", dtype="int32"), (0, POS_INF)), + ) -def test_cast_bound(): - analyzer = tvm.arith.Analyzer() + +class TestCastBound(BaseCompare): x = te.var("x", dtype="int8") tmod = tvm.tir.truncmod - bd = analyzer.const_int_bound(tmod(x, 3).astype("uint32")) - assert bd.min_value == 0 - assert bd.max_value == 2 - - bd = analyzer.const_int_bound(tmod(x, 3).astype("float32").astype("int32")) - assert bd.min_value == -2 - assert bd.max_value == 2 - - -def test_add_sub_bound(): - analyzer = tvm.arith.Analyzer() - x, y = te.var("x", "int64"), te.var("y", "int64") - bd = analyzer.const_int_bound(x + y) - assert bd.min_value == bd.NEG_INF - assert bd.max_value == bd.POS_INF - - analyzer.update(x, tvm.arith.ConstIntBound(0, 4)) - analyzer.update(y, tvm.arith.ConstIntBound(1, 10)) - bd = analyzer.const_int_bound(x + y) - assert bd.min_value == 1 - assert bd.max_value == 14 - - bd = analyzer.const_int_bound(x - y) - assert bd.min_value == -10 - assert bd.max_value == 3 - - analyzer.update(x, tvm.arith.ConstIntBound(0, bd.POS_INF), override=True) - bd = analyzer.const_int_bound(x - y) - assert bd.min_value == -10 - assert bd.max_value == bd.POS_INF - - bd = analyzer.const_int_bound(1 - x) - assert bd.min_value == bd.NEG_INF - assert bd.max_value == 1 - - ## constants with negative or positive max(int64) occassionally show up - ## in models, this is to ensure we can handle those cases - analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.NEG_INF), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True) - bd = analyzer.const_int_bound(x + y) - assert bd.min_value == bd.NEG_INF - assert bd.max_value == bd.POS_INF - - analyzer.update(x, tvm.arith.ConstIntBound(bd.POS_INF, bd.POS_INF), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True) - bd = analyzer.const_int_bound(x + y) - assert bd.min_value == bd.NEG_INF - assert bd.max_value == bd.POS_INF - - -def test_lower_bound_using_difference_of_reciprocals(): + + test_case = tvm.testing.parameter( + TestCase(tmod(x, 3).astype("uint32"), (0, 2)), + TestCase(tmod(x, 3).astype("float32").astype("int32"), (-2, 2)), + ) + + +class TestAddSubBound(BaseCompare): + x = te.var("x", "int64") + y = te.var("y", "int64") + + test_case = tvm.testing.parameter( + TestCase(x + y, (NEG_INF, POS_INF)), + TestCase(x + y, (1, 14), known_bounds={x: (0, 4), y: (1, 10)}), + TestCase(x - y, (-10, 3), known_bounds={x: (0, 4), y: (1, 10)}), + TestCase(x - y, (-10, POS_INF), known_bounds={x: (0, POS_INF), y: (1, 10)}), + TestCase(1 - x, (NEG_INF, 1), known_bounds={x: (0, POS_INF), y: (1, 10)}), + ) + + +class TestBoundsUsingReciprocals(BaseCompare): """Special handling for differences of reciprocals These terms can appear when comparing the number of operations for @@ -112,369 +113,185 @@ def test_lower_bound_using_difference_of_reciprocals(): provide the bounds required to provide these inequalities, because they treat the terms as uncorrelated. That is, they assume that `(A+B)*C` may achieve its minimum while `A*B` simultaneously achieves its maximum. - """ - analyzer = tvm.arith.Analyzer() - A, B, C = [te.var(letter, "int64") for letter in "ABC"] - - analyzer.update(A, ConstIntBound(1, 4095)) - analyzer.update(B, ConstIntBound(1, 4095)) - analyzer.update(C, ConstIntBound(2048, 2048)) - - bd = analyzer.const_int_bound((A + B) * C - A * B) - assert bd.min_value == 2048 - - bd = analyzer.const_int_bound((A + B) * C - B * A) - assert bd.min_value == 2048 - - -def test_lower_bound_using_difference_of_reciprocals_with_dominant_term(): - """Like `test_lower_bound_using_difference_of_reciprocal`, with single term known - - If a single term is enough to know the sign of `1/A + 1/B - 1/C`, - then we can still provide a bound. - """ - analyzer = tvm.arith.Analyzer() - A, B, C = [te.var(letter, "int64") for letter in "ABC"] - - analyzer.update(A, ConstIntBound(1, 1024)) - analyzer.update(B, ConstIntBound(1, ConstIntBound.POS_INF)) - analyzer.update(C, ConstIntBound(2048, 2048)) - - bd = analyzer.const_int_bound((A + B) * C - A * B) - assert bd.min_value == 2048 - - bd = analyzer.const_int_bound((B + A) * C - A * B) - assert bd.min_value == 2048 - -def test_upper_bound_using_difference_of_reciprocals(): - """Upper bound for known negative terms - - Like `test_lower_bound_using_difference_of_reciprocals`, but with the terms - reversed. - """ - analyzer = tvm.arith.Analyzer() A, B, C = [te.var(letter, "int64") for letter in "ABC"] - analyzer.update(A, ConstIntBound(1, 4095)) - analyzer.update(B, ConstIntBound(1, 4095)) - analyzer.update(C, ConstIntBound(2048, 2048)) - - bd = analyzer.const_int_bound(A * B - (A + B) * C) - assert bd.max_value == -2048 - bd = analyzer.const_int_bound(B * A - (A + B) * C) - assert bd.max_value == -2048 + symmetric_bounds = {A: (1, 4095), B: (1, 4095), C: (2048, 2048)} + asymmetric_bounds = {A: (1, 1024), B: (1, POS_INF), C: (2048, 2048)} + test_case = tvm.testing.parameter( + TestCase((A + B) * C - A * B, (2048, None), known_bounds=symmetric_bounds), + TestCase((A + B) * C - B * A, (2048, None), known_bounds=symmetric_bounds), + TestCase(A * B - (A + B) * C, (None, -2048), known_bounds=symmetric_bounds), + TestCase(B * A - (A + B) * C, (None, -2048), known_bounds=symmetric_bounds), + TestCase((A + B) * C - A * B, (2048, None), known_bounds=asymmetric_bounds), + TestCase((A + B) * C - B * A, (2048, None), known_bounds=asymmetric_bounds), + TestCase(A * B - (A + B) * C, (None, -2048), known_bounds=asymmetric_bounds), + TestCase(B * A - (A + B) * C, (None, -2048), known_bounds=asymmetric_bounds), + ) -def test_upper_bound_using_difference_of_reciprocals_with_dominant_term(): - """Upper bound for known negative terms - Like `test_lower_bound_using_difference_of_reciprocals_with_dominant_term`, - but with the terms reversed. - """ - analyzer = tvm.arith.Analyzer() - A, B, C = [te.var(letter, "int64") for letter in "ABC"] - - analyzer.update(A, ConstIntBound(1, 1024)) - analyzer.update(B, ConstIntBound(1, ConstIntBound.POS_INF)) - analyzer.update(C, ConstIntBound(2048, 2048)) - - bd = analyzer.const_int_bound(A * B - (A + B) * C) - assert bd.max_value == -2048 - - bd = analyzer.const_int_bound(A * B - (B + A) * C) - assert bd.max_value == -2048 - - -def test_mul_bound(): - analyzer = tvm.arith.Analyzer() +class TestMulBound(BaseCompare): x, y = te.var("x"), te.var("y") - analyzer.update(x, tvm.arith.ConstIntBound(-2, 4)) - analyzer.update(y, tvm.arith.ConstIntBound(4, 10)) - bd = analyzer.const_int_bound(x * y + 20) - assert bd.min_value == 0 - assert bd.max_value == 60 - - analyzer.update(x, tvm.arith.ConstIntBound(-3, 4), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(-8, 2), override=True) - bd = analyzer.const_int_bound(x * y) - assert bd.min_value == -32 - assert bd.max_value == 24 + test_case = tvm.testing.parameter( + TestCase(x * y + 20, (0, 60), {x: (-2, 4), y: (4, 10)}), + TestCase(x * y, (-32, 24), {x: (-3, 4), y: (-8, 2)}), + TestCase(x * y, (NEG_INF, POS_INF), {x: (NEG_INF, 4), y: (-8, 2)}), + ) - analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, 4), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(-8, 2), override=True) - bd = analyzer.const_int_bound(x * y) - assert bd.min_value == bd.NEG_INF - assert bd.max_value == bd.POS_INF - -def test_truncdiv_bound(): - analyzer = tvm.arith.Analyzer() +class TestTruncDivBound(BaseCompare): x, y = te.var("x"), te.var("y") - tdiv = tvm.tir.truncdiv - - analyzer.update(x, tvm.arith.ConstIntBound(-9, 4)) - analyzer.update(y, tvm.arith.ConstIntBound(4, 10)) - bd = analyzer.const_int_bound(tdiv(x, y)) - assert bd.min_value == -2 - - analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(-2, 0), override=True) - bd = analyzer.const_int_bound(tdiv(x, y)) - assert bd.min_value == -4 - assert bd.max_value == 9 - analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, 4), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(-2, 1), override=True) - bd = analyzer.const_int_bound(tdiv(x, y)) - assert bd.min_value == bd.NEG_INF - assert bd.max_value == bd.POS_INF + expr = tvm.tir.truncdiv(x, y) - analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(-4, 12), override=True) - bd = analyzer.const_int_bound(tdiv(x, y)) - assert bd.min_value == -9 - assert bd.max_value == 9 + test_case = tvm.testing.parameter( + TestCase(expr, (-2, None), {x: (-9, 4), y: (4, 10)}), + TestCase(expr, (-4, 9), {x: (-9, 4), y: (-2, 0)}), + TestCase(expr, (NEG_INF, POS_INF), {x: (NEG_INF, 4), y: (-2, 1)}), + TestCase(expr, (-9, 9), {x: (-9, 4), y: (-4, 12)}), + ) -def test_truncmod_bound(): - analyzer = tvm.arith.Analyzer() +class TestTruncModBound(BaseCompare): x, y = te.var("x"), te.var("y") - tmod = tvm.tir.truncmod - - analyzer.update(x, tvm.arith.ConstIntBound(-9, 4)) - analyzer.update(y, tvm.arith.ConstIntBound(4, 10)) - bd = analyzer.const_int_bound(tmod(x, y)) - assert bd.min_value == -9 - assert bd.max_value == 4 + expr = tvm.tir.truncmod(x, y) - analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True) - bd = analyzer.const_int_bound(tmod(x, y)) - assert bd.min_value == -9 - assert bd.max_value == 9 + test_case = tvm.testing.parameter( + TestCase(expr, (-9, 4), {x: (-9, 4), y: (4, 10)}), + TestCase(expr, (-9, 9), {x: (NEG_INF, POS_INF), y: (4, 10)}), + TestCase(expr, (0, 9), {x: (1, POS_INF), y: (4, 10)}), + ) - analyzer.update(x, tvm.arith.ConstIntBound(1, bd.POS_INF), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True) - bd = analyzer.const_int_bound(tmod(x, y)) - assert bd.min_value == 0 - assert bd.max_value == 9 - -def test_floordiv_bound(): - analyzer = tvm.arith.Analyzer() - x, y = te.var("x"), te.var("y") - fld = tvm.te.floordiv - analyzer.update(x, tvm.arith.ConstIntBound(-9, 4)) - analyzer.update(y, tvm.arith.ConstIntBound(4, 10)) - bd = analyzer.const_int_bound(fld(x, y)) - assert bd.min_value == -9 // 4 - - analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(-2, 0), override=True) - bd = analyzer.const_int_bound(fld(x, y)) - assert bd.min_value == -4 - assert bd.max_value == 9 - - analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, 4), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(-2, 1), override=True) - bd = analyzer.const_int_bound(fld(x, y)) - assert bd.min_value == bd.NEG_INF - assert bd.max_value == bd.POS_INF - - analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(-4, 12), override=True) - bd = analyzer.const_int_bound(fld(x, y)) - assert bd.min_value == -9 - assert bd.max_value == 9 - - # Test handling unsigned integers well - x, y = te.var("x", dtype="uint32"), te.var("y", dtype="uint32") - analyzer.update(x, tvm.arith.ConstIntBound(1, 4), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(0, 12), override=True) - bd = analyzer.const_int_bound(fld(x, y)) - assert bd.min_value == 0 - assert bd.max_value == 4 - - -def test_floormod_bound(): - analyzer = tvm.arith.Analyzer() +class TestFloorDivBound(BaseCompare): x, y = te.var("x"), te.var("y") - flm = tvm.te.floormod + ux = te.var("x", dtype="uint32") + uy = te.var("y", dtype="uint32") - analyzer.update(x, tvm.arith.ConstIntBound(-9, 4)) - analyzer.update(y, tvm.arith.ConstIntBound(4, 10)) - bd = analyzer.const_int_bound(flm(x, y)) - assert bd.min_value == 0 - assert bd.max_value == 9 + test_case = tvm.testing.parameter( + TestCase(x // y, (-9 // 4, None), {x: (-9, 4), y: (4, 10)}), + TestCase(x // y, (-4, 9), {x: (-9, 4), y: (-2, 0)}), + TestCase(x // y, (NEG_INF, POS_INF), {x: (NEG_INF, 4), y: (-2, 1)}), + TestCase(x // y, (-9, 9), {x: (-9, 4), y: (-4, 12)}), + TestCase(ux // uy, (0, 4), {ux: (1, 4), uy: (0, 12)}), + ) - analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True) - bd = analyzer.const_int_bound(flm(x, y)) - assert bd.min_value == 0 - assert bd.max_value == 9 - analyzer.update(x, tvm.arith.ConstIntBound(1, bd.POS_INF), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True) - bd = analyzer.const_int_bound(flm(x, y)) - assert bd.min_value == 0 - assert bd.max_value == 9 - - -def test_min_max_bound(): - analyzer = tvm.arith.Analyzer() +class TestFloorModBound(BaseCompare): x, y = te.var("x"), te.var("y") - analyzer.update(x, tvm.arith.ConstIntBound(-9, 11)) - analyzer.update(y, tvm.arith.ConstIntBound(4, 10)) - bd = analyzer.const_int_bound(tvm.te.min(x, y)) - assert bd.min_value == -9 - assert bd.max_value == 10 - - analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True) - bd = analyzer.const_int_bound(tvm.te.min(x, y)) - assert bd.min_value == bd.NEG_INF - assert bd.max_value == 10 - - bd = analyzer.const_int_bound(tvm.te.max(x, y)) - assert bd.min_value == 4 - assert bd.max_value == bd.POS_INF + test_case = tvm.testing.parameter( + TestCase(x % y, (0, 9), {x: (-9, 4), y: (4, 10)}), + TestCase(x % y, (0, 9), {x: (NEG_INF, POS_INF), y: (4, 10)}), + TestCase(x % y, (0, 9), {x: (1, POS_INF), y: (4, 10)}), + ) - analyzer.update(x, tvm.arith.ConstIntBound(1, bd.POS_INF), override=True) - analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True) - bd = analyzer.const_int_bound(tvm.te.max(x, y)) - assert bd.min_value == 4 - assert bd.max_value == bd.POS_INF - -def test_select_bound(): - analyzer = tvm.arith.Analyzer() +class TestMinMaxBound(BaseCompare): x, y = te.var("x"), te.var("y") - analyzer.update(x, tvm.arith.ConstIntBound(-9, 11)) - analyzer.update(y, tvm.arith.ConstIntBound(4, 10)) - - bd = analyzer.const_int_bound(tvm.tir.Select(x > 1, (y < 0).astype("int32"), y + 1)) - assert bd.min_value == 0 - assert bd.max_value == 11 + test_case = tvm.testing.parameter( + TestCase(tvm.te.min(x, y), (-9, 10), {x: (-9, 11), y: (4, 10)}), + TestCase(tvm.te.min(x, y), (NEG_INF, 10), {x: (NEG_INF, POS_INF), y: (4, 10)}), + TestCase(tvm.te.max(x, y), (4, POS_INF), {x: (NEG_INF, POS_INF), y: (4, 10)}), + TestCase(tvm.te.max(x, y), (4, POS_INF), {x: (1, POS_INF), y: (4, 10)}), + ) -def test_shift_and_bound(): - analyzer = tvm.arith.Analyzer() +class TestSelectBound(BaseCompare): x, y = te.var("x"), te.var("y") - analyzer.update(x, tvm.arith.ConstIntBound(-9, 11)) - analyzer.update(y, tvm.arith.ConstIntBound(2, 10)) + test_case = tvm.testing.parameter( + TestCase( + tvm.tir.Select(x > 1, (y < 0).astype("int32"), y + 1), + (0, 11), + {x: (-9, 11), y: (4, 10)}, + ), + ) - bd = analyzer.const_int_bound(x >> y) - assert bd.min_value == -3 - assert bd.max_value == 2 - bd = analyzer.const_int_bound(x & y) - assert bd.min_value == 0 - assert bd.max_value == 10 +class TestShiftAndBound(BaseCompare): + x, y = te.var("x"), te.var("y") - analyzer.update(x, tvm.arith.ConstIntBound(10, 11), override=True) - bd = analyzer.const_int_bound(x & y) - assert bd.min_value == 0 - assert bd.max_value == 10 + test_case = tvm.testing.parameter( + TestCase(x >> y, (-3, 2), {x: (-9, 11), y: (2, 10)}), + TestCase(x & y, (0, 10), {x: (-9, 11), y: (2, 10)}), + TestCase(x & y, (0, 10), {x: (10, 11), y: (2, 10)}), + ) -def test_mix_index_bound(): - analyzer = tvm.arith.Analyzer() +class TestMixIndexBound(BaseCompare): x, y = te.var("x"), te.var("y") tdiv = tvm.tir.truncdiv tmod = tvm.tir.truncmod - analyzer.update(x, tvm.arith.ConstIntBound(0, 24 - 1)) - analyzer.update(y, tvm.arith.ConstIntBound(0, 3 - 1)) - bd = analyzer.const_int_bound(tmod(x, 8) + tdiv(x, 8) * 8) - assert bd.min_value == 0 - assert bd.max_value == 24 - 1 - - bd = analyzer.const_int_bound(y + x * 3) - assert bd.min_value == 0 - assert bd.max_value == 24 * 3 - 1 + test_case = tvm.testing.parameter( + TestCase(tmod(x, 8) + tdiv(x, 8) * 8, (0, 24 - 1), {x: (0, 24 - 1), y: (0, 3 - 1)}), + TestCase(y + x * 3, (0, 24 * 3 - 1), {x: (0, 24 - 1), y: (0, 3 - 1)}), + TestCase( + tmod(x, 7) + tdiv(x, 7) * 7, (0, (23 // 7) * 7 + 6), {x: (0, 24 - 1), y: (0, 3 - 1)} + ), + ) - bd = analyzer.const_int_bound(tmod(x, 7) + tdiv(x, 7) * 7) - assert bd.min_value == 0 - assert bd.max_value == (23 // 7) * 7 + 6 - -def test_size_var_bound(): - analyzer = tvm.arith.Analyzer() - x = te.size_var("x") - bd = analyzer.const_int_bound(x) - assert bd.min_value == 0 - assert bd.max_value == bd.POS_INF - - -def test_let_bound(): - analyzer = tvm.arith.Analyzer() +class TestLetBound(BaseCompare): x = te.var("x") - bd = analyzer.const_int_bound(tvm.tir.Let(x, 1, x + 1)) - assert bd.min_value == 2 - assert bd.max_value == 2 + test_case = tvm.testing.parameter( + TestCase(tvm.tir.Let(x, 1, x + 1), (2, 2)), + ) -def test_floormod_negative_divisor(): - analyzer = tvm.arith.Analyzer() +class TestFloorModNegativeDivisor(BaseCompare): flm, fld = tvm.te.floormod, tvm.te.floordiv a, b = te.var("a"), te.var("b") - analyzer.update(a, tvm.arith.ConstIntBound(0, 6)) - analyzer.update(b, tvm.arith.ConstIntBound(-5, 7)) - bd = analyzer.const_int_bound(flm(a, b)) - assert bd.min_value == -4 - assert bd.max_value == 6 + test_case = tvm.testing.parameter( + TestCase(a % b, (-4, 6), {a: (0, 6), b: (-5, 7)}), + ) + + +class TestDivModAssumeNoZeroDivisor(BaseCompare): + """Divmod non negative expression makes assumption that divide by + zero won't occur this assumption is important to get best result + from symbolic shape programs + """ -def test_divmod_assume_no_zero_divsor(): - # Divmod non negative expression makes assumption that divide by zero won't occur - # this assumption is important to get best result from symbolic shape programs - analyzer = tvm.arith.Analyzer() - flm, fld = tvm.te.floormod, tvm.te.floordiv a, b = te.var("a"), te.var("b") - analyzer.update(a, tvm.arith.ConstIntBound(0, 6)) - analyzer.update(b, tvm.arith.ConstIntBound(0, tvm.arith.ConstIntBound.POS_INF)) - bd = analyzer.const_int_bound(fld(a, b)) - assert bd.min_value == 0 - assert bd.max_value == 6 - bd = analyzer.const_int_bound(flm(a, b)) - assert bd.min_value == 0 - assert bd.max_value == 6 + test_case = tvm.testing.parameter( + TestCase(a // b, (0, 6), {a: (0, 6), b: (0, POS_INF)}), + TestCase(a % b, (0, 6), {a: (0, 6), b: (0, POS_INF)}), + ) -def test_multiple_condition(): - analyzer = tvm.arith.Analyzer() - flm, fld = tvm.te.floormod, tvm.te.floordiv +class TestMultipleCondition(BaseCompare): a = te.var("a") - analyzer.update(a, tvm.arith.ConstIntBound(0, 128)) - with analyzer.constraint_scope(tvm.tir.all(1 <= flm(a, 58), flm(a, 58) < 57)): - bound = analyzer.const_int_bound(flm(a, 58) - 1) - assert bound.min_value == 0 + test_case = tvm.testing.parameter( + TestCase( + a % 58 - 1, + (0, None), + known_bounds={a: (0, 128)}, + constraint=tvm.tir.all(1 <= a % 58, a % 58 < 57), + ), + ) -def test_broadcast_bound(): - analyzer = tvm.arith.Analyzer() +class TestBroadcastBound(BaseCompare): a = te.var("a") - analyzer.update(a, tvm.arith.ConstIntBound(0, 128)) - bound = analyzer.const_int_bound(tvm.tir.Broadcast(a, 4)) - assert bound.min_value == 0 - assert bound.max_value == 128 + test_case = tvm.testing.parameter( + TestCase(tvm.tir.Broadcast(a, 4), (0, 128), {a: (0, 128)}), + ) -def test_ramp_bound(): - analyzer = tvm.arith.Analyzer() +class TestRampBound(BaseCompare): a = te.var("a") - analyzer.update(a, tvm.arith.ConstIntBound(0, 128)) - bound = analyzer.const_int_bound(tvm.tir.Ramp(a, 2, 4) + 2) - assert bound.min_value == 2 - assert bound.max_value == 128 + 2 * 3 + 2 + test_case = tvm.testing.parameter( + TestCase(tvm.tir.Ramp(a, 2, 4) + 2, (2, 128 + 2 * 3 + 2), {a: (0, 128)}), + ) if __name__ == "__main__": From 47a1fbd57f744447fcd032c7debe3cdb314b51e7 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 20 Feb 2024 14:32:07 -0600 Subject: [PATCH 3/4] Benchmark with/without the BoundUsingReciprocal function --- src/arith/const_int_bound.cc | 19 +++++++----- .../arith/test_arith_const_int_bound.py | 29 +++++++++++++++-- .../arith/test_arith_rewrite_simplify.py | 31 +++++++++++++++++-- 3 files changed, 67 insertions(+), 12 deletions(-) diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 5eed998384e1..d75fc46239df 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -28,6 +28,7 @@ #include #include +#include "../support/utils.h" #include "constraint_extract.h" #include "int_operator.h" #include "pattern_match.h" @@ -240,8 +241,10 @@ class ConstIntBoundAnalyzer::Impl ret.min_value = InfAwareAdd(a.min_value, b.min_value); ret.max_value = InfAwareAdd(a.max_value, b.max_value); - if (auto bound = BoundUsingReciprocal(GetRef(op))) { - ret = Intersect(ret, bound.value()); + if (support::BoolEnvironmentVar("TVM_ENABLE_RECIPROCAL_PATTERN_MATCH")) { + if (auto bound = BoundUsingReciprocal(GetRef(op))) { + ret = Intersect(ret, bound.value()); + } } return ret; @@ -254,11 +257,13 @@ class ConstIntBoundAnalyzer::Impl ret.min_value = InfAwareAdd(a.min_value, -b.max_value); ret.max_value = InfAwareAdd(a.max_value, -b.min_value); - if (auto bound = BoundUsingReciprocal(GetRef(op))) { - ret = Intersect(ret, bound.value()); - } - if (auto bound = BoundUsingReciprocal(Sub(op->b, op->a))) { - ret = Intersect(ret, Negative(bound.value())); + if (support::BoolEnvironmentVar("TVM_ENABLE_RECIPROCAL_PATTERN_MATCH")) { + if (auto bound = BoundUsingReciprocal(GetRef(op))) { + ret = Intersect(ret, bound.value()); + } + if (auto bound = BoundUsingReciprocal(Sub(op->b, op->a))) { + ret = Intersect(ret, Negative(bound.value())); + } } return ret; } diff --git a/tests/python/arith/test_arith_const_int_bound.py b/tests/python/arith/test_arith_const_int_bound.py index c22e1dcb787c..bb344d528aa1 100644 --- a/tests/python/arith/test_arith_const_int_bound.py +++ b/tests/python/arith/test_arith_const_int_bound.py @@ -43,8 +43,33 @@ def __name__(self): return str(self.expr) +with_reciprocal_pattern_match = tvm.testing.parameter( + by_dict={ + "with_updated_const_int_analyzer": True, + "without_updated_const_int_analyzer": False, + } +) + +import pytest + + +@pytest.fixture(autouse=True) +def set_reciprocal_pattern_match(with_reciprocal_pattern_match): + import os + + var_name = "TVM_ENABLE_RECIPROCAL_PATTERN_MATCH" + old_value = os.environ.get(var_name) + os.environ[var_name] = str(int(with_reciprocal_pattern_match)) + yield + + if old_value is None: + del os.environ[var_name] + else: + os.environ = old_value + + class BaseCompare: - def test_const_bounds(self, test_case): + def test_const_bounds(self, test_case, benchmark): analyzer = tvm.arith.Analyzer() for var, bounds in test_case.known_bounds.items(): @@ -54,7 +79,7 @@ def test_const_bounds(self, test_case): if test_case.constraint is not None: stack.enter_context(analyzer.constraint_scope(test_case.constraint)) - bounds = analyzer.const_int_bound(test_case.expr) + bounds = benchmark(analyzer.const_int_bound, test_case.expr) if test_case.expected_bounds[0] is None: assert bounds.max_value == test_case.expected_bounds[1] diff --git a/tests/python/arith/test_arith_rewrite_simplify.py b/tests/python/arith/test_arith_rewrite_simplify.py index 5d2c3aa283cf..a8bb0edafa22 100644 --- a/tests/python/arith/test_arith_rewrite_simplify.py +++ b/tests/python/arith/test_arith_rewrite_simplify.py @@ -50,18 +50,43 @@ def __name__(self): return str(self.before) +with_reciprocal_pattern_match = tvm.testing.parameter( + by_dict={ + "with_updated_const_int_analyzer": True, + "without_updated_const_int_analyzer": False, + } +) + + +@pytest.fixture(autouse=True) +def set_reciprocal_pattern_match(with_reciprocal_pattern_match): + import os + + var_name = "TVM_ENABLE_RECIPROCAL_PATTERN_MATCH" + old_value = os.environ.get(var_name) + os.environ[var_name] = str(int(with_reciprocal_pattern_match)) + yield + + if old_value is None: + del os.environ[var_name] + else: + os.environ = old_value + + class BaseCompare: - def test_simplify(self, test_case): + def test_simplify(self, test_case, benchmark): analyzer = tvm.arith.Analyzer() if inspect.isclass(test_case.expected) and issubclass(test_case.expected, Exception): with pytest.raises(test_case.expected): with analyzer.constraint_scope(test_case.constraint): - analyzer.rewrite_simplify(test_case.before) + # analyzer.rewrite_simplify(test_case.before) + benchmark(analyzer.rewrite_simplify, test_case.before) else: with analyzer.constraint_scope(test_case.constraint): - after = analyzer.rewrite_simplify(test_case.before) + # after = analyzer.rewrite_simplify(test_case.before) + after = benchmark(analyzer.rewrite_simplify, test_case.before) assert tvm.ir.structural_equal(after, test_case.expected), ( f"Rewrite didn't match expected.\n" From 18778b84f2668115e8a8091bf87de75623128b07 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 13 Mar 2024 17:01:35 -0500 Subject: [PATCH 4/4] Revert "Benchmark with/without the BoundUsingReciprocal function" This reverts commit 47a1fbd57f744447fcd032c7debe3cdb314b51e7. --- src/arith/const_int_bound.cc | 19 +++++------- .../arith/test_arith_const_int_bound.py | 29 ++--------------- .../arith/test_arith_rewrite_simplify.py | 31 ++----------------- 3 files changed, 12 insertions(+), 67 deletions(-) diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index d75fc46239df..5eed998384e1 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -28,7 +28,6 @@ #include #include -#include "../support/utils.h" #include "constraint_extract.h" #include "int_operator.h" #include "pattern_match.h" @@ -241,10 +240,8 @@ class ConstIntBoundAnalyzer::Impl ret.min_value = InfAwareAdd(a.min_value, b.min_value); ret.max_value = InfAwareAdd(a.max_value, b.max_value); - if (support::BoolEnvironmentVar("TVM_ENABLE_RECIPROCAL_PATTERN_MATCH")) { - if (auto bound = BoundUsingReciprocal(GetRef(op))) { - ret = Intersect(ret, bound.value()); - } + if (auto bound = BoundUsingReciprocal(GetRef(op))) { + ret = Intersect(ret, bound.value()); } return ret; @@ -257,13 +254,11 @@ class ConstIntBoundAnalyzer::Impl ret.min_value = InfAwareAdd(a.min_value, -b.max_value); ret.max_value = InfAwareAdd(a.max_value, -b.min_value); - if (support::BoolEnvironmentVar("TVM_ENABLE_RECIPROCAL_PATTERN_MATCH")) { - if (auto bound = BoundUsingReciprocal(GetRef(op))) { - ret = Intersect(ret, bound.value()); - } - if (auto bound = BoundUsingReciprocal(Sub(op->b, op->a))) { - ret = Intersect(ret, Negative(bound.value())); - } + if (auto bound = BoundUsingReciprocal(GetRef(op))) { + ret = Intersect(ret, bound.value()); + } + if (auto bound = BoundUsingReciprocal(Sub(op->b, op->a))) { + ret = Intersect(ret, Negative(bound.value())); } return ret; } diff --git a/tests/python/arith/test_arith_const_int_bound.py b/tests/python/arith/test_arith_const_int_bound.py index bb344d528aa1..c22e1dcb787c 100644 --- a/tests/python/arith/test_arith_const_int_bound.py +++ b/tests/python/arith/test_arith_const_int_bound.py @@ -43,33 +43,8 @@ def __name__(self): return str(self.expr) -with_reciprocal_pattern_match = tvm.testing.parameter( - by_dict={ - "with_updated_const_int_analyzer": True, - "without_updated_const_int_analyzer": False, - } -) - -import pytest - - -@pytest.fixture(autouse=True) -def set_reciprocal_pattern_match(with_reciprocal_pattern_match): - import os - - var_name = "TVM_ENABLE_RECIPROCAL_PATTERN_MATCH" - old_value = os.environ.get(var_name) - os.environ[var_name] = str(int(with_reciprocal_pattern_match)) - yield - - if old_value is None: - del os.environ[var_name] - else: - os.environ = old_value - - class BaseCompare: - def test_const_bounds(self, test_case, benchmark): + def test_const_bounds(self, test_case): analyzer = tvm.arith.Analyzer() for var, bounds in test_case.known_bounds.items(): @@ -79,7 +54,7 @@ def test_const_bounds(self, test_case, benchmark): if test_case.constraint is not None: stack.enter_context(analyzer.constraint_scope(test_case.constraint)) - bounds = benchmark(analyzer.const_int_bound, test_case.expr) + bounds = analyzer.const_int_bound(test_case.expr) if test_case.expected_bounds[0] is None: assert bounds.max_value == test_case.expected_bounds[1] diff --git a/tests/python/arith/test_arith_rewrite_simplify.py b/tests/python/arith/test_arith_rewrite_simplify.py index a8bb0edafa22..5d2c3aa283cf 100644 --- a/tests/python/arith/test_arith_rewrite_simplify.py +++ b/tests/python/arith/test_arith_rewrite_simplify.py @@ -50,43 +50,18 @@ def __name__(self): return str(self.before) -with_reciprocal_pattern_match = tvm.testing.parameter( - by_dict={ - "with_updated_const_int_analyzer": True, - "without_updated_const_int_analyzer": False, - } -) - - -@pytest.fixture(autouse=True) -def set_reciprocal_pattern_match(with_reciprocal_pattern_match): - import os - - var_name = "TVM_ENABLE_RECIPROCAL_PATTERN_MATCH" - old_value = os.environ.get(var_name) - os.environ[var_name] = str(int(with_reciprocal_pattern_match)) - yield - - if old_value is None: - del os.environ[var_name] - else: - os.environ = old_value - - class BaseCompare: - def test_simplify(self, test_case, benchmark): + def test_simplify(self, test_case): analyzer = tvm.arith.Analyzer() if inspect.isclass(test_case.expected) and issubclass(test_case.expected, Exception): with pytest.raises(test_case.expected): with analyzer.constraint_scope(test_case.constraint): - # analyzer.rewrite_simplify(test_case.before) - benchmark(analyzer.rewrite_simplify, test_case.before) + analyzer.rewrite_simplify(test_case.before) else: with analyzer.constraint_scope(test_case.constraint): - # after = analyzer.rewrite_simplify(test_case.before) - after = benchmark(analyzer.rewrite_simplify, test_case.before) + after = analyzer.rewrite_simplify(test_case.before) assert tvm.ir.structural_equal(after, test_case.expected), ( f"Rewrite didn't match expected.\n"