diff --git a/include/tvm/ir_operator.h b/include/tvm/ir_operator.h index c2cdc5e7a923..b336caac3e96 100644 --- a/include/tvm/ir_operator.h +++ b/include/tvm/ir_operator.h @@ -85,6 +85,16 @@ inline const uint64_t* as_const_uint(const Expr& x) { */ inline bool is_const_int(const Expr& x, int64_t value); +/*! + * \brief Check if the given expr is a const of any type equal to the given integer value. + * \param e The expression. + * \param value The value to compare to. + * \return Whether the expression is a const equal to the value. + * \tparam ValueType The value type + */ +template +inline bool is_const_value(const Expr& e, ValueType value); + /*! * \brief Check whether stmt is nop. * \param stmt The input statement @@ -503,18 +513,31 @@ inline bool is_negative_const(const Expr& a) { } } +template +inline bool is_const_value(const Expr& e, ValueType value) { + static_assert(std::is_integral::value, + "Comparison to non-integer values is forbidden."); + // This implementation was copy-pasted from HalideIR + if (const ir::IntImm* i = e.as()) { + return i->value == value; + } else if (const ir::UIntImm* i = e.as()) { + return (value >= 0) && (i->value == static_cast(value)); + } else if (const ir::FloatImm* i = e.as()) { + return i->value == value; + } else if (const ir::Cast* c = e.as()) { + return is_const_value(c->value, value); + } else if (const ir::Broadcast* b = e.as()) { + return is_const_value(b->value, value); + } else { + return false; + } +} + inline bool is_const_int(const Expr& x, int64_t value) { - if (const auto* op = x.as()) { - return op->value == value; - } else if (const auto* op = x.as()) { - return op->value == static_cast(value); + if (x.as() || x.as()) { + return is_const_value(x, value); } else if (const auto* op = x.as()) { - const Expr& val = op->value; - if (const auto* opv = val.as()) { - return opv->value == value; - } else if (const auto* opv = val.as()) { - return opv->value == static_cast(value); - } + return !op->value.as() && is_const_int(op->value, value); } return false; } diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 1a6666bdee2a..2db7c164b1e8 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -1,6 +1,7 @@ """ TVM testing utilities """ import logging import numpy as np +import tvm def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7): """ Version of np.testing.assert_allclose with `atol` and `rtol` fields set @@ -145,3 +146,200 @@ def compare_derivative(j, n_der, grad): logging.info("Numerical grad test wrt '%s' of shape %s passes, " "dist = %f, max_diff = %f, avg_diff = %f", x_name, grad.shape, dist, max_diff, avg_diff) + + +class PerformanceEstimate: + """A result of static performance estimation. + + Parameters + ---------- + iterations : int + The total number of iterations of all the loops. + + multiplications : int + The total number of expensive operations like multiplications. + + memory : int + The amount of memory to allocate. + """ + def __init__(self, iterations=0, multiplications=0, memory=0): + self.iterations = iterations + self.multiplications = multiplications + self.memory = memory + + def as_tuple(self): + return (self.iterations, self.multiplications, self.memory) + + def __add__(self, other): + return PerformanceEstimate(iterations=self.iterations + other.iterations, + multiplications=self.multiplications + other.multiplications, + memory=self.memory + other.memory) + + def max(self, other): + return PerformanceEstimate( + iterations=max(self.iterations, other.iterations), + multiplications=max(self.multiplications, other.multiplications), + memory=max(self.memory, other.memory)) + + def times(self, iters): + return PerformanceEstimate(iterations=self.iterations*iters, + multiplications=self.multiplications*iters, + memory=self.memory) + + def __repr__(self): + return "PerformanceEstimate(iterations={}, multiplications={}, memory={})".format( + self.iterations, self.multiplications, self.memory) + + def __le__(self, other): + return \ + self.iterations <= other.iterations and \ + self.multiplications <= other.multiplications and \ + self.memory <= other.memory + + +def estimate_performance(s, param_values=None, _processed_ops=None): + """Statically estimate performance of statements, expressions and tensors. Note that the + estimate is very rough, it mustn't be used to predict future performance, its only purpose is + to detect possible performance regressions. + + Parameters + ---------- + s + A statement, an expression, a tensor, an operation, or a list + of any of the above. + + param_values : Dict[tvm.expr.Var, int], optional + Values for parameters (free variables), see the example. + + _processed_ops, optional + A dict mapping already processed operations to the corresponding estimations. + This parameter is used internally. + + Returns + ------- + estimate : PerformanceEstimate + + Example + ------- + .. code-block:: python + + m = tvm.var('m') + X = tvm.placeholder((10, m), name='X') + W = tvm.placeholder((m + 5, m), name='W') + A = topi.nn.dense(X, W) + tvm.testing.estimate_performance(A, param_values={m: 5}) + """ + from tvm import stmt + from tvm import expr + + if param_values is None: + param_values = {} + + if _processed_ops is None: + _processed_ops = {} + res = estimate_performance(s, param_values=param_values, _processed_ops=_processed_ops) + for op_est in _processed_ops.values(): + res += op_est + return res + + def est(expression, param_values=param_values, _processed_ops=_processed_ops): + return estimate_performance(expression, + param_values=param_values, + _processed_ops=_processed_ops) + + def _eval(expression, param_values=param_values): + return tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(expression, param_values)).value + + def _prod(elems): + res = 1 + for x in elems: + res *= x + return res + + if s is None or isinstance(s, (stmt.AssertStmt, stmt.Free, stmt.Prefetch, + expr.ConstExpr, expr.Var, tvm.tensor.PlaceholderOp)): + return PerformanceEstimate() + elif isinstance(s, list): + res = PerformanceEstimate() + for item in s: + res += est(item) + return res + elif s in _processed_ops: + return PerformanceEstimate() + elif isinstance(s, stmt.Allocate): + mem = _prod([_eval(e) for e in s.extents]) + return est(s.condition) + est(s.body) + PerformanceEstimate(memory=mem) + elif isinstance(s, stmt.Block): + return est(s.first) + est(s.rest) + elif isinstance(s, stmt.Evaluate): + return est(s.value) + elif isinstance(s, stmt.For): + body_est = est(s.body) + body_est.iterations = max(1, body_est.iterations) + return body_est.times(_eval(s.extent)) + elif isinstance(s, stmt.IfThenElse): + return est(s.condition) + est(s.then_case) + est(s.else_case) + elif isinstance(s, stmt.LetStmt): + return est(s.value) + est(s.body) + elif isinstance(s, (stmt.ProducerConsumer, stmt.AttrStmt)): + return est(s.body) + elif isinstance(s, stmt.Provide): + return est(s.value) + elif isinstance(s, stmt.Realize): + return est(s.condition) + est(s.body) + elif isinstance(s, stmt.Store): + return est(s.value) + est(s.index) + est(s.predicate) + elif isinstance(s, (expr.Mul, expr.Div, expr.Mod)): + return est(s.a) + est(s.b) + PerformanceEstimate(multiplications=1) + elif isinstance(s, (expr.BinaryOpExpr, expr.CmpExpr, expr.LogicalExpr)): + if not hasattr(s, 'b'): + return est(s.a) + return est(s.a) + est(s.b) + elif isinstance(s, expr.Call): + res = PerformanceEstimate() + for a in s.args: + res += est(a) + if s.call_type == expr.Call.Halide: + # The estimate is added to _processed_ops, we don't need the result here + est(s.func) + elif s.name == "tvm_if_then_else": + pass + else: + # expr.If it is a non-halide call (e.g. exp or log), consider it a mul + res += PerformanceEstimate(multiplications=1) + return res + elif isinstance(s, expr.Cast): + return est(s.value) + elif isinstance(s, expr.Load): + return est(s.index) + est(s.predicate) + elif isinstance(s, expr.Select): + return est(s.condition) + est(s.true_value) + est(s.false_value) + elif isinstance(s, expr.Reduce): + iterations = _prod([_eval(iv.dom.extent) for iv in s.axis]) + res = PerformanceEstimate() + for id_elem in s.combiner.identity_element: + res += est(id_elem) + on_each_iter = est(s.condition) + for src in s.source: + on_each_iter += est(src) + for comb_res in s.combiner.result: + on_each_iter += est(comb_res) + on_each_iter.iterations = max(1, on_each_iter.iterations) + return res + on_each_iter.times(iterations) + elif isinstance(s, tvm.tensor.Tensor): + return est(s.op) + elif isinstance(s, tvm.tensor.ComputeOp): + iterations = _prod([_eval(iv.dom.extent) for iv in s.axis]) + if s.reduce_axis: + res = est(s.body[0]) + else: + res = PerformanceEstimate() + for b in s.body: + res += est(b) + res.iterations = max(1, res.iterations) + res = res.times(iterations) + PerformanceEstimate(memory=iterations*len(s.body)) + _processed_ops[s] = res + return PerformanceEstimate() + + raise ValueError("Don't know how to estimate performance of {} of type {}" + .format(s, type(s))) diff --git a/src/op/op_util.cc b/src/op/op_util.cc index b18552d5c562..4231f336a01b 100644 --- a/src/op/op_util.cc +++ b/src/op/op_util.cc @@ -245,5 +245,48 @@ ir::ForType IterVarTypeToForType(IterVarType iter_type) { } } +Tensor TensorFromExpr(const Expr& expr, const Array& axis, + const std::string& name, const std::string& tag, + const Map& attrs) { + Array new_bodies; + int new_value_index = 0; + + // If this is a reduction then we have to clone its body + if (const Reduce* red = expr.as()) { + new_value_index = red->value_index; + + for (size_t i = 0; i < red->source.size(); ++i) { + Expr ith_red = Reduce::make(red->combiner, red->source, red->axis, red->condition, i); + new_bodies.push_back(ith_red); + } + } else { + new_value_index = 0; + new_bodies.push_back(expr); + } + + return ComputeOpNode::make(name, tag, attrs, axis, new_bodies).output(new_value_index); +} + +Tensor TransformBody(const Tensor& tensor, + std::function&)> func) { + if (const ComputeOpNode* op = tensor->op.as()) { + // Transform only one body + Expr new_body = func(op->body[tensor->value_index], op->axis); + + // If the body didn't change then we can return the same tensor + if (new_body.same_as(op->body[tensor->value_index])) { + return tensor; + } + + return TensorFromExpr(new_body, op->axis, op->name, op->tag, op->attrs); + } else { + return tensor; + } +} + +Tensor TransformBody(const Tensor& tensor, std::function func) { + return TransformBody(tensor, [func](const Expr& e, const Array&) { return func(e); }); +} + } // namespace op } // namespace tvm diff --git a/src/op/op_util.h b/src/op/op_util.h index de2e44c2ed59..da7987f7162f 100644 --- a/src/op/op_util.h +++ b/src/op/op_util.h @@ -11,6 +11,7 @@ #include #include #include +#include #include "../pass/ir_util.h" #include "../pass/arg_binder.h" #include "../schedule/message_passing.h" @@ -84,6 +85,45 @@ IterVarType ForTypeToIterVarType(ir::ForType for_type); */ ir::ForType IterVarTypeToForType(IterVarType iter_type); +/*! + * \brief Create a tensor from an expression. The expression may be a reduction, in which + * case its body will be correctly duplicated if it is a multi-valued reduction. + * + * \param expr The expr which will be the tensor's body. + * \param axis The input variables with ranges. + * \param name The tensor's name. + * \param tag The tensor's tag. + * \param attrs The tensor's attrs. + * \return A tensor. + */ +Tensor TensorFromExpr(const Expr& expr, const Array& axis, + const std::string& name = "tensor", const std::string& tag = "", + const Map& attrs = {}); + +/*! + * \brief Transform the body of a tensor if it is a compute tensor, otherwise return it + * unchanged. Note that if the compute returns a tuple, it transforms only one element, + * other elements are discarded. + * + * \param tensor The tensor to transform. + * \param func The transformation function working on expressions and additionally taking + * the array of the tensor's itervars. + * \return The transformed tensor. + */ +Tensor TransformBody(const Tensor& tensor, + std::function&)> func); + +/*! + * \brief Transform the body of a tensor if it is a compute tensor, otherwise return it + * unchanged. Note that if the compute returns a tuple, it transforms only one element, + * other elements are discarded. + * + * \param tensor The tensor to transform. + * \param func The transformation function (working on expressions). + * \return The transformed tensor. + */ +Tensor TransformBody(const Tensor& tensor, std::function func); + } // namespace op } // namespace tvm #endif // TVM_OP_OP_UTIL_H_ diff --git a/src/pass/zero_elimination.cc b/src/pass/zero_elimination.cc new file mode 100644 index 000000000000..775476d61fae --- /dev/null +++ b/src/pass/zero_elimination.cc @@ -0,0 +1,1819 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file zero_elimination.cc + * \brief Transform tensors in such a way as to eliminate summation over zeros. + */ +#include "zero_elimination.h" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "arithmetic/ModulusRemainder.h" +#include "../op/op_util.h" + +namespace tvm { +namespace ir { + +using HalideIR::Internal::gcd; +using HalideIR::Internal::lcm; + +struct ExprLess { + bool operator()(const Expr& l, const Expr& r) const { + return Compare(l, r) < 0; + } +}; + +struct ExprEq { + bool operator()(const Expr& l, const Expr& r) const { + return Compare(l, r) == 0; + } +}; + +// Merge two maps, prefer the right one on conflict +template +Map Merge(Map original, const Map& update) { + for (const auto& p : update) { + original.Set(p.first, p.second); + } + return std::move(original); +} + +// Concatenate two arrays +template +Array Concat(Array a, const Array& b) { + for (const auto& x : b) { + a.push_back(x); + } + return std::move(a); +} + +// Combine all expressions from the container using &&. +template +Expr All(const container& c) { + Expr res; + for (const auto& e : c) { + if (res.get()) { + res = res && e; + } else { + res = e; + } + } + if (res.get()) { + return res; + } else { + return const_true(); + } +} + +// Create a select statement of the form cond ? on_true : 0 +Expr SelectElseZero(const Expr& cond, const Expr& on_true) { + return Select::make(cond, on_true, make_zero(on_true.type())); +} + +// Simplify the expression as thoroughly as possible by using all available simplifiers. +Expr SuperSimplify(Expr e, const Map& vranges = Map()) { + // For some reason no simplifier can detect that there is only one value of the variable + std::unordered_map vmap; + for (const auto& var_range : vranges) { + if (is_const_int(var_range.second->extent, 1)) { + vmap[var_range.first.get()] = var_range.second->min; + } + } + if (!vmap.empty()) { + e = Substitute(e, vmap); + } + + return CanonicalSimplify(Simplify(CanonicalSimplify(e, vranges), vranges), vranges); +} + +// Provability check that uses SuperSimplify +bool CanProve(Expr e, const Map& vranges = Map()) { + return is_one(SuperSimplify(e, vranges)); +} + +class ExprFreeVarsVisitor : public IRVisitor { + public: + std::vector free_array; + std::unordered_set bound; + std::unordered_set free; + + virtual void Visit(const NodeRef& node) { + if (const Variable* v = node.as()) { + if (!bound.count(v) && !free.count(v)) { + free.insert(v); + free_array.push_back(Var(node.node_)); + } + } else { + IRVisitor::Visit(node); + } + } + + void Visit_(const Variable* op) { + CHECK(false) << "This case shouldn't happen"; + } + + void Visit_(const LetStmt* op) { + bound.insert(op->var.get()); + IRVisitor::Visit_(op); + } + + void Visit_(const For* op) { + bound.insert(op->loop_var.get()); + IRVisitor::Visit_(op); + } + + void Visit_(const Let* op) { + bound.insert(op->var.get()); + IRVisitor::Visit_(op); + } + + void Visit_(const Reduce* op) { + for (const auto& iv : op->axis) { + bound.insert(iv->var.get()); + } + IRVisitor::Visit_(op); + } + + void Visit_(const Store* op) { + Visit(op->buffer_var); + IRVisitor::Visit_(op); + } + + void Visit_(const Allocate* op) { + Visit(op->buffer_var); + IRVisitor::Visit_(op); + } + + void Visit_(const Free* op) { + Visit(op->buffer_var); + IRVisitor::Visit_(op); + } + + void Visit_(const Load* op) { + Visit(op->buffer_var); + IRVisitor::Visit_(op); + } +}; + +// Get free variables of an expression +Array ExprFreeVars(const Expr& expr) { + ExprFreeVarsVisitor visitor; + visitor.Visit(expr); + return visitor.free_array; +} + +// Clone iter vars and return both the new vars and the substitution from old to new. +std::pair, std::unordered_map> CloneIterVars( + const Array& vars) { + Array new_vars; + std::unordered_map vmap; + for (const IterVar& iv : vars) { + IterVar new_v = + IterVarNode::make(iv->dom, iv->var.copy_with_suffix(""), + iv->iter_type, iv->thread_tag); + new_vars.push_back(new_v); + vmap[iv->var.get()] = new_v; + } + return std::make_pair(std::move(new_vars), std::move(vmap)); +} + +// Clone reduction by cloning the axis variables. +Expr CloneReduction(const Expr& expr) { + if (const Reduce* red = expr.as()) { + Array new_axis; + std::unordered_map vmap; + std::tie(new_axis, vmap) = CloneIterVars(red->axis); + + Array src_with_newaxis; + for (const auto& src : red->source) { + src_with_newaxis.push_back(Substitute(src, vmap)); + } + + return Reduce::make(red->combiner, src_with_newaxis, + new_axis, Substitute(red->condition, vmap), red->value_index); + } else { + return expr; + } +} + +// Convert an array of itervars to an array of inequalities +Array IterVarsToInequalities(const Array& itervars) { + Array res; + for (const IterVar& v : itervars) { + res.push_back(GE::make(v->var, v->dom->min)); + res.push_back(LT::make(v->var, v->dom->min + v->dom->extent)); + } + return res; +} + +// Convert an array of itervars to a map from vars to ranges +Map IterVarsToMap(const Array& itervars) { + Map res; + for (const IterVar& v : itervars) { + res.Set(v->var, v->dom); + } + return res; +} + +// Convert an array of itervars to an array of vars +Array IterVarsToVars(const Array& itervars) { + Array res; + for (const IterVar& v : itervars) { + res.push_back(v->var); + } + return res; +} + +// Given a map from vars to ranges create an array of itervars +Array IterVarsFromMap(const Array& vars, const Map& vranges, + IterVarType iter_type = kDataPar, std::string thread_tag = "") { + Array res; + for (const Var& v : vars) { + CHECK(vranges.count(v)) << "A range for the variable " << v + << " was not provided in map " << vranges; + res.push_back(IterVarNode::make(vranges[v], v, iter_type, thread_tag)); + } + return res; +} + +// Return true if this combiner is just a sum. +bool IsSumCombiner(const CommReducer& combiner, const Map& vranges) { + if (combiner->result.size() != 1) { + return false; + } + + if (!is_const_value(SuperSimplify(combiner->identity_element[0], vranges), 0)) { + return false; + } + + Expr should_be_zero = + SuperSimplify(combiner->result[0] - (combiner->lhs[0] + combiner->rhs[0]), vranges); + return is_const_value(should_be_zero, 0); +} + +// Return true if zero may be factored out of a reduction with this combiner. +bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index, + const Map& vranges) { + if (!is_const_value(SuperSimplify(combiner->identity_element[value_index], vranges), 0)) { + return false; + } + + Expr zero = make_zero(combiner->result[value_index].type()); + Expr in = Substitute(combiner->result[value_index], + {{combiner->lhs[value_index], zero}, + {combiner->rhs[value_index], zero}}); + in = SuperSimplify(in, vranges); + + return is_const_value(in, 0); +} + +// If expr is a Call node, perform inlining, otherwise do nothing +Expr InlineThisCall(const Expr& expr) { + if (const Call* op = expr.as()) { + if (op->call_type == Call::CallType::Halide) { + if (const ComputeOpNode* op_comp = op->func.as()) { + Array tensor_axes; + for (const auto& var : op_comp->axis) { + tensor_axes.push_back(var->var); + } + + Stmt inlined = Inline(Evaluate::make(expr), op->func, tensor_axes, + op_comp->body[op->value_index]); + if (const ir::Evaluate* ev = inlined.as()) { + // If it is a reduction, clone it + return CloneReduction(ev->value); + } + } + } + } + + return expr; +} + +Tensor InlineTailCall(const Tensor& tensor) { + return op::TransformBody(tensor, InlineThisCall); +} + +// Implements InlineTensors by trying to inline every Call of the given Expr +class InlineTensorsMutator : public IRMutator { + public: + explicit InlineTensorsMutator(const Array& inlineable, bool inline_reductions = false) + : inline_reductions_(inline_reductions) { + for (const Tensor& tensor : inlineable) { + inlineable_.emplace(tensor->op.operator->(), tensor->value_index); + } + } + + Expr Mutate_(const Call* op, const Expr& e) { + if (op->call_type == Call::CallType::Halide) { + if (const ComputeOpNode* op_comp = op->func.as()) { + // Inline only if the array of inlineable tensors is empty or contains this tensor + if (inlineable_.empty() || inlineable_.count({op_comp, op->value_index})) { + // Inline only compute nodes that are not reductions (unless inline reductions is allowed) + if (inline_reductions_ || !op_comp->body[0].as()) { + // Inline this call and then try to perform further inlining + return Mutate(InlineThisCall(e)); + } + } + } + } + + // If we cannot inline this call, we should try to do inlining in its arguments + return IRMutator::Mutate_(op, e); + } + + private: + // Tensors which are allowed to be inlined, represented as pairs (op_node, value_index) + std::set> inlineable_; + bool inline_reductions_; +}; + +Expr InlineTensors(const Expr& expr, const Array& inlineable, + bool inline_reductions) { + return InlineTensorsMutator(inlineable, inline_reductions).Mutate(expr); +} + +Tensor InlineTensors(const Tensor& tensor, const Array& inlineable, + bool inline_reductions) { + auto transformation = + [inlineable, inline_reductions](const Expr& e) { + return InlineTensorsMutator(inlineable, inline_reductions).Mutate(e); }; + return op::TransformBody(tensor, transformation); +} + + +struct NonzeronessConditionResult { + Expr cond; + Expr value; + + Expr to_expr() const { + return SelectElseZero(cond, value); + } +}; + +// The implementation of NonzeronessCondition +class NonzeronessConditionFunctor + : public ExprFunctor { + public: + NonzeronessConditionResult NonzeronessCondition(const Expr& e) { + if (e.type().is_bool()) { + // Boolean expressions are non-zero whenever they are true themselves + return {e, const_true()}; + } else { + return VisitExpr(e, e); + } + } + + // Most of the cases are implemented using helpers below + result_type VisitExpr_(const Variable*, const Expr& e) final { return Default_(e); } + result_type VisitExpr_(const IntImm* op, const Expr& e) final { return Const_(op, e); } + result_type VisitExpr_(const UIntImm* op, const Expr& e) final { return Const_(op, e); } + result_type VisitExpr_(const FloatImm* op, const Expr& e) final { return Const_(op, e); } + result_type VisitExpr_(const StringImm*, const Expr& e) final { return Default_(e); } + result_type VisitExpr_(const Add* op, const Expr& e) final { return BinOpAddLike_(op, e); } + result_type VisitExpr_(const Sub* op, const Expr& e) final { return BinOpAddLike_(op, e); } + result_type VisitExpr_(const Mul* op, const Expr& e) final { return BinOpMulLike_(op, e); } + result_type VisitExpr_(const Div* op, const Expr& e) final { return BinOpDivLike_(op, e); } + result_type VisitExpr_(const Mod* op, const Expr& e) final { return BinOpDivLike_(op, e); } + result_type VisitExpr_(const Min* op, const Expr& e) final { return BinOpAddLike_(op, e); } + result_type VisitExpr_(const Max* op, const Expr& e) final { return BinOpAddLike_(op, e); } + + result_type VisitExpr_(const Cast* op, const Expr& e) final { + auto nz_a = NonzeronessCondition(op->value); + + if (nz_a.value.same_as(op->value)) { + return {nz_a.cond, e}; + } else { + return {nz_a.cond, Cast::make(op->type, nz_a.value)}; + } + } + + result_type VisitExpr_(const Select* op, const Expr& e) final { + Expr cond = op->condition, true_val = op->true_value, false_val = op->false_value; + auto nz_a = NonzeronessCondition(true_val); + auto nz_b = NonzeronessCondition(false_val); + + // If the false part is zero, we can get rid of the select + if (is_const_value(nz_b.value, 0)) { + Expr new_cond = SuperSimplify(nz_a.cond && cond); + return {new_cond, nz_a.value}; + } + + // If the true part is zero, we can also get rid of the select + if (is_const_value(nz_a.value, 0)) { + Expr new_cond = SuperSimplify(nz_b.cond && !cond); + return {new_cond, nz_b.value}; + } + + // Otherwise we retain the select and combine the conditions into this + Expr new_cond = SuperSimplify((cond && nz_a.cond) || (!cond && nz_b.cond)); + if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) { + return {new_cond, e}; + } else { + return {new_cond, Select::make(cond, nz_a.value, nz_b.value)}; + } + } + + result_type VisitExpr_(const Call* op, const Expr& e) final { + if (op->name == intrinsic::tvm_if_then_else) { + Expr cond = op->args[0], true_val = op->args[1], false_val = op->args[2]; + auto nz_a = NonzeronessCondition(true_val); + auto nz_b = NonzeronessCondition(false_val); + + // We don't have as much freedom here as in the select case + // since the `if` must be preserved in any case + Expr new_cond = SuperSimplify((cond && nz_a.cond) || (!cond && nz_b.cond)); + if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) { + return {new_cond, e}; + } else { + return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)}; + } + } else { + return Default_(e); + } + } + + NonzeronessConditionResult Default_(const Expr& e) { + // This is always correct, so it's the default + return {const_true(), e}; + } + + template + NonzeronessConditionResult Const_(const TNode* op, const Expr& e) { + if (op->value == 0) { + return {const_false(), e}; + } else { + return {const_true(), e}; + } + } + + template + NonzeronessConditionResult BinOpAddLike_(const TNode* op, const Expr& e) { + auto nz_a = NonzeronessCondition(op->a); + auto nz_b = NonzeronessCondition(op->b); + + // For addition and similar ops the result may be nonzero if either of the arguments is + // nonzero, so we combine the conditions with Or. + + if (Equal(nz_a.cond, nz_b.cond)) { + // If the conditions are the same, we don't need Or + if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) { + return {nz_a.cond, e}; + } else { + return {nz_a.cond, TNode::make(nz_a.value, nz_b.value)}; + } + } else { + // Otherwise use Or + Expr new_cond = SuperSimplify(nz_a.cond || nz_b.cond); + // A little optimization: if the combined condition is the same as one of the inner + // conditions, we don't need to guard the inner value with a select, otherwise + // we create a select in the `to_expr` call. + Expr new_a = Equal(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr(); + Expr new_b = Equal(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr(); + Expr new_expr = TNode::make(new_a, new_b); + return {new_cond, new_expr}; + } + } + + template + NonzeronessConditionResult BinOpMulLike_(const TNode* op, const Expr& e) { + auto nz_a = NonzeronessCondition(op->a); + auto nz_b = NonzeronessCondition(op->b); + + // For multiplication and similar ops the result may be nonzero if + // both the arguments are nonzero, so we combine with And. + + Expr new_cond = SuperSimplify(nz_a.cond && nz_b.cond); + + if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) { + return {new_cond, e}; + } else { + return {new_cond, TNode::make(nz_a.value, nz_b.value)}; + } + } + + template + NonzeronessConditionResult BinOpDivLike_(const TNode* op, const Expr& e) { + auto nz_a = NonzeronessCondition(op->a); + + // For Div we simply use the condition of the numerator. + + if (nz_a.value.same_as(op->a)) { + return {nz_a.cond, e}; + } else { + return {nz_a.cond, TNode::make(nz_a.value, op->b)}; + } + } +}; + +// Transform expr into a pair (condition, new_expr) such that the old expr is equivalent to +// `select(condition, new_expr, 0)`. The pair is represented as a struct for clarity. +NonzeronessConditionResult NonzeronessCondition(const Expr& expr) { + return NonzeronessConditionFunctor().NonzeronessCondition(expr); +} + +Expr LiftNonzeronessCondition(const Expr& expr) { + return NonzeronessCondition(expr).to_expr(); +} + + +class NormalizeComparisonsMutator : public IRMutator { + public: + virtual Expr Mutate_(const EQ* op, const Expr& e) { return Make(op->a, op->b); } + virtual Expr Mutate_(const NE* op, const Expr& e) { return Make(op->a, op->b); } + virtual Expr Mutate_(const LT* op, const Expr& e) { return Make(op->a, op->b); } + virtual Expr Mutate_(const LE* op, const Expr& e) { return Make(op->a, op->b); } + virtual Expr Mutate_(const GT* op, const Expr& e) { return Make(op->b, op->a); } + virtual Expr Mutate_(const GE* op, const Expr& e) { return Make(op->b, op->a); } + + private: + template + Expr Make(const Expr& a, const Expr& b) { + // rewrite LT to LE for ints + if (std::is_same::value && (a.type().is_int() || a.type().is_uint())) { + return LE::make(SuperSimplify(a - b + 1), make_zero(a.type())); + } + return TNode::make(SuperSimplify(a - b), make_zero(a.type())); + } +}; + +// Rewrite every comparison into the form a == 0, a != 0, a <= 0, and sometimes for floats a < 0 +Expr NormalizeComparisons(const Expr& expr) { + return NormalizeComparisonsMutator().Mutate(expr); +} + + +struct FactorOutAtomicFormulasResult { + std::vector atomic_formulas; + Expr rest; + + Expr to_expr() const { + Expr res = rest; + for (const Expr& e : atomic_formulas) { + res = And::make(e, res); + } + return res; + } +}; + +// The implementation of FactorOutAtomicFormulas +class FactorOutAtomicFormulasFunctor + : public ExprFunctor { + public: + result_type Atomic_(const Expr& e) { + // For atomic expressions the result is the expr itself with True as the residual + return {{e}, make_const(e.type(), 1)}; + } + + // This is basically the list of expression kinds that are considered atomic + result_type VisitExpr_(const Variable*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const Call*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const IntImm*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const UIntImm*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const EQ*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const NE*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const LE*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const LT*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const GE*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const GT*, const Expr& e) final { return Atomic_(e); } + + result_type VisitExpr_(const And* op, const Expr& e) final { + auto res_a = VisitExpr(op->a, op->a); + auto res_b = VisitExpr(op->b, op->b); + + // For the And case we return the union of the sets of atomic formulas + std::vector res; + res.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size()); + std::set_union(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(), + res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(), + std::back_inserter(res), + ExprLess()); + + // And the residuals are combined with && + return {res, res_a.rest && res_b.rest}; + } + + result_type VisitExpr_(const Mul* op, const Expr& e) final { + auto res_a = VisitExpr(op->a, op->a); + auto res_b = VisitExpr(op->b, op->b); + + // For multiplication we do the same thing as for And + std::vector res; + res.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size()); + std::set_union(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(), + res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(), + std::back_inserter(res), + ExprLess()); + + return {res, res_a.rest * res_b.rest}; + } + + result_type VisitExpr_(const Or* op, const Expr& e) final { + auto res_a = VisitExpr(op->a, op->a); + auto res_b = VisitExpr(op->b, op->b); + + // For the Or case we intersect the sets of atomic formulas + std::vector res; + res.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size())); + std::set_intersection(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(), + res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(), + std::back_inserter(res), + ExprLess()); + + // Computing the residual is more complex: we have to compute the sets of atomic formulas + // which are left behind, and then combine them with the residuals into the new residual. + + std::vector new_cond_a; + new_cond_a.reserve(res_a.atomic_formulas.size() - res.size()); + std::set_difference(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(), + res.begin(), res.end(), + std::back_inserter(new_cond_a), + ExprLess()); + + std::vector new_cond_b; + new_cond_b.reserve(res_b.atomic_formulas.size() - res.size()); + std::set_difference(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(), + res.begin(), res.end(), + std::back_inserter(new_cond_b), + ExprLess()); + + res_a.atomic_formulas = std::move(new_cond_a); + res_b.atomic_formulas = std::move(new_cond_b); + + Expr new_rest = res_a.to_expr() || res_b.to_expr(); + + return {res, new_rest}; + } +}; + +// Transform the given formula into a conjunction of atomic formulas (represented as an array) +// and a non-atomic residual. Atomic formulas are consts, calls, variables and comparisons (a <= b, +// etc), i.e. formulas which are not logical operators (||, &&, !) on the top level. +FactorOutAtomicFormulasResult FactorOutAtomicFormulas(const Expr& e) { + return FactorOutAtomicFormulasFunctor().VisitExpr(e, e); +} + + +class RemoveRedundantInequalitiesMutator : public IRMutator { + public: + explicit RemoveRedundantInequalitiesMutator(Array known) { + for (const Expr& cond : known) { + known_.push_back(SuperSimplify(cond)); + } + } + + virtual Expr Mutate_(const Select* op, const Expr& e) { + bool has_side_effect = HasSideEffect(e); + Expr new_cond = SuperSimplify(Mutate(op->condition)); + if (is_one(new_cond) && !has_side_effect) { + return Mutate(op->true_value); + } else if (is_zero(new_cond) && !has_side_effect) { + return Mutate(op->false_value); + } else { + Array new_known = known_; + for (const Expr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) { + new_known.push_back(atomic); + } + RemoveRedundantInequalitiesMutator new_mutator(new_known); + // Note that we mutate only the true value with the new mutator + // TODO(sgrechanik-h): Update known conditions for the false value as well + return Select::make(new_cond, new_mutator.Mutate(op->true_value), Mutate(op->false_value)); + } + } + + virtual Expr Mutate_(const Call* op, const Expr& e) { + if (op->name == intrinsic::tvm_if_then_else) { + Expr new_cond = SuperSimplify(Mutate(op->args[0])); + if (is_one(new_cond)) { + return Mutate(op->args[1]); + } else if (is_zero(new_cond)) { + return Mutate(op->args[2]); + } else { + Array new_known = known_; + for (const Expr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) { + new_known.push_back(atomic); + } + RemoveRedundantInequalitiesMutator new_mutator(new_known); + // Note that we mutate only the true value with the new mutator + // TODO(sgrechanik-h): Update known conditions for the false value as well + return if_then_else(new_cond, new_mutator.Mutate(op->args[1]), Mutate(op->args[2])); + } + } else { + return IRMutator::Mutate_(op, e); + } + } + + virtual Expr Mutate_(const Reduce* op, const Expr& e) { + Array known_with_axes = known_; + for (const Expr& axis_cond : IterVarsToInequalities(op->axis)) { + known_with_axes.push_back(axis_cond); + } + RemoveRedundantInequalitiesMutator mutator_with_axes(known_with_axes); + + Expr new_cond = mutator_with_axes.Mutate(op->condition); + + Array new_known = known_with_axes; + for (const Expr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) { + new_known.push_back(atomic); + } + RemoveRedundantInequalitiesMutator new_mutator(new_known); + + Array new_source; + for (const Expr& src : op->source) { + new_source.push_back(new_mutator.Mutate(src)); + } + + return Reduce::make(op->combiner, new_source, op->axis, new_cond, op->value_index); + } + + virtual Expr Mutate_(const EQ* op, const Expr& e) { return MutateAtomic_(e); } + virtual Expr Mutate_(const NE* op, const Expr& e) { return MutateAtomic_(e); } + virtual Expr Mutate_(const LT* op, const Expr& e) { return MutateAtomic_(e); } + virtual Expr Mutate_(const LE* op, const Expr& e) { return MutateAtomic_(e); } + virtual Expr Mutate_(const GT* op, const Expr& e) { return MutateAtomic_(e); } + virtual Expr Mutate_(const GE* op, const Expr& e) { return MutateAtomic_(e); } + + virtual Expr Mutate_(const And* op, const Expr& e) { + return Mutate(op->a) && Mutate(op->b); + } + + private: + Expr MutateAtomic_(const Expr& e) { + Expr simplified = SuperSimplify(e); + for (const Expr& other : known_) { + if (Equal(simplified, other)) { + return const_true(); + } + } + return simplified; + } + + Array known_; +}; + +// Propagate information from conditions and remove redundant inequalities +// TODO(sgrechanik-h): This should be merged into standard simplifiers +Expr RemoveRedundantInequalities(const Expr& expr, const Array& known) { + return RemoveRedundantInequalitiesMutator(known).Mutate(expr); +} + + +struct EliminateDivModResult { + Expr expr; + Map substitution; + Array new_variables; + Array conditions; + Map ranges; +}; + +class EliminateDivModMutator : public IRMutator { + public: + Map substitution; + Array new_variables; + Array conditions; + Map ranges; + + explicit EliminateDivModMutator(Map ranges) + : ranges(ranges) {} + + virtual Expr Mutate_(const Div* op, const Expr& e) { + const IntImm* imm = op->b.as(); + if (imm && imm->value > 0) { + // Try to find the already existing variables for this expression + auto it = expr_to_vars_.find({op->a, imm->value}); + if (it != expr_to_vars_.end()) { + return it->second.first; + } + + // Otherwise recursively mutate the left hand side, and create new variables + Expr mutated_a = Mutate(op->a); + if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value)) { + return var_pair_opt.value().first; + } else { + return mutated_a / op->b; + } + } + + return Mutate(op->a) / Mutate(op->b); + } + + virtual Expr Mutate_(const Mod* op, const Expr& e) { + const IntImm* imm = op->b.as(); + if (imm && imm->value > 0) { + // Try to find the already existing variables for this expression + auto it = expr_to_vars_.find({op->a, imm->value}); + if (it != expr_to_vars_.end()) { + return it->second.second; + } + + // Otherwise recursively mutate the left hand side, and create new variables + Expr mutated_a = Mutate(op->a); + if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value)) { + return var_pair_opt.value().second; + } else { + return mutated_a % op->b; + } + } + + return Mutate(op->a) % Mutate(op->b); + } + + private: + dmlc::optional> AddNewVarPair(const Expr& e, const Expr& mut, int64_t val) { + using tresult = dmlc::optional>; + + // Try to find the variables using the mutated expressions + if (!e.same_as(mut)) { + auto it = expr_to_vars_.find({mut, val}); + if (it != expr_to_vars_.end()) { + return tresult(it->second); + } + } + + Expr val_e = make_const(e.type(), val); + idx_ += 1; + + // Convert `ranges` to IntSets + std::unordered_map var_intsets; + for (const auto& p : ranges) { + var_intsets[p.first.get()] = IntSet::range(p.second); + } + + // Infer ranges for the expressions we want to replace with variables + Range div_range = EvalSet(mut / val_e, var_intsets).cover_range(Range()); + Range mod_range = EvalSet(mut % val_e, var_intsets).cover_range(Range()); + + // We don't want to add unbounded variables + if (!div_range.get() || !mod_range.get()) { + LOG(WARNING) << "EliminateDivMod: won't eliminate div or mod of expr " << e + << " because its bounds cannot be inferred"; + return tresult(); + } + + // Create new variables for the expressions + auto div = Var("div" + std::to_string(idx_), e.type()); + auto mod = Var("mod" + std::to_string(idx_), e.type()); + + new_variables.push_back(div); + new_variables.push_back(mod); + + substitution.Set(div, mut / val_e); + substitution.Set(mod, mut % val_e); + + ranges.Set(div, div_range); + ranges.Set(mod, mod_range); + + // This additional condition works as a definition for the new variables + conditions.push_back(mut == div*val_e + mod); + + if (!CanProve(mod_range->extent <= val_e)) { + // Since we use the C/C++ definition of mod, there may be multiple values of `mod` + // satisfying the added condition if the expr `e` may change its sign, so we + // have to add another condition. + LOG(WARNING) << "EliminateDivMod: cannot fully eliminate div or mod of expr " << e + << " (probably it may change its sign)"; + conditions.push_back(Select::make(e >= 0, mod >= 0, mod <= 0)); + } + + auto p = std::make_pair(div, mod); + expr_to_vars_[{e, val}] = p; + if (!e.same_as(mut)) { + expr_to_vars_[{mut, val}] = p; + } + return tresult(p); + } + + // A custom comparison function for pairs of exprs and numbers. Compares exprs deeply. + struct Compare_ { + bool operator()(const std::pair& p1, const std::pair& p2) { + if (p1.second < p2.second) { + return true; + } else if (p1.second == p2.second) { + return Compare(p1.first, p2.first) < 0; + } else { + return false; + } + } + }; + + // A counter for naming new variables + int idx_{0}; + // A map from pairs of exprs and numbers (e, n) to pairs of new vars (div, mod) + // such that `div = e / n` and `mod = e % n` + std::map, std::pair, Compare_> + expr_to_vars_; +}; + +// Replace every subexpr of the form e/const and e % const with a new variable. +// Syntactically equal expressions will be mapped to the same variable. +EliminateDivModResult EliminateDivMod(const Expr& expr, Map ranges) { + EliminateDivModResult res; + EliminateDivModMutator mutator(ranges); + res.expr = mutator.Mutate(expr); + res.conditions = std::move(mutator.conditions); + res.new_variables = std::move(mutator.new_variables); + res.substitution = std::move(mutator.substitution); + res.ranges = std::move(mutator.ranges); + return res; +} + +// run EliminateDivMod from the condition of a reduction +Expr EliminateDivModFromReductionCondition(const Expr& expr, + Map vranges = Map()) { + if (const Reduce* red = expr.as()) { + for (const IterVar& iv : red->axis) { + vranges.Set(iv->var, iv->dom); + } + + auto elim_res = EliminateDivMod(red->condition, vranges); + + vranges = elim_res.ranges; + + Array new_axis = + Concat(red->axis, IterVarsFromMap(elim_res.new_variables, vranges, kCommReduce)); + + Expr new_cond = elim_res.expr && All(elim_res.conditions); + + return Reduce::make(red->combiner, red->source, new_axis, new_cond, red->value_index); + } else { + return expr; + } +} + + +VarBounds VarBounds::substitute(const Map& subst) const { + auto apply_fun = [&subst](const Expr& e) { return Substitute(e, subst); }; + return {Substitute(coef, subst), + UpdateArray(lower, apply_fun), + UpdateArray(equal, apply_fun), + UpdateArray(upper, apply_fun)}; +} + +Array SolveSystemOfInequalitiesResult::as_conditions() const { + Array res; + for (const Var& v : variables) { + auto it = bounds.find(v.get()); + CHECK(it != bounds.end()); + const VarBounds& bnds = it->second; + Expr lhs = bnds.coef * v; + for (const Expr& rhs : bnds.equal) { + res.push_back(EQ::make(lhs, rhs)); + } + for (const Expr& rhs : bnds.lower) { + res.push_back(GE::make(lhs, rhs)); + } + for (const Expr& rhs : bnds.upper) { + res.push_back(LE::make(lhs, rhs)); + } + } + for (const Expr& e : other_conditions) { + res.push_back(e); + } + return res; +} + +// Rewrite the system of inequalities using Fourier-Motzkin elimination +// Note that variable ranges help a lot, so this parameter is even non-optional +SolveSystemOfInequalitiesResult SolveSystemOfInequalities(const Array& inequalities, + const Array& variables, + const Map& vranges) { + SolveSystemOfInequalitiesResult res; + res.variables = variables; + + // The algorithm consists in doing the following things for each variable v + // - Take formulas from `current` and classify them according to polarity wrt v + // - Combine each formula of positive polarity (wrt v) with each formula of negative polarity + // - Put the resulting combinations into `new_current` along with unclassifiable formulas + // - Replace `current` with `new_current` and move to the next variable + + // current and new_current are sorted to enable some heuristics + std::set current; + std::set new_current; + // A vector of pairs (c, e), c > 0, representing formulas of the form c*v + e <= 0 + std::vector> coef_pos; + // A vector of pairs (c, e), c < 0, representing formulas of the form c*v + e <= 0 + std::vector> coef_neg; + + // formulas we don't know what to do with + std::vector rest; + + // A helper that adds an inequality to new_current if it's not obviously redundant + auto add_to_new_current = [&new_current, &vranges] (const Expr& new_ineq) { + if (CanProve(new_ineq, vranges)) { + // redundant: follows from the vranges + return; + } + if (const LE* new_le = new_ineq.as()) { + // A heuristic: check if the new inequality is a consequence of one + // of its future neighbors (in this case don't add it) or if a future neighbor is + // a consequence of the new ineq (in which case remove the neighbor) + auto it_neighbor = new_current.lower_bound(new_ineq); + if (it_neighbor != new_current.begin()) { + const LE* le = std::prev(it_neighbor)->as(); + if (le && CanProve(new_le->a - le->a <= 0, vranges)) { + return; + } else if (le && CanProve(le->a - new_le->a <= 0, vranges)) { + new_current.erase(std::prev(it_neighbor)); + } + } + // Check the other neighbor + if (it_neighbor != new_current.end()) { + const LE* le = it_neighbor->as(); + if (le && CanProve(new_le->a - le->a <= 0, vranges)) { + return; + } else if (le && CanProve(le->a - new_le->a <= 0, vranges)) { + it_neighbor = new_current.erase(it_neighbor); + } + } + + new_current.insert(it_neighbor, new_ineq); + } else { + new_current.insert(new_ineq); + } + }; + + // Simplify each inequality into the form `expr <= 0` and add to new_current formulas + for (const Expr& ineq : inequalities) { + add_to_new_current(NormalizeComparisons(SuperSimplify(ineq, vranges))); + } + + std::swap(current, new_current); + + for (const Var& v : variables) { + CHECK(!res.bounds.count(v.get())) << + "Variable " << v << " appears several times in the `variables` which might be a bug"; + + new_current.clear(); + coef_pos.clear(); + coef_neg.clear(); + + // Add bounds from vranges + if (vranges.count(v)) { + const Range& range = vranges[v]; + Expr range_lbound = SuperSimplify(range->min, vranges); + Expr range_ubound = SuperSimplify(range->min + range->extent - 1, vranges); + coef_neg.push_back({-1, range_lbound}); + coef_pos.push_back({1, -range_ubound}); + } + + // Take formulas from `current` and classify them according to polarity wrt v + for (const Expr& ineq : current) { + if (const LE* le = ineq.as()) { + Array coef = arith::DetectLinearEquation(le->a, {v}); + if (!coef.empty() && is_const(coef[0])) { + int64_t coef0 = *as_const_int(coef[0]); + if (coef0 == 0) { + // zero polarity, straight to new_current + add_to_new_current(ineq); + } else if (coef0 > 0) { + coef_pos.push_back({coef0, coef[1]}); + } else if (coef0 < 0) { + coef_neg.push_back({coef0, coef[1]}); + } + continue; + } + } else if (const EQ* eq = ineq.as()) { + Array coef = arith::DetectLinearEquation(eq->a, {v}); + if (!coef.empty() && is_const(coef[0])) { + int64_t coef0 = *as_const_int(coef[0]); + if (coef0 == 0) { + // zero polarity, straight to new_current + add_to_new_current(ineq); + } else if (coef0 > 0) { + // Equalities may be considered as pairs of two inequalities + coef_pos.push_back({coef0, coef[1]}); + coef_neg.push_back({-coef0, -coef[1]}); + } else if (coef0 < 0) { + coef_pos.push_back({-coef0, -coef[1]}); + coef_neg.push_back({coef0, coef[1]}); + } + continue; + } + } + + // if nothing worked, put it in rest + rest.push_back(ineq); + } + + // Combine each positive inequality with each negative one (by adding them together) + for (const auto& pos : coef_pos) { + for (const auto& neg : coef_neg) { + auto first_gcd = gcd(pos.first, -neg.first); + Expr c_pos = make_const(v.type(), neg.first/first_gcd); + Expr c_neg = make_const(v.type(), pos.first/first_gcd); + Expr new_lhs = c_neg*neg.second - c_pos*pos.second; + Expr new_ineq = LE::make(new_lhs, make_zero(pos.second.type())); + new_ineq = NormalizeComparisons(SuperSimplify(new_ineq, vranges)); + add_to_new_current(new_ineq); + } + } + + // Now we have to generate resulting (in)equalities for the variable v + + // Find the common denominator in a sense + // We will generate formulas of the form coef_lcm*v <= bound + int64_t coef_lcm = 1; + for (const auto& pos : coef_pos) { + coef_lcm = lcm(coef_lcm, pos.first); + } + for (const auto& neg : coef_neg) { + coef_lcm = lcm(coef_lcm, -neg.first); + } + + // The resulting lower and upper bounds stored in sorted vectors + std::vector upper_bounds; + std::vector lower_bounds; + upper_bounds.reserve(coef_pos.size()); + lower_bounds.reserve(coef_neg.size()); + + for (const auto& pos : coef_pos) { + Expr bound = make_const(v.type(), -coef_lcm/pos.first)*pos.second; + bound = SuperSimplify(bound, vranges); + // Don't add if any of the existing bounds is better + if (std::any_of(upper_bounds.begin(), upper_bounds.end(), + [&bound, &vranges](const Expr& o) { return CanProve(o - bound <= 0, + vranges); })) { + continue; + } + // Erase all worse bounds + upper_bounds.erase( + std::remove_if(upper_bounds.begin(), upper_bounds.end(), + [&bound, &vranges](const Expr& o) { return CanProve(o - bound >= 0, + vranges); }), + upper_bounds.end()); + // Add + upper_bounds.push_back(bound); + } + for (const auto& neg : coef_neg) { + Expr bound = make_const(v.type(), -coef_lcm/neg.first)*neg.second; + bound = SuperSimplify(bound, vranges); + // Don't add if any of the existing bounds is better + if (std::any_of(lower_bounds.begin(), lower_bounds.end(), + [&bound, &vranges](const Expr& o) { return CanProve(o - bound >= 0, + vranges); })) { + continue; + } + // Erase all worse bounds + lower_bounds.erase( + std::remove_if(lower_bounds.begin(), lower_bounds.end(), + [&bound, &vranges](const Expr& o) { return CanProve(o - bound <= 0, + vranges); }), + lower_bounds.end()); + // Add + lower_bounds.push_back(bound); + } + + // Sort the vectors and remove duplicates + for (std::vector* bounds : {&upper_bounds, &lower_bounds}) { + std::sort(bounds->begin(), bounds->end(), ExprLess()); + bounds->erase(std::unique(bounds->begin(), bounds->end(), ExprEq()), bounds->end()); + } + + // Bounds which are both lower and upper should go to equal... + std::vector equal; + equal.reserve(std::min(upper_bounds.size(), lower_bounds.size())); + std::set_intersection(upper_bounds.begin(), upper_bounds.end(), + lower_bounds.begin(), lower_bounds.end(), + std::back_inserter(equal), ExprLess()); + + // ...and be removed from upper bounds... + std::vector new_upper; + new_upper.reserve(upper_bounds.size() - equal.size()); + std::set_difference(upper_bounds.begin(), upper_bounds.end(), + equal.begin(), equal.end(), + std::back_inserter(new_upper), ExprLess()); + + // ...and from lower bounds. + std::vector new_lower; + new_lower.reserve(lower_bounds.size() - equal.size()); + std::set_difference(lower_bounds.begin(), lower_bounds.end(), + equal.begin(), equal.end(), + std::back_inserter(new_lower), ExprLess()); + + // Write it to the result. + auto& bnds = res.bounds[v.get()]; + bnds.coef = make_const(v.type(), coef_lcm); + bnds.equal = equal; + bnds.lower = new_lower; + bnds.upper = new_upper; + + std::swap(current, new_current); + } + + // Everything that is left goes to res.other_conditions + for (const Expr& e : current) { + Expr e_simp = SuperSimplify(e, vranges); + if (is_const_int(e_simp, 0)) { + // contradiction detected + res.other_conditions = {const_false()}; + return res; + } else if (is_const_int(e_simp, 1)) { + continue; + } else { + res.other_conditions.push_back(e_simp); + } + } + + for (const Expr& e : rest) + res.other_conditions.push_back(e); + + return res; +} + + +// Simplify an iteration domain. +DomainSimplificationResult SimplifyDomain(const Expr& cond, + const Array& axis, + Map vranges, + bool eliminate_div_mod) { + if (eliminate_div_mod) { + auto elim_res = EliminateDivMod(cond, vranges); + + Map new_vranges = elim_res.ranges; + Array new_axis = Concat(axis, elim_res.new_variables); + Expr new_cond = elim_res.expr && All(elim_res.conditions); + + auto res = SimplifyDomain(new_cond, new_axis, new_vranges, false); + + Map new_old_to_new; + for (const Var& v : axis) { + new_old_to_new.Set(v, res.old_to_new[v]); + } + + Map new_new_to_old; + for (const auto& pair : res.new_to_old) { + new_new_to_old.Set(pair.first, Substitute(pair.second, elim_res.substitution)); + } + + res.old_to_new = std::move(new_old_to_new); + res.new_to_old = std::move(new_new_to_old); + + return res; + } + + auto factoratomic_res = FactorOutAtomicFormulas(cond); + std::vector& atomic_formulas = factoratomic_res.atomic_formulas; + Expr rest_of_cond = factoratomic_res.rest; + + // Put rest_of_cond into the vector of atomic formulas so that we don't forget about it. + // Although rest_of_cond is not atomic, the subsequent functions won't complain about it. + atomic_formulas.push_back(rest_of_cond); + + // vars are variables from axis followed by all the other variables from vranges + Array vars = axis; + for (const auto& pair : vranges) { + bool already = false; + for (const Var& v : vars) { + already = already || v.same_as(pair.first); + } + if (!already) { + vars.push_back(pair.first); + } + } + + auto solved_system = SolveSystemOfInequalities(atomic_formulas, vars, vranges); + + DomainSimplificationResult res; + std::unordered_map new_var_intsets; + + // Initialize new_var_intsets with the old var intsets + for (const auto& pair : vranges) { + new_var_intsets[pair.first.get()] = IntSet::range(pair.second); + } + + // We process variables in the reverse direction to start with the most independent one. + // This order is needed to compute new ranges. + for (auto it = axis.rbegin(); it != axis.rend(); ++it) { + const Var& var = *it; + auto& bnd = solved_system.bounds[var.get()]; + // Note that we replace old vars with new ones + bnd = bnd.substitute(res.old_to_new); + if (is_one(bnd.coef) && !bnd.equal.empty()) { + // There is an equation of the form `v == expr`, so this variable can be completely removed. + // Note that we use the 0-th expression because they are ordered by complexity, so it must be + // the simplest one. + res.old_to_new.Set(var, bnd.equal[0]); + } else { + Array lowers = Concat(bnd.equal, bnd.lower); + Array uppers = Concat(bnd.equal, bnd.upper); + + // Here we will try all pairs of lower and upper bounds and find the best pair, that is, the + // pair with the minimal difference between the upper and the lower. + // Note that the bounds are for v*coef, not for v (because we don't want complex expressions + // involving division). + + // The lower bound of the best pair so far + Expr best_lower = vranges[var]->min * bnd.coef; + // The difference between the upper and the lower of the best pair so far + Expr best_diff = (vranges[var]->extent - 1) * bnd.coef; + // The overapproximation of the best difference + Expr best_diff_over = best_diff; + + for (const Expr& low : lowers) { + for (const Expr& upp : uppers) { + Expr diff = SuperSimplify(upp - low, vranges); + // Since diff may depend on some other variables, we compute its overapproximation + Expr diff_over = EvalSet(diff, new_var_intsets).max(); + + if (diff_over.same_as(HalideIR::Internal::Interval::pos_inf)) { + continue; + } + + // If it is provable that the new one is strictly better than the current best one, + // then replace it. Note that we are biased towards earlier pairs which should be simpler. + if (CanProve(diff_over - best_diff_over < 0, vranges)) { + best_lower = low; + best_diff = diff; + best_diff_over = diff_over; + } + } + } + + if (is_const_int(best_diff, 0)) { + // In this case coef*iv = best_lower + // Don't create an itervar, just replace it everywhere with its min + res.old_to_new.Set(var, SuperSimplify(best_lower / bnd.coef, vranges)); + // To assure correctness, we have to add a condition that best_lower can be divided by coef + res.conditions.push_back(SuperSimplify(best_lower % bnd.coef == 0, vranges)); + } else { + std::string suffix = Equal(best_lower, vranges[var]->min * bnd.coef) ? "" : ".shifted"; + Var new_var = var.copy_with_suffix(suffix); + + // We will replace our iv with new_var + shift. + // We use rounding-up division to compute shift. Since we want to use a single formula + // without selects in as many cases as possible, we try to prove conditions manually. + Expr shift; + if (CanProve(best_lower <= 0, vranges)) { + shift = best_lower / bnd.coef; + } else if (CanProve(best_lower > -bnd.coef, vranges)) { + shift = (best_lower + bnd.coef - 1)/bnd.coef; + } else { + shift = Select::make(best_lower <= -bnd.coef, + best_lower / bnd.coef, + (best_lower + bnd.coef - 1)/bnd.coef); + } + shift = SuperSimplify(shift, vranges); + + Expr diff = SuperSimplify(best_diff_over / bnd.coef, vranges); + + if (is_const_int(diff, 0)) { + // Don't create an itervar, just replace it everywhere with its min + res.old_to_new.Set(var, shift); + } else { + res.old_to_new.Set(var, new_var + shift); + // Note that we are substituting old with new, so best_lower contains new var, + // that is we have to substitute new with old in best_lower here + res.new_to_old.Set(new_var, + SuperSimplify(var - Substitute(shift, res.new_to_old), vranges)); + + new_var_intsets[new_var.get()] = IntSet::interval(make_zero(new_var.type()), diff); + + // Add the new var to the resulting axis + auto range = Range(make_zero(new_var.type()), SuperSimplify(diff + 1, vranges)); + res.axis.push_back(new_var); + res.ranges.Set(new_var, range); + vranges.Set(new_var, range); + } + } + } + } + + // Add the original conditions (with variables substituted) to the resulting conditions + for (const Expr& old_cond : solved_system.as_conditions()) { + res.conditions.push_back(SuperSimplify(Substitute(old_cond, res.old_to_new), vranges)); + } + + // Reverse the axis so that it matches the order of the original variables + res.axis = Array(res.axis.rbegin(), res.axis.rend()); + + return res; +} + +// Use the condition of a reduction op to simplify its domain (axis) +Expr SimplifyReductionDomain(const Expr& expr, const Map& outer_vranges) { + if (const Reduce* red = expr.as()) { + Map vranges = Merge(outer_vranges, IterVarsToMap(red->axis)); + auto res = SimplifyDomain(red->condition, IterVarsToVars(red->axis), + Merge(outer_vranges, IterVarsToMap(red->axis))); + + Array new_source; + for (const Expr& src : red->source) { + new_source.push_back(Substitute(src, res.old_to_new)); + } + + Array new_axis = IterVarsFromMap(res.axis, res.ranges, kCommReduce); + + // Perform simplification mainly to remove a possibly empty reduction. + return Simplify(Reduce::make(red->combiner, new_source, new_axis, + All(res.conditions), red->value_index)); + } else { + return expr; + } +} + +// Extract the given expr under the given condition as a separate tensor if the volume of the +// extracted tensor will be less than the volume of the outer_axis +Expr ExtractAsTensorMaybe(const Expr& e, const Expr& cond, + const Array& outer_axis, + const Map& vranges) { + // TODO(sgrechanik-h): We don't use divmod elimination here because of some performance problems + auto res = SimplifyDomain(cond, outer_axis, vranges, false); + + Expr new_expr = SuperSimplify(Substitute(e, res.old_to_new), Merge(vranges, res.ranges)); + // This is mostly done to simplify if_then_else which is not known by the Halide simplifier + new_expr = RemoveRedundantInequalities(new_expr, res.conditions); + + // Keep only those variables of the new axis which are used in the new_expr + { + Array used_res_axis; + for (const Var& var : res.axis) { + if (ExprUseVar(new_expr, var)) { + used_res_axis.push_back(var); + } + } + + res.axis = std::move(used_res_axis); + } + + // If the expression does not use vars then it is probably better to keep it inlined + if (res.axis.empty()) { + // We can return the new_expr here instead of the old e because it doesn't use variables + // otherwise we would need to replace the new vars or create a let-expression + return new_expr; + } + + // If it's already a call to a tensor then extracting it will probably be useless + if (const Call* call = new_expr.as()) { + if (call->call_type == Call::CallType::Halide) { + return e; + } + } + + // Compute volumes before and after + Expr old_volume = make_const(Int(64), 1); + for (const Var& var : outer_axis) { + old_volume = old_volume * vranges[var]->extent; + } + + Expr new_volume = make_const(Int(64), 1); + for (const Var& var : res.axis) { + new_volume = new_volume * res.ranges[var]->extent; + } + + // if we can prove that the old volume is not greater than the new volume then + // prefer the old expression. + if (CanProve(old_volume <= new_volume, vranges)) { + return e; + } + + Tensor tensor = op::TensorFromExpr(new_expr, IterVarsFromMap(res.axis, res.ranges), + "extracted_tensor"); + + Array args; + for (const Var& var : res.axis) { + args.push_back(res.new_to_old[var]); + } + + return Call::make(e.type(), tensor->op->name, args, + Call::CallType::Halide, tensor->op, tensor->value_index); +} + + +// Extract from cond an implication of cond not containing vars +std::pair ImplicationNotContainingVars( + const Expr& cond, const std::unordered_set& vars) { + CHECK(cond.type().is_bool()) << "The type of cond must be bool"; + // TODO(sgrechanik-h): not + if (const And* op = cond.as()) { + auto pair_a = ImplicationNotContainingVars(op->a, vars); + auto pair_b = ImplicationNotContainingVars(op->b, vars); + return {pair_a.first && pair_b.first, + pair_a.second && pair_b.second}; + } else if (const Or* op = cond.as()) { + auto pair_a = ImplicationNotContainingVars(op->a, vars); + auto pair_b = ImplicationNotContainingVars(op->b, vars); + return {pair_a.first || pair_b.first, + (pair_a.first || pair_b.second) && + (pair_b.first || pair_a.second) && + (pair_a.second || pair_b.second)}; + } else if (!ExprUseVar(cond, vars)) { + return {cond, const_true()}; + } else { + return {const_true(), cond}; + } +} + +// Factor conditions out of a reduction by applying Fourier-Motzkin elimination and moving out +// (in)equalities which do not depend on the reduction variables. +std::pair LiftConditionsThroughReduction(const Expr& cond, + const Array& red_axis, + const Array& outer_axis) { + // Factor out atomics so that we can consider this as a system of inequalities + auto factoratomic_res = FactorOutAtomicFormulas(cond); + Array atomics = factoratomic_res.atomic_formulas; + const Expr& rest = factoratomic_res.rest; + + Array allvars; + for (const IterVar& v : red_axis) { + allvars.push_back(v->var); + } + for (const IterVar& v : outer_axis) { + allvars.push_back(v->var); + } + + auto vranges = Merge(IterVarsToMap(red_axis), IterVarsToMap(outer_axis)); + // start from reduction vars, so that input vars don't depend on them + atomics = SolveSystemOfInequalities(atomics, allvars, vranges).as_conditions(); + + // Append the rest part + Expr rewritten_cond = All(atomics) && rest; + + std::unordered_set vset; + for (const IterVar& v : red_axis) { + vset.insert(v->var.get()); + } + + // The outer (first) condition does not contain reduction vars, + // the inner (second) condition is everything else + return ImplicationNotContainingVars(rewritten_cond, vset); +} + +class ExtractReductionsMutator : public IRMutator { + public: + explicit ExtractReductionsMutator(const Array& outer_axis, + Map vranges, + std::string name = "extracted_reduction") + : outer_axis_(outer_axis), vranges_(std::move(vranges)), name_(std::move(name)) {} + + Expr Mutate_(const Reduce* op, const Expr& e) { + ExtractReductionsMutator new_mutator(Concat(IterVarsToVars(op->axis), outer_axis_), + Merge(vranges_, IterVarsToMap(op->axis)), + name_); + + Array new_source; + for (const Expr& src : op->source) { + new_source.push_back(new_mutator.Mutate(src)); + } + + Expr new_reduce = + Reduce::make(op->combiner, new_source, op->axis, op->condition, op->value_index); + + ExprFreeVarsVisitor fv_visitor; + fv_visitor.Visit(new_reduce); + + // Vars of the tensor we are going to create for this reduction + Array vars; + for (const Var& v : outer_axis_) { + // We take variables from the outer_axis_ which are also present in the new reduction + if (fv_visitor.free.count(v.get())) { + vars.push_back(v); + } + } + + auto newaxis_vmap_pair = CloneIterVars(IterVarsFromMap(vars, vranges_)); + Array new_axis = newaxis_vmap_pair.first; + new_reduce = SuperSimplify(Substitute(new_reduce, newaxis_vmap_pair.second), + IterVarsToMap(new_axis)); + + Tensor tensor = op::TensorFromExpr(new_reduce, new_axis, name_, tag_, attrs_); + + Array args; + for (const Var& v : vars) { + args.push_back(v); + } + + return Call::make(e.type(), tensor->op->name, args, + Call::CallType::Halide, tensor->op, tensor->value_index); + } + + private: + Array outer_axis_; + Map vranges_; + std::string name_; + std::string tag_; + Map attrs_; +}; + +// Extract reductions as separate tensors. +Expr ExtractReductions(const Expr& expr, + const Array& outer_axis, + const Map& vranges) { + return ExtractReductionsMutator(outer_axis, vranges).Mutate(expr); +} + +Expr ExtractNonTopReductions(const Expr& expr, + const Array& outer_axis, + const Map& vranges) { + if (const Reduce* red = expr.as()) { + Array new_outer_axis = Concat(IterVarsToVars(red->axis), outer_axis); + Map new_vranges = Merge(vranges, IterVarsToMap(red->axis)); + Array new_source; + for (const Expr& src : red->source) { + new_source.push_back(ExtractReductions(src, new_outer_axis, new_vranges)); + } + Expr new_condition = ExtractReductions(red->condition, new_outer_axis, new_vranges); + + return Reduce::make(red->combiner, new_source, red->axis, + new_condition, red->value_index); + } else { + return ExtractReductions(expr, outer_axis, vranges); + } +} + +Expr OptimizeAndLiftNonzeronessConditionsImpl(const Expr& expr, + const Array& axis, + const Map& vranges) { + Expr result; + Map combined_vranges = Merge(vranges, IterVarsToMap(axis)); + + if (const Reduce* red = expr.as()) { + // TODO(sgrechanik-h): There are some other operations which behave like sum + bool is_sum = IsSumCombiner(red->combiner, vranges); + if (is_sum || CanFactorZeroFromCombiner(red->combiner, red->value_index, vranges)) { + Expr new_red = expr; + + // Here we simplify the reduction + { + Expr cond = red->condition; + Array source = red->source; + + // If it is a summation then we can lift nonzeroness conditions from the source + // and add them to the reduction conditions + if (is_sum) { + auto nz = NonzeronessCondition(red->source[red->value_index]); + cond = nz.cond && cond; + source.Set(0, nz.value); + } + + new_red = Reduce::make(red->combiner, source, red->axis, cond, red->value_index); + new_red = SimplifyReductionDomain(new_red, combined_vranges); + red = new_red.as(); + + // If the reduction disappears completely then transform the result as a non-reduction + if (!red) { + return OptimizeAndLiftNonzeronessConditionsImpl(new_red, axis, vranges); + } + } + + Expr new_outer_cond, new_reduce_cond; + Array new_source = red->source; + + // Partially lift conditions from the reduce condition + std::tie(new_outer_cond, new_reduce_cond) = + LiftConditionsThroughReduction(red->condition, red->axis, axis); + + // If it's not sum then we haven't yet lifted nonzeroness cond from the source + if (!is_sum) { + Expr outer_nz_cond, nz_cond, nz_source; + auto nz = NonzeronessCondition(red->source[red->value_index]); + // Append conditions from the reduction + nz_cond = new_reduce_cond && nz.cond; + nz_source = nz.value; + std::tie(outer_nz_cond, nz_cond) = + LiftConditionsThroughReduction(nz_cond, red->axis, axis); + new_outer_cond = new_outer_cond && outer_nz_cond; + new_source.Set(red->value_index, SelectElseZero(nz_cond, nz_source)); + } + + Expr new_reduce = Reduce::make(red->combiner, new_source, red->axis, + new_reduce_cond, red->value_index); + new_reduce = ExtractAsTensorMaybe(new_reduce, new_outer_cond, + IterVarsToVars(axis), + combined_vranges); + result = SelectElseZero(new_outer_cond, new_reduce); + } else { + return SimplifyReductionDomain(expr, combined_vranges); + } + } else { + auto nz = NonzeronessCondition(expr); + Expr new_expr = ExtractAsTensorMaybe(nz.value, nz.cond, + IterVarsToVars(axis), + combined_vranges); + result = SelectElseZero(nz.cond, new_expr); + } + + // Note that RemoveRedundantInequalities can sometimes propagate equalities which + // other simplifiers cannot, like (i % 3) == 0. + Array axis_conds = IterVarsToInequalities(axis); + result = RemoveRedundantInequalities(result, axis_conds); + + // Sometimes ExtractAsTensorMaybe doesn't perform extraction, so there may be some non-top + // reductions left, take care of them + return SuperSimplify(ExtractReductions(result, IterVarsToVars(axis), combined_vranges), + combined_vranges); +} + +Tensor OptimizeAndLiftNonzeronessConditions(const Tensor& tensor, const Map& vranges) { + auto transform_func = [&vranges](const Expr& expr, const Array& axis) { + return OptimizeAndLiftNonzeronessConditionsImpl(expr, axis, vranges); + }; + return op::TransformBody(tensor, transform_func); +} + +TVM_REGISTER_API("ir_pass.IsSumCombiner") +.set_body([](TVMArgs args, TVMRetValue *ret) { + if (args.size() >= 2) { + *ret = IsSumCombiner(args[0], args[1]); + } else { + *ret = IsSumCombiner(args[0]); + } + }); + +TVM_REGISTER_API("ir_pass.CanFactorZeroFromCombiner") +.set_body([](TVMArgs args, TVMRetValue *ret) { + if (args.size() >= 3) { + *ret = CanFactorZeroFromCombiner(args[0], args[1], args[2]); + } else { + *ret = CanFactorZeroFromCombiner(args[0], args[1]); + } + }); + +TVM_REGISTER_API("ir_pass.LiftNonzeronessCondition") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = LiftNonzeronessCondition(args[0]); + }); + +TVM_REGISTER_API("ir_pass.InlineTailCall") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = InlineTailCall(args[0]); + }); + +TVM_REGISTER_API("ir_pass.InlineTensors") +.set_body([](TVMArgs args, TVMRetValue *ret) { + if (args[0].IsNodeType()) { + Expr e = args[0]; + if (args.size() == 1) { + *ret = InlineTensors(e); + } else if (args.size() == 2) { + *ret = InlineTensors(e, args[1]); + } else if (args.size() >= 3) { + *ret = InlineTensors(e, args[1], args[2]); + } + } else if (args[0].IsNodeType()) { + Tensor t = args[0]; + if (args.size() == 1) { + *ret = InlineTensors(t); + } else if (args.size() == 2) { + *ret = InlineTensors(t, args[1]); + } else if (args.size() >= 3) { + *ret = InlineTensors(t, args[1], args[2]); + } + } + }); + +TVM_REGISTER_API("ir_pass.SolveSystemOfInequalities") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = SolveSystemOfInequalities(args[0], args[1], args[2]).as_conditions(); + }); + +TVM_REGISTER_API("ir_pass.SimplifyDomain") +.set_body([](TVMArgs args, TVMRetValue *ret) { + auto res = SimplifyDomain(args[0], args[1], args[2]); + Array axis = IterVarsFromMap(res.axis, res.ranges); + *ret = Array({All(res.conditions), axis, res.old_to_new, res.new_to_old}); + }); + +TVM_REGISTER_API("ir_pass.SimplifyReductionDomain") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = SimplifyReductionDomain(args[0], args[1]); + }); + +TVM_REGISTER_API("ir_pass.ExtractAsTensorMaybe") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = ExtractAsTensorMaybe(args[0], args[1], args[2], args[3]); + }); + +TVM_REGISTER_API("ir_pass.ExtractReductions") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = ExtractReductions(args[0], args[1], args[2]); + }); + +TVM_REGISTER_API("ir_pass.ExtractNonTopReductions") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = ExtractNonTopReductions(args[0], args[1], args[2]); + }); + +TVM_REGISTER_API("ir_pass.OptimizeAndLiftNonzeronessConditions") +.set_body([](TVMArgs args, TVMRetValue *ret) { + if (args.size() >= 2) { + *ret = OptimizeAndLiftNonzeronessConditions(args[0], args[1]); + } else { + *ret = OptimizeAndLiftNonzeronessConditions(args[0]); + } + }); + +} // namespace ir +} // namespace tvm diff --git a/src/pass/zero_elimination.h b/src/pass/zero_elimination.h new file mode 100644 index 000000000000..600b3cb4162f --- /dev/null +++ b/src/pass/zero_elimination.h @@ -0,0 +1,248 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file zero_elimination.h + * \brief Transform tensors in such a way as to eliminate summation over zeros. + */ +#ifndef TVM_PASS_ZERO_ELIMINATION_H_ +#define TVM_PASS_ZERO_ELIMINATION_H_ + +#include +#include + +#include +#include + +namespace tvm { +namespace ir { + +/*! + * \brief Clone the reduction by cloning its iteration variables. + */ +Expr CloneReduction(const Expr& expr); + +/*! + * \brief Check if the given combiner represents summation. + */ +EXPORT bool IsSumCombiner(const CommReducer& combiner, + const Map& vranges = Map()); + +/*! + * \brief Check if zero may be factored out of a reduction with this combiner when it is in + * the \p value_index position. + * + * For example, if the combiner works on tuples of two elements and `value_index = 1`, + * check that `(a, 0) combine (b, 0) = (c, 0)` for any a, b and some c. + * Note that all combiners generated by autodiff have this property. + */ +EXPORT bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index, + const Map& vranges = Map()); + +/*! + * \brief Transform the expression into `c ? e : 0`, that is lift the condition of being + * possible to be non-zero to the top level. + */ +EXPORT Expr LiftNonzeronessCondition(const Expr& expr); + +/*! + * \brief If the body of the tensor consists of a single tensor call (indexing) expression, + * inline it. + */ +EXPORT Tensor InlineTailCall(const Tensor& tensor); + +/*! + * \brief Inline tensors recursively. + * + * This function will inline tensors recursively until it reaches a tensor which is impossible to + * inline (a reduction if \p inline_reductions is false, a non-compute tensor, a tensor which is + * not from \p inlineable). It won't descend into non-inlinable tensors' bodies. + * + * \param expr The expression to transform. + * \param inlineable A list of tensors which are allowed to be inlined. If empty, try + * to inline all tensors. + * \param inline_reductions Whether to inline reductions (this may result in top-level reduction + * nodes). + */ +EXPORT Expr InlineTensors(const Expr& expr, + const Array& inlineable = Array(), + bool inline_reductions = false); + +/*! + * \brief Inline tensors recursively. + * + * This function will inline tensors recursively until it reaches a tensor which is impossible to + * inline (a reduction if \p inline_reductions is false, a non-compute tensor, a tensor which is + * not from \p inlineable). It won't descend into non-inlinable tensors' bodies. + * + * \param tensor The tensor whose body to transform. + * \param inlineable A list of tensors which are allowed to be inlined. If empty, try + * to inline all tensors. + * \param inline_reductions Whether to inline reductions (this may result in top-level reduction + * nodes). + */ +EXPORT Tensor InlineTensors(const Tensor& tensor, + const Array& inlineable = Array(), + bool inline_reductions = false); + + +/*! + * \brief A struct representing a set of inequalities describing bounds of a variable. + * + * Given a variable x, this struct represents the following (in)equalities: + * - `coef*x >= low` for each `low` in `lower` + * - `coef*x == eq` for each `eq` in `equal` + * - `coef*x <= upp` for each `upp` in `upper` + * + * Note that every array is supposed to be sorted in the order of increasing expression + * complexity. + */ +struct VarBounds { + Expr coef; + Array lower; + Array equal; + Array upper; + + /*! + * \brief Perform substitution on all components of the struct. + */ + VarBounds substitute(const Map& subst) const; +}; + +/*! + * \brief A struct representing a system of inequalities resulted from Fourier-Motzkin elimination. + */ +struct SolveSystemOfInequalitiesResult { + Array variables; + std::unordered_map bounds; + Array other_conditions; + + /*! + * \brief Combine the information into an array of (in)equalities. + */ + Array as_conditions() const; +}; + +/*! + * \brief Rewrite the system of inequalities using Fourier-Motzkin elimination. + * + * This function takes an array of (in)equalities and an array of variables, and essentially + * rewrites the (in)equalities into an array of (in)equalities of the following form: + * + * x0 >= f0(x1, x2, ..., xn) + * x0 <= g0(x1, x2, ..., xn) + * x1 >= f1(x2, ..., xn) + * x1 <= g1(x2, ..., xn) + * ... + * xn >= fn() // just a constant + * xn <= gn() // just a constant + * + * This array is represented in a more structural way using SolveSystemOfInequalitiesResult. + * + * Note that the algorithm is extremely slow, it is super-exponential, so please provide variable + * ranges to aid the removal of redundant inequalities. + * + * \param inequalities The original (in)equalities. + * \param variables The variables x0, ..., xn + * \param vranges A map from variables to the corresponding value ranges. Extremely important for + * efficiency. + */ +EXPORT SolveSystemOfInequalitiesResult SolveSystemOfInequalities( + const Array& inequalities, const Array& variables, const Map& vranges); + +/*! + * \brief A struct representing a result of domain simplification. It is basically + * a new array of variables, the information about their ranges, and a new condition together with + * substitutions from the old variables to the new ones and from the new ones to the old ones. + */ +struct DomainSimplificationResult { + Array conditions; + Array axis; + Map ranges; + Map old_to_new; + Map new_to_old; +}; + +/*! + * \brief Simplify an iteration domain. + * + * An iteration domain is basically an array of variables and a condition. The function will do the + * following: + * - Replace div and mod operations with new variables (optional). + * - Extract (in)equalities from the condition. + * - Perform Fourier-Motzkin elimination. + * - Shear the domain of iteration (e.g. if `y <= x <= y + 2` then x will be replaced with `y + d` + * where `d` is a new variable such that `0 <= d <= 2`). + * - Remove redundant variables. + * - Infer new variable ranges (hopefully more precise). + * + * \param cond The condition of the original domain. + * \param axis The variables of the original domain. + * \param vranges A map from variables (both domain and outer) to their value ranges. + * \param eliminate_div_mod Whether to eliminate div and mod by introducing new variables. + */ +EXPORT DomainSimplificationResult SimplifyDomain(const Expr& cond, + const Array& axis, + Map vranges, + bool eliminate_div_mod = true); + + +/*! + * \brief Simplify the iteration domain of a reduction expression using SimplifyDomain. + */ +EXPORT Expr SimplifyReductionDomain(const Expr& expr, const Map& outer_vranges); + +/*! + * \brief Extract the given expression under the given condition as a separate tensor if the volume + * of the extracted tensor will be less than the volume of the \p outer_axis. + * + * \param expr The expression to extract. + * \param cond A condition which is assumed to be true. + * \param outer_axis Some variables, usually input variables of the enclosing tensor. + * \param vranges Information about ranges of variables. + * \return Either a call to an extracted tensor or the original expression. + */ +EXPORT Expr ExtractAsTensorMaybe(const Expr& expr, const Expr& cond, + const Array& outer_axis, + const Map& vranges); + +/*! + * \brief Extract reductions as separate tensors. This may be needed when non-top-level reductions + * are created. + * + * \param expr The expression from which to extract reductions. + * \param outer_axis Input variables of the enclosing tensor. + * \param vranges Information about ranges of variables. + * \return An expression without non-top-level reductions. + */ +EXPORT Expr ExtractReductions(const Expr& expr, + const Array& outer_axis, + const Map& vranges); + +/*! + * \brief Extract reductions as separate tensors, but if the expr itself is a reduction, leave it + * intact. + * + * \param expr The expression from which to extract reductions. + * \param outer_axis Input variables of the enclosing tensor. + * \param vranges Information about ranges of variables. + * \return An expression without non-top-level reductions. + */ +EXPORT Expr ExtractNonTopReductions(const Expr& expr, + const Array& outer_axis, + const Map& vranges); + +/*! + * \brief Perform lifting of conditions of being possible to be non-zero together with + * applying some transformations like simplifying the reduction domain. Works only with + * this particular tensor's body, i.e. doesn't perform inlining. + * + * \param tensor The original tensor; + * \param vranges Optional map from free variables to their value ranges. + * \return An optimized tensor. + */ +EXPORT Tensor OptimizeAndLiftNonzeronessConditions( + const Tensor& tensor, + const Map& vranges = Map()); + +} // namespace ir +} // namespace tvm +#endif // TVM_PASS_ZERO_ELIMINATION_H_ diff --git a/tests/python/unittest/test_pass_zero_elimination.py b/tests/python/unittest/test_pass_zero_elimination.py new file mode 100644 index 000000000000..cba13746315e --- /dev/null +++ b/tests/python/unittest/test_pass_zero_elimination.py @@ -0,0 +1,528 @@ +import random +import sys +import numpy as np +import tvm +from tvm import comm_reducer +from tvm.testing import estimate_performance +from tvm.ir_pass import Simplify, Equal, LiftNonzeronessCondition, IsSumCombiner, \ + CanFactorZeroFromCombiner, InlineTailCall, InlineTensors, SolveSystemOfInequalities, \ + SimplifyDomain, SimplifyReductionDomain, ExtractAsTensorMaybe, ExtractReductions, \ + ExtractNonTopReductions, OptimizeAndLiftNonzeronessConditions + +def get_shape(tensor): + return [s.value for s in tensor.shape] + +def check_eq(t1, t2, args): + s1 = tvm.create_schedule(t1.op) + m1 = tvm.build(s1, [t1] + args) + + s2 = tvm.create_schedule(t2.op) + m2 = tvm.build(s2, [t2] + args) + + for _ in range(5): + arg_vals = [tvm.ndarray.array(np.random.uniform(-10, 10, size=get_shape(a)) + .astype(a.dtype)) + for a in [t1] + args] + m1(*arg_vals) + res1 = arg_vals[0].asnumpy() + m2(*arg_vals) + res2 = arg_vals[0].asnumpy() + + np.testing.assert_allclose(res1, res2, atol=1e-3, rtol=1e-2) + +def check_symeq(expr1, expr2): + expr1 = tvm.ir_pass.Simplify(tvm.ir_pass.CanonicalSimplify(expr1)) + expr2 = tvm.ir_pass.Simplify(tvm.ir_pass.CanonicalSimplify(expr2)) + + if tvm.ir_pass.Equal(expr1, expr2): + return + + diff = tvm.ir_pass.Simplify(tvm.ir_pass.CanonicalSimplify(expr1 - expr2)) + if not Equal(diff, tvm.const(0, expr1.dtype)): + raise AssertionError("Expressions {} and {} are not equal, their diff is {}" + .format(expr1, expr2, diff)) + +def compute(shape, fcompute): + """Like tvm.compute but automatically extracts reductions.""" + return tvm.compute(shape, + lambda *vs: ExtractNonTopReductions( + fcompute(*vs), vs, {v: tvm.Range(0, s) for v, s in zip(vs, shape)})) + +def check_tensor_symeq(A, B): + if not isinstance(B, tvm.tensor.Tensor): + B = compute(A.shape, B) + vmap = {a.var: b.var for a, b in zip(A.op.axis, B.op.axis)} + expr_a = tvm.ir_pass.Substitute(A.op.body[A.value_index], vmap) + expr_b = B.op.body[B.value_index] + expr_a = tvm.ir_pass.CanonicalSimplify(InlineTensors(expr_a, [], True)) + expr_b = tvm.ir_pass.CanonicalSimplify(InlineTensors(expr_b, [], True)) + if not Equal(expr_a, expr_b): + print(expr_a) + print(expr_b) + raise AssertionError("The expressions are not equal") + +def check_eq_bruteforce(expr1, expr2, vranges): + def _compute_body(*us): + vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)} + return tvm.ir_pass.Substitute(expr1 == expr2, vmap) + + A = compute([r.extent.value for v, r in vranges.items()], _compute_body) + args = [tvm.ndarray.empty(A.shape, A.dtype)] + sch = tvm.create_schedule(A.op) + mod = tvm.build(sch, [A]) + mod(*args) + res = args[0].asnumpy() + if not np.all(res): + indices = list(np.argwhere(res == 0)[0]) + counterex = [(str(v), i + r.min) for (v, r), i in zip(vranges.items(), indices)] + counterex = ", ".join([v + " = " + str(i) for v, i in sorted(counterex)]) + raise AssertionError("Expressions {}\nand {}\nare not equal on {}\n" + "Counterexample: {}" + .format(expr1, expr2, vranges, counterex)) + +prod_combiner = comm_reducer(lambda x, y: x*y, lambda t0: tvm.const(1, t0)) +sum_combiner = comm_reducer(lambda x, y: x + y, lambda t0: tvm.const(0, t0)) +sum2_combiner = comm_reducer(lambda x, y: y + x, lambda t0: tvm.const(0, t0)) +sum_derivative_combiner = comm_reducer(lambda x, y: (x[0] + y[0], y[1] + x[1]), + lambda t0, t1: (tvm.const(0, t0), tvm.const(0, t1))) +prod_derivative_combiner = comm_reducer(lambda x, y: (x[0]*y[0], x[0]*y[1] + x[1]*y[0]), + lambda t0, t1: (tvm.const(1, t0), tvm.const(0, t1))) +sum_both_combiner = comm_reducer(lambda x, y: (x[0] + y[0], x[0] + y[0] + x[1] + y[1]), + lambda t0, t1: (tvm.const(0, t0), tvm.const(0, t1))) +xor_combiner = comm_reducer(lambda x, y: x ^ y, lambda t0: tvm.const(0, t0)) + +m_param = tvm.var('m_param') +sum_or_prod_combiner = comm_reducer(lambda x, y: tvm.expr.Select(m_param < 0, x + y, x*y), + lambda t0: tvm.expr.Select(m_param < 0, + tvm.const(0, t0), tvm.const(1, t0))) +shifted_sum_combiner = comm_reducer(lambda x, y: x + y - m_param, + lambda t0: m_param) + +def test_is_sum_combiner(): + k = tvm.reduce_axis((0, 10), name="k") + i = tvm.const(0, "int32") + f = tvm.const(0.0, "float32") + assert IsSumCombiner(sum_combiner(i, k).combiner) + assert IsSumCombiner(sum_combiner(f, k).combiner) + assert IsSumCombiner(sum2_combiner(i, k).combiner) + assert IsSumCombiner(sum2_combiner(f, k).combiner) + assert not IsSumCombiner(sum_derivative_combiner((f, f), k)[0].combiner) + assert not IsSumCombiner(prod_combiner(f, k).combiner) + assert not IsSumCombiner(prod_derivative_combiner((f, f), k)[1].combiner) + assert not IsSumCombiner(sum_or_prod_combiner(f, k).combiner) + assert not IsSumCombiner(sum_or_prod_combiner(f, k).combiner, {m_param: tvm.Range(-5, 1)}) + assert IsSumCombiner(sum_or_prod_combiner(f, k).combiner, {m_param: tvm.Range(-5, -1)}) + assert not IsSumCombiner(shifted_sum_combiner(i, k).combiner) + assert IsSumCombiner(shifted_sum_combiner(i, k).combiner, {m_param: tvm.Range(0, 1)}) + +def test_can_factor_zero_from_combiner(): + k = tvm.reduce_axis((0, 10), name="k") + i = tvm.const(0, "int32") + f = tvm.const(0.0, "float32") + assert CanFactorZeroFromCombiner(sum_combiner(i, k).combiner, 0) + assert CanFactorZeroFromCombiner(sum2_combiner(f, k).combiner, 0) + assert CanFactorZeroFromCombiner(sum_derivative_combiner((f, f), k)[0].combiner, 0) + assert CanFactorZeroFromCombiner(sum_derivative_combiner((f, f), k)[0].combiner, 1) + assert not CanFactorZeroFromCombiner(prod_derivative_combiner((f, f), k)[0].combiner, 0) + assert CanFactorZeroFromCombiner(prod_derivative_combiner((f, f), k)[0].combiner, 1) + assert CanFactorZeroFromCombiner(sum_both_combiner((f, f), k)[0].combiner, 0) + assert not CanFactorZeroFromCombiner(sum_both_combiner((f, f), k)[0].combiner, 1) + assert not CanFactorZeroFromCombiner(sum_or_prod_combiner(f, k).combiner, 0, + {m_param: tvm.Range(-5, 1)}) + assert CanFactorZeroFromCombiner(sum_or_prod_combiner(f, k).combiner, 0, + {m_param: tvm.Range(-5, -1)}) + assert not CanFactorZeroFromCombiner(shifted_sum_combiner(i, k).combiner, 0) + assert CanFactorZeroFromCombiner(shifted_sum_combiner(i, k).combiner, 0, + {m_param: tvm.Range(0, 1)}) + +def test_lift_nonzeroness_condition(): + k = tvm.reduce_axis((0, 5), name="k") + l = tvm.reduce_axis((0, 5), name="l") + n = tvm.reduce_axis((0, 5), name="n") + A = tvm.placeholder((10,), name='A') + + def _check(shape, fun, A=A): + T1 = tvm.compute(shape, fun) + T2 = tvm.compute(shape, lambda *args: LiftNonzeronessCondition(fun(*args))) + check_eq(T1, T2, [A]) + assert isinstance(T2.op.body[0], tvm.expr.Select) + + _check((10,), lambda i: A[i]) + _check((10,), lambda i: A[i] + (i % 2 == 0)) + _check((10,), lambda i: A[i]*(i % 2 == 0) + (i % 2 == 0)) + _check((10,), lambda i: tvm.expr.Select((i % 2 == 0), A[i], 0.0)) + _check((10,), lambda i: tvm.expr.Select((i % 2 == 0), A[i], 0.0) + (i % 2 == 0)) + _check((10,), lambda i: tvm.expr.Select((i % 2 == 0), 0.0, A[i]) + (i % 2 == 0)) + def e1(i): return tvm.expr.Select((i % 2 == 1), 0.0, A[i]) + def e2(i): return tvm.expr.Select((i % 2 == 0), A[(i + 1) % 10], 0.0) + def e3(i): return tvm.expr.Select((i % 2 == 1), A[i], 0.0) + _check((10,), lambda i: e1(i) + e2(i) + e3(i) + e1(i)*e2(i)) + _check((10,), lambda i: e1(i)*e3(i)) + _check((10,), lambda i: e1(i)*e2(i)) + _check((10,10), lambda i, j: A[i]*(i == j) + A[j]*(i == 2*j) + A[j]*(j == i)) + _check((10,10), lambda i, j: tvm.min(A[i]*(i == j), A[j]*(i == 2*j))) + _check((10,10), lambda i, j: tvm.max(A[i]*(i == j), A[j]*(i == 2*j))) + _check((10,10), lambda i, j: A[i]*(i == j) - A[j]*(i == 2*j)) + _check((10,10), lambda i, j: A[i]*(i == j) / (1 + tvm.abs(A[j]*(i == 2*j)))) + _check((10,10), lambda i, j: i*(i < j) + j*(i > j)) + _check((10,10), lambda i, j: i*(i < j) % (1 + j*(i > j))) + + def _check_symeq(expr1, expr2): + expr1 = LiftNonzeronessCondition(expr1) + expr2 = LiftNonzeronessCondition(expr2) + print(expr1) + print(expr2) + print() + check_symeq(expr1, expr2) + + _check_symeq(tvm.expr.Select(tvm.expr.EQ(k, l), 0.0, tvm.expr.Cast('float32', (k < n))), + tvm.expr.Select(tvm.expr.And((k < n), tvm.expr.NE(k, l)), 1.0, 0.0)) + _check_symeq(tvm.min(tvm.expr.Cast('int32', k < n)*l, tvm.expr.Select(k >= n, 0, 1)), + tvm.expr.Select(k < n, tvm.min(l, 1), 0)) + + expr1 = tvm.if_then_else(k < n, + tvm.expr.Select(tvm.expr.EQ(k, l), A[k], 0.0), + tvm.expr.Select(l < n, A[l], 0.0)) + expr2 = tvm.expr.Select(tvm.any(tvm.all(k < n, tvm.expr.EQ(k, l)), + tvm.all(k >= n, l < n)), + tvm.if_then_else(k < n, A[k], A[l]), + 0.0) + check_symeq(LiftNonzeronessCondition(expr1), expr2) + +def test_inline_tail_call(): + A = tvm.compute((10, 10), lambda i, j: i + j*j) + B = tvm.compute((5, 6), lambda k, l: A[k + l, k + 1]) + C = InlineTailCall(B) + resbody = lambda k, l: k + l + (k + 1)*(k + 1) + check_symeq(C.op.body[0], resbody(*[iv.var for iv in C.op.axis])) + +def test_inline_tensors(): + A = tvm.compute((10, 10), lambda i, j: i + j) + B = tvm.compute((10, 10), lambda i, j: i * j) + C = tvm.compute((10, 10), lambda i, j: A[i, j] + B[i, j]) + k = tvm.reduce_axis((0, 5), name="k") + D = tvm.compute((10, 10), lambda i, j: tvm.sum(A[i, k], k)) + E = tvm.compute((10, 10), lambda i, j: A[2, j] + C[i, 2] + D[i, j]) + F = tvm.compute((10, 10), lambda i, j: tvm.exp(A[i, j]) + B[i, A[i, j]]) + + R = InlineTensors(E) + resbody = lambda i, j: 2 + j + i + 2 + i*2 + D[i, j] + check_symeq(R.op.body[0], resbody(*[iv.var for iv in R.op.axis])) + + R = InlineTensors(E, [A]) + resbody = lambda i, j: 2 + j + C[i, 2] + D[i, j] + check_symeq(R.op.body[0], resbody(*[iv.var for iv in R.op.axis])) + + R = InlineTensors(E, [A, C]) + resbody = lambda i, j: 2 + j + ((i + 2) + B[i, 2]) + D[i, j] + check_symeq(R.op.body[0], resbody(*[iv.var for iv in R.op.axis])) + + R = InlineTensors(E, [B, C]) + resbody = lambda i, j: A[2, j] + (A[i, 2] + i*2) + D[i, j] + check_symeq(R.op.body[0], resbody(*[iv.var for iv in R.op.axis])) + + R = InlineTensors(F) + resbody = lambda i, j: tvm.exp(i + j) + i * (i + j) + check_symeq(R.op.body[0], resbody(*[iv.var for iv in R.op.axis])) + + R = InlineTensors(F, [A]) + resbody = lambda i, j: tvm.exp(i + j) + B[i, (i + j)] + check_symeq(R.op.body[0], resbody(*[iv.var for iv in R.op.axis])) + + R = InlineTensors(F, [B]) + resbody = lambda i, j: tvm.exp(A[i, j]) + i * A[i, j] + check_symeq(R.op.body[0], resbody(*[iv.var for iv in R.op.axis])) + +def test_solve_system_of_inequalities(): + seed = random.randrange(sys.maxsize) + print("\nseed: {}\n".format(seed)) + random.seed(seed) + + def _check(variables, formulas, coef=(-5, 5), bounds=(-20, 20)): + vs = [tvm.var("x" + str(i)) for i in range(variables)] + + fs = [] + for i in range(formulas): + s1 = sum([v*random.randint(coef[0], coef[1]) for v in vs]) + s1 += random.randint(coef[0], coef[1]) + s2 = sum([v*random.randint(coef[0], coef[1]) for v in vs]) + s2 += random.randint(coef[0], coef[1]) + op = random.choice([tvm.expr.EQ, tvm.expr.LE, tvm.expr.LT, tvm.expr.GE, tvm.expr.GT]) + fs.append(op(s1, s2)) + + vranges = {v: tvm.Range(bounds[0], bounds[1] + 1) for v in vs} + + before = tvm.all(*fs) + print(before) + after = tvm.all(*SolveSystemOfInequalities(fs, vs, vranges)) + print(after) + print() + + check_eq_bruteforce(before, after, vranges) + + for i in range(3): + _check(1, 1) + for i in range(3): + _check(1, 2) + + for i in range(3): + _check(2, 1) + for i in range(3): + _check(2, 2) + for i in range(3): + _check(2, 3) + + # Somewhere here coefficients in the results become too large, leading to overflow, + # so we use smaller initial coefficients + + for i in range(5): + _check(3, 3, coef=(-2,2)) + for i in range(5): + _check(3, 4, coef=(-2,2)) + + for i in range(5): + _check(4, 3, coef=(-1,1)) + + for i in range(5): + _check(10, 2, coef=(-1,1), bounds=(0, 4)) + for i in range(5): + _check(10, 3, coef=(0,1), bounds=(0, 4)) + +def test_simplify_domain(): + # Note that here we test both SimplifyDomain and SimplifyReductionDomain. + def _check(cond, axis, volume, vranges={}): + vranges_with_axis = dict(vranges) + vranges_with_axis.update({iv.var: iv.dom for iv in axis}) + variables = [iv.var for iv in axis] + new_cond, new_axis, old_to_new, new_to_old = SimplifyDomain(cond, variables, + vranges_with_axis) + + print("old", axis, cond) + print("new", new_axis, new_cond) + print("old_to_new", old_to_new) + print("new_to_old", new_to_old) + print() + + cond_subst = tvm.ir_pass.Substitute(cond, old_to_new) + new_vranges = vranges.copy() + new_vranges.update({v.var: v.dom for v in new_axis}) + # If new_cond is true in the new domain, then cond_subst must also be true in the new + # domain, but the reverse is not necessarily true + check_eq_bruteforce(tvm.all(new_cond, cond_subst), new_cond, new_vranges) + + new_cond_subst = tvm.ir_pass.Substitute(new_cond, new_to_old) + old_vranges = vranges.copy() + old_vranges.update({v.var: v.dom for v in axis}) + check_eq_bruteforce(cond, tvm.all(cond, new_cond_subst), old_vranges) + + # Also check SimplifyReductionDomain + reduction = xor_combiner(sum([v*(i + 1) for i, v in enumerate(axis)]), axis) + new_reduction = SimplifyReductionDomain(reduction, vranges) + check_eq_bruteforce(reduction, new_reduction, vranges) + + vol = np.prod([iv.dom.extent.value for iv in new_axis]) + if vol != volume: + raise AssertionError("New volume is {} != {}\n" + "Old domain {} where {}\nNew domain {} where {}" + .format(vol, volume, axis, cond, new_axis, new_cond)) + + k = tvm.reduce_axis((0, 5), name="k") + l = tvm.reduce_axis((0, 5), name="l") + n = tvm.reduce_axis((0, 5), name="n") + + _check((k <= l), [k, l, n], 125) + _check((k < l), [k, l, n], 80) + _check(tvm.expr.EQ(k, l), [k, l, n], 25) + _check(tvm.all(tvm.expr.EQ(k, l), (l < n)), [k, l, n], 16) + _check(tvm.expr.EQ(2*l, k), [k, l, n], 15) + # TODO: the result depends on the order of variables because we don't have a proper solver for + # systems of linear equations yet + _check(tvm.expr.EQ(2*l, k), [n, l, k], 25) + _check(tvm.all(l - k < 2, 2*n == k), [k, l, n], 15) + _check(tvm.all(l - k < 2, l >= k), [k, l, n], 50) + + some_var = tvm.var('some_var') + _check(tvm.all(l - k < some_var, l >= k), [k, l, n], 50, {some_var: tvm.Range(0, 3)}) + _check(tvm.all(l - k < some_var, l >= k), [k, l, n], 25, {some_var: tvm.Range(0, 2)}) + + + k = tvm.reduce_axis((-3, 2), name="k") + l = tvm.reduce_axis((-3, 2), name="l") + n = tvm.reduce_axis((-3, 2), name="n") + + _check((k < l), [k, l, n], 80) + _check(tvm.expr.EQ(k, l), [k, l, n], 25) + _check(tvm.all(tvm.expr.EQ(k, l), (l < n)), [k, l, n], 16) + # Now there are only two possible values for l: {l = -1, k = -2} and {l = 0, k = 0} + _check(tvm.expr.EQ(2*l, k), [k, l, n], 10) + # TODO: the result depends on the order of variables because we don't have a proper solver for + # systems of linear equations + _check(tvm.expr.EQ(2*l, k), [n, l, k], 25) + _check(tvm.all(l - k < 2, 2*n == k), [k, l, n], 10) + _check(tvm.all(l - k < 2, l >= k), [k, l, n], 50) + + some_var = tvm.var('some_var') + _check(tvm.all(l - k < some_var, l >= k), [k, l, n], 50, {some_var: tvm.Range(0, 3)}) + _check(tvm.all(l - k < some_var, l >= k), [k, l, n], 25, {some_var: tvm.Range(0, 2)}) + + + k = tvm.reduce_axis((0, 6), name="k") + l = tvm.reduce_axis((0, 5), name="l") + n = tvm.reduce_axis((0, 30), name="n") + + _check(tvm.all(k + l*6 == n), [k, l, n], 30) + _check(tvm.all(k + l*6 == n), [n, k, l], 30) + _check(tvm.all(k + l*6 == n), [n, l, k], 30) + + _check(tvm.all(n / 5 == k, n % 5 == l), [l, k, n], 30) + # TODO: Same thing with the order + _check(tvm.all(n / 5 == k, n % 5 == l), [n, l, k], 30) + + k = tvm.reduce_axis((0, 10), name="k") + l = tvm.reduce_axis((0, 10), name="l") + _check(tvm.all((l + k)%3 <= 1, (l + k)/3 <= 2), [l, k], 48) + +def test_extract_as_tensor_maybe(): + def _check(shape, fcompute, volume=None, vranges={}): + def fcompute_extracted(*variables): + vranges_updated = dict(vranges) + vranges_updated.update({v: tvm.Range(0, s) for v, s in zip(variables, shape)}) + expr = fcompute(*variables) + if isinstance(expr, tvm.expr.Select): + new_true_value = ExtractAsTensorMaybe(expr.true_value, + expr.condition, + variables, + vranges_updated) + expr = tvm.expr.Select(expr.condition, + new_true_value, + expr.false_value) + if volume is not None: + assert isinstance(new_true_value, tvm.expr.Call) + vol = np.prod([iv.dom.extent.value for iv in new_true_value.func.axis]) + if vol != volume: + raise AssertionError("New volume is {} != {}" + .format(vol, volume)) + return expr + + A = tvm.compute(shape, fcompute) + B = tvm.compute(shape, fcompute_extracted) + check_eq(A, B, []) + + _check((10, 10), lambda i, j: tvm.expr.Select(i < 3, i + j, 0), volume=30) + _check((10, 10), lambda i, j: tvm.expr.Select(i < 3, j, 0), volume=10) + _check((10, 10), lambda i, j: tvm.expr.Select(i < 3, i, 0), volume=3) + _check((10, 10), lambda i, j: tvm.expr.Select(tvm.all(i < j, j < 5), i + j, 0), volume=16) + # This one doesn't get extracted + _check((10, 10), lambda i, j: tvm.expr.Select(i <= j, i + j, 0)) + +def test_extract_reductions(): + k = tvm.reduce_axis((0, 10), name="k") + l = tvm.reduce_axis((0, 10), name="l") + n = tvm.reduce_axis((0, 10), name="n") + + A = tvm.compute((10, 10), + lambda i, j: + ExtractReductions(sum_combiner(i + k + xor_combiner(j*k + l, l), k), + [i, j], + {i: tvm.Range(0, 10), j: tvm.Range(0, 10)})) + B = tvm.compute((10, 10), lambda j, k: xor_combiner(j*k + l, l)) + C = tvm.compute((10, 10), lambda i, j: sum_combiner(i + k + B[j, k], k)) + check_eq(C, A, []) + + fcompute = lambda i, j: \ + ExtractReductions(sum_both_combiner((prod_derivative_combiner((i*n + 2*k, j + k), k)[1], + xor_combiner(j*n + l, l)), n)[1], + [i, j], + {i: tvm.Range(0, 10), j: tvm.Range(0, 10)}) + A = tvm.compute((10, 10), fcompute) + _, B = tvm.compute((10, 10, 10), + lambda i, j, n: prod_derivative_combiner((i*n + 2*k, j + k), k)) + C = tvm.compute((10, 10), lambda j, n: xor_combiner(j*n + l, l)) + _, D = tvm.compute((10, 10), lambda i, j: sum_both_combiner((B[i, j, n], C[j, n]), n)) + check_eq(A, D, []) + +def test_optimize_and_lift_nonzeroness(): + k = tvm.reduce_axis((0, 10), name="k") + l = tvm.reduce_axis((0, 10), name="l") + n = tvm.reduce_axis((0, 10), name="n") + A = tvm.placeholder((10, 10), name="A") + + zero = tvm.const(0, 'float32') + + B = compute((10, 10), lambda i, j: tvm.sum((i == j)*A[i, k] + A[k, j]*(i == j), k)) + B = OptimizeAndLiftNonzeronessConditions(B) + R = lambda i, j: tvm.expr.Select(i == j, + tvm.sum(A[j, k] + A[k, j], k), + zero) + check_tensor_symeq(B, R) + + # TODO: This test is unstable: sometimes the resulting condition looks like + # (i == j)*(j == i) instead of (i == j) + # B = compute((10, 10), lambda i, j: tvm.sum((i == j)*(i == k)*A[i, k] + + # (i == j)*A[k, j]*(i == k), k)) + # B = OptimizeAndLiftNonzeronessConditions(B) + # R = lambda i, j: tvm.expr.Select(i == j, A[j, j]*2.0, zero) + # check_tensor_symeq(B, R) + + B = compute((10, 10), lambda i, j: tvm.sum((i < j)*(j < k)*A[j, k], k)) + B = OptimizeAndLiftNonzeronessConditions(B) + k1 = tvm.reduce_axis((2, 10), name="k1") + R = compute((10, 10), lambda i, j: + tvm.expr.Select(tvm.all(i < j, j < 10), + tvm.sum(tvm.expr.Select(j < k1, A[j, k1], zero), k1), + zero)) + check_eq(B, R, [A]) + assert estimate_performance(B) <= estimate_performance(R) + + # TODO: This one needs the equation solver + # B = compute((10, 10), lambda i, j: tvm.sum((i <= j)*(j <= k)*A[j, k], k, where=(i >= k))) + # B = OptimizeAndLiftNonzeronessConditions(B) + # R = compute((10, 10), lambda i, j: tvm.expr.Select((i == j), A[i, i], zero)) + # check_eq(B, R, [A]) + # assert estimate_performance(B) <= estimate_performance(R) + + B = compute((10, 10), + lambda i, j: prod_derivative_combiner((A[j, k], (i <= j)*(j < k)*A[i, k]), k)[1]) + B = OptimizeAndLiftNonzeronessConditions(B) + R = compute((10, 10), lambda i, j: + tvm.expr.Select(tvm.all(i <= j, j < 10), + prod_derivative_combiner((A[j, k], (j < k)*A[i, k]), k)[1], + zero)) + check_eq(B, R, [A]) + assert estimate_performance(B) <= estimate_performance(R) + + B = compute((10,), lambda i: + tvm.sum(A[i, k]*tvm.any(tvm.all(i < 5, k < 6), tvm.all(i > 5, k > 4)), k)) + B = OptimizeAndLiftNonzeronessConditions(B) + R = compute((10,), lambda i: + tvm.expr.Select(tvm.any(i < 5, i > 5), + tvm.sum(A[i, k], k, where=tvm.all(tvm.any(i < 5, k > 4), + tvm.any(i > 5, k < 6))), + zero)) + check_eq(B, R, [A]) + assert estimate_performance(B) <= estimate_performance(R) + + # Specifying ranges of parameters + B = compute((10, 10), lambda i, j: sum_or_prod_combiner((i == j)*A[i, k] + A[k, j]*(i == j), k)) + B = OptimizeAndLiftNonzeronessConditions(B, {m_param: tvm.Range(-5, -3)}) + R = lambda i, j: tvm.expr.Select(i == j, + tvm.sum(A[j, k] + A[k, j], k), + zero) + check_tensor_symeq(B, R) + + B = compute((10, 10), lambda i, j: tvm.sum(((i - k) <= m_param) * A[i, k], k)) + B = OptimizeAndLiftNonzeronessConditions(B, {m_param: tvm.Range(11, 20)}) + R = lambda i, j: tvm.sum(A[i, k], k) + check_tensor_symeq(B, R) + +if __name__ == "__main__": + test_is_sum_combiner() + test_can_factor_zero_from_combiner() + test_lift_nonzeroness_condition() + test_inline_tail_call() + test_inline_tensors() + test_solve_system_of_inequalities() + test_simplify_domain() + test_extract_as_tensor_maybe() + test_extract_reductions() + test_optimize_and_lift_nonzeroness()