Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/tvm/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,15 @@ class OperationNode : public FunctionBaseNode {
/*!
* \brief Propagate the bounds to inputs
* \param self The reference to self.
* \param analyzer The analyzer to be used in the function.
* \param dom_map the domain map of Variables(corresponds to root_iter_vars)
* \param out_dom_map The output domain.
* The function is only asked to fill the bounds for Tensors that
* is already in the out_dom_map
*/
virtual void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const = 0;
/*!
Expand Down Expand Up @@ -170,6 +172,7 @@ class PlaceholderOpNode : public OperationNode {
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
Expand Down Expand Up @@ -247,6 +250,7 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
Stmt BuildProvide(
Expand Down Expand Up @@ -299,6 +303,7 @@ class TensorComputeOpNode : public BaseComputeOpNode {
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
Stmt BuildProvide(
Expand Down Expand Up @@ -373,6 +378,7 @@ class ScanOpNode : public OperationNode {
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
Expand Down Expand Up @@ -439,6 +445,7 @@ class ExternOpNode : public OperationNode {
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
Expand Down Expand Up @@ -506,6 +513,7 @@ class HybridOpNode : public OperationNode {
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
Expand Down
29 changes: 27 additions & 2 deletions src/op/compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "op_util.h"
#include "../schedule/message_passing.h"
#include "../arithmetic/compute_expr.h"
#include "../arithmetic/int_set.h"

namespace tvm {

Expand Down Expand Up @@ -210,17 +211,41 @@ Operation ComputeOpNode::ReplaceInputs(

void ComputeOpNode::PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
CHECK_EQ(self.operator->(), this);
auto fvisit = [&dom_map, out_dom_map](const NodeRef& n) {
auto fvisit = [&dom_map, out_dom_map, analyzer](const NodeRef& n) {
auto *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) {
Tensor t = Operation(call->func.node_).output(call->value_index);
if (t->op.defined() && out_dom_map->count(t)) {
TensorDom& dom = out_dom_map->at(t);
for (size_t i = 0; i < t.ndim(); ++i) {
dom.data[i].push_back(EvalSet(call->args[i], dom_map));
// We assume that the value of the argument cannot be out of bounds (otherwise it is
// undefined behaviour), so we can intersect the estimated set of the argument with the
// range expected by the tensor. However, intersection may result in overly complex
// expressions, so we perform a more relaxed form of intersection.
IntSet arg_intset = EvalSet(call->args[i], dom_map);
const arith::IntervalSetNode* arg_interval = arg_intset.as<arith::IntervalSetNode>();
if (arg_interval) {
Expr shape_i_min_value = make_zero(t->shape[i].type());
Expr shape_i_max_value = t->shape[i] - 1;
Expr min_value = arg_interval->min_value;
Expr max_value = arg_interval->max_value;
// Prefer the shape bounds only when we can prove they are tighter.
if (arith::is_neg_inf(min_value) ||
analyzer->CanProve(shape_i_min_value >= min_value)) {
min_value = shape_i_min_value;
}
if (arith::is_pos_inf(max_value) ||
analyzer->CanProve(shape_i_max_value <= max_value)) {
max_value = shape_i_max_value;
}
dom.data[i].push_back(IntSet::interval(min_value, max_value));
} else {
dom.data[i].push_back(arg_intset);
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/op/extern_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ Operation ExternOpNode::ReplaceInputs(

void ExternOpNode::PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
for (Tensor t : this->inputs) {
Expand Down
1 change: 1 addition & 0 deletions src/op/hybrid_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ Operation HybridOpNode::ReplaceInputs(

void HybridOpNode::PropBoundToInputs(
const Operation &self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet> &dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
for (Tensor t : this->inputs) {
Expand Down
1 change: 1 addition & 0 deletions src/op/placeholder_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ Operation PlaceholderOpNode::ReplaceInputs(

void PlaceholderOpNode::PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
}
Expand Down
1 change: 1 addition & 0 deletions src/op/scan_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ Operation ScanOpNode::ReplaceInputs(

void ScanOpNode::PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
CHECK_EQ(self.operator->(), this);
Expand Down
1 change: 1 addition & 0 deletions src/op/tensor_compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ Operation TensorComputeOpNode::ReplaceInputs(

void TensorComputeOpNode::PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
for (size_t i = 0; i < this->inputs.size(); ++i) {
Expand Down
7 changes: 5 additions & 2 deletions src/op/tensorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,20 @@ size_t InferTensorizeRegion(
// Get domains if inputs
std::unordered_map<Tensor, TensorDom> in_dom;
std::unordered_map<const Variable*, IntSet> temp_dmap;
arith::Analyzer analyzer;
Array<Tensor> inputs = self->InputTensors();
for (Tensor t : inputs) {
in_dom.emplace(t, TensorDom(t.ndim()));
}
for (IterVar iv : self->root_iter_vars()) {
IntSet iset = up_state.at(iv);
(*out_dom)[iv] = iset.cover_range(dom_map.at(iv));
Range iv_range = iset.cover_range(dom_map.at(iv));
(*out_dom)[iv] = iv_range;
analyzer.Bind(iv->var, iv_range);
temp_dmap[iv->var.get()] = iset;
}
// Input domains
self->PropBoundToInputs(stage->op, temp_dmap, &in_dom);
self->PropBoundToInputs(stage->op, &analyzer, temp_dmap, &in_dom);
Range none;
for (const auto& kv : in_dom) {
Array<Range> vec;
Expand Down
4 changes: 3 additions & 1 deletion src/schedule/bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ void InferRootBound(const Stage& stage,
PassUpDomain(op_stage, *rmap, &up_state);
// Relax if needed.
std::unordered_map<const Variable*, IntSet> dom_map;
arith::Analyzer analyzer;
for (auto iv : op->root_iter_vars()) {
Range r;
if (up_state.count(iv)) {
Expand All @@ -203,8 +204,9 @@ void InferRootBound(const Stage& stage,
} else {
dom_map[iv->var.get()] = IntSet::range(r);
}
analyzer.Bind(iv->var, r);
}
op->PropBoundToInputs(op, dom_map, &tmap);
op->PropBoundToInputs(op, &analyzer, dom_map, &tmap);
}
stage->op->GatherBound(stage->op, tmap, rmap);
}
Expand Down
22 changes: 22 additions & 0 deletions tests/python/unittest/test_schedule_bound_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,27 @@ def _body():
assert isinstance(bounds, tvm.container.Map)
assert(bounds[B.op.axis[0]].extent.value == 10)

def test_bound_simplification_failure():
# Check that the bounds are not expanded
A = tvm.compute((2,), lambda j: j, "A")

def _check(B, A=A):
s = tvm.create_schedule(B.op)
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.lower(s, [B, A], simple_mode=True)
if not bounds[A.op.axis[0]].extent.value <= 2:
print(stmt)
assert bounds[A.op.axis[0]].extent.value <= 2

# These are hard to simplify, moreover we don't simplify them
_check(tvm.compute((10,), lambda i: A[tvm.min(3*i, 4*i) + tvm.min(-3*i, -2*i)]))
_check(tvm.compute((10,), lambda i: A[tvm.min(3*i, 4*i) + tvm.max(-3*i, -4*i)]))
_check(tvm.compute((10,), lambda i: A[-2*(i/2) - tvm.min(i, 0-i)]))
_check(tvm.compute((10,), lambda i: A[i + (0 - i)]))
# This would cause out of bounds, but we nevertheless include it
_check(tvm.compute((10,), lambda i: A[i]))

if __name__ == "__main__":
test_bound_nest_thread()
test_bound1()
Expand All @@ -320,3 +341,4 @@ def _body():
test_gemm_bound()
test_bound_warp()
test_bound_tensor_compute_op()
test_bound_simplification_failure()
15 changes: 0 additions & 15 deletions tests/python/unittest/test_schedule_schedule_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,20 +286,6 @@ def _compute(*indice):
stmt = tvm.schedule.ScheduleOps(s, bounds)


def test_schedule_bound_condition():
A = tvm.placeholder((64,), name='A', dtype="float32")
Apad = tvm.compute((66,), lambda i: tvm.if_then_else(
tvm.all(i>0, i < 65), A[i-1], tvm.const(0., "float32")), name='Apad')
Apad2 = tvm.compute((66,), lambda i: Apad[i]*2, name='Apad2')
s = tvm.create_schedule(Apad2.op)
AL1 = s.cache_read(A,"local",[Apad])
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.Simplify(stmt)
assert (isinstance(stmt.body.body.first.body.body.then_case, tvm.stmt.IfThenElse))


def intrin_gemv(m, n):
w = tvm.placeholder((m, n), name='w')
x = tvm.placeholder((n,), name='x')
Expand Down Expand Up @@ -514,7 +500,6 @@ def _compute(*index) :
test_schedule1()
test_schedule2()
test_schedule_cache()
test_schedule_bound_condition()
test_schedule_tensor_compute1()
test_schedule_tensor_compute2()
test_schedule_tensor_compute3()
Expand Down