diff --git a/src/arith/narrow_predicate_expression.cc b/src/arith/narrow_predicate_expression.cc new file mode 100644 index 000000000000..1c8931d2dec4 --- /dev/null +++ b/src/arith/narrow_predicate_expression.cc @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file narrow_predicate_expression.cc + * \brief Utility to deduce bound of expression + */ +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace arith { + +using namespace tir; + +/* \brief Given a true expression that includes free parameter, + * generate a true expression without the free parameters. + * + * This function provides two guarantees: + * + * 1. If the resulting expression evaluates to True, then the original + * expression also evaluates to True. + * + * 2. The resulting expression does not contain any of the free + * parameters. + * + */ +// Utility for generating a known true expression from an expression +// with free parameters, and the range of those parameters. +class ExpressionNarrower : public tir::ExprMutator { + public: + static PrimExpr Apply(PrimExpr expr, Map free_parameters) { + ICHECK(expr.dtype().is_bool()) << "Expected boolean expression, but received " << expr; + ExpressionNarrower mutator(free_parameters); + return mutator(expr); + } + + private: + explicit ExpressionNarrower(Map free_parameters) + : free_parameters_(free_parameters) {} + + using Parent = tir::ExprMutator; + using Parent::VisitExpr_; + + enum class Context { + Maximize, + Minimize, + }; + + template + PrimExpr VisitInequality(T t, Context a_ctx, Context b_ctx) { + PrimExpr a = [&]() { + WithContext context(this, a_ctx); + return this->VisitExpr(t->a); + }(); + + PrimExpr b = [&]() { + WithContext context(this, b_ctx); + return this->VisitExpr(t->b); + }(); + + if (contains_unknown_expr_ && t.dtype().is_bool()) { + contains_unknown_expr_ = false; + return Bool(CurrentContext() == Context::Minimize); + } else if (a.same_as(t->a) && b.same_as(t->b)) { + return std::move(t); + } else { + return T(a, b); + } + } + + PrimExpr VisitExpr_(const FloorModNode* op) override { + // FloorMod is non-monotonic, so inserting min/max won't remove + // the free parameters. + contains_unknown_expr_ = true; + return Parent::VisitExpr_(op); + } + + PrimExpr VisitExpr_(const FloorDivNode* op) override { + auto res_a = this->VisitExpr(op->a); + auto res_b = this->VisitExpr(op->b); + if (is_zero(res_b)) { + contains_unknown_expr_ = true; + return IntImm(op->dtype, 0); + } else { + return floordiv(res_a, res_b); + } + } + + PrimExpr VisitExpr_(const GTNode* op) override { + auto current = CurrentContext(); + return VisitInequality(GetRef(op), OppositeContext(current), current); + } + + PrimExpr VisitExpr_(const GENode* op) override { + auto current = CurrentContext(); + return VisitInequality(GetRef(op), OppositeContext(current), current); + } + + PrimExpr VisitExpr_(const LTNode* op) override { + auto current = CurrentContext(); + return VisitInequality(GetRef(op), current, OppositeContext(current)); + } + + PrimExpr VisitExpr_(const LENode* op) override { + auto current = CurrentContext(); + return VisitInequality(GetRef(op), current, OppositeContext(current)); + } + + PrimExpr VisitExpr_(const EQNode* op) override { + auto res_a = this->VisitExpr(op->a <= op->b); + auto res_b = this->VisitExpr(op->b <= op->a); + return res_a && res_b; + } + + PrimExpr VisitExpr_(const NENode* op) override { + auto res_a = this->VisitExpr(op->a < op->b); + auto res_b = this->VisitExpr(op->b < op->a); + return res_a || res_b; + } + + PrimExpr VisitExpr_(const SubNode* op) override { + auto current = CurrentContext(); + return VisitInequality(GetRef(op), current, OppositeContext(current)); + } + + PrimExpr VisitExpr_(const NotNode* op) override { + auto current = CurrentContext(); + WithContext context(this, OppositeContext(current)); + return !VisitExpr(op->a); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) override { + contains_unknown_expr_ = true; + return GetRef(op); + } + + PrimExpr VisitExpr_(const VarNode* op) override { + auto it = free_parameters_.find(GetRef(op)); + if (it == free_parameters_.end()) { + return Parent::VisitExpr_(op); + } + + Range range = (*it).second; + + switch (CurrentContext()) { + case Context::Minimize: + return range->min; + + case Context::Maximize: + return range->min + range->extent - 1; + } + + return Parent::VisitExpr_(op); + } + + Context CurrentContext() const { + if (context_stack_.size()) { + return context_stack_.back(); + } else { + return Context::Maximize; + } + } + + Context OppositeContext(Context context) const { + switch (context) { + case Context::Minimize: + return Context::Maximize; + + case Context::Maximize: + return Context::Minimize; + + default: + LOG(FATAL) << "Unhandled Context, all legal values should be handled"; + return Context::Maximize; + } + } + + struct WithContext { + WithContext(ExpressionNarrower* self, Context context) : self(self) { + self->context_stack_.push_back(context); + } + ~WithContext() { self->context_stack_.pop_back(); } + ExpressionNarrower* self; + }; + + std::vector context_stack_; + Map free_parameters_; + bool contains_unknown_expr_{false}; +}; + +PrimExpr NarrowPredicateExpression(PrimExpr expr, Map free_parameters) { + return ExpressionNarrower::Apply(std::move(expr), std::move(free_parameters)); +} + +TVM_REGISTER_GLOBAL("arith.NarrowPredicateExpression").set_body_typed(NarrowPredicateExpression); + +} // namespace arith +} // namespace tvm diff --git a/src/arith/narrow_predicate_expression.h b/src/arith/narrow_predicate_expression.h new file mode 100644 index 000000000000..1e452e3ad493 --- /dev/null +++ b/src/arith/narrow_predicate_expression.h @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file narrow_predicate_expression.h + * \brief Utility for extracting and interacting with buffer touch points + */ + +#include +#include + +#ifndef TVM_ARITH_NARROW_PREDICATE_EXPRESSION_H_ +#define TVM_ARITH_NARROW_PREDICATE_EXPRESSION_H_ + +namespace tvm { +namespace arith { + +/* \brief Narrow a true expression to remove free parameters + * + * This function provides two guarantees: + * + * 1. If the resulting expression evaluates to True, then the original + * expression also evaluates to True. + * + * 2. The resulting expression does not contain any of the free + * parameters. + * + * 3. The resulting expression does not contain any BufferLoad + * + * \param expr The expression to be examined. + * + * \param ranges The variables to be removed from the expression + * + * \returns An expression that, if true, implies that the original + * expression is also true. + */ +PrimExpr NarrowPredicateExpression(PrimExpr expr, Map free_parameters); + +} // namespace arith +} // namespace tvm +#endif // TVM_ARITH_NARROW_PREDICATE_EXPRESSION_H_ diff --git a/tests/python/unittest/test_arith_narrow_predicate_expression.py b/tests/python/unittest/test_arith_narrow_predicate_expression.py new file mode 100644 index 000000000000..d38fe70f6b5c --- /dev/null +++ b/tests/python/unittest/test_arith_narrow_predicate_expression.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing + +from tvm import tir +from tvm.runtime import convert + + +i = tir.Var("i", "int32") +j = tir.Var("j", "int32") +n = tir.Var("n", "int32") +m = tir.Var("m", "int32") +b = tir.Var("b", "bool") +buf = tir.decl_buffer(16, "int32", "buf") + +tir_false = tir.IntImm("bool", False) +tir_true = tir.IntImm("bool", True) + +before, expected = tvm.testing.parameters( + # General arithmatic + [tir_true, tir_true], + [tir_false, tir_false], + [b, b], + [i > 5, i > 5], + [i > n, i > 7], + [i < n, i < 0], + [i <= n, i <= 0], + [i >= n, i >= 7], + [n > i, convert(0) > i], + [n < i, convert(7) < i], + [n <= i, convert(7) <= i], + [n >= i, convert(0) >= i], + [i == n, tir.all(i <= 0, convert(7) <= i)], + [n == i, tir.all(convert(7) <= i, i <= 0)], + [i != n, tir.any(i < 0, convert(7) < i)], + [n != i, tir.any(convert(7) < i, i < 0)], + [i // 4 > n, i // 4 > 7], + [n < i // 4, convert(7) < i // 4], + [(i + n) // 4 > 0, tir.Add(i, 0) // 4 > 0], + [(i + n) // 4 == 0, tir.all(tir.Add(i, 7) // 4 <= 0, convert(0) <= tir.Add(i, 0) // 4)], + [i + n < 10, i + 7 < 10], + [i - n < 10, tir.Sub(i, 0) < 10], + [tir.Not(i < n), tir.Not(i < 7)], + # Use of FloorMod should make the narrowing strategy bail out, as + # it is non-monotonic. + [i % 8 == n, tir_false], + # Ensure that dividing by a free parameter doesn't generate a + # divide-by-zero to be triggered later. + [i // n == 0, tir_false], + ### Buffer handling + [buf.vload(0) > 0, tir_false], + [buf.vload(0) > i, tir_false], + [buf.vload(i) > 0, tir_false], + [tir.And(buf.vload(i) > 0, i <= 0), tir.And(tir_false, i <= 0)], + [tir.Or(buf.vload(i) > 0, i <= n), tir.Or(tir_false, i <= 0)], + [tir.Or(tir.Not(buf.vload(i) > 0), i <= n), tir.Or(tir_false, i <= 0)], +) + + +def test_narrow_expression(before, expected): + ranges = {n: tvm.ir.Range(0, 8)} + after = tvm.arith._ffi_api.NarrowPredicateExpression(before, ranges) + + if expected is None: + assert after is None + else: + tvm.ir.assert_structural_equal(after, expected) + + +if __name__ == "__main__": + tvm.testing.main()