From 4c9acaced9b1a115adc063fe41fd2c68211620ac Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Thu, 13 Jun 2019 16:19:34 +0300 Subject: [PATCH 1/2] [TVM] Fix bound inference to avoid allocating too much --- src/op/compute_op.cc | 25 ++++++++++++++++++- .../unittest/test_schedule_bound_inference.py | 22 ++++++++++++++++ .../unittest/test_schedule_schedule_ops.py | 15 ----------- 3 files changed, 46 insertions(+), 16 deletions(-) diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index bb91ed8d4a9f..533641a8a4ee 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -34,6 +34,7 @@ #include "op_util.h" #include "../schedule/message_passing.h" #include "../arithmetic/compute_expr.h" +#include "../arithmetic/int_set.h" namespace tvm { @@ -220,7 +221,29 @@ void ComputeOpNode::PropBoundToInputs( if (t->op.defined() && out_dom_map->count(t)) { TensorDom& dom = out_dom_map->at(t); for (size_t i = 0; i < t.ndim(); ++i) { - dom.data[i].push_back(EvalSet(call->args[i], dom_map)); + // We assume that the value of the argument cannot be out of bounds (otherwise it is + // undefined behaviour), so we can intersect the estimated set of the argument with the + // range expected by the tensor. However, intersection may result in overly complex + // expressions, so we perform a more relaxed form of intersection. + IntSet arg_intset = EvalSet(call->args[i], dom_map); + const arith::IntervalSetNode* arg_interval = arg_intset.as(); + if (arg_interval) { + Expr shape_i_min_value = make_zero(t->shape[i].type()); + Expr shape_i_max_value = t->shape[i] - 1; + Expr min_value = arg_interval->min_value; + Expr max_value = arg_interval->max_value; + // Prefer the shape bounds only when we can prove they are tighter. + arith::Analyzer an; + if (arith::is_neg_inf(min_value) || an.CanProve(shape_i_min_value >= min_value)) { + min_value = shape_i_min_value; + } + if (arith::is_pos_inf(max_value) || an.CanProve(shape_i_max_value <= max_value)) { + max_value = shape_i_max_value; + } + dom.data[i].push_back(IntSet::interval(min_value, max_value)); + } else { + dom.data[i].push_back(arg_intset); + } } } } diff --git a/tests/python/unittest/test_schedule_bound_inference.py b/tests/python/unittest/test_schedule_bound_inference.py index f16305779a43..2bfd187f9b4c 100644 --- a/tests/python/unittest/test_schedule_bound_inference.py +++ b/tests/python/unittest/test_schedule_bound_inference.py @@ -306,6 +306,27 @@ def _body(): assert isinstance(bounds, tvm.container.Map) assert(bounds[B.op.axis[0]].extent.value == 10) +def test_bound_simplification_failure(): + # Check that the bounds are not expanded + A = tvm.compute((2,), lambda j: j, "A") + + def _check(B, A=A): + s = tvm.create_schedule(B.op) + s = s.normalize() + bounds = tvm.schedule.InferBound(s) + stmt = tvm.lower(s, [B, A], simple_mode=True) + if not bounds[A.op.axis[0]].extent.value <= 2: + print(stmt) + assert bounds[A.op.axis[0]].extent.value <= 2 + + # These are hard to simplify, moreover we don't simplify them + _check(tvm.compute((10,), lambda i: A[tvm.min(3*i, 4*i) + tvm.min(-3*i, -2*i)])) + _check(tvm.compute((10,), lambda i: A[tvm.min(3*i, 4*i) + tvm.max(-3*i, -4*i)])) + _check(tvm.compute((10,), lambda i: A[-2*(i/2) - tvm.min(i, 0-i)])) + _check(tvm.compute((10,), lambda i: A[i + (0 - i)])) + # This would cause out of bounds, but we nevertheless include it + _check(tvm.compute((10,), lambda i: A[i])) + if __name__ == "__main__": test_bound_nest_thread() test_bound1() @@ -320,3 +341,4 @@ def _body(): test_gemm_bound() test_bound_warp() test_bound_tensor_compute_op() + test_bound_simplification_failure() diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index 864d6aea2799..5275aec4db90 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -286,20 +286,6 @@ def _compute(*indice): stmt = tvm.schedule.ScheduleOps(s, bounds) -def test_schedule_bound_condition(): - A = tvm.placeholder((64,), name='A', dtype="float32") - Apad = tvm.compute((66,), lambda i: tvm.if_then_else( - tvm.all(i>0, i < 65), A[i-1], tvm.const(0., "float32")), name='Apad') - Apad2 = tvm.compute((66,), lambda i: Apad[i]*2, name='Apad2') - s = tvm.create_schedule(Apad2.op) - AL1 = s.cache_read(A,"local",[Apad]) - s = s.normalize() - bounds = tvm.schedule.InferBound(s) - stmt = tvm.schedule.ScheduleOps(s, bounds) - stmt = tvm.ir_pass.Simplify(stmt) - assert (isinstance(stmt.body.body.first.body.body.then_case, tvm.stmt.IfThenElse)) - - def intrin_gemv(m, n): w = tvm.placeholder((m, n), name='w') x = tvm.placeholder((n,), name='x') @@ -514,7 +500,6 @@ def _compute(*index) : test_schedule1() test_schedule2() test_schedule_cache() - test_schedule_bound_condition() test_schedule_tensor_compute1() test_schedule_tensor_compute2() test_schedule_tensor_compute3() From 99fae32486c0dc8ccdd5cd9756e56030ea108f4e Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Fri, 12 Jul 2019 12:49:20 +0300 Subject: [PATCH 2/2] [ARITH][BOUND] Pass analyzer to PropBoundToInputs --- include/tvm/operation.h | 8 ++++++++ src/op/compute_op.cc | 10 ++++++---- src/op/extern_op.cc | 1 + src/op/hybrid_op.cc | 1 + src/op/placeholder_op.cc | 1 + src/op/scan_op.cc | 1 + src/op/tensor_compute_op.cc | 1 + src/op/tensorize.cc | 7 +++++-- src/schedule/bound.cc | 4 +++- 9 files changed, 27 insertions(+), 7 deletions(-) diff --git a/include/tvm/operation.h b/include/tvm/operation.h index 38dc39bbe7a7..c119b95d0a2a 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -100,6 +100,7 @@ class OperationNode : public FunctionBaseNode { /*! * \brief Propagate the bounds to inputs * \param self The reference to self. + * \param analyzer The analyzer to be used in the function. * \param dom_map the domain map of Variables(corresponds to root_iter_vars) * \param out_dom_map The output domain. * The function is only asked to fill the bounds for Tensors that @@ -107,6 +108,7 @@ class OperationNode : public FunctionBaseNode { */ virtual void PropBoundToInputs( const Operation& self, + arith::Analyzer* analyzer, const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const = 0; /*! @@ -170,6 +172,7 @@ class PlaceholderOpNode : public OperationNode { const std::unordered_map& rmap) const final; void PropBoundToInputs( const Operation& self, + arith::Analyzer* analyzer, const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const final; void GatherBound( @@ -247,6 +250,7 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { const std::unordered_map& rmap) const final; void PropBoundToInputs( const Operation& self, + arith::Analyzer* analyzer, const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const final; Stmt BuildProvide( @@ -299,6 +303,7 @@ class TensorComputeOpNode : public BaseComputeOpNode { const std::unordered_map& rmap) const final; void PropBoundToInputs( const Operation& self, + arith::Analyzer* analyzer, const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const final; Stmt BuildProvide( @@ -373,6 +378,7 @@ class ScanOpNode : public OperationNode { const std::unordered_map& rmap) const final; void PropBoundToInputs( const Operation& self, + arith::Analyzer* analyzer, const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const final; void GatherBound( @@ -439,6 +445,7 @@ class ExternOpNode : public OperationNode { const std::unordered_map& rmap) const final; void PropBoundToInputs( const Operation& self, + arith::Analyzer* analyzer, const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const final; void GatherBound( @@ -506,6 +513,7 @@ class HybridOpNode : public OperationNode { const std::unordered_map& rmap) const final; void PropBoundToInputs( const Operation& self, + arith::Analyzer* analyzer, const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const final; void GatherBound( diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index 533641a8a4ee..dabf3016e292 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -211,10 +211,11 @@ Operation ComputeOpNode::ReplaceInputs( void ComputeOpNode::PropBoundToInputs( const Operation& self, + arith::Analyzer* analyzer, const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const { CHECK_EQ(self.operator->(), this); - auto fvisit = [&dom_map, out_dom_map](const NodeRef& n) { + auto fvisit = [&dom_map, out_dom_map, analyzer](const NodeRef& n) { auto *call = n.as(); if (call != nullptr && call->func.defined()) { Tensor t = Operation(call->func.node_).output(call->value_index); @@ -233,11 +234,12 @@ void ComputeOpNode::PropBoundToInputs( Expr min_value = arg_interval->min_value; Expr max_value = arg_interval->max_value; // Prefer the shape bounds only when we can prove they are tighter. - arith::Analyzer an; - if (arith::is_neg_inf(min_value) || an.CanProve(shape_i_min_value >= min_value)) { + if (arith::is_neg_inf(min_value) || + analyzer->CanProve(shape_i_min_value >= min_value)) { min_value = shape_i_min_value; } - if (arith::is_pos_inf(max_value) || an.CanProve(shape_i_max_value <= max_value)) { + if (arith::is_pos_inf(max_value) || + analyzer->CanProve(shape_i_max_value <= max_value)) { max_value = shape_i_max_value; } dom.data[i].push_back(IntSet::interval(min_value, max_value)); diff --git a/src/op/extern_op.cc b/src/op/extern_op.cc index 7023aebe17ad..0f66c6c2be1f 100644 --- a/src/op/extern_op.cc +++ b/src/op/extern_op.cc @@ -112,6 +112,7 @@ Operation ExternOpNode::ReplaceInputs( void ExternOpNode::PropBoundToInputs( const Operation& self, + arith::Analyzer* analyzer, const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const { for (Tensor t : this->inputs) { diff --git a/src/op/hybrid_op.cc b/src/op/hybrid_op.cc index 48773c644749..c93257f8c601 100644 --- a/src/op/hybrid_op.cc +++ b/src/op/hybrid_op.cc @@ -110,6 +110,7 @@ Operation HybridOpNode::ReplaceInputs( void HybridOpNode::PropBoundToInputs( const Operation &self, + arith::Analyzer* analyzer, const std::unordered_map &dom_map, std::unordered_map* out_dom_map) const { for (Tensor t : this->inputs) { diff --git a/src/op/placeholder_op.cc b/src/op/placeholder_op.cc index f94b7d072e26..97d01ca063f1 100644 --- a/src/op/placeholder_op.cc +++ b/src/op/placeholder_op.cc @@ -78,6 +78,7 @@ Operation PlaceholderOpNode::ReplaceInputs( void PlaceholderOpNode::PropBoundToInputs( const Operation& self, + arith::Analyzer* analyzer, const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const { } diff --git a/src/op/scan_op.cc b/src/op/scan_op.cc index 78f8c82d97db..4c1c57db9ee4 100644 --- a/src/op/scan_op.cc +++ b/src/op/scan_op.cc @@ -176,6 +176,7 @@ Operation ScanOpNode::ReplaceInputs( void ScanOpNode::PropBoundToInputs( const Operation& self, + arith::Analyzer* analyzer, const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const { CHECK_EQ(self.operator->(), this); diff --git a/src/op/tensor_compute_op.cc b/src/op/tensor_compute_op.cc index 09e8af7d5cba..d333461c14b5 100644 --- a/src/op/tensor_compute_op.cc +++ b/src/op/tensor_compute_op.cc @@ -110,6 +110,7 @@ Operation TensorComputeOpNode::ReplaceInputs( void TensorComputeOpNode::PropBoundToInputs( const Operation& self, + arith::Analyzer* analyzer, const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const { for (size_t i = 0; i < this->inputs.size(); ++i) { diff --git a/src/op/tensorize.cc b/src/op/tensorize.cc index 00181aa37bb8..eb2d05455fca 100644 --- a/src/op/tensorize.cc +++ b/src/op/tensorize.cc @@ -85,17 +85,20 @@ size_t InferTensorizeRegion( // Get domains if inputs std::unordered_map in_dom; std::unordered_map temp_dmap; + arith::Analyzer analyzer; Array inputs = self->InputTensors(); for (Tensor t : inputs) { in_dom.emplace(t, TensorDom(t.ndim())); } for (IterVar iv : self->root_iter_vars()) { IntSet iset = up_state.at(iv); - (*out_dom)[iv] = iset.cover_range(dom_map.at(iv)); + Range iv_range = iset.cover_range(dom_map.at(iv)); + (*out_dom)[iv] = iv_range; + analyzer.Bind(iv->var, iv_range); temp_dmap[iv->var.get()] = iset; } // Input domains - self->PropBoundToInputs(stage->op, temp_dmap, &in_dom); + self->PropBoundToInputs(stage->op, &analyzer, temp_dmap, &in_dom); Range none; for (const auto& kv : in_dom) { Array vec; diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc index 2cd51d6e6d12..87c12a852248 100644 --- a/src/schedule/bound.cc +++ b/src/schedule/bound.cc @@ -191,6 +191,7 @@ void InferRootBound(const Stage& stage, PassUpDomain(op_stage, *rmap, &up_state); // Relax if needed. std::unordered_map dom_map; + arith::Analyzer analyzer; for (auto iv : op->root_iter_vars()) { Range r; if (up_state.count(iv)) { @@ -203,8 +204,9 @@ void InferRootBound(const Stage& stage, } else { dom_map[iv->var.get()] = IntSet::range(r); } + analyzer.Bind(iv->var, r); } - op->PropBoundToInputs(op, dom_map, &tmap); + op->PropBoundToInputs(op, &analyzer, dom_map, &tmap); } stage->op->GatherBound(stage->op, tmap, rmap); }