Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 33 additions & 14 deletions src/pass/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,11 +239,16 @@ class ThreadPartitionInserter : public IRMutator {
// Try to do partition at the candidate IRs
class LoopPartitioner : public IRMutator {
public:
explicit LoopPartitioner(std::unordered_set<const Node*> candidates)
: candidates_(candidates) {}
explicit LoopPartitioner(bool split_const_loop)
: selector(CandidateSelector(split_const_loop)) {}

Stmt VisitAndMutate(const Stmt& stmt) {
selector.Visit(stmt);
return Mutate(stmt);
}

Stmt Mutate_(const For* op, const Stmt& stmt) {
if (candidates_.count(op)) {
if (selector.candidates.count(op)) {
Stmt s = TryPartition(op, stmt, op->loop_var,
op->min, op->min + op->extent - 1, op->body, false);
if (s.defined()) return s;
Expand All @@ -266,7 +271,7 @@ class LoopPartitioner : public IRMutator {
const IterVarNode *iv = op->node.as<IterVarNode>();
CHECK(iv);
Var var = iv->var;
if (candidates_.count(op)) {
if (selector.candidates.count(op)) {
Stmt s = TryPartition(op, stmt, var, 0, op->value - 1, op->body, true);
if (s.defined()) return s;
}
Expand Down Expand Up @@ -295,9 +300,9 @@ class LoopPartitioner : public IRMutator {
inline Stmt MakeFor(const Node* op, Expr extent, Stmt body);

/* Candidate IRs that may be partitioned potentially */
std::unordered_set<const Node*> candidates_;
std::unordered_map<const Variable*, IntSet> hint_map_;
std::unordered_map<const Variable*, IntSet> relax_map_;
CandidateSelector selector;
};

Stmt LoopPartitioner::TryPartition(const Node* node,
Expand All @@ -322,7 +327,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Expr body_begin;
Stmt pre_stmt;
if (true_itrv.as<arith::IntervalSet>()->i.has_lower_bound()) {
body_begin = true_itrv.min();
body_begin = ir::Simplify(true_itrv.min());
if (!can_prove(body_begin == min)) {
Expr cond = (body_begin - min >= 0);
if (!can_prove(cond)) {
Expand All @@ -343,7 +348,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Expr post_doubt_begin;
Stmt post_stmt;
if (true_itrv.as<arith::IntervalSet>()->i.has_upper_bound()) {
post_doubt_begin = true_itrv.max() + 1;
post_doubt_begin = ir::Simplify(true_itrv.max() + 1);
if (!can_prove(true_itrv.max() == max)) {
// require the extent to be non-negative
Expr cond = (max - post_doubt_begin + 1 >= 0);
Expand All @@ -354,8 +359,17 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
}
// [post_doubt_begin, max]
if (!partition_thread_scope) {
Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}});
post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
Stmt post_body;
// If the loop is going from 0 to 1, replace the loop var with min value
if (as_const_int(max) && as_const_int(post_doubt_begin)) {
if (*as_const_int(max) == *as_const_int(post_doubt_begin)) {
post_body = Substitute(body, {{Var{var}, post_doubt_begin}});
post_stmt = post_body;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just trying to understand, why this need to be handled separately?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An example usecase is as follows. This is what statement looks in between the CLP is running

for (i = 0, 1)
  for(j = 0, 8)
    if(i + j < 4)
      Do_something

If the loop bounds are going from 0 to 1, then its better to replace i with 0 because then it simplifies the conditions, resulting in better partitioning.

} else {
post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}});
post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
}
}
}
} else {
Expand All @@ -368,8 +382,15 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Stmt simplified_body = ConditionEliminator(partitions).Mutate(body);
Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}});
s = MakeFor(node, post_doubt_begin - body_begin, new_body);
if (pre_stmt.defined()) s = Block::make(pre_stmt, s);
if (post_stmt.defined()) s = Block::make(s, post_stmt);

if (!(pre_stmt.defined() && post_stmt.defined())) s = VisitAndMutate(s);
if (pre_stmt.defined()) s = Block::make(pre_stmt, s);
if (post_stmt.defined()) {
if (as_const_int(max) && as_const_int(post_doubt_begin)) {
post_stmt = VisitAndMutate(post_stmt);
}
s = Block::make(s, post_stmt);
}
} else {
Expr cond = const_true();
if (!can_prove(body_begin == min)) cond = cond && (var >= body_begin);
Expand Down Expand Up @@ -402,9 +423,7 @@ class RemoveLikelyTags : public IRMutator {
};

Stmt LoopPartition(Stmt stmt, bool split_const_loop) {
CandidateSelector selector(split_const_loop);
selector.Visit(stmt);
stmt = LoopPartitioner(selector.candidates).Mutate(stmt);
stmt = LoopPartitioner(split_const_loop).VisitAndMutate(stmt);
stmt = RemoveLikelyTags().Mutate(stmt);
return stmt;
}
Expand Down
158 changes: 158 additions & 0 deletions tests/python/unittest/test_pass_loop_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,157 @@ def test_everything_during_deduction():
stmt = tvm.ir_pass.Simplify(stmt)
assert(isinstance(stmt.body.body, tvm.stmt.IfThenElse))

def test_single_likely():
n = 60
A = tvm.placeholder((n, ), name='A')
B = tvm.placeholder((n, ), name='B')

T = tvm.compute((n, ), lambda i: A[i]+B[i])
s = tvm.create_schedule(T.op)
x = T.op.axis[0]
xo, xi = s[T].split(x, factor=16)

bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))

def test_multi_likely():
n = 94
m = 62
A = tvm.placeholder((n, m), name='A')
B = tvm.placeholder((n, m), name='B')

T = tvm.compute((n, m), lambda i, j: A[i, j]+B[i, j])
s = tvm.create_schedule(T.op)
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
x, y = T.op.axis
xo, xi = s[T].split(x, factor=16)
yo, yi = s[T].split(y, factor=16)
s[T].reorder(xo, yo, xi, yi)

bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))

