Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,19 @@ class RewriteSimplifier {
* (a && b) || c => (a || c) && (b || c)
*/
kConvertBooleanToAndOfOrs = (1 << 1),

/* When simplifying a boolean AND or a boolean OR, simplify each
* branch under the assumption that the other branch does not
* already dominate the result. That is, simplify each branch of
* (A && B) under the assumption that the other branch is true,
* and simplify each branch of (A || B) under the assumption that
* the other branch is false.
*
* Example:
* (n < 10) && (n < 5) => (n < 10)
* (n < 10) || (n < 5) => (n < 5)
*/
kApplyConstraintsToBooleanBranches = (1 << 2),
};

/*! \brief Enable an optional extension or extensions
Expand Down
137 changes: 129 additions & 8 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,40 @@ using namespace tir;
// handled by CanonicalSimplifier.
//

/* Utility for rewriting only boolean portions of an expression
*
* Performs a subset of simplifications done by RewriteSimplifier,
* sufficient to negate a simplified expression. Intended for
* application on an expression that has previously been simplified.
*
* \param expr The boolean expression to be normalized
*
* \returns The normalized boolean expression
*/
PrimExpr NormalizeBooleanOperators(PrimExpr expr) {
PVar<PrimExpr> x, y;

while (true) {
if ((!!x).Match(expr)) {
expr = x.Eval();
} else if ((!(x || y)).Match(expr)) {
return NormalizeBooleanOperators(!x.Eval()) && NormalizeBooleanOperators(!y.Eval());
} else if ((!(x && y)).Match(expr)) {
return NormalizeBooleanOperators(!x.Eval()) || NormalizeBooleanOperators(!y.Eval());
} else if ((x >= y).Match(expr) || (!(x < y)).Match(expr) || (!(y > x)).Match(expr)) {
return y.Eval() <= x.Eval();
} else if ((x > y).Match(expr) || (!(x <= y)).Match(expr) || (!(y >= x)).Match(expr)) {
return y.Eval() < x.Eval();
} else if ((!(x == y)).Match(expr)) {
return x.Eval() != y.Eval();
} else if ((!(x != y)).Match(expr)) {
return x.Eval() == y.Eval();
} else {
return expr;
}
}
}

CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, const PrimExpr& y) {
CompareResult output = CompareResult::kUnknown;

Expand Down Expand Up @@ -261,17 +295,17 @@ std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c
for (const PrimExpr& subconstraint : ExtractConstraints(new_constraint)) {
if (SideEffect(subconstraint) <= CallEffectKind::kPure) {
literal_constraints_.push_back(subconstraint);
// We could apply this during TryMatchLiteralConstraint, but
// that would require performing a rewrite of each expression
// being checked. This way, we only apply a rewrite for each
// constraint being applied.
PrimExpr negation;
if (subconstraint.dtype().is_bool()) {
negation = Not(subconstraint);
// We could apply NormalizeBooleanOperators during
// TryMatchLiteralConstraint, but that would require
// performing a rewrite of each expression being checked.
// This way, we only apply a rewrite for each constraint being
// applied.
negation = NormalizeBooleanOperators(Not(subconstraint));
} else {
negation = subconstraint == make_zero(subconstraint.dtype());
}
negation = operator()(negation);
literal_constraints_.push_back(Not(negation));
}
}
Expand Down Expand Up @@ -1557,7 +1591,50 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) {
}

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
PrimExpr ret = [&]() -> PrimExpr {
// If this extension isn't enabled, just delegate out.
if (!(enabled_extensions_ & kApplyConstraintsToBooleanBranches)) {
return IRMutatorWithAnalyzer::VisitExpr_(op);
}

PrimExpr a = op->a;
PrimExpr b = op->b;

// Alternate which branch is used as the constraint, and which is
// being simplified. Because some sub-analyzers expect their
// constraints to already be simplified, each branch may require
// more than one update. The loop condition allows each branch to
// be visited up to twice, but only performs the second visit if
// necessary.
size_t iterations_since_update = 0;
for (size_t i = 0; i < 4; i++) {
PrimExpr& to_update = (i % 2 == 0) ? a : b;
const PrimExpr& constraint = (i % 2 == 0) ? b : a;

With<ConstraintContext> context(analyzer_, constraint);
PrimExpr updated = VisitExpr(to_update);

if (!to_update.same_as(updated)) {
to_update = updated;
iterations_since_update = 0;
} else {
iterations_since_update++;
if (iterations_since_update >= 2) {
break;
}
}
}

// Only construct a new object if a change has been made.
// Otherwise, follow ExprMutator's convention of returning the
// original object.
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
} else {
return And(a, b);
}
}();

