diff --git a/src/schedule/message_passing.cc b/src/schedule/message_passing.cc index 622e0b698902..6c185d6f8637 100644 --- a/src/schedule/message_passing.cc +++ b/src/schedule/message_passing.cc @@ -491,11 +491,12 @@ std::vector MakeBoundCheck( IntSet s = EvalSet(value, iset_dmap); Expr vmin = s.min(); Expr vmax = s.max(); - if (vmin.type() != value.type() || !can_prove(vmin >= iv->dom->min)) { + // The range of `value` resides in [vmin, vmax] + if (vmin.type() != value.type() || !can_prove(vmin >= 0)) { preds.emplace_back(value >= 0); } if (vmax.type() != value.type() || !can_prove(vmax < iv->dom->extent)) { - preds.emplace_back(value < (iv->dom->extent - iv->dom->min)); + preds.emplace_back(value < iv->dom->extent); } } } diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index 8774514cfa17..e60073fe9f5c 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -12,6 +12,7 @@ def test_schedule0(): assert isinstance(bounds, tvm.container.Map) stmt = tvm.schedule.ScheduleOps(s, bounds) + def test_schedule1(): m = tvm.var('m') l = tvm.var('l') @@ -53,10 +54,13 @@ def test_schedule_scan(): assert tuple(res.shape) == (m, n) s = tvm.create_schedule(res.op) s = s.normalize() + ir = tvm.lower(s, [s_state], simple_mode=True) + assert not hasattr(ir.body.body.body.body.rest.body.body.rest.body, "condition") bounds = tvm.schedule.InferBound(s) assert(bounds[res.op.scan_axis].min.value == 1) stmt = tvm.schedule.ScheduleOps(s, bounds) + def test_inline_multi_reduce(): def argmax_comp(x, y): idx = tvm.select((x[1] >= y[1]), x[0], y[0]) @@ -80,7 +84,6 @@ def argmax_init(idx_typ, val_typ): stmt = tvm.schedule.ScheduleOps(s, bounds) - def test_auto_inline(): m = tvm.var('m') n = tvm.var('n') @@ -96,6 +99,7 @@ def test_auto_inline(): bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) + def test_schedule_const_bound(): n = 128 A = tvm.placeholder((n,), name='A') @@ -146,6 +150,7 @@ def test_scan_inline1(): s[s_x1].compute_inline() stmt = tvm.lower(s, [x, res1, res2]) + def test_scan_inline2(): m = tvm.var("m") n = tvm.var("n") @@ -183,6 +188,7 @@ def test_schedule_cache(): bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) + def test_schedule_middle_cache(): m = tvm.var('m') n = tvm.var('n') @@ -202,7 +208,6 @@ def test_schedule_middle_cache(): stmt = tvm.schedule.ScheduleOps(s, bounds) - def test_schedule_cache_relayout1(): m = tvm.var('m') n = tvm.var('n') @@ -249,6 +254,7 @@ def test_schedule_cache_relayout3(): bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) + def test_schedule_cache_relayout4(): def _compute(*indice): return A(*indice) + 1, B(*indice) / 2