diff --git a/csrc/device_lower/pass/alias_memory.cpp b/csrc/device_lower/pass/alias_memory.cpp index 2c854e1a269..92f0468d4f8 100644 --- a/csrc/device_lower/pass/alias_memory.cpp +++ b/csrc/device_lower/pass/alias_memory.cpp @@ -1972,20 +1972,6 @@ class PromoteReuseSyncModifier : private kir::ExprMutator { debug() << "Inserting block sync before position " << position << std::endl; } - { - // TODO: This is a temporary HACK to work around - // https://github.com/NVIDIA/Fuser/issues/2000 - // Instead, we should only insert these wait statements when we detect - // that there are corresponding unsynced async operations involving the - // buffers in question. We should also update dispatch(Expr*) to check - // not only hasBlockSync but also check if there already exist AsyncWait - // expressions in the interval (or in some cases before the interval but - // after the last write?). - auto new_async_wait = - IrBuilder::create(AsyncOpType::CpAsync); - registerInsertBefore(expr, new_async_wait); - } - auto new_sync = IrBuilder::create(); inserted_syncs_.insert(new_sync); registerInsertBefore(expr, new_sync); diff --git a/csrc/device_lower/pass/double_buffer.cpp b/csrc/device_lower/pass/double_buffer.cpp index d8d2100fe46..18828272201 100644 --- a/csrc/device_lower/pass/double_buffer.cpp +++ b/csrc/device_lower/pass/double_buffer.cpp @@ -573,6 +573,16 @@ class DoubleBufferInserter : private kir::ExprMutator { // is more conceptual at the moment, aka low priority. if (has_cpasync) { insertCpAsyncCommitWaitInMainLoop(main_loop, loads); + + // The main loop will generate some async loads from invalid regions. + // These populate the current cp.async group and they fill the smem with + // zero. Subsequent code might assume an empty cp.async group (for example + // an unparallelized batch matmul), or might re-use memory (WAW + // hazard, see https://github.com/NVIDIA/Fuser/issues/2000). For safety, + // we drain the group after the loops by waiting on these transfers. + auto cp_async_wait_all = + IrBuilder::create(AsyncOpType::CpAsync, 0); + registerInsertAfter(double_buffer_loop, cp_async_wait_all); } if (requireEpilogue(loads)) { diff --git a/tests/cpp/test_loop_rotation.cpp b/tests/cpp/test_loop_rotation.cpp index 40a8cb66feb..739fab1b0a6 100644 --- a/tests/cpp/test_loop_rotation.cpp +++ b/tests/cpp/test_loop_rotation.cpp @@ -658,6 +658,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1[0LL] = T4[(3LL * ((1LL + i8) % 5LL))]; } + asm volatile("cp.async.wait_all;\n"); } )"; assertCUDAKernel(&fusion, expected_kernel);