From 5cb04511221bfce8886926a71fccd1ad016c75de Mon Sep 17 00:00:00 2001 From: Yongfeng Gu Date: Mon, 10 Feb 2020 22:47:48 -0500 Subject: [PATCH 1/5] Set split node's range to minimum of ext and split factor. --- src/te/schedule/message_passing.cc | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index 5b6fa861895a..190e02a73d99 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -62,6 +62,13 @@ void PassDownDomain(const Stage& stage, return actx->Simplify(indexdiv(a + (b - 1), b)); }; + auto minimum = [actx](PrimExpr a, PrimExpr b) { + if (actx->CanProve(a < b)) { + return actx->Simplify(a); + } + return actx->Simplify(b); + }; + auto& state = *p_state; // forwar iteration on relations for (IterVarRelation rel : stage->relations) { @@ -74,7 +81,8 @@ void PassDownDomain(const Stage& stage, const Range& range_parent = state.at(r->parent); if (r->factor.defined()) { Update(p_state, r->inner, - Range::make_by_min_extent(0, r->factor), actx); + Range::make_by_min_extent( + 0, minimum(range_parent->extent, r->factor)), actx); Update(p_state, r->outer, Range::make_by_min_extent( 0, ceil_div(range_parent->extent, r->factor)), actx); From df93c30e0c31623ab618c080092072ede0ba0228 Mon Sep 17 00:00:00 2001 From: Yongfeng Gu Date: Fri, 14 Feb 2020 10:46:47 -0500 Subject: [PATCH 2/5] Add a test function to ensure that stringent (8 instead of 32) range is inferred. --- .../unittest/test_schedule_bound_inference.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/python/unittest/test_schedule_bound_inference.py b/tests/python/unittest/test_schedule_bound_inference.py index 9c3d1df17f2b..e37769a7f42d 100644 --- a/tests/python/unittest/test_schedule_bound_inference.py +++ b/tests/python/unittest/test_schedule_bound_inference.py @@ -81,6 +81,19 @@ def test_bound_split_divisible(): assert bounds[xo].extent == m assert bounds[xi].extent.value == 8 +def test_bound_split_ext_less_than_factor(): + m = 8 + I = tvm.placeholder((m,), name='I') + EF = tvm.compute((m,), lambda i: I[i] * 2, name = "EF") + E = tvm.compute((m,), lambda i: EF[i] * 2, name = "E") + s = tvm.create_schedule([E.op]) + xo, xi = s[E].split(s[E].op.axis[0], factor = 32) + s[EF].compute_at(s[E], xo) + + bounds = tvm.schedule.InferBound(s) + assert isinstance(bounds, tvm.container.Map) + assert bounds[xi].extent.value == m + def test_bound_tile_divisible(): m = tvm.var('m') l = tvm.var('l') @@ -423,4 +436,5 @@ def _check(B, A=A): test_bound_fusesplit1() test_bound_fusesplit2() test_bound_split_divisible() + test_bound_split_ext_less_than_factor() test_bound_tile_divisible() From d1e101ea1eecbfccd2736f8c07fe3edb221ea171 Mon Sep 17 00:00:00 2001 From: Yongfeng Gu Date: Fri, 14 Feb 2020 10:52:17 -0500 Subject: [PATCH 3/5] Update an existing test. Otherwise, it fails with the split node minimum range fix. Confirmed that the updated test generates expected and cleaner code. --- tests/python/unittest/test_schedule_tensor_core.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/python/unittest/test_schedule_tensor_core.py b/tests/python/unittest/test_schedule_tensor_core.py index cd9e062dc07b..b6c8b90e2432 100644 --- a/tests/python/unittest/test_schedule_tensor_core.py +++ b/tests/python/unittest/test_schedule_tensor_core.py @@ -339,8 +339,6 @@ def test_tensor_core_batch_conv(): ty, yo = s[AS].split(xo, nparts=block_col_warps) t = s[AS].fuse(nn, ii) to, ti = s[AS].split(t, factor=warp_size) - s[AS].bind(tx, thread_y) - s[AS].bind(ty, thread_z) s[AS].bind(ti, thread_x) kh, kw, ic, o, ii, oo = WS.op.axis @@ -348,8 +346,6 @@ def test_tensor_core_batch_conv(): ty, yo = s[WS].split(xo, nparts=block_col_warps) t = s[WS].fuse(ii, oo) to, ti = s[WS].split(t, nparts=warp_size) - s[WS].bind(tx, thread_y) - s[WS].bind(ty, thread_z) s[WS].bind(to, thread_x) s[WS].vectorize(ti) From cdd6fd17bc2267ab891c47cba0a8addb45e8a79d Mon Sep 17 00:00:00 2001 From: Yongfeng Gu Date: Thu, 20 Feb 2020 11:26:35 -0500 Subject: [PATCH 4/5] Update a helper function name. --- src/te/schedule/message_passing.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index 190e02a73d99..09fe4c88050a 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -62,7 +62,7 @@ void PassDownDomain(const Stage& stage, return actx->Simplify(indexdiv(a + (b - 1), b)); }; - auto minimum = [actx](PrimExpr a, PrimExpr b) { + auto minimum_or_later = [actx](PrimExpr a, PrimExpr b) { if (actx->CanProve(a < b)) { return actx->Simplify(a); } @@ -82,7 +82,7 @@ void PassDownDomain(const Stage& stage, if (r->factor.defined()) { Update(p_state, r->inner, Range::make_by_min_extent( - 0, minimum(range_parent->extent, r->factor)), actx); + 0, minimum_or_later(range_parent->extent, r->factor)), actx); Update(p_state, r->outer, Range::make_by_min_extent( 0, ceil_div(range_parent->extent, r->factor)), actx); From 8ff3b4f487bc5aaa2b5f4310445dd4daf72615f7 Mon Sep 17 00:00:00 2001 From: Yongfeng Gu Date: Fri, 21 Feb 2020 00:07:26 -0500 Subject: [PATCH 5/5] Apply the same change to set split node's range for nparts. --- src/te/schedule/message_passing.cc | 7 ++++--- topi/python/topi/cuda/conv2d_direct.py | 1 - 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index 09fe4c88050a..a3f758b44d8e 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -87,10 +87,11 @@ void PassDownDomain(const Stage& stage, Range::make_by_min_extent( 0, ceil_div(range_parent->extent, r->factor)), actx); } else { - Update(p_state, r->outer, Range::make_by_min_extent(0, r->nparts), actx); - Update(p_state, r->inner, + Update(p_state, r->outer, Range::make_by_min_extent( - 0, ceil_div(range_parent->extent, r->nparts)), actx); + 0, minimum_or_later(range_parent->extent, r->nparts)), actx); + Update(p_state, r->inner, + Range::make_by_min_extent(0, ceil_div(range_parent->extent, r->nparts)), actx); } } else if (const FuseNode* r = rel.as()) { if (!state.count(r->outer) || !state.count(r->inner)) { diff --git a/topi/python/topi/cuda/conv2d_direct.py b/topi/python/topi/cuda/conv2d_direct.py index b7df88579f49..58d4bcc29d0e 100644 --- a/topi/python/topi/cuda/conv2d_direct.py +++ b/topi/python/topi/cuda/conv2d_direct.py @@ -106,7 +106,6 @@ def schedule_direct_cuda(cfg, s, conv): tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2]) s[load].bind(tz, tvm.thread_axis("threadIdx.z")) s[load].bind(ty, tvm.thread_axis("threadIdx.y")) - s[load].bind(tx, tvm.thread_axis("threadIdx.x")) # unroll s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)