diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index 5b6fa861895a..a3f758b44d8e 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -62,6 +62,13 @@ void PassDownDomain(const Stage& stage, return actx->Simplify(indexdiv(a + (b - 1), b)); }; + auto minimum_or_later = [actx](PrimExpr a, PrimExpr b) { + if (actx->CanProve(a < b)) { + return actx->Simplify(a); + } + return actx->Simplify(b); + }; + auto& state = *p_state; // forwar iteration on relations for (IterVarRelation rel : stage->relations) { @@ -74,15 +81,17 @@ void PassDownDomain(const Stage& stage, const Range& range_parent = state.at(r->parent); if (r->factor.defined()) { Update(p_state, r->inner, - Range::make_by_min_extent(0, r->factor), actx); + Range::make_by_min_extent( + 0, minimum_or_later(range_parent->extent, r->factor)), actx); Update(p_state, r->outer, Range::make_by_min_extent( 0, ceil_div(range_parent->extent, r->factor)), actx); } else { - Update(p_state, r->outer, Range::make_by_min_extent(0, r->nparts), actx); - Update(p_state, r->inner, + Update(p_state, r->outer, Range::make_by_min_extent( - 0, ceil_div(range_parent->extent, r->nparts)), actx); + 0, minimum_or_later(range_parent->extent, r->nparts)), actx); + Update(p_state, r->inner, + Range::make_by_min_extent(0, ceil_div(range_parent->extent, r->nparts)), actx); } } else if (const FuseNode* r = rel.as()) { if (!state.count(r->outer) || !state.count(r->inner)) { diff --git a/tests/python/unittest/test_schedule_bound_inference.py b/tests/python/unittest/test_schedule_bound_inference.py index 9c3d1df17f2b..e37769a7f42d 100644 --- a/tests/python/unittest/test_schedule_bound_inference.py +++ b/tests/python/unittest/test_schedule_bound_inference.py @@ -81,6 +81,19 @@ def test_bound_split_divisible(): assert bounds[xo].extent == m assert bounds[xi].extent.value == 8 +def test_bound_split_ext_less_than_factor(): + m = 8 + I = tvm.placeholder((m,), name='I') + EF = tvm.compute((m,), lambda i: I[i] * 2, name = "EF") + E = tvm.compute((m,), lambda i: EF[i] * 2, name = "E") + s = tvm.create_schedule([E.op]) + xo, xi = s[E].split(s[E].op.axis[0], factor = 32) + s[EF].compute_at(s[E], xo) + + bounds = tvm.schedule.InferBound(s) + assert isinstance(bounds, tvm.container.Map) + assert bounds[xi].extent.value == m + def test_bound_tile_divisible(): m = tvm.var('m') l = tvm.var('l') @@ -423,4 +436,5 @@ def _check(B, A=A): test_bound_fusesplit1() test_bound_fusesplit2() test_bound_split_divisible() + test_bound_split_ext_less_than_factor() test_bound_tile_divisible() diff --git a/tests/python/unittest/test_schedule_tensor_core.py b/tests/python/unittest/test_schedule_tensor_core.py index cd9e062dc07b..b6c8b90e2432 100644 --- a/tests/python/unittest/test_schedule_tensor_core.py +++ b/tests/python/unittest/test_schedule_tensor_core.py @@ -339,8 +339,6 @@ def test_tensor_core_batch_conv(): ty, yo = s[AS].split(xo, nparts=block_col_warps) t = s[AS].fuse(nn, ii) to, ti = s[AS].split(t, factor=warp_size) - s[AS].bind(tx, thread_y) - s[AS].bind(ty, thread_z) s[AS].bind(ti, thread_x) kh, kw, ic, o, ii, oo = WS.op.axis @@ -348,8 +346,6 @@ def test_tensor_core_batch_conv(): ty, yo = s[WS].split(xo, nparts=block_col_warps) t = s[WS].fuse(ii, oo) to, ti = s[WS].split(t, nparts=warp_size) - s[WS].bind(tx, thread_y) - s[WS].bind(ty, thread_z) s[WS].bind(to, thread_x) s[WS].vectorize(ti) diff --git a/topi/python/topi/cuda/conv2d_direct.py b/topi/python/topi/cuda/conv2d_direct.py index b7df88579f49..58d4bcc29d0e 100644 --- a/topi/python/topi/cuda/conv2d_direct.py +++ b/topi/python/topi/cuda/conv2d_direct.py @@ -106,7 +106,6 @@ def schedule_direct_cuda(cfg, s, conv): tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2]) s[load].bind(tz, tvm.thread_axis("threadIdx.z")) s[load].bind(ty, tvm.thread_axis("threadIdx.y")) - s[load].bind(tx, tvm.thread_axis("threadIdx.x")) # unroll s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)