op = ret.as<AndNode>();

if (auto const_res = TryConstFold<And>(op->a, op->b)) return const_res.value();
Expand Down Expand Up @@ -1601,7 +1678,51 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) {
}

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
PrimExpr orig = GetRef<PrimExpr>(op);

PrimExpr ret = [&]() -> PrimExpr {
// If this extension isn't enabled, just delegate out.
if (!(enabled_extensions_ & kApplyConstraintsToBooleanBranches)) {
return IRMutatorWithAnalyzer::VisitExpr_(op);
}

PrimExpr a = op->a;
PrimExpr b = op->b;

// Alternate which branch is used as the constraint, and which
// is being simplified. Because some sub-analyzers expect their
// constraints to already be simplified, each branch may require
// more than update. The loop condition allows each branch to be
// visited up to twice, but only if performs the second visit if
// necessary.
size_t iterations_since_update = 0;
for (size_t i = 0; i < 4; i++) {
PrimExpr& to_update = (i % 2 == 0) ? a : b;
const PrimExpr& constraint = (i % 2 == 0) ? b : a;

With<ConstraintContext> context(analyzer_, NormalizeBooleanOperators(Not(constraint)));
PrimExpr updated = VisitExpr(to_update);

if (!to_update.same_as(updated)) {
to_update = updated;
iterations_since_update = 0;
} else {
iterations_since_update++;
if (iterations_since_update >= 2) {
break;
}
}
}

// Only construct a new object if a change has been made.
// Otherwise, follow ExprMutator's convention of returning the
// original object.
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
} else {
return Or(a, b);
}
}();

