From 300d8ac69e3b4e716230297a809792c3774779fa Mon Sep 17 00:00:00 2001 From: zhuochen Date: Tue, 9 Mar 2021 16:54:47 +0800 Subject: [PATCH 01/17] add graph runtime cuGraph poc --- CMakeLists.txt | 9 ++ cmake/config.cmake | 3 + .../tvm/contrib/cu_graph/cugraph_runtime.py | 65 ++++++++++ .../graph/cugraph/graph_runtime_cugraph.cc | 117 ++++++++++++++++++ 4 files changed, 194 insertions(+) create mode 100644 python/tvm/contrib/cu_graph/cugraph_runtime.py create mode 100644 src/runtime/graph/cugraph/graph_runtime_cugraph.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 56170c693e3c..94054b4231ea 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,6 +36,7 @@ tvm_option(USE_LLVM "Build with LLVM, can be set to specific llvm-config path" O tvm_option(USE_STACKVM_RUNTIME "Include stackvm into the runtime" OFF) tvm_option(USE_GRAPH_RUNTIME "Build with tiny graph runtime" ON) tvm_option(USE_GRAPH_RUNTIME_DEBUG "Build with tiny graph runtime debug mode" OFF) +tvm_option(USE_GRAPH_RUNTIME_CUGRAPH "Build with tiny graph runtime cuGraph launch mode" OFF) tvm_option(USE_OPENMP "Build with OpenMP thread pool implementation" OFF) tvm_option(USE_RELAY_DEBUG "Building Relay in debug mode..." OFF) tvm_option(USE_RTTI "Build with RTTI" ON) @@ -321,6 +322,14 @@ if(USE_GRAPH_RUNTIME) set_source_files_properties(${RUNTIME_GRAPH_SRCS} PROPERTIES COMPILE_DEFINITIONS "TVM_GRAPH_RUNTIME_DEBUG") endif(USE_GRAPH_RUNTIME_DEBUG) + + if(USE_GRAPH_RUNTIME_CUGRAPH) + message(STATUS "Build with Graph runtime cuGraph support...") + file(GLOB RUNTIME_CUGRAPH_SRCS src/runtime/graph/cugraph/*.cc) + list(APPEND RUNTIME_SRCS ${RUNTIME_CUGRAPH_SRCS}) + set_source_files_properties(${RUNTIME_GRAPH_SRCS} + PROPERTIES COMPILE_DEFINITIONS "TVM_GRAPH_RUNTIME_CUGRAPH") + endif(USE_GRAPH_RUNTIME_CUGRAPH) endif(USE_GRAPH_RUNTIME) if(USE_VM_PROFILER) diff --git a/cmake/config.cmake b/cmake/config.cmake index 872feb918a4f..257e62bd9b7d 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -102,6 +102,9 @@ set(USE_GRAPH_RUNTIME ON) # Whether enable additional graph debug functions set(USE_GRAPH_RUNTIME_DEBUG OFF) +# Whether enable tiny graph runtime for cudaGraph Launch +set(USE_GRAPH_RUNTIME_CUGRAPH OFF) + # Whether enable additional vm profiler functions set(USE_VM_PROFILER OFF) diff --git a/python/tvm/contrib/cu_graph/cugraph_runtime.py b/python/tvm/contrib/cu_graph/cugraph_runtime.py new file mode 100644 index 000000000000..cef9bd8c5acb --- /dev/null +++ b/python/tvm/contrib/cu_graph/cugraph_runtime.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Graph runtime test cuGraph""" +import tvm._ffi + +from tvm._ffi.base import string_types +from tvm.contrib import graph_runtime + + +def create(graph_json_str, libmod, ctx): + assert isinstance(graph_json_str, string_types) + try: + ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx) + if num_rpc_ctx == len(ctx): + pass + else: + fcreate = tvm._ffi.get_global_func("tvm.graph_runtime_cugraph.create") + except ValueError: + raise ValueError( + "Please set '(USE_GRAPH_RUNTIME_CUGRAPH ON)' in " + "config.cmake and rebuild TVM to enable cu_graph test mode" + ) + + func_obj = fcreate(graph_json_str, libmod, *device_type_id) + return GraphModuleCuGraph(func_obj, ctx, graph_json_str) + + +class GraphModuleCuGraph(graph_runtime.GraphModule): + def __init__(self, module, ctx, graph_json_str): + + self._start_capture = module["start_capture"] + self._end_capture = module["end_capture"] + self._run_cuda_graph = module["run_cuda_graph"] + + graph_runtime.GraphModule.__init__(self, module) + + def capture_cuda_graph(self): + # call cuModuleLoadData before cudaStream API + self._run() + + print("====== Start Stream Capture ======") + self._start_capture() + print("====== Start Run Ops On Stream ======") + self._run() + print("====== End Stream Capture ======") + self._end_capture() + + + def run_cuda_graph(self): + self._run_cuda_graph() + diff --git a/src/runtime/graph/cugraph/graph_runtime_cugraph.cc b/src/runtime/graph/cugraph/graph_runtime_cugraph.cc new file mode 100644 index 000000000000..5ceb9b789bcd --- /dev/null +++ b/src/runtime/graph/cugraph/graph_runtime_cugraph.cc @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file graph_runtime_cugraph.cc + */ + +#include + +#include "../../cuda/cuda_common.h" +#include "../graph_runtime.h" + +namespace tvm { +namespace runtime { + +class GraphRuntimeCuGraph : public GraphRuntime { + public: + int StartCapture() { + const TVMContext& ctx = data_entry_[entry_id(0, 0)]->ctx; + + TVMStreamCreate(ctx.device_type, ctx.device_id, &capture_stream_); + TVMSetStream(ctx.device_type, ctx.device_id, capture_stream_); + + CUDA_CALL(cudaStreamBeginCapture(static_cast(capture_stream_), + cudaStreamCaptureModeGlobal)); + return 0; + } + + int RunCudaGraph() { + cudaStream_t cuStream = static_cast(capture_stream_); + CUDA_CALL(cudaGraphLaunch(cu_graph_exec_, cuStream)); + CUDA_CALL(cudaStreamSynchronize(cuStream)); + return 0; + } + + int EndCapture() { + cudaGraph_t graph; + CUDA_CALL(cudaStreamEndCapture(static_cast(capture_stream_), &graph)); + + cudaGraphNode_t* nodes = NULL; + size_t numNodes = 0; + CUDA_CALL(cudaGraphGetNodes(graph, nodes, &numNodes)); + LOG(INFO) << "Num of nodes in the cuda graph created using stream capture API = " << numNodes; + + CUDA_CALL(cudaGraphInstantiate(&cu_graph_exec_, graph, NULL, NULL, 0)); + return 0; + } + + /*! + * \brief GetFunction Get the function based on input. + * \param name The function which needs to be invoked. + * \param sptr_to_self Packed function pointer. + */ + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); + + private: + TVMStreamHandle capture_stream_; + cudaGraphExec_t cu_graph_exec_; +}; + +PackedFunc GraphRuntimeCuGraph::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { + if (name == "run_cuda_graph") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->RunCudaGraph(); }); + } else if (name == "start_capture") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->StartCapture(); }); + } else if (name == "end_capture") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->EndCapture(); }); + } else { + return GraphRuntime::GetFunction(name, sptr_to_self); + } +} + +Module GraphRuntimeCuGraphCreate(const std::string& sym_json, const tvm::runtime::Module& m, + const std::vector& ctxs, + PackedFunc lookup_linked_param_func) { + auto exec = make_object(); + exec->Init(sym_json, m, ctxs, lookup_linked_param_func); + return Module(exec); +} + +TVM_REGISTER_GLOBAL("tvm.graph_runtime_cugraph.create").set_body([](TVMArgs args, TVMRetValue* rv) { + ICHECK_GE(args.num_args, 4) << "The expected number of arguments for graph_runtime.create is " + "at least 4, but it has " + << args.num_args; + PackedFunc lookup_linked_param_func; + int ctx_start_arg = 2; + if (args[2].type_code() == kTVMPackedFuncHandle) { + lookup_linked_param_func = args[2]; + ctx_start_arg++; + } + + *rv = GraphRuntimeCuGraphCreate(args[0], args[1], GetAllContext(args, ctx_start_arg), + lookup_linked_param_func); +}); +} // namespace runtime +} // namespace tvm + From cc97f6b692f09e0b6621b33d4527b23df3a0f141 Mon Sep 17 00:00:00 2001 From: zhuochen Date: Tue, 9 Mar 2021 17:11:06 +0800 Subject: [PATCH 02/17] lint format --- python/tvm/contrib/cu_graph/cugraph_runtime.py | 5 +---- src/runtime/graph/cugraph/graph_runtime_cugraph.cc | 1 - 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/python/tvm/contrib/cu_graph/cugraph_runtime.py b/python/tvm/contrib/cu_graph/cugraph_runtime.py index cef9bd8c5acb..539f3fdf7f89 100644 --- a/python/tvm/contrib/cu_graph/cugraph_runtime.py +++ b/python/tvm/contrib/cu_graph/cugraph_runtime.py @@ -49,8 +49,7 @@ def __init__(self, module, ctx, graph_json_str): graph_runtime.GraphModule.__init__(self, module) def capture_cuda_graph(self): - # call cuModuleLoadData before cudaStream API - self._run() + self._run() # call cuModuleLoadData before cudaStream API print("====== Start Stream Capture ======") self._start_capture() @@ -59,7 +58,5 @@ def capture_cuda_graph(self): print("====== End Stream Capture ======") self._end_capture() - def run_cuda_graph(self): self._run_cuda_graph() - diff --git a/src/runtime/graph/cugraph/graph_runtime_cugraph.cc b/src/runtime/graph/cugraph/graph_runtime_cugraph.cc index 5ceb9b789bcd..22def8246827 100644 --- a/src/runtime/graph/cugraph/graph_runtime_cugraph.cc +++ b/src/runtime/graph/cugraph/graph_runtime_cugraph.cc @@ -114,4 +114,3 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime_cugraph.create").set_body([](TVMArgs args }); } // namespace runtime } // namespace tvm - From ed3b250705af61caba637fcf8aff5e8e0c7d2a20 Mon Sep 17 00:00:00 2001 From: zhuochen Date: Thu, 11 Mar 2021 13:15:11 +0800 Subject: [PATCH 03/17] add unittest --- CMakeLists.txt | 16 ++-- .../unittest/test_runtime_graph_cugraph.py | 92 +++++++++++++++++++ 2 files changed, 101 insertions(+), 7 deletions(-) create mode 100644 tests/python/unittest/test_runtime_graph_cugraph.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 94054b4231ea..9d0bc19f4cb0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -323,13 +323,15 @@ if(USE_GRAPH_RUNTIME) PROPERTIES COMPILE_DEFINITIONS "TVM_GRAPH_RUNTIME_DEBUG") endif(USE_GRAPH_RUNTIME_DEBUG) - if(USE_GRAPH_RUNTIME_CUGRAPH) - message(STATUS "Build with Graph runtime cuGraph support...") - file(GLOB RUNTIME_CUGRAPH_SRCS src/runtime/graph/cugraph/*.cc) - list(APPEND RUNTIME_SRCS ${RUNTIME_CUGRAPH_SRCS}) - set_source_files_properties(${RUNTIME_GRAPH_SRCS} - PROPERTIES COMPILE_DEFINITIONS "TVM_GRAPH_RUNTIME_CUGRAPH") - endif(USE_GRAPH_RUNTIME_CUGRAPH) + if(USE_CUDA) + if(USE_GRAPH_RUNTIME_CUGRAPH) + message(STATUS "Build with Graph runtime cuGraph support...") + file(GLOB RUNTIME_CUGRAPH_SRCS src/runtime/graph/cugraph/*.cc) + list(APPEND RUNTIME_SRCS ${RUNTIME_CUGRAPH_SRCS}) + set_source_files_properties(${RUNTIME_GRAPH_SRCS} + PROPERTIES COMPILE_DEFINITIONS "TVM_GRAPH_RUNTIME_CUGRAPH") + endif(USE_GRAPH_RUNTIME_CUGRAPH) + endif(USE_CUDA) endif(USE_GRAPH_RUNTIME) if(USE_VM_PROFILER) diff --git a/tests/python/unittest/test_runtime_graph_cugraph.py b/tests/python/unittest/test_runtime_graph_cugraph.py new file mode 100644 index 000000000000..f4999cc96446 --- /dev/null +++ b/tests/python/unittest/test_runtime_graph_cugraph.py @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import json +import os +import re +import sys +import time + +import pytest + +import tvm +import tvm.testing +from tvm import te +import numpy as np + +from tvm.contrib import utils, graph_runtime +from tvm.contrib.cu_graph import cugraph_runtime + + +bx = te.thread_axis("blockIdx.x") +tx = te.thread_axis("threadIdx.x") + + +@tvm.testing.requires_cuda +def test_graph_simple(): + n = 32 + A = te.placeholder((n,), name="A") + B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B") + s = te.create_schedule(B.op) + xo, xi = s[B].split(B.op.axis[0], factor=8) + s[B].bind(xo, bx) + s[B].bind(xi, tx) + + node0 = {"op": "null", "name": "x", "inputs": []} + node1 = { + "op": "tvm_op", + "name": "add", + "inputs": [[0, 0, 0]], + "attrs": {"func_name": "myadd", "flatten_data": "1", "num_inputs": "1", "num_outputs": "1"}, + } + nodes = [node0, node1] + arg_nodes = [0] + node_row_ptr = [0, 1, 2] + outputs = [[1, 0, 0]] + shape = (n,) + attrs = { + "shape": ["list_shape", [shape, shape]], + "dltype": ["list_str", ["float32", "float32"]], + "storage_id": ["list_int", [0, 1]], + } + graph = { + "nodes": nodes, + "arg_nodes": arg_nodes, + "node_row_ptr": node_row_ptr, + "heads": outputs, + "attrs": attrs, + } + graph = json.dumps(graph) + + def check_verify(): + mlib = tvm.build(s, [A, B], "cuda", name="myadd") + ctx = tvm.gpu(0) + try: + mod = cugraph_runtime.create(graph, mlib, ctx) + except ValueError: + return + mod.capture_cuda_graph() + a = np.random.uniform(size=(n,)).astype(A.dtype) + mod.set_input(x=a) + mod.run_cuda_graph() + out = mod.get_output(0, tvm.nd.empty((n,))) + np.testing.assert_equal(out.asnumpy(), a + 1) + + check_verify() + + +if __name__ == "__main__": + test_graph_simple() From 520e532e380d415528b7c01181001e60385404a8 Mon Sep 17 00:00:00 2001 From: zhuochen Date: Tue, 16 Mar 2021 19:30:15 +0800 Subject: [PATCH 04/17] fix review comments --- CMakeLists.txt | 10 --- cmake/modules/CUDA.cmake | 11 +++ .../tvm/contrib/cu_graph/cugraph_runtime.py | 79 ++++++++++++++++--- python/tvm/contrib/nvcc.py | 13 +++ python/tvm/testing.py | 17 ++++ .../graph/cugraph/graph_runtime_cugraph.cc | 38 ++++++--- src/runtime/graph/graph_runtime_factory.cc | 33 ++++++++ src/runtime/graph/graph_runtime_factory.h | 8 ++ .../unittest/test_runtime_graph_cugraph.py | 10 ++- .../test_runtime_module_based_interface.py | 30 +++++++ 10 files changed, 218 insertions(+), 31 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9d0bc19f4cb0..1f80fc07ee5d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -322,16 +322,6 @@ if(USE_GRAPH_RUNTIME) set_source_files_properties(${RUNTIME_GRAPH_SRCS} PROPERTIES COMPILE_DEFINITIONS "TVM_GRAPH_RUNTIME_DEBUG") endif(USE_GRAPH_RUNTIME_DEBUG) - - if(USE_CUDA) - if(USE_GRAPH_RUNTIME_CUGRAPH) - message(STATUS "Build with Graph runtime cuGraph support...") - file(GLOB RUNTIME_CUGRAPH_SRCS src/runtime/graph/cugraph/*.cc) - list(APPEND RUNTIME_SRCS ${RUNTIME_CUGRAPH_SRCS}) - set_source_files_properties(${RUNTIME_GRAPH_SRCS} - PROPERTIES COMPILE_DEFINITIONS "TVM_GRAPH_RUNTIME_CUGRAPH") - endif(USE_GRAPH_RUNTIME_CUGRAPH) - endif(USE_CUDA) endif(USE_GRAPH_RUNTIME) if(USE_VM_PROFILER) diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake index 1e104218a456..59e5f9700025 100644 --- a/cmake/modules/CUDA.cmake +++ b/cmake/modules/CUDA.cmake @@ -64,6 +64,17 @@ if(USE_CUDA) list(APPEND RUNTIME_SRCS ${CONTRIB_THRUST_SRC}) endif(USE_THRUST) + if(USE_GRAPH_RUNTIME_CUGRAPH) + if(NOT USE_GRAPH_RUNTIME) + message(FATAL_ERROR "CUDA Graph is only supported by graph runtime, should set USE_GRAPH_RUNTIME=ON") + endif() + if(CUDAToolkit_VERSION_MAJOR LESS "10") + message(FATAL_ERROR "CUDA Graph requires at least CUDA 10, got=" ${CUDAToolkit_VERSION}) + endif() + message(STATUS "Build with Graph runtime cuGraph support...") + file(GLOB RUNTIME_CUGRAPH_SRCS src/runtime/graph/cugraph/*.cc) + list(APPEND RUNTIME_SRCS ${RUNTIME_CUGRAPH_SRCS}) + endif() else(USE_CUDA) list(APPEND COMPILER_SRCS src/target/opt/build_cuda_off.cc) endif(USE_CUDA) diff --git a/python/tvm/contrib/cu_graph/cugraph_runtime.py b/python/tvm/contrib/cu_graph/cugraph_runtime.py index 539f3fdf7f89..90bfc95b2aa3 100644 --- a/python/tvm/contrib/cu_graph/cugraph_runtime.py +++ b/python/tvm/contrib/cu_graph/cugraph_runtime.py @@ -22,41 +22,100 @@ def create(graph_json_str, libmod, ctx): + """Create a runtime executor module given a graph and module. + + Parameters + ---------- + graph_json_str : str + The graph to be deployed in json format output by json graph. + The graph can contain operator(tvm_op) that points to the name + of PackedFunc in the libmod. + + libmod : tvm.runtime.Module + The module of the corresponding function + + ctx : TVMContext + The context to deploy the module, only supports CUDA GPU + + Returns + ------- + graph_module : GraphModuleCuGraph + CUDA graph runtime module that can be used to execute the graph. + + Note + ---- + See also :py:class:`tvm.contrib.cu_graph.GraphModuleCuGraph` + for examples to directly construct a GraphModuleCuGraph from an exported + relay compiled library. + """ assert isinstance(graph_json_str, string_types) try: ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx) if num_rpc_ctx == len(ctx): - pass + fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime_cugraph.create") else: fcreate = tvm._ffi.get_global_func("tvm.graph_runtime_cugraph.create") except ValueError: raise ValueError( "Please set '(USE_GRAPH_RUNTIME_CUGRAPH ON)' in " - "config.cmake and rebuild TVM to enable cu_graph test mode" + "config.cmake and rebuild TVM to enable CUDA graph support" ) - func_obj = fcreate(graph_json_str, libmod, *device_type_id) - return GraphModuleCuGraph(func_obj, ctx, graph_json_str) + return GraphModuleCuGraph(fcreate(graph_json_str, libmod, *device_type_id)) class GraphModuleCuGraph(graph_runtime.GraphModule): - def __init__(self, module, ctx, graph_json_str): + """CUDA graph runtime module. + + This is a CUDA graph runtime wrapper over the TVM runtime. + Runtime interfaces are wrapped with CUDA graph functionalities. + + Parameters + ---------- + module : Module + The internal tvm module that holds the actual graph functions. + """ + def __init__(self, module): self._start_capture = module["start_capture"] self._end_capture = module["end_capture"] self._run_cuda_graph = module["run_cuda_graph"] - + self._cuda_graph_captured = False graph_runtime.GraphModule.__init__(self, module) def capture_cuda_graph(self): - self._run() # call cuModuleLoadData before cudaStream API + """Capture a CUDA graph for tvm_op graph - print("====== Start Stream Capture ======") + This should be called before run_cuda_graph() to capture and + instantiate a CUDA graph instance. + """ + self._run() # call cuModuleLoadData before cudaStream API self._start_capture() - print("====== Start Run Ops On Stream ======") self._run() - print("====== End Stream Capture ======") self._end_capture() + self._cuda_graph_captured = True def run_cuda_graph(self): + """Run the CUDA graph for tvm_op graph + + Run the captured CUDA graph instance instead of the + for-loop kernel launch of default graph runtime + """ self._run_cuda_graph() + + def run(self, **input_dict): + """A run wrapper for graph capture / launch, user can just + change default graph runtime to cuda graph runtime, and + the first call will capture a cuda graph for future launch + + Parameters + ---------- + input_dict: dict of str to NDArray + List of input values to be feed to + """ + if input_dict: + self.set_input(**input_dict) + if not self._cuda_graph_captured: + self.capture_cuda_graph() + else: + self._run_cuda_graph() diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index f33603b923a5..f6eecde0ba4f 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -308,6 +308,19 @@ def have_tensorcore(compute_version=None, target=None): return False +def have_cudagraph(): + """Either CUDA Graph support is provided""" + try: + cuda_path = find_cuda_path() + cuda_ver = get_cuda_version(cuda_path) + if cuda_ver < 10.0: + return False + return True + except RuntimeError: + warnings.warn("Cannot find cuda path") + return False + + def have_bf16(compute_version): """Either bf16 support is provided in the compute capability or not diff --git a/python/tvm/testing.py b/python/tvm/testing.py index d65ab23677b5..a4e7e640765b 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -514,6 +514,23 @@ def requires_cuda(*args): return _compose(args, _requires_cuda) +def requires_cudagraph(*args): + """Mark a test as requiring the CUDA Graph Feature + + This also marks the test as requiring cuda + + Parameters + ---------- + f : function + Function to mark + """ + _requires_cudagraph = [ + pytest.mark.skipif(not nvcc.have_cudagraph(), reason="CUDA Graph not support"), + *requires_cuda(), + ] + return _compose(args, _requires_cudagraph) + + def requires_opencl(*args): """Mark a test as requiring the OpenCL runtime. diff --git a/src/runtime/graph/cugraph/graph_runtime_cugraph.cc b/src/runtime/graph/cugraph/graph_runtime_cugraph.cc index 22def8246827..59af561d8b54 100644 --- a/src/runtime/graph/cugraph/graph_runtime_cugraph.cc +++ b/src/runtime/graph/cugraph/graph_runtime_cugraph.cc @@ -29,9 +29,22 @@ namespace tvm { namespace runtime { +/*! + * \brief Graph runtime with CUDA Graph Support. + * + * This is the extension of GraphRuntime class used for CUDA graph launch + * instead of CUDA kernel launch. CUDA graph launch requires CUDA 10.0 or + * above, currently there are two ways of constructing CUDA graphs: + * (1) Using CUDA stream capture API to capture a series of operations on + * CUDA stream, and automatically generates a graph (2) Building a graph + * using CUDA graph API manually. This implementation uses stream capture. + */ class GraphRuntimeCuGraph : public GraphRuntime { public: - int StartCapture() { + /*! + * \brief Begin CUDA graph capture on stream, the stream enters capture mode. + */ + void StartCapture() { const TVMContext& ctx = data_entry_[entry_id(0, 0)]->ctx; TVMStreamCreate(ctx.device_type, ctx.device_id, &capture_stream_); @@ -39,17 +52,22 @@ class GraphRuntimeCuGraph : public GraphRuntime { CUDA_CALL(cudaStreamBeginCapture(static_cast(capture_stream_), cudaStreamCaptureModeGlobal)); - return 0; } - int RunCudaGraph() { + /*! + * \brief Launch the instantiated graph on stream + */ + void RunCudaGraph() { cudaStream_t cuStream = static_cast(capture_stream_); CUDA_CALL(cudaGraphLaunch(cu_graph_exec_, cuStream)); CUDA_CALL(cudaStreamSynchronize(cuStream)); - return 0; } - int EndCapture() { + /*! + * \brief End CUDA graph capture on stream, a graph will be created and + * instantiated. + */ + void EndCapture() { cudaGraph_t graph; CUDA_CALL(cudaStreamEndCapture(static_cast(capture_stream_), &graph)); @@ -59,7 +77,6 @@ class GraphRuntimeCuGraph : public GraphRuntime { LOG(INFO) << "Num of nodes in the cuda graph created using stream capture API = " << numNodes; CUDA_CALL(cudaGraphInstantiate(&cu_graph_exec_, graph, NULL, NULL, 0)); - return 0; } /*! @@ -70,7 +87,9 @@ class GraphRuntimeCuGraph : public GraphRuntime { PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); private: + /*! \brief The Cuda stream on which to capture a CUDA graph. */ TVMStreamHandle capture_stream_; + /*! \brief The captured CUDA graph will be instantiated to this. */ cudaGraphExec_t cu_graph_exec_; }; @@ -78,13 +97,12 @@ PackedFunc GraphRuntimeCuGraph::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { if (name == "run_cuda_graph") { return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->RunCudaGraph(); }); + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->RunCudaGraph(); }); } else if (name == "start_capture") { return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->StartCapture(); }); + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->StartCapture(); }); } else if (name == "end_capture") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->EndCapture(); }); + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->EndCapture(); }); } else { return GraphRuntime::GetFunction(name, sptr_to_self); } diff --git a/src/runtime/graph/graph_runtime_factory.cc b/src/runtime/graph/graph_runtime_factory.cc index 4d3993a9a36f..97010fc8cd15 100644 --- a/src/runtime/graph/graph_runtime_factory.cc +++ b/src/runtime/graph/graph_runtime_factory.cc @@ -72,6 +72,14 @@ PackedFunc GraphRuntimeFactory::GetFunction( exec->Import(this->imports_[0]); *rv = Module(exec); }); + } else if (name == "cuda_graph_create") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::vector contexts; + for (int i = 0; i < args.num_args; ++i) { + contexts.emplace_back(args[i].operator TVMContext()); + } + *rv = this->CudaGraphRuntimeCreate(contexts); + }); } else { return PackedFunc(); } @@ -130,6 +138,31 @@ Module GraphRuntimeFactory::DebugRuntimeCreate(const std::vector& ct return mod; } +Module GraphRuntimeFactory::CudaGraphRuntimeCreate(const std::vector& ctxs) { + const PackedFunc* pf = tvm::runtime::Registry::Get("tvm.graph_runtime_cugraph.create"); + ICHECK(pf != nullptr) << "Cannot find function tvm.graph_runtime_cugraph.create in registry. " + "Do you enable cuda graph runtime build?"; + std::vector unpacked_ctxs; + for (const auto& ctx : ctxs) { + unpacked_ctxs.emplace_back(ctx.device_type); + unpacked_ctxs.emplace_back(ctx.device_id); + } + size_t args_size = unpacked_ctxs.size() + 2; + std::vector values(args_size); + std::vector codes(args_size); + runtime::TVMArgsSetter setter(values.data(), codes.data()); + setter(0, this->graph_json_); + setter(1, this->imports_[0]); + for (size_t i = 0; i < unpacked_ctxs.size(); ++i) { + setter(i + 2, unpacked_ctxs[i]); + } + TVMRetValue rv; + pf->CallPacked(TVMArgs(values.data(), codes.data(), args_size), &rv); + Module mod = rv.operator Module(); + SetParams(const_cast(mod.as()), this->params_); + return mod; +} + Module GraphRuntimeFactoryModuleLoadBinary(void* strm) { dmlc::Stream* stream = static_cast(strm); std::string graph_json; diff --git a/src/runtime/graph/graph_runtime_factory.h b/src/runtime/graph/graph_runtime_factory.h index 98fb27c43ea2..f2f11ee66802 100644 --- a/src/runtime/graph/graph_runtime_factory.h +++ b/src/runtime/graph/graph_runtime_factory.h @@ -89,6 +89,14 @@ class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { */ Module DebugRuntimeCreate(const std::vector& ctxs); + /*! + * \brief Create a specific cuda graph runtime module + * \param ctxs The context of the host and devices where graph nodes will be + * executed on. + * \return created cuda graph runtime module + */ + Module CudaGraphRuntimeCreate(const std::vector& ctx); + /*! * \brief Set params. * \param graph_runtime The graph runtime we want to set the params into. diff --git a/tests/python/unittest/test_runtime_graph_cugraph.py b/tests/python/unittest/test_runtime_graph_cugraph.py index f4999cc96446..32dee733b184 100644 --- a/tests/python/unittest/test_runtime_graph_cugraph.py +++ b/tests/python/unittest/test_runtime_graph_cugraph.py @@ -35,7 +35,7 @@ tx = te.thread_axis("threadIdx.x") -@tvm.testing.requires_cuda +@tvm.testing.requires_cudagraph def test_graph_simple(): n = 32 A = te.placeholder((n,), name="A") @@ -78,6 +78,14 @@ def check_verify(): mod = cugraph_runtime.create(graph, mlib, ctx) except ValueError: return + + for i in range(3): + a = np.random.uniform(size=(n,)).astype(A.dtype) + mod.run(x=a) # The first run captured a CUDA graph + out = mod.get_output(0, tvm.nd.empty((n,))) + np.testing.assert_equal(out.asnumpy(), a + 1) + + # capture / run CUDA graph manually mod.capture_cuda_graph() a = np.random.uniform(size=(n,)).astype(A.dtype) mod.set_input(x=a) diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index a34fe4a062cb..6aba4414c253 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -20,6 +20,7 @@ import tvm from tvm.contrib import graph_runtime from tvm.contrib.debugger import debug_runtime +from tvm.contrib.cu_graph import cugraph_runtime import tvm.testing @@ -538,6 +539,35 @@ def test_debug_graph_runtime(): tvm.testing.assert_allclose(out, verify(data), atol=1e-5) +@tvm.testing.requires_cudagraph +def test_cuda_graph_runtime(): + mod, params = relay.testing.synthetic.get_workload() + with tvm.transform.PassContext(opt_level=3): + complied_graph_lib = relay.build_module.build(mod, "cuda", params=params) + data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32") + + ctx = tvm.gpu() + try: + gmod = complied_graph_lib["cuda_graph_create"](ctx) + except: + print("Skip because cuda_graph not enabled") + return + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + set_input("data", tvm.nd.array(data)) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + # cuda graph runtime wrapper + cu_gmod = cugraph_runtime.GraphModuleCuGraph(gmod) + cu_gmod.set_input("data", data) + cu_gmod.run() + out = cu_gmod.get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + def test_multiple_imported_modules(): def make_func(symbol): n = tvm.te.size_var("n") From 96f94eba94f3de6fa80fb06d4a916e4f833346fc Mon Sep 17 00:00:00 2001 From: zhuochen Date: Tue, 16 Mar 2021 19:50:00 +0800 Subject: [PATCH 05/17] Update CMakeLists.txt Co-authored-by: Cody Yu --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 126a32fbe376..fc64e8054941 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -35,7 +35,7 @@ tvm_option(USE_THREADS "Build with thread support" ON) tvm_option(USE_LLVM "Build with LLVM, can be set to specific llvm-config path" OFF) tvm_option(USE_STACKVM_RUNTIME "Include stackvm into the runtime" OFF) tvm_option(USE_GRAPH_RUNTIME "Build with tiny graph runtime" ON) -tvm_option(USE_GRAPH_RUNTIME_CUGRAPH "Build with tiny graph runtime cuGraph launch mode" OFF) +tvm_option(USE_GRAPH_RUNTIME_CUGRAPH "Build with tiny graph runtime with cuGraph for GPUs" OFF) tvm_option(USE_PROFILER "Build profiler for the VM and graph runtime" ON) tvm_option(USE_OPENMP "Build with OpenMP thread pool implementation" OFF) tvm_option(USE_RELAY_DEBUG "Building Relay in debug mode..." OFF) From f286711e4126c696860be3ec3d82400ca8542bd5 Mon Sep 17 00:00:00 2001 From: zhuochen Date: Tue, 16 Mar 2021 20:19:04 +0800 Subject: [PATCH 06/17] build cuda graph runtime in gpu test --- tests/scripts/task_config_build_gpu.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/scripts/task_config_build_gpu.sh b/tests/scripts/task_config_build_gpu.sh index 13dfb4136547..e8ae3b9cbbb4 100755 --- a/tests/scripts/task_config_build_gpu.sh +++ b/tests/scripts/task_config_build_gpu.sh @@ -26,6 +26,7 @@ cp ../cmake/config.cmake . echo set\(USE_CUBLAS ON\) >> config.cmake echo set\(USE_CUDNN ON\) >> config.cmake echo set\(USE_CUDA ON\) >> config.cmake +echo set\(USE_GRAPH_RUNTIME_CUGRAPH ON\) >> config.cmake echo set\(USE_OPENGL ON\) >> config.cmake echo set\(USE_MICRO ON\) >> config.cmake echo set\(USE_MICRO_STANDALONE_RUNTIME ON\) >> config.cmake From cc36e34c1e9b13a416d98b0d3474ee7d0d7ec62f Mon Sep 17 00:00:00 2001 From: zhuochen Date: Tue, 16 Mar 2021 20:27:50 +0800 Subject: [PATCH 07/17] Revert "build cuda graph runtime in gpu test" This reverts commit f286711e4126c696860be3ec3d82400ca8542bd5. --- tests/scripts/task_config_build_gpu.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/scripts/task_config_build_gpu.sh b/tests/scripts/task_config_build_gpu.sh index e8ae3b9cbbb4..13dfb4136547 100755 --- a/tests/scripts/task_config_build_gpu.sh +++ b/tests/scripts/task_config_build_gpu.sh @@ -26,7 +26,6 @@ cp ../cmake/config.cmake . echo set\(USE_CUBLAS ON\) >> config.cmake echo set\(USE_CUDNN ON\) >> config.cmake echo set\(USE_CUDA ON\) >> config.cmake -echo set\(USE_GRAPH_RUNTIME_CUGRAPH ON\) >> config.cmake echo set\(USE_OPENGL ON\) >> config.cmake echo set\(USE_MICRO ON\) >> config.cmake echo set\(USE_MICRO_STANDALONE_RUNTIME ON\) >> config.cmake From 604980d652a240e46275e85bd15bd5b1a68f9953 Mon Sep 17 00:00:00 2001 From: zhuochen Date: Wed, 17 Mar 2021 15:28:25 +0800 Subject: [PATCH 08/17] rename cuGraph to CUDA Graph --- CMakeLists.txt | 2 +- cmake/config.cmake | 4 +- cmake/modules/CUDA.cmake | 12 +- .../tvm/contrib/cu_graph/cugraph_runtime.py | 121 ---------------- .../graph/cugraph/graph_runtime_cugraph.cc | 134 ------------------ src/runtime/graph/graph_runtime_factory.cc | 4 +- .../unittest/test_runtime_graph_cugraph.py | 100 ------------- .../test_runtime_module_based_interface.py | 4 +- 8 files changed, 13 insertions(+), 368 deletions(-) delete mode 100644 python/tvm/contrib/cu_graph/cugraph_runtime.py delete mode 100644 src/runtime/graph/cugraph/graph_runtime_cugraph.cc delete mode 100644 tests/python/unittest/test_runtime_graph_cugraph.py diff --git a/CMakeLists.txt b/CMakeLists.txt index fc64e8054941..16968ce41f70 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -35,7 +35,7 @@ tvm_option(USE_THREADS "Build with thread support" ON) tvm_option(USE_LLVM "Build with LLVM, can be set to specific llvm-config path" OFF) tvm_option(USE_STACKVM_RUNTIME "Include stackvm into the runtime" OFF) tvm_option(USE_GRAPH_RUNTIME "Build with tiny graph runtime" ON) -tvm_option(USE_GRAPH_RUNTIME_CUGRAPH "Build with tiny graph runtime with cuGraph for GPUs" OFF) +tvm_option(USE_GRAPH_RUNTIME_CUDA_GRAPH "Build with tiny graph runtime with CUDA Graph for GPUs" OFF) tvm_option(USE_PROFILER "Build profiler for the VM and graph runtime" ON) tvm_option(USE_OPENMP "Build with OpenMP thread pool implementation" OFF) tvm_option(USE_RELAY_DEBUG "Building Relay in debug mode..." OFF) diff --git a/cmake/config.cmake b/cmake/config.cmake index 81f148d1a389..60c718c97bc1 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -99,8 +99,8 @@ set(USE_STACKVM_RUNTIME OFF) # Whether enable tiny embedded graph runtime. set(USE_GRAPH_RUNTIME ON) -# Whether enable tiny graph runtime for cudaGraph Launch -set(USE_GRAPH_RUNTIME_CUGRAPH OFF) +# Whether enable tiny graph runtime with CUDA Graph +set(USE_GRAPH_RUNTIME_CUDA_GRAPH OFF) # Whether to enable the profiler for the graph runtime and vm set(USE_PROFILER ON) diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake index 593f9f6e23ad..262a4e6e7123 100644 --- a/cmake/modules/CUDA.cmake +++ b/cmake/modules/CUDA.cmake @@ -65,16 +65,16 @@ if(USE_CUDA) list(APPEND RUNTIME_SRCS ${CONTRIB_THRUST_SRC}) endif(USE_THRUST) - if(USE_GRAPH_RUNTIME_CUGRAPH) + if(USE_GRAPH_RUNTIME_CUDA_GRAPH) if(NOT USE_GRAPH_RUNTIME) - message(FATAL_ERROR "CUDA Graph is only supported by graph runtime, should set USE_GRAPH_RUNTIME=ON") + message(FATAL_ERROR "CUDA Graph is only supported by graph runtime, please set USE_GRAPH_RUNTIME=ON") endif() if(CUDAToolkit_VERSION_MAJOR LESS "10") - message(FATAL_ERROR "CUDA Graph requires at least CUDA 10, got=" ${CUDAToolkit_VERSION}) + message(FATAL_ERROR "CUDA Graph requires CUDA 10 or above, got=" ${CUDAToolkit_VERSION}) endif() - message(STATUS "Build with Graph runtime cuGraph support...") - file(GLOB RUNTIME_CUGRAPH_SRCS src/runtime/graph/cugraph/*.cc) - list(APPEND RUNTIME_SRCS ${RUNTIME_CUGRAPH_SRCS}) + message(STATUS "Build with Graph runtime with CUDA Graph support...") + file(GLOB RUNTIME_CUDA_GRAPH_SRCS src/runtime/graph/cuda_graph/*.cc) + list(APPEND RUNTIME_SRCS ${RUNTIME_CUDA_GRAPH_SRCS}) endif() else(USE_CUDA) list(APPEND COMPILER_SRCS src/target/opt/build_cuda_off.cc) diff --git a/python/tvm/contrib/cu_graph/cugraph_runtime.py b/python/tvm/contrib/cu_graph/cugraph_runtime.py deleted file mode 100644 index 90bfc95b2aa3..000000000000 --- a/python/tvm/contrib/cu_graph/cugraph_runtime.py +++ /dev/null @@ -1,121 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Graph runtime test cuGraph""" -import tvm._ffi - -from tvm._ffi.base import string_types -from tvm.contrib import graph_runtime - - -def create(graph_json_str, libmod, ctx): - """Create a runtime executor module given a graph and module. - - Parameters - ---------- - graph_json_str : str - The graph to be deployed in json format output by json graph. - The graph can contain operator(tvm_op) that points to the name - of PackedFunc in the libmod. - - libmod : tvm.runtime.Module - The module of the corresponding function - - ctx : TVMContext - The context to deploy the module, only supports CUDA GPU - - Returns - ------- - graph_module : GraphModuleCuGraph - CUDA graph runtime module that can be used to execute the graph. - - Note - ---- - See also :py:class:`tvm.contrib.cu_graph.GraphModuleCuGraph` - for examples to directly construct a GraphModuleCuGraph from an exported - relay compiled library. - """ - assert isinstance(graph_json_str, string_types) - try: - ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx) - if num_rpc_ctx == len(ctx): - fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime_cugraph.create") - else: - fcreate = tvm._ffi.get_global_func("tvm.graph_runtime_cugraph.create") - except ValueError: - raise ValueError( - "Please set '(USE_GRAPH_RUNTIME_CUGRAPH ON)' in " - "config.cmake and rebuild TVM to enable CUDA graph support" - ) - - return GraphModuleCuGraph(fcreate(graph_json_str, libmod, *device_type_id)) - - -class GraphModuleCuGraph(graph_runtime.GraphModule): - """CUDA graph runtime module. - - This is a CUDA graph runtime wrapper over the TVM runtime. - Runtime interfaces are wrapped with CUDA graph functionalities. - - Parameters - ---------- - module : Module - The internal tvm module that holds the actual graph functions. - """ - - def __init__(self, module): - self._start_capture = module["start_capture"] - self._end_capture = module["end_capture"] - self._run_cuda_graph = module["run_cuda_graph"] - self._cuda_graph_captured = False - graph_runtime.GraphModule.__init__(self, module) - - def capture_cuda_graph(self): - """Capture a CUDA graph for tvm_op graph - - This should be called before run_cuda_graph() to capture and - instantiate a CUDA graph instance. - """ - self._run() # call cuModuleLoadData before cudaStream API - self._start_capture() - self._run() - self._end_capture() - self._cuda_graph_captured = True - - def run_cuda_graph(self): - """Run the CUDA graph for tvm_op graph - - Run the captured CUDA graph instance instead of the - for-loop kernel launch of default graph runtime - """ - self._run_cuda_graph() - - def run(self, **input_dict): - """A run wrapper for graph capture / launch, user can just - change default graph runtime to cuda graph runtime, and - the first call will capture a cuda graph for future launch - - Parameters - ---------- - input_dict: dict of str to NDArray - List of input values to be feed to - """ - if input_dict: - self.set_input(**input_dict) - if not self._cuda_graph_captured: - self.capture_cuda_graph() - else: - self._run_cuda_graph() diff --git a/src/runtime/graph/cugraph/graph_runtime_cugraph.cc b/src/runtime/graph/cugraph/graph_runtime_cugraph.cc deleted file mode 100644 index 59af561d8b54..000000000000 --- a/src/runtime/graph/cugraph/graph_runtime_cugraph.cc +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file graph_runtime_cugraph.cc - */ - -#include - -#include "../../cuda/cuda_common.h" -#include "../graph_runtime.h" - -namespace tvm { -namespace runtime { - -/*! - * \brief Graph runtime with CUDA Graph Support. - * - * This is the extension of GraphRuntime class used for CUDA graph launch - * instead of CUDA kernel launch. CUDA graph launch requires CUDA 10.0 or - * above, currently there are two ways of constructing CUDA graphs: - * (1) Using CUDA stream capture API to capture a series of operations on - * CUDA stream, and automatically generates a graph (2) Building a graph - * using CUDA graph API manually. This implementation uses stream capture. - */ -class GraphRuntimeCuGraph : public GraphRuntime { - public: - /*! - * \brief Begin CUDA graph capture on stream, the stream enters capture mode. - */ - void StartCapture() { - const TVMContext& ctx = data_entry_[entry_id(0, 0)]->ctx; - - TVMStreamCreate(ctx.device_type, ctx.device_id, &capture_stream_); - TVMSetStream(ctx.device_type, ctx.device_id, capture_stream_); - - CUDA_CALL(cudaStreamBeginCapture(static_cast(capture_stream_), - cudaStreamCaptureModeGlobal)); - } - - /*! - * \brief Launch the instantiated graph on stream - */ - void RunCudaGraph() { - cudaStream_t cuStream = static_cast(capture_stream_); - CUDA_CALL(cudaGraphLaunch(cu_graph_exec_, cuStream)); - CUDA_CALL(cudaStreamSynchronize(cuStream)); - } - - /*! - * \brief End CUDA graph capture on stream, a graph will be created and - * instantiated. - */ - void EndCapture() { - cudaGraph_t graph; - CUDA_CALL(cudaStreamEndCapture(static_cast(capture_stream_), &graph)); - - cudaGraphNode_t* nodes = NULL; - size_t numNodes = 0; - CUDA_CALL(cudaGraphGetNodes(graph, nodes, &numNodes)); - LOG(INFO) << "Num of nodes in the cuda graph created using stream capture API = " << numNodes; - - CUDA_CALL(cudaGraphInstantiate(&cu_graph_exec_, graph, NULL, NULL, 0)); - } - - /*! - * \brief GetFunction Get the function based on input. - * \param name The function which needs to be invoked. - * \param sptr_to_self Packed function pointer. - */ - PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); - - private: - /*! \brief The Cuda stream on which to capture a CUDA graph. */ - TVMStreamHandle capture_stream_; - /*! \brief The captured CUDA graph will be instantiated to this. */ - cudaGraphExec_t cu_graph_exec_; -}; - -PackedFunc GraphRuntimeCuGraph::GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) { - if (name == "run_cuda_graph") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->RunCudaGraph(); }); - } else if (name == "start_capture") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->StartCapture(); }); - } else if (name == "end_capture") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->EndCapture(); }); - } else { - return GraphRuntime::GetFunction(name, sptr_to_self); - } -} - -Module GraphRuntimeCuGraphCreate(const std::string& sym_json, const tvm::runtime::Module& m, - const std::vector& ctxs, - PackedFunc lookup_linked_param_func) { - auto exec = make_object(); - exec->Init(sym_json, m, ctxs, lookup_linked_param_func); - return Module(exec); -} - -TVM_REGISTER_GLOBAL("tvm.graph_runtime_cugraph.create").set_body([](TVMArgs args, TVMRetValue* rv) { - ICHECK_GE(args.num_args, 4) << "The expected number of arguments for graph_runtime.create is " - "at least 4, but it has " - << args.num_args; - PackedFunc lookup_linked_param_func; - int ctx_start_arg = 2; - if (args[2].type_code() == kTVMPackedFuncHandle) { - lookup_linked_param_func = args[2]; - ctx_start_arg++; - } - - *rv = GraphRuntimeCuGraphCreate(args[0], args[1], GetAllContext(args, ctx_start_arg), - lookup_linked_param_func); -}); -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/graph/graph_runtime_factory.cc b/src/runtime/graph/graph_runtime_factory.cc index 7562e34d2ffc..c04347927f6b 100644 --- a/src/runtime/graph/graph_runtime_factory.cc +++ b/src/runtime/graph/graph_runtime_factory.cc @@ -139,8 +139,8 @@ Module GraphRuntimeFactory::DebugRuntimeCreate(const std::vector& ct } Module GraphRuntimeFactory::CudaGraphRuntimeCreate(const std::vector& ctxs) { - const PackedFunc* pf = tvm::runtime::Registry::Get("tvm.graph_runtime_cugraph.create"); - ICHECK(pf != nullptr) << "Cannot find function tvm.graph_runtime_cugraph.create in registry. " + const PackedFunc* pf = tvm::runtime::Registry::Get("tvm.graph_runtime_cuda_graph.create"); + ICHECK(pf != nullptr) << "Cannot find function tvm.graph_runtime_cuda_graph.create in registry. " "Do you enable cuda graph runtime build?"; std::vector unpacked_ctxs; for (const auto& ctx : ctxs) { diff --git a/tests/python/unittest/test_runtime_graph_cugraph.py b/tests/python/unittest/test_runtime_graph_cugraph.py deleted file mode 100644 index 32dee733b184..000000000000 --- a/tests/python/unittest/test_runtime_graph_cugraph.py +++ /dev/null @@ -1,100 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import json -import os -import re -import sys -import time - -import pytest - -import tvm -import tvm.testing -from tvm import te -import numpy as np - -from tvm.contrib import utils, graph_runtime -from tvm.contrib.cu_graph import cugraph_runtime - - -bx = te.thread_axis("blockIdx.x") -tx = te.thread_axis("threadIdx.x") - - -@tvm.testing.requires_cudagraph -def test_graph_simple(): - n = 32 - A = te.placeholder((n,), name="A") - B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B") - s = te.create_schedule(B.op) - xo, xi = s[B].split(B.op.axis[0], factor=8) - s[B].bind(xo, bx) - s[B].bind(xi, tx) - - node0 = {"op": "null", "name": "x", "inputs": []} - node1 = { - "op": "tvm_op", - "name": "add", - "inputs": [[0, 0, 0]], - "attrs": {"func_name": "myadd", "flatten_data": "1", "num_inputs": "1", "num_outputs": "1"}, - } - nodes = [node0, node1] - arg_nodes = [0] - node_row_ptr = [0, 1, 2] - outputs = [[1, 0, 0]] - shape = (n,) - attrs = { - "shape": ["list_shape", [shape, shape]], - "dltype": ["list_str", ["float32", "float32"]], - "storage_id": ["list_int", [0, 1]], - } - graph = { - "nodes": nodes, - "arg_nodes": arg_nodes, - "node_row_ptr": node_row_ptr, - "heads": outputs, - "attrs": attrs, - } - graph = json.dumps(graph) - - def check_verify(): - mlib = tvm.build(s, [A, B], "cuda", name="myadd") - ctx = tvm.gpu(0) - try: - mod = cugraph_runtime.create(graph, mlib, ctx) - except ValueError: - return - - for i in range(3): - a = np.random.uniform(size=(n,)).astype(A.dtype) - mod.run(x=a) # The first run captured a CUDA graph - out = mod.get_output(0, tvm.nd.empty((n,))) - np.testing.assert_equal(out.asnumpy(), a + 1) - - # capture / run CUDA graph manually - mod.capture_cuda_graph() - a = np.random.uniform(size=(n,)).astype(A.dtype) - mod.set_input(x=a) - mod.run_cuda_graph() - out = mod.get_output(0, tvm.nd.empty((n,))) - np.testing.assert_equal(out.asnumpy(), a + 1) - - check_verify() - - -if __name__ == "__main__": - test_graph_simple() diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index 6aba4414c253..930011d4fd33 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -20,7 +20,7 @@ import tvm from tvm.contrib import graph_runtime from tvm.contrib.debugger import debug_runtime -from tvm.contrib.cu_graph import cugraph_runtime +from tvm.contrib.cuda_graph import cuda_graph_runtime import tvm.testing @@ -561,7 +561,7 @@ def test_cuda_graph_runtime(): tvm.testing.assert_allclose(out, verify(data), atol=1e-5) # cuda graph runtime wrapper - cu_gmod = cugraph_runtime.GraphModuleCuGraph(gmod) + cu_gmod = cuda_graph_runtime.GraphModuleCudaGraph(gmod) cu_gmod.set_input("data", data) cu_gmod.run() out = cu_gmod.get_output(0).asnumpy() From 41e6b9a820f6e568399cd39bcf4c8376b705350b Mon Sep 17 00:00:00 2001 From: zhuochen Date: Wed, 17 Mar 2021 15:43:30 +0800 Subject: [PATCH 09/17] rename cuda_graph --- python/tvm/contrib/cuda_graph/__init__.py | 16 +++ .../contrib/cuda_graph/cuda_graph_runtime.py | 121 ++++++++++++++++++ 2 files changed, 137 insertions(+) create mode 100644 python/tvm/contrib/cuda_graph/__init__.py create mode 100644 python/tvm/contrib/cuda_graph/cuda_graph_runtime.py diff --git a/python/tvm/contrib/cuda_graph/__init__.py b/python/tvm/contrib/cuda_graph/__init__.py new file mode 100644 index 000000000000..d216be4ddc94 --- /dev/null +++ b/python/tvm/contrib/cuda_graph/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. \ No newline at end of file diff --git a/python/tvm/contrib/cuda_graph/cuda_graph_runtime.py b/python/tvm/contrib/cuda_graph/cuda_graph_runtime.py new file mode 100644 index 000000000000..810b16a2d322 --- /dev/null +++ b/python/tvm/contrib/cuda_graph/cuda_graph_runtime.py @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Graph runtime with CUDA Graph""" +import tvm._ffi + +from tvm._ffi.base import string_types +from tvm.contrib import graph_runtime +tvm.contrib.graph_runtime.GraphModule + +def create(graph_json_str, libmod, ctx): + """Create a runtime executor module given a graph and module. + + Parameters + ---------- + graph_json_str : str + The graph to be deployed in json format output by json graph. + The graph can contain operator(tvm_op) that points to the name + of PackedFunc in the libmod. + + libmod : tvm.runtime.Module + The module of the corresponding function + + ctx : TVMContext + The context to deploy the module, only supports CUDA GPU + + Returns + ------- + graph_module : GraphModuleCudaGraph + CUDA graph runtime module that can be used to execute the graph. + + Note + ---- + See also :py:class:`tvm.contrib.cuda_graph.cuda_graph_runtime.GraphModuleCudaGraph` + for examples to directly construct a GraphModuleCudaGraph from an exported + relay compiled library. + """ + assert isinstance(graph_json_str, string_types) + try: + ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx) + if num_rpc_ctx == len(ctx): + fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime_cuda_graph.create") + else: + fcreate = tvm._ffi.get_global_func("tvm.graph_runtime_cuda_graph.create") + except ValueError: + raise ValueError( + "Please set '(USE_GRAPH_RUNTIME_CUDA_GRAPH ON)' in " + "config.cmake and rebuild TVM to enable CUDA graph support" + ) + + return GraphModuleCudaGraph(fcreate(graph_json_str, libmod, *device_type_id)) + + +class GraphModuleCudaGraph(graph_runtime.GraphModule): + """CUDA graph runtime module. + + This is a CUDA graph runtime wrapper over the TVM runtime. + Runtime interfaces are wrapped with CUDA graph functionalities. + + Parameters + ---------- + module : Module + The internal tvm module that holds the actual graph functions. + """ + + def __init__(self, module): + self._start_capture = module["start_capture"] + self._end_capture = module["end_capture"] + self._run_cuda_graph = module["run_cuda_graph"] + self._cuda_graph_captured = False + graph_runtime.GraphModule.__init__(self, module) + + def capture_cuda_graph(self): + """Capture a CUDA graph for tvm_op graph + + This should be called before run_cuda_graph() to capture and + instantiate a CUDA graph instance. + """ + self._run() # call cuModuleLoadData before cudaStream API + self._start_capture() + self._run() + self._end_capture() + self._cuda_graph_captured = True + + def run_cuda_graph(self): + """Run the CUDA graph for tvm_op graph + + Run the captured CUDA graph instance instead of the + for-loop kernel launch of default graph runtime + """ + self._run_cuda_graph() + + def run(self, **input_dict): + """A run wrapper for graph capture / launch, user can just + change default graph runtime to cuda graph runtime, and + the first call will capture a cuda graph for future launch + + Parameters + ---------- + input_dict: dict of str to NDArray + List of input values to be feed to + """ + if input_dict: + self.set_input(**input_dict) + if not self._cuda_graph_captured: + self.capture_cuda_graph() + else: + self._run_cuda_graph() From f437a2d928c5d1f4cd01b51dff4893352e0225bc Mon Sep 17 00:00:00 2001 From: zhuochen Date: Wed, 17 Mar 2021 15:43:56 +0800 Subject: [PATCH 10/17] rename cuda_graph --- .../cuda_graph/graph_runtime_cuda_graph.cc | 134 ++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 src/runtime/graph/cuda_graph/graph_runtime_cuda_graph.cc diff --git a/src/runtime/graph/cuda_graph/graph_runtime_cuda_graph.cc b/src/runtime/graph/cuda_graph/graph_runtime_cuda_graph.cc new file mode 100644 index 000000000000..4a743611152b --- /dev/null +++ b/src/runtime/graph/cuda_graph/graph_runtime_cuda_graph.cc @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file graph_runtime_cuda_graph.cc + */ + +#include + +#include "../../cuda/cuda_common.h" +#include "../graph_runtime.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief Graph runtime with CUDA Graph Support. + * + * This is the extension of GraphRuntime class used for CUDA graph launch + * instead of CUDA kernel launch. CUDA graph launch requires CUDA 10.0 or + * above, currently there are two ways of constructing CUDA graphs: + * (1) Using CUDA stream capture API to capture a series of operations on + * CUDA stream, and automatically generates a graph (2) Building a graph + * using CUDA graph API manually. This implementation uses stream capture. + */ +class GraphRuntimeCudaGraph : public GraphRuntime { + public: + /*! + * \brief Begin CUDA graph capture on stream, the stream enters capture mode. + */ + void StartCapture() { + const TVMContext& ctx = data_entry_[entry_id(0, 0)]->ctx; + + TVMStreamCreate(ctx.device_type, ctx.device_id, &capture_stream_); + TVMSetStream(ctx.device_type, ctx.device_id, capture_stream_); + + CUDA_CALL(cudaStreamBeginCapture(static_cast(capture_stream_), + cudaStreamCaptureModeGlobal)); + } + + /*! + * \brief Launch the instantiated graph on stream + */ + void RunCudaGraph() { + cudaStream_t cuStream = static_cast(capture_stream_); + CUDA_CALL(cudaGraphLaunch(cuda_graph_exec_, cuStream)); + CUDA_CALL(cudaStreamSynchronize(cuStream)); + } + + /*! + * \brief End CUDA graph capture on stream, a graph will be created and + * instantiated. + */ + void EndCapture() { + cudaGraph_t graph; + CUDA_CALL(cudaStreamEndCapture(static_cast(capture_stream_), &graph)); + + cudaGraphNode_t* nodes = NULL; + size_t numNodes = 0; + CUDA_CALL(cudaGraphGetNodes(graph, nodes, &numNodes)); + LOG(INFO) << "Num of nodes in the cuda graph created using stream capture API = " << numNodes; + + CUDA_CALL(cudaGraphInstantiate(&cuda_graph_exec_, graph, NULL, NULL, 0)); + } + + /*! + * \brief GetFunction Get the function based on input. + * \param name The function which needs to be invoked. + * \param sptr_to_self Packed function pointer. + */ + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); + + private: + /*! \brief The Cuda stream on which to capture a CUDA graph. */ + TVMStreamHandle capture_stream_; + /*! \brief The captured CUDA graph will be instantiated to this. */ + cudaGraphExec_t cuda_graph_exec_; +}; + +PackedFunc GraphRuntimeCudaGraph::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { + if (name == "run_cuda_graph") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->RunCudaGraph(); }); + } else if (name == "start_capture") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->StartCapture(); }); + } else if (name == "end_capture") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->EndCapture(); }); + } else { + return GraphRuntime::GetFunction(name, sptr_to_self); + } +} + +Module GraphRuntimeCudaGraphCreate(const std::string& sym_json, const tvm::runtime::Module& m, + const std::vector& ctxs, + PackedFunc lookup_linked_param_func) { + auto exec = make_object(); + exec->Init(sym_json, m, ctxs, lookup_linked_param_func); + return Module(exec); +} + +TVM_REGISTER_GLOBAL("tvm.graph_runtime_cuda_graph.create") + .set_body([](TVMArgs args, TVMRetValue* rv) { + ICHECK_GE(args.num_args, 4) << "The expected number of arguments for graph_runtime.create is " + "at least 4, but it has " << args.num_args; + PackedFunc lookup_linked_param_func; + int ctx_start_arg = 2; + if (args[2].type_code() == kTVMPackedFuncHandle) { + lookup_linked_param_func = args[2]; + ctx_start_arg++; + } + + *rv = GraphRuntimeCudaGraphCreate(args[0], args[1], GetAllContext(args, ctx_start_arg), + lookup_linked_param_func); +}); +} // namespace runtime +} // namespace tvm \ No newline at end of file From 88812acfb8e85017f3f9501a8337a59d80fa3eef Mon Sep 17 00:00:00 2001 From: zhuochen Date: Wed, 17 Mar 2021 15:53:32 +0800 Subject: [PATCH 11/17] lint format --- .../cuda_graph/graph_runtime_cuda_graph.cc | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/src/runtime/graph/cuda_graph/graph_runtime_cuda_graph.cc b/src/runtime/graph/cuda_graph/graph_runtime_cuda_graph.cc index 4a743611152b..ee5e50a3b9d4 100644 --- a/src/runtime/graph/cuda_graph/graph_runtime_cuda_graph.cc +++ b/src/runtime/graph/cuda_graph/graph_runtime_cuda_graph.cc @@ -117,18 +117,19 @@ Module GraphRuntimeCudaGraphCreate(const std::string& sym_json, const tvm::runti } TVM_REGISTER_GLOBAL("tvm.graph_runtime_cuda_graph.create") - .set_body([](TVMArgs args, TVMRetValue* rv) { - ICHECK_GE(args.num_args, 4) << "The expected number of arguments for graph_runtime.create is " - "at least 4, but it has " << args.num_args; - PackedFunc lookup_linked_param_func; - int ctx_start_arg = 2; - if (args[2].type_code() == kTVMPackedFuncHandle) { - lookup_linked_param_func = args[2]; - ctx_start_arg++; - } - - *rv = GraphRuntimeCudaGraphCreate(args[0], args[1], GetAllContext(args, ctx_start_arg), - lookup_linked_param_func); -}); + .set_body([](TVMArgs args, TVMRetValue* rv) { + ICHECK_GE(args.num_args, 4) << "The expected number of arguments for graph_runtime.create is " + "at least 4, but it has " + << args.num_args; + PackedFunc lookup_linked_param_func; + int ctx_start_arg = 2; + if (args[2].type_code() == kTVMPackedFuncHandle) { + lookup_linked_param_func = args[2]; + ctx_start_arg++; + } + + *rv = GraphRuntimeCudaGraphCreate(args[0], args[1], GetAllContext(args, ctx_start_arg), + lookup_linked_param_func); + }); } // namespace runtime -} // namespace tvm \ No newline at end of file +} // namespace tvm From f128580d420e7b08bf9fc58b3fbdfde462c40694 Mon Sep 17 00:00:00 2001 From: zhuochen Date: Wed, 17 Mar 2021 15:57:08 +0800 Subject: [PATCH 12/17] Update src/runtime/graph/graph_runtime_factory.cc Co-authored-by: Cody Yu --- src/runtime/graph/graph_runtime_factory.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/runtime/graph/graph_runtime_factory.cc b/src/runtime/graph/graph_runtime_factory.cc index c04347927f6b..1682afa8464a 100644 --- a/src/runtime/graph/graph_runtime_factory.cc +++ b/src/runtime/graph/graph_runtime_factory.cc @@ -141,7 +141,7 @@ Module GraphRuntimeFactory::DebugRuntimeCreate(const std::vector& ct Module GraphRuntimeFactory::CudaGraphRuntimeCreate(const std::vector& ctxs) { const PackedFunc* pf = tvm::runtime::Registry::Get("tvm.graph_runtime_cuda_graph.create"); ICHECK(pf != nullptr) << "Cannot find function tvm.graph_runtime_cuda_graph.create in registry. " - "Do you enable cuda graph runtime build?"; + "Did you set(USE_GRAPH_RUNTIME_CUGRAPH=ON)?"; std::vector unpacked_ctxs; for (const auto& ctx : ctxs) { unpacked_ctxs.emplace_back(ctx.device_type); From 67080c30b7f65a77eebc7f5441773c15919f31e6 Mon Sep 17 00:00:00 2001 From: zhuochen Date: Wed, 17 Mar 2021 15:57:24 +0800 Subject: [PATCH 13/17] Update python/tvm/testing.py Co-authored-by: Cody Yu --- python/tvm/testing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/testing.py b/python/tvm/testing.py index a4e7e640765b..d81279ea1b77 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -525,7 +525,7 @@ def requires_cudagraph(*args): Function to mark """ _requires_cudagraph = [ - pytest.mark.skipif(not nvcc.have_cudagraph(), reason="CUDA Graph not support"), + pytest.mark.skipif(not nvcc.have_cudagraph(), reason="CUDA Graph is not supported in this environment"), *requires_cuda(), ] return _compose(args, _requires_cudagraph) From ab4f8c3c22c8f234e48416f186d3ce95c47ff0b4 Mon Sep 17 00:00:00 2001 From: zhuochen Date: Wed, 17 Mar 2021 16:59:12 +0800 Subject: [PATCH 14/17] fix lint error --- python/tvm/contrib/cuda_graph/__init__.py | 2 +- python/tvm/contrib/cuda_graph/cuda_graph_runtime.py | 2 +- python/tvm/testing.py | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/tvm/contrib/cuda_graph/__init__.py b/python/tvm/contrib/cuda_graph/__init__.py index d216be4ddc94..13a83393a912 100644 --- a/python/tvm/contrib/cuda_graph/__init__.py +++ b/python/tvm/contrib/cuda_graph/__init__.py @@ -13,4 +13,4 @@ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations -# under the License. \ No newline at end of file +# under the License. diff --git a/python/tvm/contrib/cuda_graph/cuda_graph_runtime.py b/python/tvm/contrib/cuda_graph/cuda_graph_runtime.py index 810b16a2d322..ba0ec3558f20 100644 --- a/python/tvm/contrib/cuda_graph/cuda_graph_runtime.py +++ b/python/tvm/contrib/cuda_graph/cuda_graph_runtime.py @@ -19,7 +19,7 @@ from tvm._ffi.base import string_types from tvm.contrib import graph_runtime -tvm.contrib.graph_runtime.GraphModule + def create(graph_json_str, libmod, ctx): """Create a runtime executor module given a graph and module. diff --git a/python/tvm/testing.py b/python/tvm/testing.py index d81279ea1b77..2fe29701b810 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -525,7 +525,9 @@ def requires_cudagraph(*args): Function to mark """ _requires_cudagraph = [ - pytest.mark.skipif(not nvcc.have_cudagraph(), reason="CUDA Graph is not supported in this environment"), + pytest.mark.skipif( + not nvcc.have_cudagraph(), + reason="CUDA Graph is not supported in this environment"), *requires_cuda(), ] return _compose(args, _requires_cudagraph) From 6ae69c57d51c8490c9aee65329a7fa9a4314278d Mon Sep 17 00:00:00 2001 From: zhuochen Date: Wed, 17 Mar 2021 17:05:50 +0800 Subject: [PATCH 15/17] remove unnecessary warn --- python/tvm/contrib/nvcc.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 13744aa2584c..99844f799d7a 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -358,7 +358,6 @@ def have_cudagraph(): return False return True except RuntimeError: - warnings.warn("Cannot find cuda path") return False From eac3ce93e8bebae7fffccffd30d2d8ed25b50d79 Mon Sep 17 00:00:00 2001 From: zhuochen Date: Wed, 17 Mar 2021 17:23:22 +0800 Subject: [PATCH 16/17] add test, fix lint --- .../contrib/cuda_graph/cuda_graph_runtime.py | 4 +- python/tvm/testing.py | 4 +- .../unittest/test_runtime_graph_cuda_graph.py | 100 ++++++++++++++++++ 3 files changed, 104 insertions(+), 4 deletions(-) create mode 100644 tests/python/unittest/test_runtime_graph_cuda_graph.py diff --git a/python/tvm/contrib/cuda_graph/cuda_graph_runtime.py b/python/tvm/contrib/cuda_graph/cuda_graph_runtime.py index ba0ec3558f20..ef07bbbaaf2e 100644 --- a/python/tvm/contrib/cuda_graph/cuda_graph_runtime.py +++ b/python/tvm/contrib/cuda_graph/cuda_graph_runtime.py @@ -57,8 +57,8 @@ def create(graph_json_str, libmod, ctx): fcreate = tvm._ffi.get_global_func("tvm.graph_runtime_cuda_graph.create") except ValueError: raise ValueError( - "Please set '(USE_GRAPH_RUNTIME_CUDA_GRAPH ON)' in " - "config.cmake and rebuild TVM to enable CUDA graph support" + "To enable CUDA graph support (experimental), please set " + "'(USE_GRAPH_RUNTIME_CUGRAPH ON)' in config.cmake and rebuild TVM" ) return GraphModuleCudaGraph(fcreate(graph_json_str, libmod, *device_type_id)) diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 2fe29701b810..1cb43b29c521 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -526,8 +526,8 @@ def requires_cudagraph(*args): """ _requires_cudagraph = [ pytest.mark.skipif( - not nvcc.have_cudagraph(), - reason="CUDA Graph is not supported in this environment"), + not nvcc.have_cudagraph(), reason="CUDA Graph is not supported in this environment" + ), *requires_cuda(), ] return _compose(args, _requires_cudagraph) diff --git a/tests/python/unittest/test_runtime_graph_cuda_graph.py b/tests/python/unittest/test_runtime_graph_cuda_graph.py new file mode 100644 index 000000000000..4a31873cb93c --- /dev/null +++ b/tests/python/unittest/test_runtime_graph_cuda_graph.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import json +import os +import re +import sys +import time + +import pytest + +import tvm +import tvm.testing +from tvm import te +import numpy as np + +from tvm.contrib import utils, graph_runtime +from tvm.contrib.cuda_graph import cuda_graph_runtime + + +bx = te.thread_axis("blockIdx.x") +tx = te.thread_axis("threadIdx.x") + + +@tvm.testing.requires_cudagraph +def test_graph_simple(): + n = 32 + A = te.placeholder((n,), name="A") + B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B") + s = te.create_schedule(B.op) + xo, xi = s[B].split(B.op.axis[0], factor=8) + s[B].bind(xo, bx) + s[B].bind(xi, tx) + + node0 = {"op": "null", "name": "x", "inputs": []} + node1 = { + "op": "tvm_op", + "name": "add", + "inputs": [[0, 0, 0]], + "attrs": {"func_name": "myadd", "flatten_data": "1", "num_inputs": "1", "num_outputs": "1"}, + } + nodes = [node0, node1] + arg_nodes = [0] + node_row_ptr = [0, 1, 2] + outputs = [[1, 0, 0]] + shape = (n,) + attrs = { + "shape": ["list_shape", [shape, shape]], + "dltype": ["list_str", ["float32", "float32"]], + "storage_id": ["list_int", [0, 1]], + } + graph = { + "nodes": nodes, + "arg_nodes": arg_nodes, + "node_row_ptr": node_row_ptr, + "heads": outputs, + "attrs": attrs, + } + graph = json.dumps(graph) + + def check_verify(): + mlib = tvm.build(s, [A, B], "cuda", name="myadd") + ctx = tvm.gpu(0) + try: + mod = cuda_graph_runtime.create(graph, mlib, ctx) + except ValueError: + return + + for i in range(3): + a = np.random.uniform(size=(n,)).astype(A.dtype) + mod.run(x=a) # The first run captured a CUDA graph + out = mod.get_output(0, tvm.nd.empty((n,))) + np.testing.assert_equal(out.asnumpy(), a + 1) + + # capture / run CUDA graph manually + mod.capture_cuda_graph() + a = np.random.uniform(size=(n,)).astype(A.dtype) + mod.set_input(x=a) + mod.run_cuda_graph() + out = mod.get_output(0, tvm.nd.empty((n,))) + np.testing.assert_equal(out.asnumpy(), a + 1) + + check_verify() + + +if __name__ == "__main__": + test_graph_simple() From 262783136f7f6200de5b33decd42ddf41b8996b7 Mon Sep 17 00:00:00 2001 From: zhuochen Date: Wed, 17 Mar 2021 18:26:18 +0800 Subject: [PATCH 17/17] fix lint W0223 --- python/tvm/contrib/cuda_graph/cuda_graph_runtime.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/python/tvm/contrib/cuda_graph/cuda_graph_runtime.py b/python/tvm/contrib/cuda_graph/cuda_graph_runtime.py index ef07bbbaaf2e..45ec89d37b3d 100644 --- a/python/tvm/contrib/cuda_graph/cuda_graph_runtime.py +++ b/python/tvm/contrib/cuda_graph/cuda_graph_runtime.py @@ -119,3 +119,16 @@ def run(self, **input_dict): self.capture_cuda_graph() else: self._run_cuda_graph() + + def debug_get_output(self, node, out): + """Run graph up to node and get the output to out + + Parameters + ---------- + node : int / str + The node index or name + + out : NDArray + The output array container + """ + raise NotImplementedError("Please use debugger.debug_runtime as graph_runtime instead.")