From 7a36ed63000e977769d30728b67f46c83a5c01c3 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 24 Feb 2020 12:07:36 -0800 Subject: [PATCH 1/3] [SCHEDULE] Improve bound inference for split --- src/te/schedule/message_passing.cc | 17 ++++++--- .../unittest/test_schedule_bound_inference.py | 37 +++++++++++++++++++ 2 files changed, 48 insertions(+), 6 deletions(-) diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index 5b6fa861895a..a8942cef7847 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -55,7 +55,7 @@ void PassDownDomain(const Stage& stage, std::unordered_map* p_state, arith::Analyzer* actx, bool allow_missing) { - auto ceil_div = [actx](PrimExpr a, PrimExpr b) { + auto ceil_div = [actx](const PrimExpr& a, const PrimExpr& b) { if (actx->CanProve(indexmod(a, b) == 0)) { return actx->Simplify(indexdiv(a, b)); } @@ -63,7 +63,7 @@ void PassDownDomain(const Stage& stage, }; auto& state = *p_state; - // forwar iteration on relations + // forward iteration on relations for (IterVarRelation rel : stage->relations) { if (const SplitNode* r = rel.as()) { if (!state.count(r->parent)) { @@ -73,11 +73,16 @@ void PassDownDomain(const Stage& stage, CHECK(!state.count(r->inner)); const Range& range_parent = state.at(r->parent); if (r->factor.defined()) { - Update(p_state, r->inner, - Range::make_by_min_extent(0, r->factor), actx); + PrimExpr outer_extent = ceil_div(range_parent->extent, r->factor); + if (is_one(outer_extent)) { + Update(p_state, r->inner, + Range::make_by_min_extent(0, range_parent->extent), actx); + } else { + Update(p_state, r->inner, + Range::make_by_min_extent(0, r->factor), actx); + } Update(p_state, r->outer, - Range::make_by_min_extent( - 0, ceil_div(range_parent->extent, r->factor)), actx); + Range::make_by_min_extent(0,outer_extent) , actx); } else { Update(p_state, r->outer, Range::make_by_min_extent(0, r->nparts), actx); Update(p_state, r->inner, diff --git a/tests/python/unittest/test_schedule_bound_inference.py b/tests/python/unittest/test_schedule_bound_inference.py index 9c3d1df17f2b..2a5f70dde81f 100644 --- a/tests/python/unittest/test_schedule_bound_inference.py +++ b/tests/python/unittest/test_schedule_bound_inference.py @@ -148,6 +148,42 @@ def test_bound_fusesplit2(): assert(tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[1]].extent, vars)).value == 3) +def test_bound_split_compute_at(): + N = M = K = 1024 + + A = tvm.placeholder((N, K), name='A') + B = tvm.placeholder((M, K), name='B') + C = tvm.compute((N, M), lambda i, j: A[i][j] * 2.0, name='C') + D = tvm.compute((N, M), lambda i, j: C[i][j] + 2.0, name='D') + + s = tvm.create_schedule([D.op]) + + def multiple_split(X): + i, j = s[X].op.axis + i2, i3 = s[X].split(i, 4) + i1, i2 = s[X].split(i2, 4) + i0, i1 = s[X].split(i1, 4) + + j2, j3 = s[X].split(j, 4) + j1, j2 = s[X].split(j2, 4) + j0, j1 = s[X].split(j1, 4) + + iters = i0, j0, i1, j1, i2, j2, i3, j3 + s[X].reorder(*iters) + return iters + + c_iters = multiple_split(C) + d_iters = multiple_split(D) + + s[C].compute_at(s[D], d_iters[-3]) + + bounds = tvm.schedule.InferBound(s) + for i in range(0, 6): + assert bounds[c_iters[i]].extent.value == 1 + for i in range(6, 8): + assert bounds[c_iters[i]].extent.value == 4 + + def test_bound_warp(): m = tvm.var('m') l = tvm.var('l') @@ -422,5 +458,6 @@ def _check(B, A=A): test_bound_simplification_failure() test_bound_fusesplit1() test_bound_fusesplit2() + test_bound_split_compute_at() test_bound_split_divisible() test_bound_tile_divisible() From 9bbcef9e3b77a1488f06e49987d2eb22d4aefe79 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 24 Feb 2020 12:16:00 -0800 Subject: [PATCH 2/3] fix lint --- src/te/schedule/message_passing.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index a8942cef7847..332697519b37 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -82,7 +82,7 @@ void PassDownDomain(const Stage& stage, Range::make_by_min_extent(0, r->factor), actx); } Update(p_state, r->outer, - Range::make_by_min_extent(0,outer_extent) , actx); + Range::make_by_min_extent(0, outer_extent) , actx); } else { Update(p_state, r->outer, Range::make_by_min_extent(0, r->nparts), actx); Update(p_state, r->inner, From 6dac28806a24641da8ce8458c2b0cf3664dbea26 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 24 Feb 2020 12:33:09 -0800 Subject: [PATCH 3/3] remove useless testcase --- .../unittest/test_pass_loop_partition.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/tests/python/unittest/test_pass_loop_partition.py b/tests/python/unittest/test_pass_loop_partition.py index e9df98e43d79..d8722dcdf597 100644 --- a/tests/python/unittest/test_pass_loop_partition.py +++ b/tests/python/unittest/test_pass_loop_partition.py @@ -184,23 +184,6 @@ def test_condition_EQ(): stmt = tvm.ir_pass.Simplify(stmt) assert(not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tir.Select)))) -def test_thread_axis2(): - n = tvm.convert(4096) - m = tvm.size_var('m') - A = tvm.placeholder((n,), name='A') - B = tvm.placeholder((n,), name='B') - C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C') - s = tvm.create_schedule(C.op) - num_thread = 32 - bx, x = s[C].split(C.op.axis[0], factor=32) - tx, x = s[C].split(x, nparts=num_thread) - _, x = s[C].split(x, factor=m) - s[C].bind(bx, tvm.thread_axis("blockIdx.x")) - s[C].bind(tx, tvm.thread_axis("threadIdx.x")) - stmt = lower(s, [A, B]) - for_body = stmt.body.body.body.body.body[0] - assert('threadIdx' not in str(for_body.extent)) - def test_everything_during_deduction(): m = tvm.size_var('m') n = tvm.size_var('n') @@ -455,7 +438,6 @@ def test_simple_rfactor(): test_vectorize() test_condition() test_condition_EQ() - test_thread_axis2() test_everything_during_deduction() test_single_likely() test_multi_likely()