From 44eb222be36dd1cccc4021782d4679529577727f Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Thu, 5 Nov 2020 00:04:22 +0000 Subject: [PATCH 1/3] Allocate data buffers for gpu fix --- .../contrib/tensorrt/tensorrt_builder.cc | 44 +++++++++++++++---- .../contrib/tensorrt/tensorrt_builder.h | 25 +++++++++-- .../contrib/tensorrt/tensorrt_runtime.cc | 38 ++++++++++++---- tests/python/contrib/test_tensorrt.py | 29 +++++++++--- 4 files changed, 111 insertions(+), 25 deletions(-) diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/contrib/tensorrt/tensorrt_builder.cc index d308200eba05..9425cad84f2d 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc @@ -37,9 +37,12 @@ namespace tvm { namespace runtime { namespace contrib { -TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger, size_t max_workspace_size, - bool use_implicit_batch, bool use_fp16, int batch_size) - : max_workspace_size_(max_workspace_size), +TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger, + const std::vector& data_entry, + size_t max_workspace_size, bool use_implicit_batch, bool use_fp16, + int batch_size) + : data_entry_(data_entry), + max_workspace_size_(max_workspace_size), use_implicit_batch_(use_implicit_batch), use_fp16_(use_fp16), batch_size_(batch_size) { @@ -63,7 +66,7 @@ TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger, size_t max_workspace_si #endif } -void TensorRTBuilder::AddInput(int nid, const JSONGraphNode& node) { +void TensorRTBuilder::AddInput(int nid, uint32_t entry_id, const JSONGraphNode& node) { auto node_name = node.GetOpName(); auto shapes = node.GetOpShape(); auto dtypes = node.GetOpDataType(); @@ -80,7 +83,8 @@ void TensorRTBuilder::AddInput(int nid, const JSONGraphNode& node) { ICHECK(TypeMatch(dtypes[i], kDLFloat, 32)) << "Only FP32 inputs are supported."; auto input_tensor = network_->addInput(name.c_str(), nvinfer1::DataType::kFLOAT, dims); node_output_map_[nid].push_back(TensorRTOpInput(input_tensor)); - network_input_names_.push_back(input_tensor->getName()); + network_input_names_.push_back(name); + entry_id_map_[name] = entry_id + i; } } @@ -94,14 +98,15 @@ void TensorRTBuilder::AddConstant(int nid, const DLTensor* data) { node_output_map_[nid] = {TensorRTOpInput(weight, shape)}; } -void TensorRTBuilder::AddOutput(const JSONGraphNodeEntry& node) { +void TensorRTBuilder::AddOutput(const JSONGraphNodeEntry& node, uint32_t entry_id) { auto it = node_output_map_.find(node.id_); ICHECK(it != node_output_map_.end()) << "Output was not found."; auto out_tensor = it->second[node.index_].tensor; std::string name = "tensorrt_output_" + std::to_string(network_output_names_.size()); out_tensor->setName(name.c_str()); network_->markOutput(*out_tensor); - network_output_names_.push_back(out_tensor->getName()); + network_output_names_.push_back(name); + entry_id_map_[name] = entry_id; } void TensorRTBuilder::AddLayer(int nid, const JSONGraphNode& node) { @@ -168,7 +173,16 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() { ICHECK_EQ(engine->getNbBindings(), network_input_names_.size() + network_output_names_.size()); nvinfer1::IExecutionContext* context = engine->createExecutionContext(); CleanUp(); - return {engine, context, network_input_names_, network_output_names_}; + + // Allocate I/O buffers on GPU for TVM inputs which are on a different context. + std::vector device_buffers(engine->getNbBindings()); + for (size_t i = 0; i < network_input_names_.size(); ++i) { + AllocateDeviceBufferIfNeeded(engine, network_input_names_[i], &device_buffers); + } + for (size_t i = 0; i < network_output_names_.size(); ++i) { + AllocateDeviceBufferIfNeeded(engine, network_output_names_[i], &device_buffers); + } + return {engine, context, network_input_names_, network_output_names_, device_buffers}; } nvinfer1::Weights TensorRTBuilder::GetDLTensorAsWeights(const DLTensor* dptr, @@ -217,6 +231,20 @@ void TensorRTBuilder::CleanUp() { } } +void TensorRTBuilder::AllocateDeviceBufferIfNeeded(nvinfer1::ICudaEngine* engine, + const std::string& name, + std::vector* device_buffers) { + const uint32_t entry_id = entry_id_map_[name]; + if (data_entry_[entry_id]->ctx.device_type != kDLGPU) { + const int binding_index = engine->getBindingIndex(name.c_str()); + ICHECK_NE(binding_index, -1); + std::vector shape(data_entry_[entry_id]->shape, + data_entry_[entry_id]->shape + data_entry_[entry_id]->ndim); + device_buffers->at(binding_index) = + runtime::NDArray::Empty(shape, data_entry_[entry_id]->dtype, {kDLGPU, 0}); + } +} + } // namespace contrib } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.h b/src/runtime/contrib/tensorrt/tensorrt_builder.h index efb4d8175650..30f01e73cfd0 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.h +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.h @@ -25,6 +25,8 @@ #ifndef TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_BUILDER_H_ #define TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_BUILDER_H_ +#include + #include #include #include @@ -50,6 +52,8 @@ struct TensorRTEngineAndContext { nvinfer1::IExecutionContext* context; std::vector inputs; std::vector outputs; + /*! \brief GPU buffers for inputs and outputs. */ + std::vector device_buffers; }; /*! @@ -69,15 +73,17 @@ class TensorRTBuilder { * \param use_fp16 Whether to use implicit batch mode (default) * \param batch_size If use_implicit_batch, */ - TensorRTBuilder(TensorRTLogger* logger, size_t max_workspace_size, bool use_implicit_batch, - bool use_fp16, int batch_size); + TensorRTBuilder(TensorRTLogger* logger, const std::vector& data_entry, + size_t max_workspace_size, bool use_implicit_batch, bool use_fp16, + int batch_size); /*! * \brief Add TensorRT input(s) for input node in network definition. * \param nid The input node id. + * \param entry_id The index into data_entry_ for first entry in node. * \param node The input node. */ - void AddInput(int nid, const JSONGraphNode& node); + void AddInput(int nid, uint32_t entry_id, const JSONGraphNode& node); /*! * \brief Add TensorRT weight for input constant in network definition. @@ -96,8 +102,9 @@ class TensorRTBuilder { /*! * \brief Mark TensorRT output in network definition. * \param entry The output node entry. + * \param entry_id The output node entry id. */ - void AddOutput(const JSONGraphNodeEntry& entry); + void AddOutput(const JSONGraphNodeEntry& entry, uint32_t entry_id); /*! * \brief Takes network definition and "compiles" a TensorRT engine which can be used for @@ -116,6 +123,10 @@ class TensorRTBuilder { /*! \brief Clean up resources used to create engine. */ void CleanUp(); + /*! \brief If the input DLTensor is not on the GPU, allocate a buffer for it. */ + void AllocateDeviceBufferIfNeeded(nvinfer1::ICudaEngine* engine, const std::string& name, + std::vector* device_buffers); + /*! \brief Maps a node to its outputs. */ std::unordered_map> node_output_map_; @@ -133,6 +144,12 @@ class TensorRTBuilder { /*! \brief List of all weights held in memory. */ std::vector trt_weights_; + /*! \brief Input and output tensors from TVM. */ + const std::vector& data_entry_; + + /*! \brief Map TensorRT binding name to index in data_entry_. */ + std::unordered_map entry_id_map_; + /*! \brief Max workspace size in bytes for TRT. */ size_t max_workspace_size_; diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index f183e2f24449..e7d16c0c4b5c 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -78,8 +78,6 @@ class TensorRTRuntime : public JSONRuntimeBase { LoadGlobalAttributes(); if (GetCachedEnginesFromDisk()) return; SetupConstants(consts); - BuildEngine(); - CacheEngineToDisk(); } void LoadGlobalAttributes() { @@ -106,9 +104,11 @@ class TensorRTRuntime : public JSONRuntimeBase { #ifdef TVM_GRAPH_RUNTIME_TENSORRT /*! \brief Run inference using built engine. */ void Run() override { + BuildEngine(); auto& engine_and_context = trt_engine_cache_.at(symbol_name_); auto engine = engine_and_context.engine; auto context = engine_and_context.context; + auto& device_buffers = engine_and_context.device_buffers; std::vector bindings(engine->getNbBindings(), nullptr); for (size_t i = 0; i < input_nodes_.size(); ++i) { @@ -119,7 +119,12 @@ class TensorRTRuntime : public JSONRuntimeBase { const std::string name = nodes_[nid].GetOpName() + "_" + std::to_string(j); int binding_index = engine->getBindingIndex(name.c_str()); ICHECK_NE(binding_index, -1); - bindings[binding_index] = data_entry_[eid]->data; + if (data_entry_[eid]->ctx.device_type == kDLGPU) { + bindings[binding_index] = data_entry_[eid]->data; + } else { + device_buffers[binding_index].CopyFrom(data_entry_[eid]); + bindings[binding_index] = reinterpret_cast(device_buffers[binding_index]->data); + } } } } @@ -129,7 +134,11 @@ class TensorRTRuntime : public JSONRuntimeBase { const std::string& name = engine_and_context.outputs[i]; int binding_index = engine->getBindingIndex(name.c_str()); ICHECK_NE(binding_index, -1); - bindings[binding_index] = data_entry_[eid]->data; + if (data_entry_[eid]->ctx.device_type == kDLGPU) { + bindings[binding_index] = reinterpret_cast(data_entry_[eid]->data); + } else { + bindings[binding_index] = reinterpret_cast(device_buffers[binding_index]->data); + } } #if TRT_VERSION_GE(6, 0, 1) @@ -141,6 +150,17 @@ class TensorRTRuntime : public JSONRuntimeBase { #else ICHECK(context->execute(batch_size_, bindings.data())) << "Running TensorRT failed."; #endif + + // Copy outputs from GPU buffers if needed. + for (size_t i = 0; i < outputs_.size(); ++i) { + uint32_t eid = EntryID(outputs_[i]); + const std::string& name = engine_and_context.outputs[i]; + int binding_index = engine->getBindingIndex(name.c_str()); + ICHECK_NE(binding_index, -1); + if (data_entry_[eid]->ctx.device_type != kDLGPU) { + device_buffers[binding_index].CopyTo(const_cast(data_entry_[eid])); + } + } } private: @@ -148,11 +168,12 @@ class TensorRTRuntime : public JSONRuntimeBase { * \brief Build TensorRT engine from JSON representation. */ void BuildEngine() { + if (trt_engine_cache_.count(symbol_name_)) return; DLOG(INFO) << "Building new TensorRT engine for subgraph " << symbol_name_; const bool use_fp16 = dmlc::GetEnv("TVM_TENSORRT_USE_FP16", false); batch_size_ = GetBatchSize(); - TensorRTBuilder builder(&logger_, max_workspace_size_, use_implicit_batch_, use_fp16, - batch_size_); + TensorRTBuilder builder(&logger_, data_entry_, max_workspace_size_, use_implicit_batch_, + use_fp16, batch_size_); // Add inputs and constants. for (size_t i = 0; i < input_nodes_.size(); ++i) { @@ -160,7 +181,7 @@ class TensorRTRuntime : public JSONRuntimeBase { const auto& node = nodes_[nid]; std::string name = node.GetOpName(); if (node.GetOpType() == "input") { - builder.AddInput(nid, node); + builder.AddInput(nid, EntryID(nid, 0), node); } else { ICHECK_EQ(node.GetOpType(), "const"); uint32_t eid = EntryID(nid, 0); @@ -177,12 +198,13 @@ class TensorRTRuntime : public JSONRuntimeBase { // Add outputs. for (size_t i = 0; i < outputs_.size(); ++i) { - builder.AddOutput(outputs_[i]); + builder.AddOutput(outputs_[i], EntryID(outputs_[i])); } // Build engine. trt_engine_cache_[symbol_name_] = builder.BuildEngine(); DLOG(INFO) << "Finished building TensorRT engine for subgraph " << symbol_name_; + CacheEngineToDisk(); } /*! \brief If TVM_TENSORRT_CACHE_DIR is set, will check that directory for diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index 9faf51f397f3..8e8e54e8650a 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -46,7 +46,7 @@ def skip_runtime_test(): return False -def run_and_verify_func(config): +def run_and_verify_func(config, target="cuda"): """Test a Relay func by compiling, running, and comparing TVM and TRT outputs. Parameters @@ -70,10 +70,11 @@ def run_and_verify_func(config): mod["main"] = f mod, config = tensorrt.partition_for_tensorrt(mod, params) with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}): - graph, lib, graph_params = relay.build(mod, "cuda", params=params) + graph, lib, graph_params = relay.build(mod, target, params=params) if skip_runtime_test(): return - mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0)) + ctx = tvm.context(target) + mod = graph_runtime.create(graph, lib, ctx=ctx) mod.set_input(**graph_params) mod.run(**input_dict) results = [mod.get_output(i) for i in range(mod.get_num_outputs())] @@ -82,8 +83,8 @@ def run_and_verify_func(config): mod = tvm.IRModule() mod["main"] = f with tvm.transform.PassContext(opt_level=3): - graph, lib, graph_params = relay.build(mod, "cuda", params=params) - mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0)) + graph, lib, graph_params = relay.build(mod, target, params=params) + mod = graph_runtime.create(graph, lib, ctx=ctx) mod.set_input(**graph_params) mod.run(**input_dict) ref_results = [mod.get_output(i) for i in range(mod.get_num_outputs())] @@ -188,6 +189,23 @@ def test_tensorrt_simple(): results = [mod.get_output(i).asnumpy() for i in range(mod.get_num_outputs())] +def test_tensorrt_simple_cpu_io(): + def get_graph(): + dtype = "float32" + x_shape = (1, 3, 2, 2) + y_shape = (1, 3, 1, 1) + z_shape = (1, 1, 1, 1) + x = relay.var("x", shape=(x_shape), dtype=dtype) + y = relay.var("y", shape=(y_shape), dtype=dtype) + z = relay.var("z", shape=(z_shape), dtype=dtype) + w = z * (x + y) + out = relay.nn.relu(w) + f = relay.Function([x, y, z], out) + return f, {"x": x_shape, "y": y_shape, "z": z_shape}, ["y"] + + run_and_verify_func(get_graph(), target="llvm") + + def test_tensorrt_not_compatible(): if skip_codegen_test(): return @@ -859,6 +877,7 @@ def test_densenet121(): if __name__ == "__main__": test_tensorrt_not_compatible() test_tensorrt_simple() + test_tensorrt_simple_cpu_io() test_tensorrt_serialize() # Op tests From 5a02216002455bee1896d20e99b78594c8ad6744 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Sat, 7 Nov 2020 16:53:24 +0000 Subject: [PATCH 2/3] Rename AllocateDeviceBuffer, update docstrings --- src/runtime/contrib/tensorrt/tensorrt_builder.cc | 9 ++++----- src/runtime/contrib/tensorrt/tensorrt_builder.h | 8 +++++--- src/runtime/contrib/tensorrt/tensorrt_runtime.cc | 3 ++- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/contrib/tensorrt/tensorrt_builder.cc index 9425cad84f2d..4060b240cf8e 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc @@ -177,10 +177,10 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() { // Allocate I/O buffers on GPU for TVM inputs which are on a different context. std::vector device_buffers(engine->getNbBindings()); for (size_t i = 0; i < network_input_names_.size(); ++i) { - AllocateDeviceBufferIfNeeded(engine, network_input_names_[i], &device_buffers); + AllocateDeviceBuffer(engine, network_input_names_[i], &device_buffers); } for (size_t i = 0; i < network_output_names_.size(); ++i) { - AllocateDeviceBufferIfNeeded(engine, network_output_names_[i], &device_buffers); + AllocateDeviceBuffer(engine, network_output_names_[i], &device_buffers); } return {engine, context, network_input_names_, network_output_names_, device_buffers}; } @@ -231,9 +231,8 @@ void TensorRTBuilder::CleanUp() { } } -void TensorRTBuilder::AllocateDeviceBufferIfNeeded(nvinfer1::ICudaEngine* engine, - const std::string& name, - std::vector* device_buffers) { +void TensorRTBuilder::AllocateDeviceBuffer(nvinfer1::ICudaEngine* engine, const std::string& name, + std::vector* device_buffers) { const uint32_t entry_id = entry_id_map_[name]; if (data_entry_[entry_id]->ctx.device_type != kDLGPU) { const int binding_index = engine->getBindingIndex(name.c_str()); diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.h b/src/runtime/contrib/tensorrt/tensorrt_builder.h index 30f01e73cfd0..4926a4d02685 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.h +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.h @@ -123,9 +123,11 @@ class TensorRTBuilder { /*! \brief Clean up resources used to create engine. */ void CleanUp(); - /*! \brief If the input DLTensor is not on the GPU, allocate a buffer for it. */ - void AllocateDeviceBufferIfNeeded(nvinfer1::ICudaEngine* engine, const std::string& name, - std::vector* device_buffers); + /*! \brief Allocate a GPU buffer for input or output DLTensor, only if the context is not GPU + * already. Inputs that are already on the GPU can be passed directly to TensorRT and will not + * need a buffer. */ + void AllocateDeviceBuffer(nvinfer1::ICudaEngine* engine, const std::string& name, + std::vector* device_buffers); /*! \brief Maps a node to its outputs. */ std::unordered_map> node_output_map_; diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index e7d16c0c4b5c..cfc2a3dde3a5 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -165,7 +165,8 @@ class TensorRTRuntime : public JSONRuntimeBase { private: /*! - * \brief Build TensorRT engine from JSON representation. + * \brief Build TensorRT engine from JSON representation and cache it. If engine is already built, + * do nothing. */ void BuildEngine() { if (trt_engine_cache_.count(symbol_name_)) return; From 0f630c7a72e2f5079c0db57b79720976dde101c0 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Sat, 7 Nov 2020 09:47:09 -0800 Subject: [PATCH 3/3] Remove unneeded cast --- src/runtime/contrib/tensorrt/tensorrt_runtime.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index cfc2a3dde3a5..445010321668 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -123,7 +123,7 @@ class TensorRTRuntime : public JSONRuntimeBase { bindings[binding_index] = data_entry_[eid]->data; } else { device_buffers[binding_index].CopyFrom(data_entry_[eid]); - bindings[binding_index] = reinterpret_cast(device_buffers[binding_index]->data); + bindings[binding_index] = device_buffers[binding_index]->data; } } } @@ -135,9 +135,9 @@ class TensorRTRuntime : public JSONRuntimeBase { int binding_index = engine->getBindingIndex(name.c_str()); ICHECK_NE(binding_index, -1); if (data_entry_[eid]->ctx.device_type == kDLGPU) { - bindings[binding_index] = reinterpret_cast(data_entry_[eid]->data); + bindings[binding_index] = data_entry_[eid]->data; } else { - bindings[binding_index] = reinterpret_cast(device_buffers[binding_index]->data); + bindings[binding_index] = device_buffers[binding_index]->data; } }