def test_oneD_pool():
m = tvm.var('m')
ib = tvm.ir_builder.create()
#data = tvm.placeholder((16,), name = 'data')
data = ib.pointer("float32", name="A")
out = ib.pointer("float32", name="A")
with ib.for_range(0, 16, 'ow') as ow:
with ib.for_range(0, 3, 'kw') as kw:
with ib.if_scope(ib.likely(ow > 0)):
with ib.if_scope(ib.likely(ow < 15)):
out[ow] = tvm.max(out[ow], data[ow + kw - 1])
with ib.for_range(0, 16, 'ow') as ow:
with ib.for_range(0, 3, 'kw') as kw:
with ib.if_scope(ib.likely(ow < 1)):
with ib.if_scope(ib.likely(kw > 0)):
out[ow] = tvm.max(out[ow], data[ow + kw - 1])
with ib.for_range(0, 16, 'ow') as ow:
with ib.for_range(0, 3, 'kw') as kw:
with ib.if_scope(ib.likely(ow > 14)):
with ib.if_scope(ib.likely(kw < 2)):
out[ow] = tvm.max(out[ow], data[ow + kw - 1])

stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))

def test_cce_loop_1():
ib = tvm.ir_builder.create()
dtype = 'float16'
n = 514
m = 514
_A = tvm.placeholder((n*m,), name = 'A')
Ab = tvm.decl_buffer((n*m,), dtype, name="A")
A = ib.buffer_ptr(Ab)
_B = tvm.placeholder((n*m,), name = 'B')
Bb = tvm.decl_buffer((n*m,), dtype, name="B")
B = ib.buffer_ptr(Bb)
#for i in 0 to n-1:
with ib.for_range(0, 11, name="i") as i:
with ib.for_range(0, 160, name="j") as j:
with ib.if_scope(ib.likely(((i*160) + j) < 1600)):
A[(i+1)*m+j+1] = B[(i)*m+j+1] + B[(i+1)*m+j+1] + B[(i+2)*m+j+1]
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))

def test_cce_loop_2():
ib = tvm.ir_builder.create()
len = 112
tile = 32
loop = (len + tile - 1) // tile
with ib.for_range(0, loop, 'i') as i:
head = i * tile
with ib.if_scope(ib.likely(head + tile > len)):
tail = len
ib.emit(tvm.call_extern('float32', "cce_intrisic", head, tail))
with ib.else_scope():
tail = head + tile
ib.emit(tvm.call_extern('float32', "cce_intrisic", head, tail))

stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))


def test_cce_loop_3():
ib = tvm.ir_builder.create()
loop1 = 4
loop2 = 9998
tile = 39991
with ib.for_range(0,loop2,'i') as i:
with ib.for_range(0,loop1,'j') as j:
head1 = i
head2 = j
with ib.if_scope(ib.likely(head1*loop1 + head2 < tile)):
ib.emit(tvm.call_extern('float16',"cce_intrisic",head1))

stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt,True)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))

def test_conv_tiling():
HSTR = WSTR = 1
in_channel = 128
kernel_height = kernel_width = 3
out_channel = 64
batch_size = 1
in_height = in_width = 64
out_height = out_width = in_height - kernel_height + 1
data = tvm.placeholder((batch_size, in_channel, in_height, in_width), name='data')
kernel = tvm.placeholder((kernel_height, kernel_width, in_channel,
out_channel), name='kernel')
ic = tvm.reduce_axis((0, in_channel), name='ic')
kh = tvm.reduce_axis((0, kernel_height), name='kh')
kw = tvm.reduce_axis((0, kernel_width), name='kw')
conv = tvm.compute((batch_size, out_channel, out_height, out_width),
lambda n, oc, oh, ow: tvm.sum(data[n, ic, oh*HSTR + kh, ow*WSTR + kw] *
kernel[kh, kw, ic, oc],
axis=[ic, kh, kw]),
name="conv2d")
s = tvm.create_schedule(conv.op)

n, oc, oh, ow = conv.op.axis
oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16)
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))

if __name__ == "__main__":
test_basic()
test_const_loop()
Expand All @@ -187,3 +338,10 @@ def test_everything_during_deduction():
test_select()
test_thread_axis2()
test_everything_during_deduction()
test_single_likely()
test_multi_likely()
test_oneD_pool()
test_cce_loop_1()
test_cce_loop_2()
test_cce_loop_3()
test_conv_tiling()