Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 65 additions & 16 deletions src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ namespace tvm {
namespace runtime {
namespace relax_vm {

namespace {

struct CUDAGraphCaptureKey {
// The unique index of the capture function within the module
int64_t index;
Expand Down Expand Up @@ -67,6 +69,18 @@ struct CUDAGraphCaptureKeyEqual {

/*! \brief The captured state of a CUDA graph */
struct CUDAGraphCapturedState {
CUDAGraphCapturedState() {}

CUDAGraphCapturedState(const CUDAGraphCapturedState&) = delete;
CUDAGraphCapturedState(CUDAGraphCapturedState&& other) { *this = std::move(other); }

CUDAGraphCapturedState& operator=(const CUDAGraphCapturedState&) = delete;
CUDAGraphCapturedState& operator=(CUDAGraphCapturedState&& other) {
std::swap(states, other.states);
std::swap(exec, other.exec);
return *this;
}

~CUDAGraphCapturedState() {
if (exec) {
CUDA_CALL(cudaGraphExecDestroy(exec));
Expand All @@ -82,6 +96,43 @@ struct CUDAGraphCapturedState {
cudaGraphExec_t exec = nullptr;
};

class ScopedCUDAStream {
public:
ScopedCUDAStream() { CUDA_CALL(cudaStreamCreate(&stream_)); }
~ScopedCUDAStream() { cudaStreamDestroy(stream_); }
ScopedCUDAStream(const ScopedCUDAStream&) = delete;
ScopedCUDAStream(ScopedCUDAStream&&) = delete;
ScopedCUDAStream& operator=(const ScopedCUDAStream&) = delete;
ScopedCUDAStream& operator=(ScopedCUDAStream&&) = delete;

operator cudaStream_t() const { return stream_; }

private:
cudaStream_t stream_;
};

class CUDACaptureStream {
public:
explicit CUDACaptureStream(cudaGraph_t* graph)
: prev_default_stream_(CUDAThreadEntry::ThreadLocal()->stream), output_graph_(graph) {
CUDAThreadEntry::ThreadLocal()->stream = capture_stream_;

CUDA_CALL(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal));
}
~CUDACaptureStream() {
cudaStreamEndCapture(capture_stream_, output_graph_);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be enclosed with CUDA_CALL to check return code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested it with CUDA_CALL to verify, and using the CUDA_CALL would re-introduce the same bug that this change is intended to fix. If the stack is unwinding due to a thrown exception, then throwing another exception would result in std::terminate being called. To avoid this, destructors shouldn't throw exceptions (stackoverflow link).

There are some ways to use std::uncaught_exceptions to determine whether an exception is being unwound, and to conditionally throw an exception if it isn't already the case. However, those tend to be pretty context-dependent, and probably aren't worth using in this case.

CUDAThreadEntry::ThreadLocal()->stream = prev_default_stream_;
}

private:
cudaStream_t prev_default_stream_;
ScopedCUDAStream capture_stream_;

cudaGraph_t* output_graph_;
};

} // namespace

/*! \brief The VM extension of CUDA graph. */
class CUDAGraphExtensionNode : public VMExtensionNode {
public:
Expand All @@ -107,10 +158,6 @@ class CUDAGraphExtensionNode : public VMExtensionNode {
return states;
}

cudaStream_t capture_stream;
CUDA_CALL(cudaStreamCreate(&capture_stream));
CUDAGraphCapturedState entry;

// Set up arguments for the graph execution
Array<ObjectRef> tuple_args = Downcast<Array<ObjectRef>>(args);
int nargs = static_cast<int>(tuple_args.size());
Expand All @@ -130,21 +177,23 @@ class CUDAGraphExtensionNode : public VMExtensionNode {

// Run the graph in capture mode
cudaGraph_t graph;
std::swap(capture_stream, CUDAThreadEntry::ThreadLocal()->stream);
CUDA_CALL(cudaStreamBeginCapture(CUDAThreadEntry::ThreadLocal()->stream,
cudaStreamCaptureModeGlobal));

vm->InvokeClosurePacked(capture_func, TVMArgs(values.data(), tcodes.data(), nargs),
&capture_func_rv);
entry.states = capture_func_rv;
CUDA_CALL(cudaStreamEndCapture(CUDAThreadEntry::ThreadLocal()->stream, &graph));
std::swap(capture_stream, CUDAThreadEntry::ThreadLocal()->stream);
{
CUDACaptureStream capture_stream(&graph);
vm->InvokeClosurePacked(capture_func, TVMArgs(values.data(), tcodes.data(), nargs),
&capture_func_rv);
}

capture_cache_[entry_key] = entry;
CUDA_CALL(cudaGraphInstantiate(&capture_cache_[entry_key].exec, graph, NULL, NULL, 0));
CUDA_CALL(cudaStreamDestroy(capture_stream));
CUDAGraphCapturedState entry;
entry.states = capture_func_rv;
CUDA_CALL(cudaGraphInstantiate(&entry.exec, graph, NULL, NULL, 0));
CUDA_CALL(cudaGraphDestroy(graph));
return entry.states;

ObjectRef states = entry.states;

capture_cache_[entry_key] = std::move(entry);

return states;
}

/*!
Expand Down
77 changes: 75 additions & 2 deletions tests/python/relax/test_vm_cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
# under the License.

import tvm
from tvm.script import tir as T, relax as R, ir as I
from tvm import relax
import tvm.testing

from tvm import relax
from tvm.script import tir as T, relax as R, ir as I

import numpy as np
import pytest


# fmt: off
Expand Down Expand Up @@ -104,5 +107,75 @@ def test_vm_run():
tvm.testing.assert_allclose(y.asnumpy(), y_np, rtol=1e-5, atol=1e-5)


@tvm.testing.requires_cudagraph
def test_capture_error_is_recoverable():
"""Function calls while capturing cudagraph may throw exceptions

Calls to PackedFuncs may occur within a captured cudaGraph. If a
call to that PackedFunc raises an exception while capturing the
cudaGraph, throwing exception should cleanly unwind the stack, and
the exception may be caught in the calling scope.

This is a regression test. In previous implementations, an
exception thrown while capturing a cudaGraph would skip the call
to `cudaStreamEndCapture`, causing additional exceptions to be
thrown while freeing memory in TVM destructors. Since C++ does
not support stack unwinding from multiple simultaneous exceptions,
this would result in immediate `std::terminate`, making it
difficult to debug the original error.

"""

target = tvm.target.Target("cuda")
dev = tvm.cuda()

@tvm.register_func("test_vm_cuda_graph.invalid_impl_for_cudagraph", override=True)
def invalid_impl_for_cudagraph(arg_tensor):
# Memory allocation/deallocation may not be performed while
# capturing a cudaGraph. This passes the warm-up run
# performed by "vm.builtin.cuda_graph.run_or_capture", but
# throws an exception when the cudaGraph is being captured.
_dummy_workspace = tvm.nd.empty([16], "float16", dev)
return arg_tensor

@I.ir_module
class Module:
@R.function
def main(A: R.Tensor([16], "float16")):
B = R.add(A, A)
C = R.call_pure_packed(
"test_vm_cuda_graph.invalid_impl_for_cudagraph",
B,
sinfo_args=R.Tensor([16], "float16"),
)
D = R.add(C, C)
return D

with target, tvm.ir.transform.PassContext(config={"relax.backend.use_cuda_graph": True}):
Module = tvm.ir.transform.Sequential(
[
tvm.relax.transform.LegalizeOps(),
tvm.tir.transform.DefaultGPUSchedule(),
tvm.relax.transform.RemovePurityChecking(),
tvm.relax.transform.CallTIRRewrite(),
tvm.relax.transform.StaticPlanBlockMemory(),
tvm.relax.transform.RewriteCUDAGraph(),
]
)(Module)

assert "cuda_graph_alloc" in Module, (
"Validity of unit test requires the call to `invalid_impl_for_cudagraph` "
"to have been captured by RewriteCUDAGraph."
)

built = tvm.relax.build(Module, target=target)
vm = tvm.relax.VirtualMachine(built, dev)

arg = tvm.nd.array(np.arange(16).astype("float16"), dev)

with pytest.raises(tvm.TVMError):
vm["main"](arg)


if __name__ == "__main__":
tvm.testing.main()