-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Description
HoistIfThenElse is a pass currently not enabled in TVM. I tried to enable it in #5553, but there are too many bugs in this pass. Let's fix them first.
BUG 1: HoistIfThenElse transforms
for (n.inner, 0, 2) {
for (o.inner, 0, 2) {
if ((((threadIdx.y*2) + n.inner) < 2)) {
if ((((threadIdx.z*2) + o.inner) < 4)) {
if ((threadIdx.y < 1)) {
if ((threadIdx.z < 2)) {
tvm_store_matrix_sync(Conv.wmma.accumulator, 16, 16, 16, ((n.inner*2) + o.inner), tvm_access_ptr(type_annotation(), Conv, (((((threadIdx.y*401408) + (n.inner*200704)) + (blockIdx.z*1024)) + (threadIdx.z*512)) + (o.inner*256)), 256, 2), 16, "row_major")
}
}
}
}
}
}
into
if ((((threadIdx.y*2) + n.inner) < 2)) {
if ((threadIdx.y < 1)) {
if ((threadIdx.z < 2)) {
for (n.inner, 0, 2) {
for (o.inner, 0, 2) {
if ((((threadIdx.z*2) + o.inner) < 4)) {
tvm_store_matrix_sync(Conv.wmma.accumulator, 16, 16, 16, ((n.inner*2) + o.inner), tvm_access_ptr(type_annotation(), Conv, (((((threadIdx.y*401408) + (n.inner*200704)) + (blockIdx.z*1024)) + (threadIdx.z*512)) + (o.inner*256)), 256, 2), 16, "row_major")
}
}
}
}
}
}
Possible cause:
It only checks whether if_stmt has a preferred position, but that position is not guaranteed to be the current position. Change it to
if (if_position_map.count(if_stmt.get()) &&
if_position_map.at(if_stmt.get()).as<ForNode>()->loop_var.get() == top_for_var) {may solve the problem.
BUG 2: src/tir/transforms/split_host_device.cc want the IR to be an SSA form, where each variable can only be defined once. Since we are copying loops into both "then" branches and "else" branches, we have to rename the loop variables in "else" branches to be different from those in "then" branches. I have already written some code for this, see #5553.
BUG 3: IfThenElse nodes containing thread indices should not be hoisted over the definition of the indices. This would happen when Attr node for thread_extent is scheduled into the body of a For node, using a compute_at command. I have already written some code for this, see #5553.
BUG 4:
Look at this line. if_stmt can already been updated when running this line. Look at the example below.
for (i, 0, 10) {
for (j, 0, 10) {
for (k, 0, 10) {
if ((i >= 3)) {
if ((j >= 3)) {
data[(((i*100) + (j*10)) + k)] = (data[(((i*100) + (j*10)) + k)] + 0.5f)
}
}
}
}
}
After hoisting j >= 3, if becomes
for (i, 0, 10) {
for (j, 0, 10) {
if ((j >= 3)) {
for (k, 0, 10) {
if ((i >= 3)) {
data[(((i*100) + (j*10)) + k)] = (data[(((i*100) + (j*10)) + k)] + 0.5f)
}
}
}
}
}
Now, when we are hoisting i >= 3, we need to compare and remove
if ((i >= 3)) {
if ((j >= 3)) {
data[(((i*100) + (j*10)) + k)] = (data[(((i*100) + (j*10)) + k)] + 0.5f)
}
}
But j >= 3 has been gone, so RemoveIf fails. We have to track the updating to IfThenElse just like what we did for For.
BUG 5: It is for tests this time.
Why do we expect a ('For', 'j') inside itself? As a potential problem, maybe we should change the variable names to prevent there are two is and two js.
These are all the bugs I found.
Beside, I suggest changing all the for (size_t i = 0; i < xxx.size(); i++) into for (size_t i = 0, n = xxx.size(); i < n; i++), since C++ compiler can't detect this loop invariant.
@kevinthesun Maybe you can have a look.