diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index b2a1bea0dd5b..1d669c8b39f1 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -513,17 +513,19 @@ Stmt LoopPartitioner::TryPartition(const Node* node, bool pre_stmt_recurse = true; if (middle_interval_i->HasLowerBound()) { body_begin = ir::Simplify(middle_interval.min()); - Expr cond = (body_begin - min >= 0); - if (!analyzer_.CanProve(cond)) { - LOG(WARNING) << "Cannot prove: " << cond - << ", when generating the pre doubt loop"; - body_begin = Max::make(body_begin, min); - // stop recursing on this interval if we can't prove it has non-negative length - pre_stmt_recurse = false; - } - if (!partition_thread_scope) { - Stmt pre_body = Substitute(body, {{Var{var}, var + min}}); - pre_stmt = MakeFor(node, body_begin - min, pre_body); + if (!analyzer_.CanProve(body_begin == min)) { + Expr cond = (body_begin - min >= 0); + if (!analyzer_.CanProve(cond)) { + LOG(WARNING) << "Cannot prove: " << cond + << ", when generating the pre doubt loop"; + body_begin = Max::make(body_begin, min); + // stop recursing on this interval if we can't prove it has non-negative length + pre_stmt_recurse = false; + } + if (!partition_thread_scope) { + Stmt pre_body = Substitute(body, {{Var{var}, var + min}}); + pre_stmt = MakeFor(node, body_begin - min, pre_body); + } } } else { body_begin = min; @@ -536,19 +538,21 @@ Stmt LoopPartitioner::TryPartition(const Node* node, bool post_stmt_recurse = true; if (middle_interval_i->HasUpperBound()) { post_doubt_begin = ir::Simplify(middle_interval.max() + 1); - // require the extent to be non-negative - Expr cond = (max - post_doubt_begin + 1 >= 0); - if (!analyzer_.CanProve(cond)) { - LOG(WARNING) << "Cannot prove: " << cond - << ", when generating the post doubt loop"; - post_doubt_begin = Min::make(post_doubt_begin, max+1); - // stop recursing on this interval if we can't prove it has non-negative length - post_stmt_recurse = false; - } - if (!partition_thread_scope) { - Stmt post_body = - Substitute(body, {{Var{var}, var + post_doubt_begin}}); - post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body); + if (!analyzer_.CanProve(middle_interval.max() == max)) { + // require the extent to be non-negative + Expr cond = (max - post_doubt_begin + 1 >= 0); + if (!analyzer_.CanProve(cond)) { + LOG(WARNING) << "Cannot prove: " << cond + << ", when generating the post doubt loop"; + post_doubt_begin = Min::make(post_doubt_begin, max+1); + // stop recursing on this interval if we can't prove it has non-negative length + post_stmt_recurse = false; + } + if (!partition_thread_scope) { + Stmt post_body = + Substitute(body, {{Var{var}, var + post_doubt_begin}}); + post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body); + } } } else { post_doubt_begin = max + 1; diff --git a/tests/python/unittest/test_pass_loop_partition.py b/tests/python/unittest/test_pass_loop_partition.py index b6fcfa3a1512..021709506754 100644 --- a/tests/python/unittest/test_pass_loop_partition.py +++ b/tests/python/unittest/test_pass_loop_partition.py @@ -365,6 +365,27 @@ def test_conv_tiling(): stmt = tvm.ir_pass.Simplify(stmt) assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse)))) + +def test_multilevel_splitting_with_indivisble_factors(): + import topi + A = tvm.placeholder((130,), dtype="float32") + B = topi.nn.relu(A) + s = tvm.create_schedule(B.op) + (y,) = s[B].op.axis + (yo, yi) = s[B].split(y, factor=8) + (yoo, yoi) = s[B].split(yo, factor=16) + s[B].reorder(yoo, yoi, yi) + s[B].unroll(yi) + + ## But this does the right thing. + with tvm.build_config(partition_const_loop=True): + lowered_body = tvm.lower(s, [A, B]).body + def visit_stmt(op): + return(isinstance(op, tvm.expr.Max)) + num_max = collect_visit(lowered_body, visit_stmt) + assert num_max.count(True) == 10 + + def test_double_splitting_with_indivisible_factors(): m = 48 dtype="float32" @@ -443,4 +464,5 @@ def test_simple_rfactor(): test_cce_loop_3() test_conv_tiling() test_double_splitting_with_indivisible_factors() + test_multilevel_splitting_with_indivisble_factors() test_simple_rfactor()