-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Description
TL;DR. TVM doesn't allow syncing threads inside an IfThenElse, because it is very prone to deadlock or infinite wait once the conditional diverges. However, some conditionals should be eliminated before doing this check, e.g. if (0 <= xxxx.idx) { // certainly true thing }.
Details and causes of the error
The LSTM recipe in the current master link fails to compile, complaining stuff like:
>>> python $TVM_HOME/topi/recipe/rnn/lstm.py
Traceback (most recent call last):
File "topi/recipe/rnn/lstm.py", line 182, in <module>
lstm()
File "topi/recipe/rnn/lstm.py", line 179, in lstm
check_device("cuda")
File "topi/recipe/rnn/lstm.py", line 150, in check_device
target)
File "/home/junrus/Projects/tvm/python/tvm/build_module.py", line 585, in build
fhost, mdev = _build_for_device(flist, tar, target_host)
File "/home/junrus/Projects/tvm/python/tvm/build_module.py", line 417, in _build_for_device
func = ir_pass.ThreadSync(func, "global")
File "/home/junrus/Projects/tvm/python/tvm/_ffi/_ctypes/function.py", line 185, in __call__
ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
File "/home/junrus/Projects/tvm/python/tvm/_ffi/base.py", line 68, in check_call
raise TVMError(py_str(_LIB.TVMGetLastError()))
tvm._ffi.base.TVMError: [18:29:14] /home/junrus/Projects/tvm/src/pass/storage_sync.cc:76: Check failed: condition_counter() == 0 (1 vs. 0) Cannot insert syncs inside condition
Stack trace returned 10 entries:
[bt] (0) /home/junrus/Projects/tvm/build/libtvm.so(dmlc::StackTrace[abi:cxx11]()+0x1a9) [0x7f40217e5989]
[bt] (1) /home/junrus/Projects/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x18) [0x7f40217e6878]
[bt] (2) /home/junrus/Projects/tvm/build/libtvm.so(tvm::ir::ThreadSyncPlanner::Summarize(std::vector<tvm::ir::StorageAccessVisitor::StmtEntry, std::allocator<tvm::ir::StorageAccessVisitor::StmtEntry> >, HalideIR::Internal::For const*)+0x8f3) [0x7f4021a4dce3]
[bt] (3) /home/junrus/Projects/tvm/build/libtvm.so(tvm::ir::StorageAccessVisitor::Visit_(HalideIR::Internal::IfThenElse const*)+0x25d) [0x7f4021a29f1d]
[bt] (4) /home/junrus/Projects/tvm/build/libtvm.so(std::_Function_handler<void (tvm::NodeRef const&, tvm::ir::IRVisitor*), tvm::IRFunctor<void (tvm::NodeRef const&, tvm::ir::IRVisitor*)>& tvm::IRFunctor<void (tvm::NodeRef const&, tvm::ir::IRVisitor*)>::set_dispatch<HalideIR::Internal::IfThenElse>(std::function<void (HalideIR::Internal::IfThenElse const*, tvm::ir::IRVisitor*)>)::{lambda(tvm::NodeRef const&, tvm::ir::IRVisitor*)#1}>::_M_invoke(std::_Any_data const&, tvm::NodeRef const&, tvm::ir::IRVisitor*&&)+0x3b) [0x7f40219b932b]
[bt] (5) /home/junrus/Projects/tvm/build/libtvm.so(tvm::IRFunctor<void (tvm::NodeRef const&, tvm::ir::IRVisitor*)>::operator()(tvm::NodeRef const&, tvm::ir::IRVisitor*) const+0x3f6) [0x7f402183d076]
[bt] (6) /home/junrus/Projects/tvm/build/libtvm.so(std::_Function_handler<void (tvm::NodeRef const&, tvm::ir::IRVisitor*), tvm::IRFunctor<void (tvm::NodeRef const&, tvm::ir::IRVisitor*)>& tvm::IRFunctor<void (tvm::NodeRef const&, tvm::ir::IRVisitor*)>::set_dispatch<HalideIR::Internal::ProducerConsumer>(std::function<void (HalideIR::Internal::ProducerConsumer const*, tvm::ir::IRVisitor*)>)::{lambda(tvm::NodeRef const&, tvm::ir::IRVisitor*)#1}>::_M_invoke(std::_Any_data const&, tvm::NodeRef const&, tvm::ir::IRVisitor*&&)+0x3b) [0x7f40219b9c8b]
[bt] (7) /home/junrus/Projects/tvm/build/libtvm.so(tvm::IRFunctor<void (tvm::NodeRef const&, tvm::ir::IRVisitor*)>::operator()(tvm::NodeRef const&, tvm::ir::IRVisitor*) const+0x3f6) [0x7f402183d076]
[bt] (8) /home/junrus/Projects/tvm/build/libtvm.so(std::_Function_handler<void (tvm::NodeRef const&, tvm::ir::IRVisitor*), tvm::IRFunctor<void (tvm::NodeRef const&, tvm::ir::IRVisitor*)>& tvm::IRFunctor<void (tvm::NodeRef const&, tvm::ir::IRVisitor*)>::set_dispatch<HalideIR::Internal::Block>(std::function<void (HalideIR::Internal::Block const*, tvm::ir::IRVisitor*)>)::{lambda(tvm::NodeRef const&, tvm::ir::IRVisitor*)#1}>::_M_invoke(std::_Any_data const&, tvm::NodeRef const&, tvm::ir::IRVisitor*&&)+0x3b) [0x7f40219b9d7b]
[bt] (9) /home/junrus/Projects/tvm/build/libtvm.so(tvm::IRFunctor<void (tvm::NodeRef const&, tvm::ir::IRVisitor*)>::operator()(tvm::NodeRef const&, tvm::ir::IRVisitor*) const+0x3f6) [0x7f402183d076]
I dig into the problem and find the then branch of this line raised the error: if ((0 <= lstm_scan.idx)) in the lowered code below:
produce lstm_scan {
// attr [iter_var(blockIdx.x, Range(min=0, extent=24), blockIdx.x)] thread_extent = 24
// attr [Wh2h.local] storage_scope = "local"
allocate Wh2h.local[float32 * 4 * 1 * 72]
// attr [placeholder.shared] storage_scope = "shared"
allocate placeholder.shared[float32 * 1 * 1 * 24]
// attr [s_h2h] storage_scope = "local"
allocate s_h2h[float32 * 1 * 1 * 4 * 1]
// attr [placeholder.shared] storage_scope = "shared"
allocate placeholder.shared[float32 * 1 * 1 * 576]
// attr [s_h2h.rf] storage_scope = "local"
allocate s_h2h.rf[float32 * 1 * 1 * 1 * 1 * 1]
// attr [reduce_temp0] storage_scope = "local"
allocate reduce_temp0[float32 * 1]
// attr [next_c] storage_scope = "local"
allocate next_c[float32 * 1 * 1 * 1]
// attr [next_h] storage_scope = "local"
allocate next_h[float32 * 1 * 1 * 1]
// attr [iter_var(threadIdx.y, Range(min=0, extent=8), threadIdx.y)] thread_extent = 8
// attr [iter_var(threadIdx.x, Range(min=0, extent=24), threadIdx.x)] thread_extent = 24
produce Wh2h.local {
unrolled (ax0, 0, 4) {
unrolled (ax2, 0, 72) {
Wh2h.local[((ax0*72) + ax2)] = Wh2h[((((((blockIdx.x*192) + threadIdx.y) + (threadIdx.x*8)) + (ax0*4608))*72) + ax2)]
}
}
}
lstm_scan.v0[((blockIdx.x*24) + threadIdx.x)] = 0.000000f
lstm_scan.v1[((blockIdx.x*24) + threadIdx.x)] = 0.000000f
for (lstm_scan.idx, 0, (num_step + -1)) {
produce placeholder.shared {
unrolled (ax2, 0, 24) {
placeholder.shared[ax2] = lstm_scan.v1[(((blockIdx.x + (lstm_scan.idx*24))*24) + ax2)]
}
}
if ((0 <= lstm_scan.idx)) {
produce s_h2h {
for (ax2, 0, 4) {
produce s_h2h.rf {
produce placeholder.shared {
unrolled (ax2.outer, 0, 3) {
placeholder.shared[(((threadIdx.y*24) + threadIdx.x) + (ax2.outer*192))] = lstm_scan.v0[((((threadIdx.y*24) + threadIdx.x) + (lstm_scan.idx*576)) + (ax2.outer*192))]
}
}
s_h2h.rf[0] = 0.000000f
unrolled (ki2h.inner, 0, 72) {
s_h2h.rf[0] = (s_h2h.rf[0] + (placeholder.shared[((threadIdx.y*72) + ki2h.inner)]*Wh2h.local[((ax2*72) + ki2h.inner)]))
}
}
// attr [comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0.000000f])] reduce_scope = reinterpret((uint64)0)
tvm_thread_allreduce((uint32)1, s_h2h.rf[0], (uint1)1, reduce_temp0, threadIdx.y)
s_h2h[ax2] = reduce_temp0[0]
}
}
produce next_c {
if ((threadIdx.y == 0)) {
next_c[0] = ((sigmoid((Xi2h[((((blockIdx.x*24) + threadIdx.x) + (lstm_scan.idx*2304)) + 3456)] + s_h2h[2]))*placeholder.shared[threadIdx.x]) + (sigmoid((Xi2h[((((blockIdx.x*24) + threadIdx.x) +
(lstm_scan.idx*2304)) + 2304)] + s_h2h[0]))*tanh((Xi2h[((((blockIdx.x*24) + threadIdx.x) + (lstm_scan.idx*2304)) + 2880)] + s_h2h[1]))))
}
}
produce next_h {
if ((threadIdx.y == 0)) {
next_h[0] = (sigmoid((Xi2h[((((blockIdx.x*24) + threadIdx.x) + (lstm_scan.idx*2304)) + 4032)] + s_h2h[3]))*tanh(next_c[0]))
}
}
if ((threadIdx.y == 0)) {
lstm_scan.v0[((((blockIdx.x*24) + threadIdx.x) + (lstm_scan.idx*576)) + 576)] = next_h[0]
lstm_scan.v1[((((blockIdx.x*24) + threadIdx.x) + (lstm_scan.idx*576)) + 576)] = next_c[0]
}
}
}
}
I think this conditional should be eliminated before the ThreadSyncPlanner does the check in this line.
Possible Solution
I am not quite familiar with passes in TVM, but in which pass we could eliminate the conditionals? I think it would be better to be arranged before the ThreadSync pass.