diff --git a/CMakeLists.txt b/CMakeLists.txt index 451b6a7ee2c2..16968ce41f70 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -35,6 +35,7 @@ tvm_option(USE_THREADS "Build with thread support" ON) tvm_option(USE_LLVM "Build with LLVM, can be set to specific llvm-config path" OFF) tvm_option(USE_STACKVM_RUNTIME "Include stackvm into the runtime" OFF) tvm_option(USE_GRAPH_RUNTIME "Build with tiny graph runtime" ON) +tvm_option(USE_GRAPH_RUNTIME_CUDA_GRAPH "Build with tiny graph runtime with CUDA Graph for GPUs" OFF) tvm_option(USE_PROFILER "Build profiler for the VM and graph runtime" ON) tvm_option(USE_OPENMP "Build with OpenMP thread pool implementation" OFF) tvm_option(USE_RELAY_DEBUG "Building Relay in debug mode..." OFF) diff --git a/cmake/config.cmake b/cmake/config.cmake index 65859566a664..60c718c97bc1 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -99,6 +99,9 @@ set(USE_STACKVM_RUNTIME OFF) # Whether enable tiny embedded graph runtime. set(USE_GRAPH_RUNTIME ON) +# Whether enable tiny graph runtime with CUDA Graph +set(USE_GRAPH_RUNTIME_CUDA_GRAPH OFF) + # Whether to enable the profiler for the graph runtime and vm set(USE_PROFILER ON) diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake index 0ec2f1466bd1..262a4e6e7123 100644 --- a/cmake/modules/CUDA.cmake +++ b/cmake/modules/CUDA.cmake @@ -65,6 +65,17 @@ if(USE_CUDA) list(APPEND RUNTIME_SRCS ${CONTRIB_THRUST_SRC}) endif(USE_THRUST) + if(USE_GRAPH_RUNTIME_CUDA_GRAPH) + if(NOT USE_GRAPH_RUNTIME) + message(FATAL_ERROR "CUDA Graph is only supported by graph runtime, please set USE_GRAPH_RUNTIME=ON") + endif() + if(CUDAToolkit_VERSION_MAJOR LESS "10") + message(FATAL_ERROR "CUDA Graph requires CUDA 10 or above, got=" ${CUDAToolkit_VERSION}) + endif() + message(STATUS "Build with Graph runtime with CUDA Graph support...") + file(GLOB RUNTIME_CUDA_GRAPH_SRCS src/runtime/graph/cuda_graph/*.cc) + list(APPEND RUNTIME_SRCS ${RUNTIME_CUDA_GRAPH_SRCS}) + endif() else(USE_CUDA) list(APPEND COMPILER_SRCS src/target/opt/build_cuda_off.cc) endif(USE_CUDA) diff --git a/python/tvm/contrib/cuda_graph/__init__.py b/python/tvm/contrib/cuda_graph/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/python/tvm/contrib/cuda_graph/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/python/tvm/contrib/cuda_graph/cuda_graph_runtime.py b/python/tvm/contrib/cuda_graph/cuda_graph_runtime.py new file mode 100644 index 000000000000..45ec89d37b3d --- /dev/null +++ b/python/tvm/contrib/cuda_graph/cuda_graph_runtime.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Graph runtime with CUDA Graph""" +import tvm._ffi + +from tvm._ffi.base import string_types +from tvm.contrib import graph_runtime + + +def create(graph_json_str, libmod, ctx): + """Create a runtime executor module given a graph and module. + + Parameters + ---------- + graph_json_str : str + The graph to be deployed in json format output by json graph. + The graph can contain operator(tvm_op) that points to the name + of PackedFunc in the libmod. + + libmod : tvm.runtime.Module + The module of the corresponding function + + ctx : TVMContext + The context to deploy the module, only supports CUDA GPU + + Returns + ------- + graph_module : GraphModuleCudaGraph + CUDA graph runtime module that can be used to execute the graph. + + Note + ---- + See also :py:class:`tvm.contrib.cuda_graph.cuda_graph_runtime.GraphModuleCudaGraph` + for examples to directly construct a GraphModuleCudaGraph from an exported + relay compiled library. + """ + assert isinstance(graph_json_str, string_types) + try: + ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx) + if num_rpc_ctx == len(ctx): + fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime_cuda_graph.create") + else: + fcreate = tvm._ffi.get_global_func("tvm.graph_runtime_cuda_graph.create") + except ValueError: + raise ValueError( + "To enable CUDA graph support (experimental), please set " + "'(USE_GRAPH_RUNTIME_CUGRAPH ON)' in config.cmake and rebuild TVM" + ) + + return GraphModuleCudaGraph(fcreate(graph_json_str, libmod, *device_type_id)) + + +class GraphModuleCudaGraph(graph_runtime.GraphModule): + """CUDA graph runtime module. + + This is a CUDA graph runtime wrapper over the TVM runtime. + Runtime interfaces are wrapped with CUDA graph functionalities. + + Parameters + ---------- + module : Module + The internal tvm module that holds the actual graph functions. + """ + + def __init__(self, module): + self._start_capture = module["start_capture"] + self._end_capture = module["end_capture"] + self._run_cuda_graph = module["run_cuda_graph"] + self._cuda_graph_captured = False + graph_runtime.GraphModule.__init__(self, module) + + def capture_cuda_graph(self): + """Capture a CUDA graph for tvm_op graph + + This should be called before run_cuda_graph() to capture and + instantiate a CUDA graph instance. + """ + self._run() # call cuModuleLoadData before cudaStream API + self._start_capture() + self._run() + self._end_capture() + self._cuda_graph_captured = True + + def run_cuda_graph(self): + """Run the CUDA graph for tvm_op graph + + Run the captured CUDA graph instance instead of the + for-loop kernel launch of default graph runtime + """ + self._run_cuda_graph() + + def run(self, **input_dict): + """A run wrapper for graph capture / launch, user can just + change default graph runtime to cuda graph runtime, and + the first call will capture a cuda graph for future launch + + Parameters + ---------- + input_dict: dict of str to NDArray + List of input values to be feed to + """ + if input_dict: + self.set_input(**input_dict) + if not self._cuda_graph_captured: + self.capture_cuda_graph() + else: + self._run_cuda_graph() + + def debug_get_output(self, node, out): + """Run graph up to node and get the output to out + + Parameters + ---------- + node : int / str + The node index or name + + out : NDArray + The output array container + """ + raise NotImplementedError("Please use debugger.debug_runtime as graph_runtime instead.") diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 7e49f55e8d32..99844f799d7a 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -349,6 +349,18 @@ def have_tensorcore(compute_version=None, target=None): return False +def have_cudagraph(): + """Either CUDA Graph support is provided""" + try: + cuda_path = find_cuda_path() + cuda_ver = get_cuda_version(cuda_path) + if cuda_ver < 10.0: + return False + return True + except RuntimeError: + return False + + def have_bf16(compute_version): """Either bf16 support is provided in the compute capability or not diff --git a/python/tvm/testing.py b/python/tvm/testing.py index d65ab23677b5..1cb43b29c521 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -514,6 +514,25 @@ def requires_cuda(*args): return _compose(args, _requires_cuda) +def requires_cudagraph(*args): + """Mark a test as requiring the CUDA Graph Feature + + This also marks the test as requiring cuda + + Parameters + ---------- + f : function + Function to mark + """ + _requires_cudagraph = [ + pytest.mark.skipif( + not nvcc.have_cudagraph(), reason="CUDA Graph is not supported in this environment" + ), + *requires_cuda(), + ] + return _compose(args, _requires_cudagraph) + + def requires_opencl(*args): """Mark a test as requiring the OpenCL runtime. diff --git a/src/runtime/graph/cuda_graph/graph_runtime_cuda_graph.cc b/src/runtime/graph/cuda_graph/graph_runtime_cuda_graph.cc new file mode 100644 index 000000000000..ee5e50a3b9d4 --- /dev/null +++ b/src/runtime/graph/cuda_graph/graph_runtime_cuda_graph.cc @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file graph_runtime_cuda_graph.cc + */ + +#include + +#include "../../cuda/cuda_common.h" +#include "../graph_runtime.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief Graph runtime with CUDA Graph Support. + * + * This is the extension of GraphRuntime class used for CUDA graph launch + * instead of CUDA kernel launch. CUDA graph launch requires CUDA 10.0 or + * above, currently there are two ways of constructing CUDA graphs: + * (1) Using CUDA stream capture API to capture a series of operations on + * CUDA stream, and automatically generates a graph (2) Building a graph + * using CUDA graph API manually. This implementation uses stream capture. + */ +class GraphRuntimeCudaGraph : public GraphRuntime { + public: + /*! + * \brief Begin CUDA graph capture on stream, the stream enters capture mode. + */ + void StartCapture() { + const TVMContext& ctx = data_entry_[entry_id(0, 0)]->ctx; + + TVMStreamCreate(ctx.device_type, ctx.device_id, &capture_stream_); + TVMSetStream(ctx.device_type, ctx.device_id, capture_stream_); + + CUDA_CALL(cudaStreamBeginCapture(static_cast(capture_stream_), + cudaStreamCaptureModeGlobal)); + } + + /*! + * \brief Launch the instantiated graph on stream + */ + void RunCudaGraph() { + cudaStream_t cuStream = static_cast(capture_stream_); + CUDA_CALL(cudaGraphLaunch(cuda_graph_exec_, cuStream)); + CUDA_CALL(cudaStreamSynchronize(cuStream)); + } + + /*! + * \brief End CUDA graph capture on stream, a graph will be created and + * instantiated. + */ + void EndCapture() { + cudaGraph_t graph; + CUDA_CALL(cudaStreamEndCapture(static_cast(capture_stream_), &graph)); + + cudaGraphNode_t* nodes = NULL; + size_t numNodes = 0; + CUDA_CALL(cudaGraphGetNodes(graph, nodes, &numNodes)); + LOG(INFO) << "Num of nodes in the cuda graph created using stream capture API = " << numNodes; + + CUDA_CALL(cudaGraphInstantiate(&cuda_graph_exec_, graph, NULL, NULL, 0)); + } + + /*! + * \brief GetFunction Get the function based on input. + * \param name The function which needs to be invoked. + * \param sptr_to_self Packed function pointer. + */ + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); + + private: + /*! \brief The Cuda stream on which to capture a CUDA graph. */ + TVMStreamHandle capture_stream_; + /*! \brief The captured CUDA graph will be instantiated to this. */ + cudaGraphExec_t cuda_graph_exec_; +}; + +PackedFunc GraphRuntimeCudaGraph::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { + if (name == "run_cuda_graph") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->RunCudaGraph(); }); + } else if (name == "start_capture") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->StartCapture(); }); + } else if (name == "end_capture") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->EndCapture(); }); + } else { + return GraphRuntime::GetFunction(name, sptr_to_self); + } +} + +Module GraphRuntimeCudaGraphCreate(const std::string& sym_json, const tvm::runtime::Module& m, + const std::vector& ctxs, + PackedFunc lookup_linked_param_func) { + auto exec = make_object(); + exec->Init(sym_json, m, ctxs, lookup_linked_param_func); + return Module(exec); +} + +TVM_REGISTER_GLOBAL("tvm.graph_runtime_cuda_graph.create") + .set_body([](TVMArgs args, TVMRetValue* rv) { + ICHECK_GE(args.num_args, 4) << "The expected number of arguments for graph_runtime.create is " + "at least 4, but it has " + << args.num_args; + PackedFunc lookup_linked_param_func; + int ctx_start_arg = 2; + if (args[2].type_code() == kTVMPackedFuncHandle) { + lookup_linked_param_func = args[2]; + ctx_start_arg++; + } + + *rv = GraphRuntimeCudaGraphCreate(args[0], args[1], GetAllContext(args, ctx_start_arg), + lookup_linked_param_func); + }); +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/graph/graph_runtime_factory.cc b/src/runtime/graph/graph_runtime_factory.cc index 605d6b0ce892..1682afa8464a 100644 --- a/src/runtime/graph/graph_runtime_factory.cc +++ b/src/runtime/graph/graph_runtime_factory.cc @@ -72,6 +72,14 @@ PackedFunc GraphRuntimeFactory::GetFunction( exec->Import(this->imports_[0]); *rv = Module(exec); }); + } else if (name == "cuda_graph_create") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::vector contexts; + for (int i = 0; i < args.num_args; ++i) { + contexts.emplace_back(args[i].operator TVMContext()); + } + *rv = this->CudaGraphRuntimeCreate(contexts); + }); } else { return PackedFunc(); } @@ -130,6 +138,31 @@ Module GraphRuntimeFactory::DebugRuntimeCreate(const std::vector& ct return mod; } +Module GraphRuntimeFactory::CudaGraphRuntimeCreate(const std::vector& ctxs) { + const PackedFunc* pf = tvm::runtime::Registry::Get("tvm.graph_runtime_cuda_graph.create"); + ICHECK(pf != nullptr) << "Cannot find function tvm.graph_runtime_cuda_graph.create in registry. " + "Did you set(USE_GRAPH_RUNTIME_CUGRAPH=ON)?"; + std::vector unpacked_ctxs; + for (const auto& ctx : ctxs) { + unpacked_ctxs.emplace_back(ctx.device_type); + unpacked_ctxs.emplace_back(ctx.device_id); + } + size_t args_size = unpacked_ctxs.size() + 2; + std::vector values(args_size); + std::vector codes(args_size); + runtime::TVMArgsSetter setter(values.data(), codes.data()); + setter(0, this->graph_json_); + setter(1, this->imports_[0]); + for (size_t i = 0; i < unpacked_ctxs.size(); ++i) { + setter(i + 2, unpacked_ctxs[i]); + } + TVMRetValue rv; + pf->CallPacked(TVMArgs(values.data(), codes.data(), args_size), &rv); + Module mod = rv.operator Module(); + SetParams(const_cast(mod.as()), this->params_); + return mod; +} + Module GraphRuntimeFactoryModuleLoadBinary(void* strm) { dmlc::Stream* stream = static_cast(strm); std::string graph_json; diff --git a/src/runtime/graph/graph_runtime_factory.h b/src/runtime/graph/graph_runtime_factory.h index 98fb27c43ea2..f2f11ee66802 100644 --- a/src/runtime/graph/graph_runtime_factory.h +++ b/src/runtime/graph/graph_runtime_factory.h @@ -89,6 +89,14 @@ class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { */ Module DebugRuntimeCreate(const std::vector& ctxs); + /*! + * \brief Create a specific cuda graph runtime module + * \param ctxs The context of the host and devices where graph nodes will be + * executed on. + * \return created cuda graph runtime module + */ + Module CudaGraphRuntimeCreate(const std::vector& ctx); + /*! * \brief Set params. * \param graph_runtime The graph runtime we want to set the params into. diff --git a/tests/python/unittest/test_runtime_graph_cuda_graph.py b/tests/python/unittest/test_runtime_graph_cuda_graph.py new file mode 100644 index 000000000000..4a31873cb93c --- /dev/null +++ b/tests/python/unittest/test_runtime_graph_cuda_graph.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import json +import os +import re +import sys +import time + +import pytest + +import tvm +import tvm.testing +from tvm import te +import numpy as np + +from tvm.contrib import utils, graph_runtime +from tvm.contrib.cuda_graph import cuda_graph_runtime + + +bx = te.thread_axis("blockIdx.x") +tx = te.thread_axis("threadIdx.x") + + +@tvm.testing.requires_cudagraph +def test_graph_simple(): + n = 32 + A = te.placeholder((n,), name="A") + B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B") + s = te.create_schedule(B.op) + xo, xi = s[B].split(B.op.axis[0], factor=8) + s[B].bind(xo, bx) + s[B].bind(xi, tx) + + node0 = {"op": "null", "name": "x", "inputs": []} + node1 = { + "op": "tvm_op", + "name": "add", + "inputs": [[0, 0, 0]], + "attrs": {"func_name": "myadd", "flatten_data": "1", "num_inputs": "1", "num_outputs": "1"}, + } + nodes = [node0, node1] + arg_nodes = [0] + node_row_ptr = [0, 1, 2] + outputs = [[1, 0, 0]] + shape = (n,) + attrs = { + "shape": ["list_shape", [shape, shape]], + "dltype": ["list_str", ["float32", "float32"]], + "storage_id": ["list_int", [0, 1]], + } + graph = { + "nodes": nodes, + "arg_nodes": arg_nodes, + "node_row_ptr": node_row_ptr, + "heads": outputs, + "attrs": attrs, + } + graph = json.dumps(graph) + + def check_verify(): + mlib = tvm.build(s, [A, B], "cuda", name="myadd") + ctx = tvm.gpu(0) + try: + mod = cuda_graph_runtime.create(graph, mlib, ctx) + except ValueError: + return + + for i in range(3): + a = np.random.uniform(size=(n,)).astype(A.dtype) + mod.run(x=a) # The first run captured a CUDA graph + out = mod.get_output(0, tvm.nd.empty((n,))) + np.testing.assert_equal(out.asnumpy(), a + 1) + + # capture / run CUDA graph manually + mod.capture_cuda_graph() + a = np.random.uniform(size=(n,)).astype(A.dtype) + mod.set_input(x=a) + mod.run_cuda_graph() + out = mod.get_output(0, tvm.nd.empty((n,))) + np.testing.assert_equal(out.asnumpy(), a + 1) + + check_verify() + + +if __name__ == "__main__": + test_graph_simple() diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index a34fe4a062cb..930011d4fd33 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -20,6 +20,7 @@ import tvm from tvm.contrib import graph_runtime from tvm.contrib.debugger import debug_runtime +from tvm.contrib.cuda_graph import cuda_graph_runtime import tvm.testing @@ -538,6 +539,35 @@ def test_debug_graph_runtime(): tvm.testing.assert_allclose(out, verify(data), atol=1e-5) +@tvm.testing.requires_cudagraph +def test_cuda_graph_runtime(): + mod, params = relay.testing.synthetic.get_workload() + with tvm.transform.PassContext(opt_level=3): + complied_graph_lib = relay.build_module.build(mod, "cuda", params=params) + data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32") + + ctx = tvm.gpu() + try: + gmod = complied_graph_lib["cuda_graph_create"](ctx) + except: + print("Skip because cuda_graph not enabled") + return + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + set_input("data", tvm.nd.array(data)) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + # cuda graph runtime wrapper + cu_gmod = cuda_graph_runtime.GraphModuleCudaGraph(gmod) + cu_gmod.set_input("data", data) + cu_gmod.run() + out = cu_gmod.get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + def test_multiple_imported_modules(): def make_func(symbol): n = tvm.te.size_var("n")