Skip to content

TOPI LSTM recipe fails to compile due to inserting sync is disallowed inside unnecessary IfThenElse #2088

@junrushao

Description

@junrushao

TL;DR. TVM doesn't allow syncing threads inside an IfThenElse, because it is very prone to deadlock or infinite wait once the conditional diverges. However, some conditionals should be eliminated before doing this check, e.g. if (0 <= xxxx.idx) { // certainly true thing }.

Details and causes of the error

The LSTM recipe in the current master link fails to compile, complaining stuff like:

>>> python $TVM_HOME/topi/recipe/rnn/lstm.py
Traceback (most recent call last):
  File "topi/recipe/rnn/lstm.py", line 182, in <module>
    lstm()
  File "topi/recipe/rnn/lstm.py", line 179, in lstm
    check_device("cuda")
  File "topi/recipe/rnn/lstm.py", line 150, in check_device
    target)
  File "/home/junrus/Projects/tvm/python/tvm/build_module.py", line 585, in build
    fhost, mdev = _build_for_device(flist, tar, target_host)
  File "/home/junrus/Projects/tvm/python/tvm/build_module.py", line 417, in _build_for_device
    func = ir_pass.ThreadSync(func, "global")
  File "/home/junrus/Projects/tvm/python/tvm/_ffi/_ctypes/function.py", line 185, in __call__
    ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
  File "/home/junrus/Projects/tvm/python/tvm/_ffi/base.py", line 68, in check_call
    raise TVMError(py_str(_LIB.TVMGetLastError()))
tvm._ffi.base.TVMError: [18:29:14] /home/junrus/Projects/tvm/src/pass/storage_sync.cc:76: Check failed: condition_counter() == 0 (1 vs. 0) Cannot insert syncs inside condition

