Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 25 additions & 29 deletions src/pass/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -513,46 +513,42 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
bool pre_stmt_recurse = true;
if (middle_interval_i->HasLowerBound()) {
body_begin = ir::Simplify(middle_interval.min());
if (!analyzer_.CanProve(body_begin == min)) {
Expr cond = (body_begin - min >= 0);
if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the pre doubt loop";
body_begin = Max::make(body_begin, min);
// stop recursing on this interval if we can't prove it has non-negative length
pre_stmt_recurse = false;
}
if (!partition_thread_scope) {
Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
pre_stmt = MakeFor(node, body_begin - min, pre_body);
}
Expr cond = (body_begin - min >= 0);
if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the pre doubt loop";
body_begin = Max::make(body_begin, min);
// stop recursing on this interval if we can't prove it has non-negative length
pre_stmt_recurse = false;
}
if (!partition_thread_scope) {
Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
pre_stmt = MakeFor(node, body_begin - min, pre_body);
}
} else {
body_begin = min;
}

// Calculating post-subrange and generating code for it.
// post-subrange = [post_doubt_begin, max]
// post-subrange = [post_doubt_begin, max+1)
Expr post_doubt_begin;
Stmt post_stmt;
bool post_stmt_recurse = true;
if (middle_interval_i->HasUpperBound()) {
post_doubt_begin = ir::Simplify(middle_interval.max() + 1);
if (!analyzer_.CanProve(middle_interval.max() == max)) {
// require the extent to be non-negative
Expr cond = (max - post_doubt_begin + 1 >= 0);
if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the post doubt loop";
post_doubt_begin = Min::make(post_doubt_begin, max);
// stop recursing on this interval if we can't prove it has non-negative length
post_stmt_recurse = false;
}
if (!partition_thread_scope) {
Stmt post_body =
Substitute(body, {{Var{var}, var + post_doubt_begin}});
post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
}
// require the extent to be non-negative
Expr cond = (max - post_doubt_begin + 1 >= 0);
if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the post doubt loop";
post_doubt_begin = Min::make(post_doubt_begin, max+1);
// stop recursing on this interval if we can't prove it has non-negative length
post_stmt_recurse = false;
}
if (!partition_thread_scope) {
Stmt post_body =
Substitute(body, {{Var{var}, var + post_doubt_begin}});
post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
}
} else {
post_doubt_begin = max + 1;
Expand Down
1 change: 1 addition & 0 deletions tests/python/unittest/test_pass_bound_checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def lower(sch, args):
bounds = tvm.schedule.InferBound(sch)
stmt = tvm.schedule.ScheduleOps(sch, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.RemoveNoOp(stmt)
stmt = tvm.ir_pass.StorageFlatten(stmt, binds, 64, True)
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.VectorizeLoop(stmt)
Expand Down