diff --git a/src/schedule/message_passing.cc b/src/schedule/message_passing.cc index a144e7fc40d1..3cea560318d8 100644 --- a/src/schedule/message_passing.cc +++ b/src/schedule/message_passing.cc @@ -477,9 +477,14 @@ std::vector MakeBoundCheck( CHECK(iv->dom.defined()); if (!skip_ivar_domain && !iv->dom.same_as(dom)) { Expr value = ComputeExpr(value_map.at(iv), iv->dom->min); - Expr vmax = EvalSet(value, iset_dmap).max(); + IntSet s = EvalSet(value, iset_dmap); + Expr vmin = s.min(); + Expr vmax = s.max(); + if (vmin.type() != value.type() || !can_prove(vmin >= iv->dom->min)) { + preds.emplace_back(value >= 0); + } if (vmax.type() != value.type() || !can_prove(vmax < iv->dom->extent)) { - preds.emplace_back(value < iv->dom->extent); + preds.emplace_back(value < (iv->dom->extent - iv->dom->min)); } } } diff --git a/tests/python/unittest/test_pass_inject_copy_intrin.py b/tests/python/unittest/test_pass_inject_copy_intrin.py index c6ed19d65b69..a44f3899c282 100644 --- a/tests/python/unittest/test_pass_inject_copy_intrin.py +++ b/tests/python/unittest/test_pass_inject_copy_intrin.py @@ -82,6 +82,7 @@ def test_copy_pad_split(): Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) + stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.CanonicalSimplify(stmt) def cb(src, dst, pad_before, pad_after, pad_value): assert(dst.elem_offset.value == 0) diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index 03b8dbf48c8c..dd57c35c39b6 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -249,6 +249,18 @@ def test_schedule_cache_relayout3(): stmt = tvm.schedule.ScheduleOps(s, bounds) +def test_schedule_bound_condition(): + A = tvm.placeholder((64,), name='A', dtype="float32") + Apad = tvm.compute((66,), lambda i: tvm.select(tvm.all(i>0, i < 65), A[i-1], tvm.const(0.)), name='Apad') + Apad2 = tvm.compute((66,), lambda i: Apad[i]*2, name='Apad2') + s = tvm.create_schedule(Apad2.op) + AL1 = s.cache_read(A,"local",[Apad]) + s = s.normalize() + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + stmt = tvm.ir_pass.Simplify(stmt) + assert (isinstance(stmt.body.body.first.body.body.then_case, tvm.stmt.IfThenElse)) + if __name__ == "__main__": test_schedule_middle_cache() test_inline_multi_reduce() @@ -265,3 +277,4 @@ def test_schedule_cache_relayout3(): test_schedule1() test_schedule2() test_schedule_cache() + test_schedule_bound_condition()