Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion src/op/compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,40 @@ void ComputeOpNode::GatherBound(
std::unordered_map<IterVar, Range>* out_dom_map) const {
const TensorDom& tdom = tensor_dom.at(self.output(0));
for (size_t i = 0; i < this->axis.size(); ++i) {
Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom);
// Bounds we get from the declaration of i
Range r_dom = this->axis[i]->dom;
// Bounds we get from the uses of the tensor
Range r_from_uses = arith::Union(tdom.data.at(i)).cover_range(r_dom);
// The result
Range r;

if (can_prove(r_from_uses->extent <= r_dom->extent)) {
// Bounds from the uses are provably tighter, use them
if (can_prove(r_from_uses->extent == r_dom->extent)) {
// If the extents are equal, prefer using r_dom, as it probably has the simpler min
r = r_dom;
} else {
r = r_from_uses;
}
} else if (can_prove(r_dom->extent <= r_from_uses->extent)) {
// The declared bounds are better. This may mean one of the following two things:
// either we have an out-of-bounds error in the input user code, or the simplifier
// did a poor job simplifying call arguments before evaluating ranges.
// Use the declared bounds but issue a warning.
LOG(WARNING) << "GatherBound: the declared bounds " << r_dom
<< " are tighter than the bounds from uses " << r_from_uses
<< " for the variable " << this->axis[i]->var << " of the tensor " << self->name
<< ". Either out-of-bounds or poor simplification.";
r = r_dom;
} else {
// We can prove neither. Issue a warning and use r_from_uses since it was the old behaviour
// and it leads to fewer problems.
LOG(WARNING) << "GatherBound: cannot prove either the declared bounds " << r_dom
<< " or the bounds from uses " << r_from_uses
<< " to be tighter than the other. Will use the bounds from uses.";
r = r_from_uses;
}

CHECK(!out_dom_map->count(this->axis[i]));
(*out_dom_map)[this->axis[i]] = r;
}
Expand Down
21 changes: 21 additions & 0 deletions tests/python/unittest/test_schedule_bound_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,26 @@ def test_gemm_bound():
assert(bounds[CC.op.axis[0]].extent.value == 8)
assert(bounds[CC.op.axis[1]].extent.value == 8)

def test_bound_simplification_failure():
# Check that the bounds are not expanded
A = tvm.compute((2,), lambda j: j, "A")

def _check(B, A=A):
s = tvm.create_schedule(B.op)
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.lower(s, [B, A], simple_mode=True)
if not bounds[A.op.axis[0]].extent.value <= 2:
print(stmt)
assert bounds[A.op.axis[0]].extent.value <= 2

# These are hard to simplify, moreover we don't simplify them
_check(tvm.compute((10,), lambda i: A[tvm.min(3*i, 4*i) + tvm.min(-3*i, -2*i)]))
_check(tvm.compute((10,), lambda i: A[tvm.min(3*i, 4*i) + tvm.max(-3*i, -4*i)]))
_check(tvm.compute((10,), lambda i: A[-2*(i/2) - tvm.min(i, 0-i)]))
_check(tvm.compute((10,), lambda i: A[i + (0 - i)]))
# This would cause out of bounds, but we nevertheless include it
_check(tvm.compute((10,), lambda i: A[i]))

if __name__ == "__main__":
test_bound_nest_thread()
Expand All @@ -273,3 +293,4 @@ def test_gemm_bound():
test_bound2()
test_gemm_bound()
test_bound_warp()
test_bound_simplification_failure()
14 changes: 0 additions & 14 deletions tests/python/unittest/test_schedule_schedule_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,19 +264,6 @@ def _compute(*indice):
stmt = tvm.schedule.ScheduleOps(s, bounds)


def test_schedule_bound_condition():
A = tvm.placeholder((64,), name='A', dtype="float32")
Apad = tvm.compute((66,), lambda i: tvm.select(tvm.all(i>0, i < 65), A[i-1], tvm.const(0.)), name='Apad')
Apad2 = tvm.compute((66,), lambda i: Apad[i]*2, name='Apad2')
s = tvm.create_schedule(Apad2.op)
AL1 = s.cache_read(A,"local",[Apad])
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.Simplify(stmt)
assert (isinstance(stmt.body.body.first.body.body.then_case, tvm.stmt.IfThenElse))


def intrin_gemv(m, n):
w = tvm.placeholder((m, n), name='w')
x = tvm.placeholder((n,), name='x')
Expand Down Expand Up @@ -420,7 +407,6 @@ def test_schedule_tensor_compute3():
test_schedule1()
test_schedule2()
test_schedule_cache()
test_schedule_bound_condition()
test_schedule_tensor_compute1()
test_schedule_tensor_compute2()
test_schedule_tensor_compute3()