diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 79b01d0cc859..b80d75a17058 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -297,6 +297,14 @@ class RewriteSimplifier { * if_then_else(i if_then_else(i (a || c) && (b || c) + */ + kConvertBooleanToAndOfOrs = (1 << 1), }; /*! \brief Enable an optional extension or extensions diff --git a/src/arith/conjunctive_normal_form.cc b/src/arith/conjunctive_normal_form.cc new file mode 100644 index 000000000000..19d6a234e6ad --- /dev/null +++ b/src/arith/conjunctive_normal_form.cc @@ -0,0 +1,430 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/arith/conjunctive_normal_form.cc + */ + +#include "conjunctive_normal_form.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include "pattern_match.h" +#include "rewrite_simplify.h" + +namespace tvm { +namespace arith { + +namespace { +/* \brief A utility for simplifying expressions using conjunctive/disjuctive normal forms */ +class AndOfOrs { + public: + /*! \brief Construct the simplifier + * + * Convert a PrimExpr to the internal representation. + * + * \param expr The PrimExpr to be simplified. + */ + explicit AndOfOrs(const PrimExpr& expr); + + /*! \brief Convert internal representation to PrimExpr */ + PrimExpr AsPrimExpr() const; + + /*! \brief Simplify the internal representation */ + void Simplify(Analyzer* analyzer); + + private: + /*! \brief Internal utility, simplify within each group of expressions + * + * For each pair of values within a chunk, attempt to simplify them into + * a single expression. + * + * For example, + * before = (a == 5) && ((b < 10) || (b > 10)) + * after = (a == 5) && ((b != 10) || false) + */ + void SimplifyWithinChunks(Analyzer* analyzer); + + /*! \brief Internal utility, simplify across groups of expressions + * + * For each pair of chunks, if the two chunks differ by only a single + * term, attempt to simplify those differing terms. + * + * For example, + * before = ((a == 5) || (b <= 10)) && ((a == 5) || (b >= 10)) + * after = ((a == 5) || (b == 10)) && ((a == 5) || true) + */ + void SimplifyAcrossChunks(Analyzer* analyzer); + + /*! \brief Remove instances of true/false from internal representation + * + * To avoid invalidating iterators, `SimplifyWithinChunks` and + * `SimplifyAcrossChunks` may replace keys, but may not remove keys + * from the internal representation. For example, `(a < 5) && (a < + * 10)` would be simplified to `(a < 5) && true`. The + * `RemoveTrueFalse` function removes these leftover instances of + * true/false. + */ + void RemoveTrueFalse(); + + /*! \brief Internal utility function used to convert to internal form */ + static void VisitAndExpressions(const PrimExpr& expr, + std::function callback); + /*! \brief Internal utility function used to convert to internal form */ + static void VisitOrExpressions(const PrimExpr& expr, + std::function callback); + + /* \brief Type-safe wrapper class that represents an PrimExpr + * + * Because integer indices are used frequently through this class, + * maintaining a separation between integer indices used to access + * specific elements of the internal representation, and unique + * identifiers used to represent expressions PrimExpr, is useful. + */ + enum class Key : size_t {}; + + /*! \brief Convert a PrimExpr to a Key */ + Key GetKey(const PrimExpr& expr); + + /*! \brief Convert a Key to a PrimExpr */ + PrimExpr GetExpr(Key key) const; + + /*! \brief Attempt to simplify (a && b) + * + * If successful, will overwrite the parameters `a` and `b` with the + * simplified form. + */ + void TrySimplifyOr(Key* a, Key* b, Analyzer* analyzer); + + /*! \brief Attempt to simplify (a || b) + * + * If successful, will overwrite the parameters `a` and `b` with the + * simplified form. + */ + void TrySimplifyAnd(Key* a, Key* b, Analyzer* analyzer); + + /*! \brief The internal representation + * + * `chunks[i][j]` is the j-th expression in the i-th OR-group. + */ + std::vector> chunks_; + + /*! \brief Mapping from internal Key to PrimExpr */ + std::unordered_map key_to_expr_; + + /*! \brief Mapping from PrimExpr to internal Key */ + std::unordered_map expr_to_key_; + + /*! \brief Cached key representing tir::Bool(true) */ + Key key_true_; + + /*! \brief Cached key representing tir::Bool(false) */ + Key key_false_; +}; + +AndOfOrs::AndOfOrs(const PrimExpr& expr) + : key_true_(GetKey(Bool(true))), key_false_(GetKey(Bool(false))) { + VisitAndExpressions(expr, [&](const PrimExpr& outer_expr) { + std::vector or_components; + VisitOrExpressions(outer_expr, [&](const PrimExpr& inner_expr) { + Key key = GetKey(inner_expr); + bool is_duplicate = std::any_of(or_components.begin(), or_components.end(), + [&](Key prev) { return prev == key; }); + if (!is_duplicate) { + or_components.push_back(key); + } + }); + + bool is_permutation = + std::any_of(chunks_.begin(), chunks_.end(), [&](const std::vector& prev_components) { + return or_components.size() == prev_components.size() && + std::is_permutation(prev_components.begin(), prev_components.end(), + or_components.begin()); + }); + if (!is_permutation) { + chunks_.push_back(std::move(or_components)); + } + }); +} + +void AndOfOrs::VisitAndExpressions(const PrimExpr& expr, + std::function callback) { + PVar x, y, z; + if ((x && y).Match(expr)) { + // These are separate AND conditions, recurse into them in case + // they contain AND internally. + VisitAndExpressions(x.Eval(), callback); + VisitAndExpressions(y.Eval(), callback); + } else if ((x || y).Match(expr)) { + // This may be the bottom-most breakdown, but either x or y may + // themselves contain AND. (e.g. (A && B) || (C && D) should be + // split into (A || C), (A || D), (B || C), and (B || D).) + // Recurse into each, then reconstruct an OR condition. + VisitAndExpressions(x.Eval(), [&](const PrimExpr& x_part) { + VisitAndExpressions(y.Eval(), [&](const PrimExpr& y_part) { callback(x_part || y_part); }); + }); + } else { + // This is bottom-most breakdown. + callback(expr); + } +} + +void AndOfOrs::VisitOrExpressions(const PrimExpr& expr, + std::function callback) { + PVar x, y, z; + if ((x || y).Match(expr)) { + // These are separate OR conditions, recurse into them in case + // they contain OR internally. + VisitOrExpressions(x.Eval(), callback); + VisitOrExpressions(y.Eval(), callback); + } else if ((x && y).Match(expr)) { + // This may be the bottom-most breakdown, but either x or y may + // themselves contain OR. (e.g. (A || B) && (C || D) should be + // split into (A && C), (A && D), (B && C), and (B && D).) + // Recurse into each, then reconstruct an AND condition. + VisitOrExpressions(x.Eval(), [&](const PrimExpr& x_part) { + VisitOrExpressions(y.Eval(), [&](const PrimExpr& y_part) { callback(x_part && y_part); }); + }); + } else { + // This is bottom-most breakdown. + callback(expr); + } +} + +AndOfOrs::Key AndOfOrs::GetKey(const PrimExpr& expr) { + auto it = expr_to_key_.find(expr); + if (it != expr_to_key_.end()) { + return it->second; + } + + Key key{expr_to_key_.size()}; + expr_to_key_[expr] = key; + key_to_expr_[key] = expr; + return key; +} + +PrimExpr AndOfOrs::GetExpr(AndOfOrs::Key key) const { + auto it = key_to_expr_.find(key); + ICHECK(it != key_to_expr_.end()); + return it->second; +} + +PrimExpr AndOfOrs::AsPrimExpr() const { + PrimExpr expr = Bool(true); + for (const auto& chunk : chunks_) { + PrimExpr chunk_expr = Bool(false); + for (Key j : chunk) { + chunk_expr = chunk_expr || GetExpr(j); + } + expr = expr && chunk_expr; + } + return expr; +} + +void AndOfOrs::TrySimplifyOr(Key* a_ptr, Key* b_ptr, Analyzer* analyzer) { + Key& a = *a_ptr; + Key& b = *b_ptr; + PrimExpr joint = GetExpr(a) || GetExpr(b); + PrimExpr simplified = analyzer->Simplify(joint); + if (!ExprDeepEqual()(simplified, joint)) { + if (auto* simplified_or = simplified.as()) { + a = GetKey(simplified_or->a); + b = GetKey(simplified_or->b); + } else { + a = GetKey(simplified); + b = key_false_; + } + } +} + +void AndOfOrs::TrySimplifyAnd(Key* a_ptr, Key* b_ptr, Analyzer* analyzer) { + Key& a = *a_ptr; + Key& b = *b_ptr; + PrimExpr joint = GetExpr(a) && GetExpr(b); + PrimExpr simplified = analyzer->Simplify(joint); + if (!ExprDeepEqual()(simplified, joint)) { + if (auto* simplified_and = simplified.as()) { + a = GetKey(simplified_and->a); + b = GetKey(simplified_and->b); + } else { + a = GetKey(simplified); + b = key_true_; + } + } +} + +void AndOfOrs::Simplify(Analyzer* analyzer) { + SimplifyWithinChunks(analyzer); + RemoveTrueFalse(); + SimplifyAcrossChunks(analyzer); + RemoveTrueFalse(); +} + +void AndOfOrs::SimplifyWithinChunks(Analyzer* analyzer) { + for (auto& chunk : chunks_) { + for (size_t expr_i = 0; expr_i < chunk.size(); expr_i++) { + for (size_t expr_j = expr_i + 1; expr_j < chunk.size(); expr_j++) { + Key& key_i = chunk[expr_i]; + Key& key_j = chunk[expr_j]; + + TrySimplifyOr(&key_i, &key_j, analyzer); + } + } + } +} + +void AndOfOrs::SimplifyAcrossChunks(Analyzer* analyzer) { + for (size_t i_and = 0; i_and < chunks_.size(); i_and++) { + for (size_t j_and = i_and + 1; j_and < chunks_.size(); j_and++) { + auto& i_chunk = chunks_[i_and]; + auto& j_chunk = chunks_[j_and]; + + if (i_chunk.size() == 1 && j_chunk.size() == 1) { + auto& key_i = i_chunk[0]; + auto& key_j = j_chunk[0]; + TrySimplifyAnd(&key_i, &key_j, analyzer); + continue; + } + std::unordered_set j_set(j_chunk.begin(), j_chunk.end()); + + std::optional i_distinct_index; + for (size_t i = 0; i < i_chunk.size(); i++) { + if (!j_set.count(i_chunk[i])) { + i_distinct_index = i; + break; + } + } + + if (!i_distinct_index.has_value()) { + // I = (i_0 || i_1 || ... || i_N) + // J = (i_0 || i_1 || ... || i_N || j_0 || ... || j_N) + // I && J == I == I && true + + j_chunk = {key_true_}; + continue; + } + + std::unordered_set i_set(i_chunk.begin(), i_chunk.end()); + + std::optional j_distinct_index; + for (size_t j = 0; j < j_chunk.size(); j++) { + if (!i_set.count(j_chunk[j])) { + j_distinct_index = j; + break; + } + } + + if (!j_distinct_index.has_value()) { + // I = (i_0 || ... || i_N || j_0 || ... || j_N) + // J = (j_0 || ... || j_N) + // I && J == J == true && J + + i_chunk = {key_true_}; + continue; + } + + if (i_chunk.size() == j_chunk.size()) { + size_t num_shared_exprs = 0; + for (const auto& j_key : j_chunk) { + if (i_set.count(j_key)) { + ++num_shared_exprs; + } + } + + if (num_shared_exprs + 1 == i_chunk.size()) { + // All but one of the expressions are shared. If the AND + // of the distinct expressions can be simplified, we can + // replace. + // + // (A or B) and (A or C) => A or (B and C) + auto& key_i = i_chunk[i_distinct_index.value()]; + auto& key_j = j_chunk[j_distinct_index.value()]; + TrySimplifyAnd(&key_i, &key_j, analyzer); + } + } + } + } +} + +void AndOfOrs::RemoveTrueFalse() { + for (auto& chunk : chunks_) { + // Any occurrence of True inside an OR makes the entire expression True. + if (std::any_of(chunk.begin(), chunk.end(), [&](Key key) { return key == key_true_; })) { + chunk = {key_true_}; + } else { + // Any occurrence of False inside an OR can be removed + chunk.erase( + std::remove_if(chunk.begin(), chunk.end(), [&](Key key) { return key == key_false_; }), + chunk.end()); + } + } + + // Any occurence of False inside an AND makes the entire expression False. + if (std::any_of(chunks_.begin(), chunks_.end(), + [&](const std::vector& chunk) { return chunk.size() == 0; })) { + chunks_ = {{}}; + } else { + // Any occurrence of True inside an AND can be removed. + chunks_.erase(std::remove_if(chunks_.begin(), chunks_.end(), + [&](const std::vector& chunk) { + return chunk.size() == 1 && chunk[0] == key_true_; + }), + chunks_.end()); + } +} + +// Helper utility for temporarily disabling the +// kConvertBooleanToAndOfOrs flag on an analyzer, to prevent infinite +// recursion. +class DisableAndOfOrRecursion { + public: + explicit DisableAndOfOrRecursion(Analyzer* analyzer) + : analyzer_(analyzer), cached_flags_(analyzer->rewrite_simplify.GetEnabledExtensions()) { + auto new_flags = static_cast( + cached_flags_ & (~RewriteSimplifier::kConvertBooleanToAndOfOrs)); + analyzer->rewrite_simplify.SetEnabledExtensions(new_flags); + } + ~DisableAndOfOrRecursion() { analyzer_->rewrite_simplify.SetEnabledExtensions(cached_flags_); } + + DisableAndOfOrRecursion(const DisableAndOfOrRecursion&) = delete; + DisableAndOfOrRecursion& operator=(const DisableAndOfOrRecursion&) = delete; + + private: + Analyzer* analyzer_; + RewriteSimplifier::Extension cached_flags_; +}; + +} // namespace + +PrimExpr SimplifyAsAndOfOrs(const PrimExpr& expr, Analyzer* analyzer) { + DisableAndOfOrRecursion context(analyzer); + AndOfOrs repr(analyzer->Simplify(expr)); + repr.Simplify(analyzer); + return repr.AsPrimExpr(); +} + +} // namespace arith +} // namespace tvm diff --git a/src/arith/conjunctive_normal_form.h b/src/arith/conjunctive_normal_form.h new file mode 100644 index 000000000000..84ee972d030e --- /dev/null +++ b/src/arith/conjunctive_normal_form.h @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file conjunctive_normal_form.h + * + * \brief Centralized location for simplifying into specific forms + */ + +#ifndef TVM_ARITH_CONJUNCTIVE_NORMAL_FORM_H_ +#define TVM_ARITH_CONJUNCTIVE_NORMAL_FORM_H_ + +#include + +namespace tvm { +namespace arith { + +class Analyzer; + +/*! \brief Convert boolean expression to AND of ORs and simplify + * + * \param expr The PrimExpr to be simplified + * + * \param analyzer The analyzer with which to simplify + * + * \return The simplified expression + */ +PrimExpr SimplifyAsAndOfOrs(const PrimExpr& expr, Analyzer* analyzer); + +} // namespace arith +} // namespace tvm + +#endif // TVM_ARITH_CONJUNCTIVE_NORMAL_FORM_H_ diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 019b8cd5d353..5e565d7e36c6 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -31,6 +31,7 @@ #include #include "../target/datatype/registry.h" +#include "conjunctive_normal_form.h" #include "const_fold.h" #include "constraint_extract.h" #include "pattern_match.h" @@ -1558,8 +1559,13 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); + if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); + if ((enabled_extensions_ & RewriteSimplifier::kConvertBooleanToAndOfOrs) && + !recursively_visiting_boolean_) { + return SimplifyAsAndOfOrs(ret, analyzer_); + } // Pattern var to match any expression PVar x, y; @@ -1596,9 +1602,14 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + op = ret.as(); if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); + if ((enabled_extensions_ & RewriteSimplifier::kConvertBooleanToAndOfOrs) && + !recursively_visiting_boolean_) { + return SimplifyAsAndOfOrs(ret, analyzer_); + } // Pattern var to match any expression PVar x, y; diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index 00c60e21ee42..02c54902153a 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -98,6 +98,10 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { // Optionally enabled extensions Extension enabled_extensions_{kNone}; + /*! Whether the simplifier is current + */ + bool recursively_visiting_boolean_{false}; + // maximum number of recursion allowed during a single pass. static const constexpr int kMaxRecurDepth = 5; diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index 2a7c0f5a3585..894dfb8ca09f 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -38,12 +38,17 @@ using namespace tir; struct SimplifyConfigNode : public tvm::AttrsNode { bool transitively_prove_inequalities; + bool convert_boolean_to_and_of_ors; TVM_DECLARE_ATTRS(SimplifyConfigNode, "tir.transform.SimplifyConfig") { TVM_ATTR_FIELD(transitively_prove_inequalities) .describe( "If true, simplify conditionals with transitive combinations of scoped constraints") .set_default(false); + + TVM_ATTR_FIELD(convert_boolean_to_and_of_ors) + .describe("If true, simplify conditionals into an AND of ORs") + .set_default(false); } RewriteSimplifier::Extension GetEnabledExtensions() const { @@ -52,6 +57,9 @@ struct SimplifyConfigNode : public tvm::AttrsNode { flags = RewriteSimplifier::Extension(flags | RewriteSimplifier::kTransitivelyProveInequalities); } + if (convert_boolean_to_and_of_ors) { + flags = RewriteSimplifier::Extension(flags | RewriteSimplifier::kConvertBooleanToAndOfOrs); + } return flags; } }; @@ -202,6 +210,5 @@ Pass Simplify() { TVM_REGISTER_GLOBAL("tir.transform.Simplify").set_body_typed(Simplify); } // namespace transform - } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py index 0a1263f70287..2eb9c3546ee5 100644 --- a/tests/python/unittest/test_tir_transform_simplify.py +++ b/tests/python/unittest/test_tir_transform_simplify.py @@ -138,12 +138,14 @@ def sls(n, d): class BaseBeforeAfter(tvm.testing.CompareBeforeAfter): transitively_prove_inequalities = False + convert_boolean_to_and_of_ors = False def transform(self): def inner(mod): config = { "tir.Simplify": { "transitively_prove_inequalities": self.transitively_prove_inequalities, + "convert_boolean_to_and_of_ors": self.convert_boolean_to_and_of_ors, } } with tvm.transform.PassContext(config=config): @@ -686,5 +688,133 @@ def before(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32, k: T.int32): expected = before +class TestRewriteAsAndOfOrs(BaseBeforeAfter): + """If enabled, rewrite boolean expressions into AND of OR""" + + convert_boolean_to_and_of_ors = True + + def before(A: T.Buffer[3, "bool"]): + T.evaluate(A[0] or (A[1] and A[2])) + + def expected(A: T.Buffer[3, "bool"]): + T.evaluate((A[0] or A[1]) and (A[0] or A[2])) + + +class TestSuppressRewriteAsAndOfOrs(BaseBeforeAfter): + """Only rewrite into AND of OR when allowed""" + + convert_boolean_to_and_of_ors = False + + def before(A: T.Buffer[3, "bool"]): + T.evaluate(A[0] or (A[1] and A[2])) + + expected = before + + +class TestRewriteAsAndOfOrsWithTopLevelAnd(BaseBeforeAfter): + """The expression being rewritten may start with an AND + + Like TestRewriteAsAndOfOrs, but with an AndNode as the outermost + booelan operator. Even though it is primarily OR nodes that are + being rewritten, the call to SimplifyAsAndOfOrs should apply to + the outermost AndNode or OrNode in order to enable better + simplification. + """ + + convert_boolean_to_and_of_ors = True + + def before(A: T.Buffer[4, "bool"]): + T.evaluate((A[0] or A[1]) and (A[1] or (A[0] and A[2] and A[3]))) + + def expected(A: T.Buffer[4, "bool"]): + # If the simplification is applied to the OrNode, then a + # redundant `(A[1] or A[0])` would't be canceled out. When + # applying SimplifyAsAndOfOrs to the top-level AndNode, the + # internal representation is `[[0,1], [1,0], [1,2], [1,3]]`, and + # the redundant `[1,0]` can be removed. + # + # If the simplification were only applied when encountering an + # OrNode, the internal representation would be `[[0,1]]` during + # the first call and `[[1,0], [1,2], [1,3]]` during the second + # call. As a result, the `[0,1]` and `[1,0]` representations + # wouldn't occur within the same call, and the redundant `[1,0]` + # wouldn't be removed. + T.evaluate((A[0] or A[1]) and (A[1] or A[2]) and (A[1] or A[3])) + + +class TestRewriteAsAndOfOrsWithSimplificationBetweenGroups(BaseBeforeAfter): + """Apply rewrite rules between OR groups that differ by a single element + + The expression `(k==20 and k!=30)` could be rewritten into `(k==20)`. + However, by default these two terms must appear as part of an explict part + of the simplified expression. The AndOfOr simplification checks for + rewrite patterns of the form `(A or B) and (A or C)`, where `(B and C)` can + simplify to a single expression `D`. These can be rewritten to `(A or D)`. + """ + + convert_boolean_to_and_of_ors = True + + def before(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32, k: T.int32): + A[0] = (i == 0 or j == 10 or k == 20) and (i == 0 or j == 10 or k != 30) + + def expected(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32, k: T.int32): + A[0] = i == 0 or j == 10 or k == 20 + + +class TestRewriteAsAndOfOrsWithSimplificationBetweenReorderedGroups(BaseBeforeAfter): + """Rewrite rules between OR groups do not depend on order + + Like TestRewriteAsAndOfOrsWithSimplificationBetweenGroups, but the groups + are ordered differently. If this removes a group entirely, the result is + ordered according to the first group in the expression. + """ + + convert_boolean_to_and_of_ors = True + + def before(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32, k: T.int32): + A[0] = (i == 0 or j == 10 or k == 20) and (j == 10 or k != 30 or i == 0) + + def expected(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32, k: T.int32): + A[0] = i == 0 or j == 10 or k == 20 + + +class TestRewriteAsAndOfOrUsingSimplificationAcrossAnd(BaseBeforeAfter): + """Apply AndNode rewrites to non-adjacent expressions + + The RewriteSimplifier rules only check for simplifications between + left/right branches of an And/Or node. Simplifications that would require + rearranging components in a chain of And/Or nodes are not performed. + """ + + convert_boolean_to_and_of_ors = True + + def before(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32, k: T.int32): + A[0] = (k == 20) and ((i == 0 or j == 10) and (k != 30)) + + def expected(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32, k: T.int32): + A[0] = (k == 20) and (i == 0 or j == 10) + + +class TestRewriteAsAndOfOrUsingSimplificationWithinOr(BaseBeforeAfter): + """Rewrite rules between OR groups do not depend on order + + The RewriteSimplifier rules only check for simplifications between + left/right branches of an And/Or node. Simplifications that would require + rearranging components in a chain of And/Or nodes are not performed. + + This test validates that `(i == 20) or (i != 30)` can be rewritten to + `(i != 30)`, even when there's an intervening clause between the + clauses being simplified. + """ + + convert_boolean_to_and_of_ors = True + + def before(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32, k: T.int32): + A[0] = (i == 20) or (j == 0) or (i != 30) + + def expected(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32, k: T.int32): + A[0] = (i != 30) or (j == 0) + + if __name__ == "__main__": tvm.testing.main()