-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Description
The way loop partition works currently is that it divides the iteration space of the for loop in three different segments based on the conditions on the loop var.
mid_stmtwhere the condition is proven to be definitely either true or false.
and two doubt loops
-
pre_stmt, where we can't prove the condition to definitely true or false, just yet andpre_stmtgets recursed further as long as valid. -
post_stmt, similar to pre_stmt
Now, consider the following example:
import tvm
def test_loop_bug():
ib = tvm.ir_builder.create()
m = tvm.var("m")
n = tvm.var("n")
data = ib.pointer("float32", name="data")
out = ib.pointer("float32", name="out")
with ib.for_range(0, 16, "i") as i:
with ib.if_scope(ib.likely(i>4)):
with ib.if_scope(ib.likely(i<20)):
out[i] = tvm.max(out[i], data[i])
with ib.else_scope():
out[i] = data[i]
with ib.else_scope():
out[i] = data[i]
stmt = ib.get()
print("===================")
print(stmt)
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt)
stmt = tvm.ir_pass.RemoveNoOp(stmt)
stmt = tvm.ir_pass.RemoveNoOp(stmt)
print("====================")
print(stmt)
if __name__ == "__main__":
test_loop_bug()
Example above is essentially doing
out = tvm.if_then_else((i >4 && i < 20), tvm.max(out, data), data)
HalideIR for the above example before the loop partitioning is as follows:
for (i, 0, 16) {
if (likely((i > 4))) {
if (likely((i < 20))) {
out[i] = max(out[i], data[i])
} else {
out[i] = data[i]
}
} else {
out[i] = data[i]
}
}
Based on the conditions, We expect that we should get the following like Output after the loop partition:
for (i, 0, 5) {
out[i] = data[i] // here (i>4 && i<20) is false.
}
for (i, 0, 11) {
out[(i + 5)] = max(out[(i + 5)], data[(i + 5)]) // here both the conditions are true.
}
However, instead of the above HalideIR, we get the following after the loop partitioning
for (i, 0, 3) {
out[i] = data[i]
}
out[3] = data[3]
out[4] = data[4]
for (i, 0, 10) {
out[(i + 5)] = max(out[(i + 5)], data[(i + 5)])
}
out[15] = max(out[15], data[15])
This difference in the expected result is because LoopPartition generates the loop with extent one in the post_stmt .
Let's say the For loop has the IV range from [min, max+1) and conditions are provably true or false in the region [body_begin, post_doubt_begin) then,
Expected intervals are as follows:
pre_stmtwill have range[min, body_begin)mid_stmtwill have range[body_begin, post_doubt_begin)post_stmtwill have range[post_doubt_begin, max+1).
Notice that intervals should be half-open.
However, when, LoopPartition can not prove the post_stmt extent to be non-negative, it changes the post_doubt_begin to the following:
post_doubt_begin = Min::make(post_doubt_begin, max);
Now consider the case where post_doubt_begin = max from above assignment.
In that case, post_stmt will always at least an extent of one and mid_stmt will have one less extent.