diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index b80d75a17058..e2d60684da7b 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -305,6 +305,19 @@ class RewriteSimplifier { * (a && b) || c => (a || c) && (b || c) */ kConvertBooleanToAndOfOrs = (1 << 1), + + /* When simplifying a boolean AND or a boolean OR, simplify each + * branch under the assumption that the other branch does not + * already dominate the result. That is, simplify each branch of + * (A && B) under the assumption that the other branch is true, + * and simplify each branch of (A || B) under the assumption that + * the other branch is false. + * + * Example: + * (n < 10) && (n < 5) => (n < 10) + * (n < 10) || (n < 5) => (n < 5) + */ + kApplyConstraintsToBooleanBranches = (1 << 2), }; /*! \brief Enable an optional extension or extensions diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 5e565d7e36c6..6cc2aa9e4591 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -72,6 +72,40 @@ using namespace tir; // handled by CanonicalSimplifier. // +/* Utility for rewriting only boolean portions of an expression + * + * Performs a subset of simplifications done by RewriteSimplifier, + * sufficient to negate a simplified expression. Intended for + * application on an expression that has previously been simplified. + * + * \param expr The boolean expression to be normalized + * + * \returns The normalized boolean expression + */ +PrimExpr NormalizeBooleanOperators(PrimExpr expr) { + PVar x, y; + + while (true) { + if ((!!x).Match(expr)) { + expr = x.Eval(); + } else if ((!(x || y)).Match(expr)) { + return NormalizeBooleanOperators(!x.Eval()) && NormalizeBooleanOperators(!y.Eval()); + } else if ((!(x && y)).Match(expr)) { + return NormalizeBooleanOperators(!x.Eval()) || NormalizeBooleanOperators(!y.Eval()); + } else if ((x >= y).Match(expr) || (!(x < y)).Match(expr) || (!(y > x)).Match(expr)) { + return y.Eval() <= x.Eval(); + } else if ((x > y).Match(expr) || (!(x <= y)).Match(expr) || (!(y >= x)).Match(expr)) { + return y.Eval() < x.Eval(); + } else if ((!(x == y)).Match(expr)) { + return x.Eval() != y.Eval(); + } else if ((!(x != y)).Match(expr)) { + return x.Eval() == y.Eval(); + } else { + return expr; + } + } +} + CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, const PrimExpr& y) { CompareResult output = CompareResult::kUnknown; @@ -261,17 +295,17 @@ std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c for (const PrimExpr& subconstraint : ExtractConstraints(new_constraint)) { if (SideEffect(subconstraint) <= CallEffectKind::kPure) { literal_constraints_.push_back(subconstraint); - // We could apply this during TryMatchLiteralConstraint, but - // that would require performing a rewrite of each expression - // being checked. This way, we only apply a rewrite for each - // constraint being applied. PrimExpr negation; if (subconstraint.dtype().is_bool()) { - negation = Not(subconstraint); + // We could apply NormalizeBooleanOperators during + // TryMatchLiteralConstraint, but that would require + // performing a rewrite of each expression being checked. + // This way, we only apply a rewrite for each constraint being + // applied. + negation = NormalizeBooleanOperators(Not(subconstraint)); } else { negation = subconstraint == make_zero(subconstraint.dtype()); } - negation = operator()(negation); literal_constraints_.push_back(Not(negation)); } } @@ -1557,7 +1591,50 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) { } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { - PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + PrimExpr ret = [&]() -> PrimExpr { + // If this extension isn't enabled, just delegate out. + if (!(enabled_extensions_ & kApplyConstraintsToBooleanBranches)) { + return IRMutatorWithAnalyzer::VisitExpr_(op); + } + + PrimExpr a = op->a; + PrimExpr b = op->b; + + // Alternate which branch is used as the constraint, and which is + // being simplified. Because some sub-analyzers expect their + // constraints to already be simplified, each branch may require + // more than one update. The loop condition allows each branch to + // be visited up to twice, but only performs the second visit if + // necessary. + size_t iterations_since_update = 0; + for (size_t i = 0; i < 4; i++) { + PrimExpr& to_update = (i % 2 == 0) ? a : b; + const PrimExpr& constraint = (i % 2 == 0) ? b : a; + + With context(analyzer_, constraint); + PrimExpr updated = VisitExpr(to_update); + + if (!to_update.same_as(updated)) { + to_update = updated; + iterations_since_update = 0; + } else { + iterations_since_update++; + if (iterations_since_update >= 2) { + break; + } + } + } + + // Only construct a new object if a change has been made. + // Otherwise, follow ExprMutator's convention of returning the + // original object. + if (a.same_as(op->a) && b.same_as(op->b)) { + return GetRef(op); + } else { + return And(a, b); + } + }(); + op = ret.as(); if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); @@ -1601,7 +1678,51 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) { - PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + PrimExpr orig = GetRef(op); + + PrimExpr ret = [&]() -> PrimExpr { + // If this extension isn't enabled, just delegate out. + if (!(enabled_extensions_ & kApplyConstraintsToBooleanBranches)) { + return IRMutatorWithAnalyzer::VisitExpr_(op); + } + + PrimExpr a = op->a; + PrimExpr b = op->b; + + // Alternate which branch is used as the constraint, and which + // is being simplified. Because some sub-analyzers expect their + // constraints to already be simplified, each branch may require + // more than update. The loop condition allows each branch to be + // visited up to twice, but only if performs the second visit if + // necessary. + size_t iterations_since_update = 0; + for (size_t i = 0; i < 4; i++) { + PrimExpr& to_update = (i % 2 == 0) ? a : b; + const PrimExpr& constraint = (i % 2 == 0) ? b : a; + + With context(analyzer_, NormalizeBooleanOperators(Not(constraint))); + PrimExpr updated = VisitExpr(to_update); + + if (!to_update.same_as(updated)) { + to_update = updated; + iterations_since_update = 0; + } else { + iterations_since_update++; + if (iterations_since_update >= 2) { + break; + } + } + } + + // Only construct a new object if a change has been made. + // Otherwise, follow ExprMutator's convention of returning the + // original object. + if (a.same_as(op->a) && b.same_as(op->b)) { + return GetRef(op); + } else { + return Or(a, b); + } + }(); op = ret.as(); if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index 894dfb8ca09f..b6e3581aa614 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -39,6 +39,7 @@ using namespace tir; struct SimplifyConfigNode : public tvm::AttrsNode { bool transitively_prove_inequalities; bool convert_boolean_to_and_of_ors; + bool apply_constraints_to_boolean_branches; TVM_DECLARE_ATTRS(SimplifyConfigNode, "tir.transform.SimplifyConfig") { TVM_ATTR_FIELD(transitively_prove_inequalities) @@ -49,6 +50,12 @@ struct SimplifyConfigNode : public tvm::AttrsNode { TVM_ATTR_FIELD(convert_boolean_to_and_of_ors) .describe("If true, simplify conditionals into an AND of ORs") .set_default(false); + + TVM_ATTR_FIELD(apply_constraints_to_boolean_branches) + .describe( + "If true, simplify each branch of AND/OR " + "under a constraints provided by the other branch") + .set_default(false); } RewriteSimplifier::Extension GetEnabledExtensions() const { @@ -60,6 +67,10 @@ struct SimplifyConfigNode : public tvm::AttrsNode { if (convert_boolean_to_and_of_ors) { flags = RewriteSimplifier::Extension(flags | RewriteSimplifier::kConvertBooleanToAndOfOrs); } + if (apply_constraints_to_boolean_branches) { + flags = RewriteSimplifier::Extension(flags | + RewriteSimplifier::kApplyConstraintsToBooleanBranches); + } return flags; } }; diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py index 46b6858ec773..91ef60f9d3f1 100644 --- a/tests/python/unittest/test_tir_transform_simplify.py +++ b/tests/python/unittest/test_tir_transform_simplify.py @@ -139,6 +139,7 @@ def sls(n, d): class BaseBeforeAfter(tvm.testing.CompareBeforeAfter): transitively_prove_inequalities = False convert_boolean_to_and_of_ors = False + apply_constraints_to_boolean_branches = False def transform(self): def inner(mod): @@ -146,6 +147,7 @@ def inner(mod): "tir.Simplify": { "transitively_prove_inequalities": self.transitively_prove_inequalities, "convert_boolean_to_and_of_ors": self.convert_boolean_to_and_of_ors, + "apply_constraints_to_boolean_branches": self.apply_constraints_to_boolean_branches, } } with tvm.transform.PassContext(config=config): @@ -845,5 +847,147 @@ def expected(A: T.Buffer[1, "bool"], i: T.int32): A[0] = True +class TestSimplifyRHSOfBooleanAndUsingLHS(BaseBeforeAfter): + """Boolean expressions can introduce contexts. + + In `A and B`, the result of `B` only matters when `A` is + true, and can be simplified under that context. This test + simplifies `n < 10` under the assumption that `n < 5`. + """ + + apply_constraints_to_boolean_branches = True + + def before(A: T.Buffer[1, "bool"], n: T.int32): + A[0] = n < 5 and n < 10 + + def expected(A: T.Buffer[1, "bool"], n: T.int32): + A[0] = n < 5 + + +class TestSimplifyLHSOfBooleanAndUsingRHS(BaseBeforeAfter): + """Boolean expressions can introduce contexts for their arguments. + + Like TestSimplifyRHSOfBooleanAndUsingLHS, but using the RHS to + simplify the LHS. + """ + + apply_constraints_to_boolean_branches = True + + def before(A: T.Buffer[1, "bool"], n: T.int32): + A[0] = n < 10 and n < 5 + + def expected(A: T.Buffer[1, "bool"], n: T.int32): + A[0] = n < 5 + + +class TestSimplifyRHSOfBooleanOrUsingLHS(BaseBeforeAfter): + """Boolean expressions can introduce contexts. + + In `A or B`, the result of `B` only matters when `A` is false, so + `B` can be simplified under the assumption that `A` is false. + This test simplifies `n < 5` under the assumption that `!(n < 10)` + """ + + apply_constraints_to_boolean_branches = True + + def before(A: T.Buffer[1, "bool"], n: T.int32): + A[0] = n < 10 or n < 5 + + def expected(A: T.Buffer[1, "bool"], n: T.int32): + A[0] = n < 10 + + +class TestSimplifyLHSOfBooleanOrUsingRHS(BaseBeforeAfter): + """Boolean expressions can introduce contexts for their arguments. + + Like TestSimplifyRHSOfBooleanOrUsingLHS, but using the RHS to + simplify the LHS. + """ + + apply_constraints_to_boolean_branches = True + + def before(A: T.Buffer[1, "bool"], n: T.int32): + A[0] = n < 5 or n < 10 + + def expected(A: T.Buffer[1, "bool"], n: T.int32): + A[0] = n < 10 + + +class TestSimplifyRHSOfBooleanAndUsingLHSWithoutConst(BaseBeforeAfter): + """Boolean expressions can introduce contexts. + + Like TestSimplifyRHSOfBooleanAndUsingLHS, but with variables in + the conditions, preventing ConstIntBoundAnalyzer from handling it. + This proof requires the extension to transitively prove + inequalities. + """ + + apply_constraints_to_boolean_branches = True + transitively_prove_inequalities = True + + def before(A: T.Buffer[1, "bool"], n: T.int32, m: T.int32): + A[0] = n < m + 5 and n < m + 10 + + def expected(A: T.Buffer[1, "bool"], n: T.int32, m: T.int32): + A[0] = n < m + 5 + + +class TestSimplifyLHSOfBooleanAndUsingRHSWithoutConst(BaseBeforeAfter): + """Boolean expressions can introduce contexts for their arguments. + + Like TestSimplifyLHSOfBooleanAndUsingRHS, but with variables in + the conditions, preventing ConstIntBoundAnalyzer from handling it. + This proof requires the extension to transitively prove + inequalities. + """ + + apply_constraints_to_boolean_branches = True + transitively_prove_inequalities = True + + def before(A: T.Buffer[1, "bool"], n: T.int32, m: T.int32): + A[0] = n < m + 10 and n < m + 5 + + def expected(A: T.Buffer[1, "bool"], n: T.int32, m: T.int32): + A[0] = n < m + 5 + + +class TestSimplifyRHSOfBooleanOrUsingLHSWithoutConst(BaseBeforeAfter): + """Boolean expressions can introduce contexts. + + Like TestSimplifyRHSOfBooleanOrUsingLHS, but with variables in the + conditions, preventing ConstIntBoundAnalyzer from handling it. + This proof requires the extension to transitively prove + inequalities. + """ + + apply_constraints_to_boolean_branches = True + transitively_prove_inequalities = True + + def before(A: T.Buffer[1, "bool"], n: T.int32, m: T.int32): + A[0] = n < m + 10 or n < m + 5 + + def expected(A: T.Buffer[1, "bool"], n: T.int32, m: T.int32): + A[0] = n < m + 10 + + +class TestSimplifyLHSOfBooleanOrUsingRHSWithoutConst(BaseBeforeAfter): + """Boolean expressions can introduce contexts for their arguments. + + Like TestSimplifyLHSOfBooleanOrUsingRHS, but with variables in the + conditions, preventing ConstIntBoundAnalyzer from handling it. + This proof requires the extension to transitively prove + inequalities. + """ + + apply_constraints_to_boolean_branches = True + transitively_prove_inequalities = True + + def before(A: T.Buffer[1, "bool"], n: T.int32, m: T.int32): + A[0] = n < m + 5 or n < m + 10 + + def expected(A: T.Buffer[1, "bool"], n: T.int32, m: T.int32): + A[0] = n < m + 10 + + if __name__ == "__main__": tvm.testing.main()