Stack trace returned 10 entries:
[bt] (0) /home/junrus/Projects/tvm/build/libtvm.so(dmlc::StackTrace[abi:cxx11]()+0x1a9) [0x7f40217e5989]
[bt] (1) /home/junrus/Projects/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x18) [0x7f40217e6878]
[bt] (2) /home/junrus/Projects/tvm/build/libtvm.so(tvm::ir::ThreadSyncPlanner::Summarize(std::vector<tvm::ir::StorageAccessVisitor::StmtEntry, std::allocator<tvm::ir::StorageAccessVisitor::StmtEntry> >, HalideIR::Internal::For const*)+0x8f3) [0x7f4021a4dce3]
[bt] (3) /home/junrus/Projects/tvm/build/libtvm.so(tvm::ir::StorageAccessVisitor::Visit_(HalideIR::Internal::IfThenElse const*)+0x25d) [0x7f4021a29f1d]
[bt] (4) /home/junrus/Projects/tvm/build/libtvm.so(std::_Function_handler<void (tvm::NodeRef const&, tvm::ir::IRVisitor*), tvm::IRFunctor<void (tvm::NodeRef const&, tvm::ir::IRVisitor*)>& tvm::IRFunctor<void (tvm::NodeRef const&, tvm::ir::IRVisitor*)>::set_dispatch<HalideIR::Internal::IfThenElse>(std::function<void (HalideIR::Internal::IfThenElse const*, tvm::ir::IRVisitor*)>)::{lambda(tvm::NodeRef const&, tvm::ir::IRVisitor*)#1}>::_M_invoke(std::_Any_data const&, tvm::NodeRef const&, tvm::ir::IRVisitor*&&)+0x3b) [0x7f40219b932b]
[bt] (5) /home/junrus/Projects/tvm/build/libtvm.so(tvm::IRFunctor<void (tvm::NodeRef const&, tvm::ir::IRVisitor*)>::operator()(tvm::NodeRef const&, tvm::ir::IRVisitor*) const+0x3f6) [0x7f402183d076]
[bt] (6) /home/junrus/Projects/tvm/build/libtvm.so(std::_Function_handler<void (tvm::NodeRef const&, tvm::ir::IRVisitor*), tvm::IRFunctor<void (tvm::NodeRef const&, tvm::ir::IRVisitor*)>& tvm::IRFunctor<void (tvm::NodeRef const&, tvm::ir::IRVisitor*)>::set_dispatch<HalideIR::Internal::ProducerConsumer>(std::function<void (HalideIR::Internal::ProducerConsumer const*, tvm::ir::IRVisitor*)>)::{lambda(tvm::NodeRef const&, tvm::ir::IRVisitor*)#1}>::_M_invoke(std::_Any_data const&, tvm::NodeRef const&, tvm::ir::IRVisitor*&&)+0x3b) [0x7f40219b9c8b]
[bt] (7) /home/junrus/Projects/tvm/build/libtvm.so(tvm::IRFunctor<void (tvm::NodeRef const&, tvm::ir::IRVisitor*)>::operator()(tvm::NodeRef const&, tvm::ir::IRVisitor*) const+0x3f6) [0x7f402183d076]
[bt] (8) /home/junrus/Projects/tvm/build/libtvm.so(std::_Function_handler<void (tvm::NodeRef const&, tvm::ir::IRVisitor*), tvm::IRFunctor<void (tvm::NodeRef const&, tvm::ir::IRVisitor*)>& tvm::IRFunctor<void (tvm::NodeRef const&, tvm::ir::IRVisitor*)>::set_dispatch<HalideIR::Internal::Block>(std::function<void (HalideIR::Internal::Block const*, tvm::ir::IRVisitor*)>)::{lambda(tvm::NodeRef const&, tvm::ir::IRVisitor*)#1}>::_M_invoke(std::_Any_data const&, tvm::NodeRef const&, tvm::ir::IRVisitor*&&)+0x3b) [0x7f40219b9d7b]
[bt] (9) /home/junrus/Projects/tvm/build/libtvm.so(tvm::IRFunctor<void (tvm::NodeRef const&, tvm::ir::IRVisitor*)>::operator()(tvm::NodeRef const&, tvm::ir::IRVisitor*) const+0x3f6) [0x7f402183d076]

I dig into the problem and find the then branch of this line raised the error: if ((0 <= lstm_scan.idx)) in the lowered code below:

produce lstm_scan {
  // attr [iter_var(blockIdx.x, Range(min=0, extent=24), blockIdx.x)] thread_extent = 24
  // attr [Wh2h.local] storage_scope = "local"
  allocate Wh2h.local[float32 * 4 * 1 * 72]
  // attr [placeholder.shared] storage_scope = "shared"
  allocate placeholder.shared[float32 * 1 * 1 * 24]
  // attr [s_h2h] storage_scope = "local"
  allocate s_h2h[float32 * 1 * 1 * 4 * 1]
  // attr [placeholder.shared] storage_scope = "shared"
  allocate placeholder.shared[float32 * 1 * 1 * 576]
  // attr [s_h2h.rf] storage_scope = "local"
  allocate s_h2h.rf[float32 * 1 * 1 * 1 * 1 * 1]
  // attr [reduce_temp0] storage_scope = "local"
  allocate reduce_temp0[float32 * 1]
  // attr [next_c] storage_scope = "local"
  allocate next_c[float32 * 1 * 1 * 1]
  // attr [next_h] storage_scope = "local"
  allocate next_h[float32 * 1 * 1 * 1]
  // attr [iter_var(threadIdx.y, Range(min=0, extent=8), threadIdx.y)] thread_extent = 8
  // attr [iter_var(threadIdx.x, Range(min=0, extent=24), threadIdx.x)] thread_extent = 24
  produce Wh2h.local {
    unrolled (ax0, 0, 4) {
      unrolled (ax2, 0, 72) {
        Wh2h.local[((ax0*72) + ax2)] = Wh2h[((((((blockIdx.x*192) + threadIdx.y) + (threadIdx.x*8)) + (ax0*4608))*72) + ax2)]
      }
    }
  }
  lstm_scan.v0[((blockIdx.x*24) + threadIdx.x)] = 0.000000f
  lstm_scan.v1[((blockIdx.x*24) + threadIdx.x)] = 0.000000f
  for (lstm_scan.idx, 0, (num_step + -1)) {
    produce placeholder.shared {
      unrolled (ax2, 0, 24) {
        placeholder.shared[ax2] = lstm_scan.v1[(((blockIdx.x + (lstm_scan.idx*24))*24) + ax2)]
      }
    }
    if ((0 <= lstm_scan.idx)) {
      produce s_h2h {
        for (ax2, 0, 4) {
          produce s_h2h.rf {
            produce placeholder.shared {
              unrolled (ax2.outer, 0, 3) {
                placeholder.shared[(((threadIdx.y*24) + threadIdx.x) + (ax2.outer*192))] = lstm_scan.v0[((((threadIdx.y*24) + threadIdx.x) + (lstm_scan.idx*576)) + (ax2.outer*192))]
              }
            }
            s_h2h.rf[0] = 0.000000f
            unrolled (ki2h.inner, 0, 72) {
              s_h2h.rf[0] = (s_h2h.rf[0] + (placeholder.shared[((threadIdx.y*72) + ki2h.inner)]*Wh2h.local[((ax2*72) + ki2h.inner)]))
            }
          }
          // attr [comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0.000000f])] reduce_scope = reinterpret((uint64)0)
          tvm_thread_allreduce((uint32)1, s_h2h.rf[0], (uint1)1, reduce_temp0, threadIdx.y)
          s_h2h[ax2] = reduce_temp0[0]
        }
      }
      produce next_c {
        if ((threadIdx.y == 0)) {
          next_c[0] = ((sigmoid((Xi2h[((((blockIdx.x*24) + threadIdx.x) + (lstm_scan.idx*2304)) + 3456)] + s_h2h[2]))*placeholder.shared[threadIdx.x]) + (sigmoid((Xi2h[((((blockIdx.x*24) + threadIdx.x) +
(lstm_scan.idx*2304)) + 2304)] + s_h2h[0]))*tanh((Xi2h[((((blockIdx.x*24) + threadIdx.x) + (lstm_scan.idx*2304)) + 2880)] + s_h2h[1]))))
        }
      }
      produce next_h {
        if ((threadIdx.y == 0)) {
          next_h[0] = (sigmoid((Xi2h[((((blockIdx.x*24) + threadIdx.x) + (lstm_scan.idx*2304)) + 4032)] + s_h2h[3]))*tanh(next_c[0]))
        }
      }
      if ((threadIdx.y == 0)) {
        lstm_scan.v0[((((blockIdx.x*24) + threadIdx.x) + (lstm_scan.idx*576)) + 576)] = next_h[0]
        lstm_scan.v1[((((blockIdx.x*24) + threadIdx.x) + (lstm_scan.idx*576)) + 576)] = next_c[0]
      }
    }
  }
}

I think this conditional should be eliminated before the ThreadSyncPlanner does the check in this line.

Possible Solution

I am not quite familiar with passes in TVM, but in which pass we could eliminate the conditionals? I think it would be better to be arranged before the ThreadSync pass.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions