diff --git a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc index dea497e4a9d7..e8901c0f19fa 100644 --- a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc +++ b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc @@ -32,6 +32,8 @@ namespace tvm { namespace runtime { namespace relax_vm { +namespace { + struct CUDAGraphCaptureKey { // The unique index of the capture function within the module int64_t index; @@ -67,6 +69,18 @@ struct CUDAGraphCaptureKeyEqual { /*! \brief The captured state of a CUDA graph */ struct CUDAGraphCapturedState { + CUDAGraphCapturedState() {} + + CUDAGraphCapturedState(const CUDAGraphCapturedState&) = delete; + CUDAGraphCapturedState(CUDAGraphCapturedState&& other) { *this = std::move(other); } + + CUDAGraphCapturedState& operator=(const CUDAGraphCapturedState&) = delete; + CUDAGraphCapturedState& operator=(CUDAGraphCapturedState&& other) { + std::swap(states, other.states); + std::swap(exec, other.exec); + return *this; + } + ~CUDAGraphCapturedState() { if (exec) { CUDA_CALL(cudaGraphExecDestroy(exec)); @@ -82,6 +96,43 @@ struct CUDAGraphCapturedState { cudaGraphExec_t exec = nullptr; }; +class ScopedCUDAStream { + public: + ScopedCUDAStream() { CUDA_CALL(cudaStreamCreate(&stream_)); } + ~ScopedCUDAStream() { cudaStreamDestroy(stream_); } + ScopedCUDAStream(const ScopedCUDAStream&) = delete; + ScopedCUDAStream(ScopedCUDAStream&&) = delete; + ScopedCUDAStream& operator=(const ScopedCUDAStream&) = delete; + ScopedCUDAStream& operator=(ScopedCUDAStream&&) = delete; + + operator cudaStream_t() const { return stream_; } + + private: + cudaStream_t stream_; +}; + +class CUDACaptureStream { + public: + explicit CUDACaptureStream(cudaGraph_t* graph) + : prev_default_stream_(CUDAThreadEntry::ThreadLocal()->stream), output_graph_(graph) { + CUDAThreadEntry::ThreadLocal()->stream = capture_stream_; + + CUDA_CALL(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal)); + } + ~CUDACaptureStream() { + cudaStreamEndCapture(capture_stream_, output_graph_); + CUDAThreadEntry::ThreadLocal()->stream = prev_default_stream_; + } + + private: + cudaStream_t prev_default_stream_; + ScopedCUDAStream capture_stream_; + + cudaGraph_t* output_graph_; +}; + +} // namespace + /*! \brief The VM extension of CUDA graph. */ class CUDAGraphExtensionNode : public VMExtensionNode { public: @@ -107,10 +158,6 @@ class CUDAGraphExtensionNode : public VMExtensionNode { return states; } - cudaStream_t capture_stream; - CUDA_CALL(cudaStreamCreate(&capture_stream)); - CUDAGraphCapturedState entry; - // Set up arguments for the graph execution Array tuple_args = Downcast>(args); int nargs = static_cast(tuple_args.size()); @@ -130,21 +177,23 @@ class CUDAGraphExtensionNode : public VMExtensionNode { // Run the graph in capture mode cudaGraph_t graph; - std::swap(capture_stream, CUDAThreadEntry::ThreadLocal()->stream); - CUDA_CALL(cudaStreamBeginCapture(CUDAThreadEntry::ThreadLocal()->stream, - cudaStreamCaptureModeGlobal)); - vm->InvokeClosurePacked(capture_func, TVMArgs(values.data(), tcodes.data(), nargs), - &capture_func_rv); - entry.states = capture_func_rv; - CUDA_CALL(cudaStreamEndCapture(CUDAThreadEntry::ThreadLocal()->stream, &graph)); - std::swap(capture_stream, CUDAThreadEntry::ThreadLocal()->stream); + { + CUDACaptureStream capture_stream(&graph); + vm->InvokeClosurePacked(capture_func, TVMArgs(values.data(), tcodes.data(), nargs), + &capture_func_rv); + } - capture_cache_[entry_key] = entry; - CUDA_CALL(cudaGraphInstantiate(&capture_cache_[entry_key].exec, graph, NULL, NULL, 0)); - CUDA_CALL(cudaStreamDestroy(capture_stream)); + CUDAGraphCapturedState entry; + entry.states = capture_func_rv; + CUDA_CALL(cudaGraphInstantiate(&entry.exec, graph, NULL, NULL, 0)); CUDA_CALL(cudaGraphDestroy(graph)); - return entry.states; + + ObjectRef states = entry.states; + + capture_cache_[entry_key] = std::move(entry); + + return states; } /*! diff --git a/tests/python/relax/test_vm_cuda_graph.py b/tests/python/relax/test_vm_cuda_graph.py index 6a20b6b1f892..49ebcc1d05b2 100644 --- a/tests/python/relax/test_vm_cuda_graph.py +++ b/tests/python/relax/test_vm_cuda_graph.py @@ -16,10 +16,13 @@ # under the License. import tvm -from tvm.script import tir as T, relax as R, ir as I -from tvm import relax import tvm.testing + +from tvm import relax +from tvm.script import tir as T, relax as R, ir as I + import numpy as np +import pytest # fmt: off @@ -104,5 +107,75 @@ def test_vm_run(): tvm.testing.assert_allclose(y.asnumpy(), y_np, rtol=1e-5, atol=1e-5) +@tvm.testing.requires_cudagraph +def test_capture_error_is_recoverable(): + """Function calls while capturing cudagraph may throw exceptions + + Calls to PackedFuncs may occur within a captured cudaGraph. If a + call to that PackedFunc raises an exception while capturing the + cudaGraph, throwing exception should cleanly unwind the stack, and + the exception may be caught in the calling scope. + + This is a regression test. In previous implementations, an + exception thrown while capturing a cudaGraph would skip the call + to `cudaStreamEndCapture`, causing additional exceptions to be + thrown while freeing memory in TVM destructors. Since C++ does + not support stack unwinding from multiple simultaneous exceptions, + this would result in immediate `std::terminate`, making it + difficult to debug the original error. + + """ + + target = tvm.target.Target("cuda") + dev = tvm.cuda() + + @tvm.register_func("test_vm_cuda_graph.invalid_impl_for_cudagraph", override=True) + def invalid_impl_for_cudagraph(arg_tensor): + # Memory allocation/deallocation may not be performed while + # capturing a cudaGraph. This passes the warm-up run + # performed by "vm.builtin.cuda_graph.run_or_capture", but + # throws an exception when the cudaGraph is being captured. + _dummy_workspace = tvm.nd.empty([16], "float16", dev) + return arg_tensor + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([16], "float16")): + B = R.add(A, A) + C = R.call_pure_packed( + "test_vm_cuda_graph.invalid_impl_for_cudagraph", + B, + sinfo_args=R.Tensor([16], "float16"), + ) + D = R.add(C, C) + return D + + with target, tvm.ir.transform.PassContext(config={"relax.backend.use_cuda_graph": True}): + Module = tvm.ir.transform.Sequential( + [ + tvm.relax.transform.LegalizeOps(), + tvm.tir.transform.DefaultGPUSchedule(), + tvm.relax.transform.RemovePurityChecking(), + tvm.relax.transform.CallTIRRewrite(), + tvm.relax.transform.StaticPlanBlockMemory(), + tvm.relax.transform.RewriteCUDAGraph(), + ] + )(Module) + + assert "cuda_graph_alloc" in Module, ( + "Validity of unit test requires the call to `invalid_impl_for_cudagraph` " + "to have been captured by RewriteCUDAGraph." + ) + + built = tvm.relax.build(Module, target=target) + vm = tvm.relax.VirtualMachine(built, dev) + + arg = tvm.nd.array(np.arange(16).astype("float16"), dev) + + with pytest.raises(tvm.TVMError): + vm["main"](arg) + + if __name__ == "__main__": tvm.testing.main()