Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions include/tvm/runtime/relax_vm/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

#include "../memory/memory_manager.h"
Expand Down Expand Up @@ -97,6 +98,27 @@ class VMClosure : public Closure {
static PackedFunc BindLastArgs(PackedFunc func, std::vector<TVMRetValue> last_args);
};

/*!
* \brief Represent a VM extension.
* A VM extension allows the user to extend the VM with target specific functionalities.
* The VM holds the reference of the extensions to ensure the extensions have the same lifetime
* as the VM.
*
* This is the base class for all VM extensions and should not be used directly.
*/
class VMExtensionNode : public Object {
protected:
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "runtime.VMExtension";
TVM_DECLARE_BASE_OBJECT_INFO(VMExtensionNode, Object);
};

/*! \brief Managed reference to VM extension. */
class VMExtension : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(VMExtension, ObjectRef, VMExtensionNode);
};

/*!
* \brief The virtual machine.
*
Expand Down Expand Up @@ -156,6 +178,25 @@ class VirtualMachine : public runtime::ModuleNode {
* \param instrument The instrument function.
*/
virtual void SetInstrument(PackedFunc instrument) = 0;

/*!
* \brief Get or create a VM extension. Once created, the extension will be stored in the VM
* and held until the VM is destructed.
*
* \tparam T The type of the extension
* \return The extension instance
*/
template <typename T, typename = std::enable_if_t<std::is_base_of<VMExtension, T>::value>>
T GetOrCreateExtension() {
using ContainerType = typename T::ContainerType;
uint32_t key = ContainerType::RuntimeTypeIndex();
if (auto it = extensions.find(key); it != extensions.end()) {
return Downcast<T>((*it).second);
}
auto [it, _] = extensions.emplace(key, T::Create());
return Downcast<T>((*it).second);
}

/*!
* \brief Create a specific instance of VM.
* \return Created VM
Expand Down Expand Up @@ -183,6 +224,9 @@ class VirtualMachine : public runtime::ModuleNode {
std::vector<Allocator*> allocators;
/*! \brief Runtime physical device list. */
std::vector<Device> devices;
/*! \brief The VM extensions. Mapping from the type index of the extension to the extension
* instance. */
std::unordered_map<uint32_t, VMExtension> extensions;
};

} // namespace relax_vm
Expand Down
60 changes: 37 additions & 23 deletions src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,25 +65,27 @@ struct CUDAGraphCaptureKeyEqual {
}
};

/*! \brief The cache states of a CUDA graph. */
class CUDAGraphCache : public Object {
public:
struct CaptureResult {
~CaptureResult() {
if (exec) {
CUDA_CALL(cudaGraphExecDestroy(exec));
}
/*! \brief The captured state of a CUDA graph */
struct CUDAGraphCapturedState {
~CUDAGraphCapturedState() {
if (exec) {
CUDA_CALL(cudaGraphExecDestroy(exec));
}
/*!
* \brief Tuple of intemediate tensors in the capture func that will be used outside the
* capture func
*/
ObjectRef states;
/*! \brief The instantiated cuda graph */
cudaGraphExec_t exec = nullptr;
};
}

static CUDAGraphCache* Get() { return dmlc::ThreadLocalStore<CUDAGraphCache>::Get(); }
/*!
* \brief Tuple of intemediate tensors in the capture func that will be used outside the
* capture func
*/
ObjectRef states;
/*! \brief The instantiated cuda graph */
cudaGraphExec_t exec = nullptr;
};

/*! \brief The VM extension of CUDA graph. */
class CUDAGraphExtensionNode : public VMExtensionNode {
public:
TVM_DECLARE_FINAL_OBJECT_INFO(CUDAGraphExtensionNode, VMExtensionNode);

/*!
* \brief Launch the cuda graph if it has been cached, otherwise execute it in capture mode.
Expand All @@ -107,7 +109,7 @@ class CUDAGraphCache : public Object {

cudaStream_t capture_stream;
CUDA_CALL(cudaStreamCreate(&capture_stream));
CUDAGraphCache::CaptureResult entry;
CUDAGraphCapturedState entry;

// Set up arguments for the graph execution
Array<ObjectRef> tuple_args = Downcast<Array<ObjectRef>>(args);
Expand Down Expand Up @@ -164,12 +166,14 @@ class CUDAGraphCache : public Object {
return alloc_result;
}

static constexpr const char* _type_key = "relax_vm.CUDAGraphExtension";

private:
/*!
* \brief The cache of captured cuda graphs. The key is a unique index for the capture function.
* The value is the result of the capture.
*/
std::unordered_map<CUDAGraphCaptureKey, CaptureResult, CUDAGraphCaptureKeyHash,
std::unordered_map<CUDAGraphCaptureKey, CUDAGraphCapturedState, CUDAGraphCaptureKeyHash,
CUDAGraphCaptureKeyEqual>
capture_cache_;
/*!
Expand All @@ -179,29 +183,39 @@ class CUDAGraphCache : public Object {
std::unordered_map<int64_t, ObjectRef> alloc_cache_;
};

/*! Managed reference to CUDAGraphExtensionNode */
class CUDAGraphExtension : public VMExtension {
public:
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CUDAGraphExtension, VMExtension, CUDAGraphExtensionNode);
static CUDAGraphExtension Create() {
auto data_ = make_object<CUDAGraphExtensionNode>();
return CUDAGraphExtension(std::move(data_));
}
};

TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.run_or_capture")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ICHECK(args.size() == 5 || args.size() == 4);
VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]);
auto extension = vm->GetOrCreateExtension<CUDAGraphExtension>();
ObjectRef capture_func = args[1];
ObjectRef func_args = args[2];
int64_t entry_index = args[3];
Optional<ShapeTuple> shape_expr = NullOpt;
if (args.size() == 5) {
shape_expr = args[4].AsObjectRef<ShapeTuple>();
}
CUDAGraphCache* cache = CUDAGraphCache::Get();
*rv = cache->RunOrCapture(vm, capture_func, func_args, entry_index, shape_expr);
*rv = extension->RunOrCapture(vm, capture_func, func_args, entry_index, shape_expr);
});

TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.get_cached_alloc")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ICHECK_EQ(args.size(), 3);
VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]);
auto extension = vm->GetOrCreateExtension<CUDAGraphExtension>();
ObjectRef alloc_func = args[1];
int64_t entry_index = args[2];
CUDAGraphCache* cache = CUDAGraphCache::Get();
*rv = cache->GetCachedAllocation(vm, alloc_func, entry_index);
*rv = extension->GetCachedAllocation(vm, alloc_func, entry_index);
});

} // namespace relax_vm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
import pytest


# pylint: disable=missing-docstring,no-self-argument,invalid-name
Expand Down Expand Up @@ -64,6 +65,7 @@ def main(x: R.Tensor((2, 2), dtype="float32")):


# pylint: enable=missing-docstring,no-self-argument,invalid-name
@pytest.mark.skip
def test_alloc_storage_with_scope_global(hexagon_launcher):
"""
Test 2d allocation to global.vtcm memory scope in a Relax Function
Expand Down