Skip to content

[TIR] Bugs in HoistIfThenElse #5559

@roastduck

Description

@roastduck

HoistIfThenElse is a pass currently not enabled in TVM. I tried to enable it in #5553, but there are too many bugs in this pass. Let's fix them first.

BUG 1: HoistIfThenElse transforms

for (n.inner, 0, 2) {
  for (o.inner, 0, 2) {
    if ((((threadIdx.y*2) + n.inner) < 2)) {
      if ((((threadIdx.z*2) + o.inner) < 4)) {
        if ((threadIdx.y < 1)) {
          if ((threadIdx.z < 2)) {
            tvm_store_matrix_sync(Conv.wmma.accumulator, 16, 16, 16, ((n.inner*2) + o.inner), tvm_access_ptr(type_annotation(), Conv, (((((threadIdx.y*401408) + (n.inner*200704)) + (blockIdx.z*1024)) + (threadIdx.z*512)) + (o.inner*256)), 256, 2), 16, "row_major")
          }
        }
      }
    }
  }
}

into

if ((((threadIdx.y*2) + n.inner) < 2)) {
  if ((threadIdx.y < 1)) {
    if ((threadIdx.z < 2)) {
      for (n.inner, 0, 2) {
        for (o.inner, 0, 2) {
          if ((((threadIdx.z*2) + o.inner) < 4)) {
            tvm_store_matrix_sync(Conv.wmma.accumulator, 16, 16, 16, ((n.inner*2) + o.inner), tvm_access_ptr(type_annotation(), Conv, (((((threadIdx.y*401408) + (n.inner*200704)) + (blockIdx.z*1024)) + (threadIdx.z*512)) + (o.inner*256)), 256, 2), 16, "row_major")
          }
        }
      }
    }
  }
}

Possible cause:

https://github.com/apache/incubator-tvm/blob/0e877521f454e239f5c44bb88e557801444d81a5/src/tir/pass/hoist_if_then_else.cc#L295

It only checks whether if_stmt has a preferred position, but that position is not guaranteed to be the current position. Change it to

if (if_position_map.count(if_stmt.get()) &&
    if_position_map.at(if_stmt.get()).as<ForNode>()->loop_var.get() == top_for_var) {

may solve the problem.

BUG 2: src/tir/transforms/split_host_device.cc want the IR to be an SSA form, where each variable can only be defined once. Since we are copying loops into both "then" branches and "else" branches, we have to rename the loop variables in "else" branches to be different from those in "then" branches. I have already written some code for this, see #5553.

BUG 3: IfThenElse nodes containing thread indices should not be hoisted over the definition of the indices. This would happen when Attr node for thread_extent is scheduled into the body of a For node, using a compute_at command. I have already written some code for this, see #5553.

BUG 4:

https://github.com/apache/incubator-tvm/blob/0e877521f454e239f5c44bb88e557801444d81a5/src/tir/pass/hoist_if_then_else.cc#L371

Look at this line. if_stmt can already been updated when running this line. Look at the example below.

for (i, 0, 10) {
  for (j, 0, 10) {
    for (k, 0, 10) {
      if ((i >= 3)) {
        if ((j >= 3)) {
          data[(((i*100) + (j*10)) + k)] = (data[(((i*100) + (j*10)) + k)] + 0.5f)
        }
      }
    }
  }
}

After hoisting j >= 3, if becomes

for (i, 0, 10) {
  for (j, 0, 10) {
    if ((j >= 3)) {
      for (k, 0, 10) {
        if ((i >= 3)) {
          data[(((i*100) + (j*10)) + k)] = (data[(((i*100) + (j*10)) + k)] + 0.5f)
        }
      }
    }
  }
}

Now, when we are hoisting i >= 3, we need to compare and remove

if ((i >= 3)) {
  if ((j >= 3)) {
    data[(((i*100) + (j*10)) + k)] = (data[(((i*100) + (j*10)) + k)] + 0.5f)
  }
}

But j >= 3 has been gone, so RemoveIf fails. We have to track the updating to IfThenElse just like what we did for For.

BUG 5: It is for tests this time.

https://github.com/apache/incubator-tvm/blob/0e877521f454e239f5c44bb88e557801444d81a5/tests/python/unittest/test_tir_pass_hoist_if.py#L175

Why do we expect a ('For', 'j') inside itself? As a potential problem, maybe we should change the variable names to prevent there are two is and two js.

These are all the bugs I found.

Beside, I suggest changing all the for (size_t i = 0; i < xxx.size(); i++) into for (size_t i = 0, n = xxx.size(); i < n; i++), since C++ compiler can't detect this loop invariant.

@kevinthesun Maybe you can have a look.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions