From 3d4f5ae2c9e9db0399f3d9cdf9fbd10cfa8164ac Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 24 Jun 2024 09:00:20 -0500 Subject: [PATCH 1/3] [CudaGraph] Handle exceptions thrown while capturing cuda graph Prior to this commit, an exception thrown during the capture of a cuda graph would result in `std::terminate` being called. This commit updates the implementation of `"vm.builtin.cuda_graph.run_or_capture"` such that a thrown exception can be recovered from, and does not cause any changes to the state of TVM's cuda graph cache. - Call to `cudaStreamDestroy` was previously skipped, now moved to a RAII-style destructor in a `ScopedCUDAStream` class. - Call to `cudaStreamEndCapture` was previously skipped, end of cuda graph capture now performed as part of RAII-style destructor for `CUDACaptureStream` class. - Restoration of `CUDAThreadEntry::ThreadLocal()->stream` was previously skipped, now restored as part of RAII-style destructor for `CUDACaptureStream` class. - Previously, an error raised from `cudaGraphInstantiate` would leave the `capture_cache_` in an ill-formed state. Now, the `capture_cache_` is only updated after a valid `CUDAGraphCapturedState` has been fully constructed. --- .../relax_vm/cuda/cuda_graph_builtin.cc | 77 +++++++++++++++---- tests/python/relax/test_vm_cuda_graph.py | 77 ++++++++++++++++++- 2 files changed, 137 insertions(+), 17 deletions(-) diff --git a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc index dea497e4a9d7..4782291d1ad2 100644 --- a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc +++ b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc @@ -32,6 +32,8 @@ namespace tvm { namespace runtime { namespace relax_vm { +namespace { + struct CUDAGraphCaptureKey { // The unique index of the capture function within the module int64_t index; @@ -67,6 +69,18 @@ struct CUDAGraphCaptureKeyEqual { /*! \brief The captured state of a CUDA graph */ struct CUDAGraphCapturedState { + CUDAGraphCapturedState() {} + + CUDAGraphCapturedState(const CUDAGraphCapturedState&) = delete; + CUDAGraphCapturedState(CUDAGraphCapturedState&& other) { *this = std::move(other); } + + CUDAGraphCapturedState& operator=(const CUDAGraphCapturedState&) = delete; + CUDAGraphCapturedState& operator=(CUDAGraphCapturedState&& other) { + std::swap(states, other.states); + std::swap(exec, other.exec); + return *this; + } + ~CUDAGraphCapturedState() { if (exec) { CUDA_CALL(cudaGraphExecDestroy(exec)); @@ -82,6 +96,43 @@ struct CUDAGraphCapturedState { cudaGraphExec_t exec = nullptr; }; +class ScopedCUDAStream { + public: + ScopedCUDAStream() { CUDA_CALL(cudaStreamCreate(&stream_)); } + ~ScopedCUDAStream() { cudaStreamDestroy(stream_); } + ScopedCUDAStream(const ScopedCUDAStream&) = delete; + ScopedCUDAStream(ScopedCUDAStream&&) = delete; + ScopedCUDAStream& operator=(const ScopedCUDAStream&) = delete; + ScopedCUDAStream& operator=(ScopedCUDAStream&&) = delete; + + operator cudaStream_t() const { return stream_; } + + private: + cudaStream_t stream_; +}; + +class CUDACaptureStream { + public: + CUDACaptureStream(cudaGraph_t* graph) + : prev_default_stream_(CUDAThreadEntry::ThreadLocal()->stream), output_graph_(graph) { + CUDAThreadEntry::ThreadLocal()->stream = capture_stream_; + + CUDA_CALL(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal)); + } + ~CUDACaptureStream() { + cudaStreamEndCapture(capture_stream_, output_graph_); + CUDAThreadEntry::ThreadLocal()->stream = prev_default_stream_; + } + + private: + cudaStream_t prev_default_stream_; + ScopedCUDAStream capture_stream_; + + cudaGraph_t* output_graph_; +}; + +} // namespace + /*! \brief The VM extension of CUDA graph. */ class CUDAGraphExtensionNode : public VMExtensionNode { public: @@ -107,10 +158,6 @@ class CUDAGraphExtensionNode : public VMExtensionNode { return states; } - cudaStream_t capture_stream; - CUDA_CALL(cudaStreamCreate(&capture_stream)); - CUDAGraphCapturedState entry; - // Set up arguments for the graph execution Array tuple_args = Downcast>(args); int nargs = static_cast(tuple_args.size()); @@ -130,20 +177,20 @@ class CUDAGraphExtensionNode : public VMExtensionNode { // Run the graph in capture mode cudaGraph_t graph; - std::swap(capture_stream, CUDAThreadEntry::ThreadLocal()->stream); - CUDA_CALL(cudaStreamBeginCapture(CUDAThreadEntry::ThreadLocal()->stream, - cudaStreamCaptureModeGlobal)); - vm->InvokeClosurePacked(capture_func, TVMArgs(values.data(), tcodes.data(), nargs), - &capture_func_rv); - entry.states = capture_func_rv; - CUDA_CALL(cudaStreamEndCapture(CUDAThreadEntry::ThreadLocal()->stream, &graph)); - std::swap(capture_stream, CUDAThreadEntry::ThreadLocal()->stream); + { + CUDACaptureStream capture_stream(&graph); + vm->InvokeClosurePacked(capture_func, TVMArgs(values.data(), tcodes.data(), nargs), + &capture_func_rv); + } - capture_cache_[entry_key] = entry; - CUDA_CALL(cudaGraphInstantiate(&capture_cache_[entry_key].exec, graph, NULL, NULL, 0)); - CUDA_CALL(cudaStreamDestroy(capture_stream)); + CUDAGraphCapturedState entry; + entry.states = capture_func_rv; + CUDA_CALL(cudaGraphInstantiate(&entry.exec, graph, NULL, NULL, 0)); CUDA_CALL(cudaGraphDestroy(graph)); + + capture_cache_[entry_key] = std::move(entry); + return entry.states; } diff --git a/tests/python/relax/test_vm_cuda_graph.py b/tests/python/relax/test_vm_cuda_graph.py index 6a20b6b1f892..49ebcc1d05b2 100644 --- a/tests/python/relax/test_vm_cuda_graph.py +++ b/tests/python/relax/test_vm_cuda_graph.py @@ -16,10 +16,13 @@ # under the License. import tvm -from tvm.script import tir as T, relax as R, ir as I -from tvm import relax import tvm.testing + +from tvm import relax +from tvm.script import tir as T, relax as R, ir as I + import numpy as np +import pytest # fmt: off @@ -104,5 +107,75 @@ def test_vm_run(): tvm.testing.assert_allclose(y.asnumpy(), y_np, rtol=1e-5, atol=1e-5) +@tvm.testing.requires_cudagraph +def test_capture_error_is_recoverable(): + """Function calls while capturing cudagraph may throw exceptions + + Calls to PackedFuncs may occur within a captured cudaGraph. If a + call to that PackedFunc raises an exception while capturing the + cudaGraph, throwing exception should cleanly unwind the stack, and + the exception may be caught in the calling scope. + + This is a regression test. In previous implementations, an + exception thrown while capturing a cudaGraph would skip the call + to `cudaStreamEndCapture`, causing additional exceptions to be + thrown while freeing memory in TVM destructors. Since C++ does + not support stack unwinding from multiple simultaneous exceptions, + this would result in immediate `std::terminate`, making it + difficult to debug the original error. + + """ + + target = tvm.target.Target("cuda") + dev = tvm.cuda() + + @tvm.register_func("test_vm_cuda_graph.invalid_impl_for_cudagraph", override=True) + def invalid_impl_for_cudagraph(arg_tensor): + # Memory allocation/deallocation may not be performed while + # capturing a cudaGraph. This passes the warm-up run + # performed by "vm.builtin.cuda_graph.run_or_capture", but + # throws an exception when the cudaGraph is being captured. + _dummy_workspace = tvm.nd.empty([16], "float16", dev) + return arg_tensor + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([16], "float16")): + B = R.add(A, A) + C = R.call_pure_packed( + "test_vm_cuda_graph.invalid_impl_for_cudagraph", + B, + sinfo_args=R.Tensor([16], "float16"), + ) + D = R.add(C, C) + return D + + with target, tvm.ir.transform.PassContext(config={"relax.backend.use_cuda_graph": True}): + Module = tvm.ir.transform.Sequential( + [ + tvm.relax.transform.LegalizeOps(), + tvm.tir.transform.DefaultGPUSchedule(), + tvm.relax.transform.RemovePurityChecking(), + tvm.relax.transform.CallTIRRewrite(), + tvm.relax.transform.StaticPlanBlockMemory(), + tvm.relax.transform.RewriteCUDAGraph(), + ] + )(Module) + + assert "cuda_graph_alloc" in Module, ( + "Validity of unit test requires the call to `invalid_impl_for_cudagraph` " + "to have been captured by RewriteCUDAGraph." + ) + + built = tvm.relax.build(Module, target=target) + vm = tvm.relax.VirtualMachine(built, dev) + + arg = tvm.nd.array(np.arange(16).astype("float16"), dev) + + with pytest.raises(tvm.TVMError): + vm["main"](arg) + + if __name__ == "__main__": tvm.testing.main() From 96f841b5f4010f75f8180c223b686b78797933d6 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 24 Jun 2024 11:02:42 -0500 Subject: [PATCH 2/3] lint fix --- src/runtime/relax_vm/cuda/cuda_graph_builtin.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc index 4782291d1ad2..ae7d3a7067e3 100644 --- a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc +++ b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc @@ -113,7 +113,7 @@ class ScopedCUDAStream { class CUDACaptureStream { public: - CUDACaptureStream(cudaGraph_t* graph) + explicit CUDACaptureStream(cudaGraph_t* graph) : prev_default_stream_(CUDAThreadEntry::ThreadLocal()->stream), output_graph_(graph) { CUDAThreadEntry::ThreadLocal()->stream = capture_stream_; From 03ce9d1255078675345e95a2ced691ab095c235d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 24 Jun 2024 19:45:46 -0500 Subject: [PATCH 3/3] Unit test fix --- src/runtime/relax_vm/cuda/cuda_graph_builtin.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc index ae7d3a7067e3..e8901c0f19fa 100644 --- a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc +++ b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc @@ -189,9 +189,11 @@ class CUDAGraphExtensionNode : public VMExtensionNode { CUDA_CALL(cudaGraphInstantiate(&entry.exec, graph, NULL, NULL, 0)); CUDA_CALL(cudaGraphDestroy(graph)); + ObjectRef states = entry.states; + capture_cache_[entry_key] = std::move(entry); - return entry.states; + return states; } /*!