op = ret.as<OrNode>();
if (auto const_res = TryConstFold<Or>(op->a, op->b)) return const_res.value();
Expand Down
11 changes: 11 additions & 0 deletions src/tir/transforms/simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ using namespace tir;
struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> {
bool transitively_prove_inequalities;
bool convert_boolean_to_and_of_ors;
bool apply_constraints_to_boolean_branches;

TVM_DECLARE_ATTRS(SimplifyConfigNode, "tir.transform.SimplifyConfig") {
TVM_ATTR_FIELD(transitively_prove_inequalities)
Expand All @@ -49,6 +50,12 @@ struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> {
TVM_ATTR_FIELD(convert_boolean_to_and_of_ors)
.describe("If true, simplify conditionals into an AND of ORs")
.set_default(false);

TVM_ATTR_FIELD(apply_constraints_to_boolean_branches)
.describe(
"If true, simplify each branch of AND/OR "
"under a constraints provided by the other branch")
.set_default(false);
}

RewriteSimplifier::Extension GetEnabledExtensions() const {
Expand All @@ -60,6 +67,10 @@ struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> {
if (convert_boolean_to_and_of_ors) {
flags = RewriteSimplifier::Extension(flags | RewriteSimplifier::kConvertBooleanToAndOfOrs);
}
if (apply_constraints_to_boolean_branches) {
flags = RewriteSimplifier::Extension(flags |
RewriteSimplifier::kApplyConstraintsToBooleanBranches);
}
return flags;
}
};
Expand Down
144 changes: 144 additions & 0 deletions tests/python/unittest/test_tir_transform_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,15 @@ def sls(n, d):
class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
transitively_prove_inequalities = False
convert_boolean_to_and_of_ors = False
apply_constraints_to_boolean_branches = False

def transform(self):
def inner(mod):
config = {
"tir.Simplify": {
"transitively_prove_inequalities": self.transitively_prove_inequalities,
"convert_boolean_to_and_of_ors": self.convert_boolean_to_and_of_ors,
"apply_constraints_to_boolean_branches": self.apply_constraints_to_boolean_branches,
}
}
with tvm.transform.PassContext(config=config):
Expand Down Expand Up @@ -845,5 +847,147 @@ def expected(A: T.Buffer[1, "bool"], i: T.int32):
A[0] = True


class TestSimplifyRHSOfBooleanAndUsingLHS(BaseBeforeAfter):
"""Boolean expressions can introduce contexts.

In `A and B`, the result of `B` only matters when `A` is
true, and can be simplified under that context. This test
simplifies `n < 10` under the assumption that `n < 5`.
"""

apply_constraints_to_boolean_branches = True

def before(A: T.Buffer[1, "bool"], n: T.int32):
A[0] = n < 5 and n < 10

def expected(A: T.Buffer[1, "bool"], n: T.int32):
A[0] = n < 5


class TestSimplifyLHSOfBooleanAndUsingRHS(BaseBeforeAfter):
"""Boolean expressions can introduce contexts for their arguments.

Like TestSimplifyRHSOfBooleanAndUsingLHS, but using the RHS to
simplify the LHS.
"""

apply_constraints_to_boolean_branches = True

def before(A: T.Buffer[1, "bool"], n: T.int32):
A[0] = n < 10 and n < 5

def expected(A: T.Buffer[1, "bool"], n: T.int32):
A[0] = n < 5


class TestSimplifyRHSOfBooleanOrUsingLHS(BaseBeforeAfter):
"""Boolean expressions can introduce contexts.

In `A or B`, the result of `B` only matters when `A` is false, so
`B` can be simplified under the assumption that `A` is false.
This test simplifies `n < 5` under the assumption that `!(n < 10)`
"""

apply_constraints_to_boolean_branches = True

def before(A: T.Buffer[1, "bool"], n: T.int32):
A[0] = n < 10 or n < 5

def expected(A: T.Buffer[1, "bool"], n: T.int32):
A[0] = n < 10


class TestSimplifyLHSOfBooleanOrUsingRHS(BaseBeforeAfter):
"""Boolean expressions can introduce contexts for their arguments.

Like TestSimplifyRHSOfBooleanOrUsingLHS, but using the RHS to
simplify the LHS.
"""

apply_constraints_to_boolean_branches = True

def before(A: T.Buffer[1, "bool"], n: T.int32):
A[0] = n < 5 or n < 10

def expected(A: T.Buffer[1, "bool"], n: T.int32):
A[0] = n < 10


class TestSimplifyRHSOfBooleanAndUsingLHSWithoutConst(BaseBeforeAfter):
"""Boolean expressions can introduce contexts.

Like TestSimplifyRHSOfBooleanAndUsingLHS, but with variables in
the conditions, preventing ConstIntBoundAnalyzer from handling it.
This proof requires the extension to transitively prove
inequalities.
"""

apply_constraints_to_boolean_branches = True
transitively_prove_inequalities = True

def before(A: T.Buffer[1, "bool"], n: T.int32, m: T.int32):
A[0] = n < m + 5 and n < m + 10

def expected(A: T.Buffer[1, "bool"], n: T.int32, m: T.int32):
A[0] = n < m + 5


class TestSimplifyLHSOfBooleanAndUsingRHSWithoutConst(BaseBeforeAfter):
"""Boolean expressions can introduce contexts for their arguments.

Like TestSimplifyLHSOfBooleanAndUsingRHS, but with variables in
the conditions, preventing ConstIntBoundAnalyzer from handling it.
This proof requires the extension to transitively prove
inequalities.
"""

apply_constraints_to_boolean_branches = True
transitively_prove_inequalities = True

def before(A: T.Buffer[1, "bool"], n: T.int32, m: T.int32):
A[0] = n < m + 10 and n < m + 5

def expected(A: T.Buffer[1, "bool"], n: T.int32, m: T.int32):
A[0] = n < m + 5


class TestSimplifyRHSOfBooleanOrUsingLHSWithoutConst(BaseBeforeAfter):
"""Boolean expressions can introduce contexts.

Like TestSimplifyRHSOfBooleanOrUsingLHS, but with variables in the
conditions, preventing ConstIntBoundAnalyzer from handling it.
This proof requires the extension to transitively prove
inequalities.
"""

apply_constraints_to_boolean_branches = True
transitively_prove_inequalities = True

def before(A: T.Buffer[1, "bool"], n: T.int32, m: T.int32):
A[0] = n < m + 10 or n < m + 5

def expected(A: T.Buffer[1, "bool"], n: T.int32, m: T.int32):
A[0] = n < m + 10


class TestSimplifyLHSOfBooleanOrUsingRHSWithoutConst(BaseBeforeAfter):
"""Boolean expressions can introduce contexts for their arguments.

Like TestSimplifyLHSOfBooleanOrUsingRHS, but with variables in the
conditions, preventing ConstIntBoundAnalyzer from handling it.
This proof requires the extension to transitively prove
inequalities.
"""

apply_constraints_to_boolean_branches = True
transitively_prove_inequalities = True

def before(A: T.Buffer[1, "bool"], n: T.int32, m: T.int32):
A[0] = n < m + 5 or n < m + 10

def expected(A: T.Buffer[1, "bool"], n: T.int32, m: T.int32):
A[0] = n < m + 10


if __name__ == "__main__":
tvm.testing.main()