diff --git a/docs/deploy/tensorrt.rst b/docs/deploy/tensorrt.rst index a39d9c8edea7..7950fcfbdbc9 100644 --- a/docs/deploy/tensorrt.rst +++ b/docs/deploy/tensorrt.rst @@ -166,6 +166,14 @@ There are some additional options which can be configured at runtime using envir model can use. It is generally best to use the highest value which does not cause you to run out of memory. You can use ``TVM_TENSORRT_MAX_WORKSPACE_SIZE`` to override this by specifying the workspace size in bytes you would like to use. +* For models which contain a dynamic batch dimension, the varaible ``TVM_TENSORRT_MULTI_ENGINE`` + can be used to determine how TensorRT engines will be created at runtime. The default mode, + ``TVM_TENSORRT_MULTI_ENGINE=0``, will maintain only one engine in memory at a time. If an input + is encountered with a higher batch size, the engine will be rebuilt with the new max_batch_size + setting. That engine will be compatible with all batch sizes from 1 to max_batch_size. This mode + reduces the amount of memory used at runtime. The second mode, ``TVM_TENSORRT_MULTI_ENGINE=1`` + will build a unique TensorRT engine which is optimized for each batch size that is encountered. + This will give greater performance, but will consume more memory. Operator support diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/contrib/tensorrt/tensorrt_builder.cc index b8d6f6cd9ff0..d8182b0e8378 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc @@ -178,15 +178,7 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() { nvinfer1::IExecutionContext* context = engine->createExecutionContext(); CleanUp(); - // Allocate I/O buffers on GPU for TVM inputs which are on a different context. - std::vector device_buffers(engine->getNbBindings()); - for (size_t i = 0; i < network_input_names_.size(); ++i) { - AllocateDeviceBuffer(engine, network_input_names_[i], &device_buffers); - } - for (size_t i = 0; i < network_output_names_.size(); ++i) { - AllocateDeviceBuffer(engine, network_output_names_[i], &device_buffers); - } - return {engine, context, network_input_names_, network_output_names_, device_buffers}; + return {engine, context, network_input_names_, network_output_names_}; } nvinfer1::Weights TensorRTBuilder::GetDLTensorAsWeights(const DLTensor* dptr, @@ -245,19 +237,6 @@ void TensorRTBuilder::CleanUp() { } } -void TensorRTBuilder::AllocateDeviceBuffer(nvinfer1::ICudaEngine* engine, const std::string& name, - std::vector* device_buffers) { - const uint32_t entry_id = entry_id_map_[name]; - if (data_entry_[entry_id]->device.device_type != kDLCUDA) { - const int binding_index = engine->getBindingIndex(name.c_str()); - ICHECK_NE(binding_index, -1); - std::vector shape(data_entry_[entry_id]->shape, - data_entry_[entry_id]->shape + data_entry_[entry_id]->ndim); - device_buffers->at(binding_index) = - runtime::NDArray::Empty(shape, data_entry_[entry_id]->dtype, {kDLCUDA, 0}); - } -} - } // namespace contrib } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.h b/src/runtime/contrib/tensorrt/tensorrt_builder.h index 4926a4d02685..0b1c3997ec57 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.h +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.h @@ -52,8 +52,6 @@ struct TensorRTEngineAndContext { nvinfer1::IExecutionContext* context; std::vector inputs; std::vector outputs; - /*! \brief GPU buffers for inputs and outputs. */ - std::vector device_buffers; }; /*! @@ -123,12 +121,6 @@ class TensorRTBuilder { /*! \brief Clean up resources used to create engine. */ void CleanUp(); - /*! \brief Allocate a GPU buffer for input or output DLTensor, only if the context is not GPU - * already. Inputs that are already on the GPU can be passed directly to TensorRT and will not - * need a buffer. */ - void AllocateDeviceBuffer(nvinfer1::ICudaEngine* engine, const std::string& name, - std::vector* device_buffers); - /*! \brief Maps a node to its outputs. */ std::unordered_map> node_output_map_; diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index e96359481ddb..6358e59ce3bc 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -64,7 +64,9 @@ class TensorRTRuntime : public JSONRuntimeBase { const Array& const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names), use_implicit_batch_(true), - max_workspace_size_(size_t(1) << 30) {} + max_workspace_size_(size_t(1) << 30), + max_batch_size_(-1), + multi_engine_mode_(false) {} /*! * \brief The type key of the module. @@ -85,6 +87,7 @@ class TensorRTRuntime : public JSONRuntimeBase { LoadGlobalAttributes(); if (GetCachedEnginesFromDisk()) return; SetupConstants(consts); + multi_engine_mode_ = dmlc::GetEnv("TVM_TENSORRT_MULTI_ENGINE", false); } void LoadGlobalAttributes() { @@ -110,23 +113,25 @@ class TensorRTRuntime : public JSONRuntimeBase { #ifdef TVM_GRAPH_EXECUTOR_TENSORRT /*! \brief Destroy engines and contexts. */ - ~TensorRTRuntime() { + void DestroyEngines() { for (auto& it : trt_engine_cache_) { it.second.context->destroy(); it.second.engine->destroy(); } + trt_engine_cache_.clear(); } + ~TensorRTRuntime() { DestroyEngines(); } + /*! \brief Run inference using built engine. */ void Run() override { - BuildEngine(); - batch_size_ = data_entry_[input_var_eid_[0]]->shape[0]; - if (batch_size_ == 0) return; - auto& engine_and_context = trt_engine_cache_.at(std::make_pair(symbol_name_, batch_size_)); + auto& engine_and_context = GetOrBuildEngine(); + int batch_size = GetBatchSize(); + if (batch_size == 0) return; auto engine = engine_and_context.engine; auto context = engine_and_context.context; - auto& device_buffers = engine_and_context.device_buffers; std::vector bindings(engine->getNbBindings(), nullptr); + // Setup input bindings. for (size_t i = 0; i < input_nodes_.size(); ++i) { auto nid = input_nodes_[i]; if (nodes_[nid].GetOpType() == "input") { @@ -138,13 +143,14 @@ class TensorRTRuntime : public JSONRuntimeBase { if (data_entry_[eid]->device.device_type == kDLCUDA) { bindings[binding_index] = data_entry_[eid]->data; } else { - device_buffers[binding_index].CopyFrom(data_entry_[eid]); - bindings[binding_index] = device_buffers[binding_index]->data; + auto device_buffer = GetOrAllocateDeviceBuffer(eid, binding_index); + device_buffer.CopyFrom(data_entry_[eid]); + bindings[binding_index] = device_buffer->data; } } } } - + // Setup output bindings. for (size_t i = 0; i < outputs_.size(); ++i) { uint32_t eid = EntryID(outputs_[i]); const std::string& name = engine_and_context.outputs[i]; @@ -153,18 +159,19 @@ class TensorRTRuntime : public JSONRuntimeBase { if (data_entry_[eid]->device.device_type == kDLCUDA) { bindings[binding_index] = data_entry_[eid]->data; } else { - bindings[binding_index] = device_buffers[binding_index]->data; + auto device_buffer = GetOrAllocateDeviceBuffer(eid, binding_index); + bindings[binding_index] = device_buffer->data; } } #if TRT_VERSION_GE(6, 0, 1) if (use_implicit_batch_) { - ICHECK(context->execute(batch_size_, bindings.data())) << "Running TensorRT failed."; + ICHECK(context->execute(batch_size, bindings.data())) << "Running TensorRT failed."; } else { ICHECK(context->executeV2(bindings.data())) << "Running TensorRT failed."; } #else - ICHECK(context->execute(batch_size_, bindings.data())) << "Running TensorRT failed."; + ICHECK(context->execute(batch_size, bindings.data())) << "Running TensorRT failed."; #endif // Copy outputs from GPU buffers if needed. @@ -174,25 +181,58 @@ class TensorRTRuntime : public JSONRuntimeBase { int binding_index = engine->getBindingIndex(name.c_str()); ICHECK_NE(binding_index, -1); if (data_entry_[eid]->device.device_type != kDLCUDA) { - device_buffers[binding_index].CopyTo(const_cast(data_entry_[eid])); + auto device_buffer = GetOrAllocateDeviceBuffer(eid, binding_index); + device_buffer.CopyTo(const_cast(data_entry_[eid])); } } } private: + /*! \brief Get batch size for engine from the runtime input shapes. */ + int GetBatchSize() { + return data_entry_[input_var_eid_[0]]->ndim == 0 ? 1 : data_entry_[input_var_eid_[0]]->shape[0]; + } + + /*! \brief Find an engine in the cache which we can reuse depending on the mode. If no compatible + * engine exists, return false to indicate that a new one should be built. */ + bool FindCompatibleEngine(int batch_size, int* compatible_engine_batch_size) { + if (multi_engine_mode_) { + // Exact match is required for multi engine mode. + if (trt_engine_cache_.count(std::make_pair(symbol_name_, batch_size))) { + *compatible_engine_batch_size = batch_size; + return true; + } + return false; + } + // Check for engine with compatible max_batch_size. + if (batch_size <= max_batch_size_) { + *compatible_engine_batch_size = max_batch_size_; + return true; + } + return false; + } + /*! - * \brief Build TensorRT engine from JSON representation and cache it. If engine is already built, - * do nothing. + * \brief Build TensorRT engine from JSON representation and cache it. If compatible engine is + * already built, do nothing. */ - void BuildEngine() { - batch_size_ = - data_entry_[input_var_eid_[0]]->ndim == 0 ? 1 : data_entry_[input_var_eid_[0]]->shape[0]; - if (trt_engine_cache_.count(std::make_pair(symbol_name_, batch_size_))) return; + TensorRTEngineAndContext& GetOrBuildEngine() { + int batch_size = GetBatchSize(); + int compatible_engine_batch_size = -1; + if (FindCompatibleEngine(batch_size, &compatible_engine_batch_size)) { + // A compatible engine already exists. + return trt_engine_cache_.at(std::make_pair(symbol_name_, compatible_engine_batch_size)); + } + // For single engine mode, remove previous engine and update max_batch_size. + if (!multi_engine_mode_) { + DestroyEngines(); + max_batch_size_ = batch_size; + } DLOG(INFO) << "Building new TensorRT engine for subgraph " << symbol_name_ - << " with batch size " << batch_size_; + << " with batch size " << batch_size; const bool use_fp16 = dmlc::GetEnv("TVM_TENSORRT_USE_FP16", false); TensorRTBuilder builder(&logger_, data_entry_, max_workspace_size_, use_implicit_batch_, - use_fp16, batch_size_); + use_fp16, batch_size); // Add inputs and constants. for (size_t i = 0; i < input_nodes_.size(); ++i) { @@ -221,10 +261,11 @@ class TensorRTRuntime : public JSONRuntimeBase { } // Build engine. - trt_engine_cache_[std::make_pair(symbol_name_, batch_size_)] = builder.BuildEngine(); + trt_engine_cache_[std::make_pair(symbol_name_, batch_size)] = builder.BuildEngine(); DLOG(INFO) << "Finished building TensorRT engine for subgraph " << symbol_name_ - << " with batch size " << batch_size_; + << " with batch size " << batch_size; CacheEngineToDisk(); + return trt_engine_cache_.at(std::make_pair(symbol_name_, batch_size)); } /*! \brief If TVM_TENSORRT_CACHE_DIR is set, will check that directory for @@ -268,7 +309,7 @@ class TensorRTRuntime : public JSONRuntimeBase { * directory so it can be loaded later. */ void CacheEngineToDisk() { - batch_size_ = data_entry_[input_var_eid_[0]]->shape[0]; + int batch_size = GetBatchSize(); std::string cache_dir = dmlc::GetEnv("TVM_TENSORRT_CACHE_DIR", std::string("")); if (cache_dir.empty()) return; std::string key = GetSubgraphKey(); @@ -276,7 +317,7 @@ class TensorRTRuntime : public JSONRuntimeBase { DLOG(INFO) << "Caching TensorRT engine to " << path; // Serialize engine to disk nvinfer1::IHostMemory* serialized_engine = - trt_engine_cache_[std::make_pair(symbol_name_, batch_size_)].engine->serialize(); + trt_engine_cache_[std::make_pair(symbol_name_, batch_size)].engine->serialize(); SaveBinaryToFile(path, std::string(static_cast(serialized_engine->data()), serialized_engine->size())); serialized_engine->destroy(); @@ -285,9 +326,9 @@ class TensorRTRuntime : public JSONRuntimeBase { dmlc::JSONWriter writer(&os); writer.BeginObject(); writer.WriteObjectKeyValue("inputs", - trt_engine_cache_[std::make_pair(symbol_name_, batch_size_)].inputs); - writer.WriteObjectKeyValue( - "outputs", trt_engine_cache_[std::make_pair(symbol_name_, batch_size_)].outputs); + trt_engine_cache_[std::make_pair(symbol_name_, batch_size)].inputs); + writer.WriteObjectKeyValue("outputs", + trt_engine_cache_[std::make_pair(symbol_name_, batch_size)].outputs); writer.EndObject(); std::string meta_path = cache_dir + "/" + key + ".meta"; SaveBinaryToFile(meta_path, os.str()); @@ -300,29 +341,41 @@ class TensorRTRuntime : public JSONRuntimeBase { return symbol_name_ + (dmlc::GetEnv("TVM_TENSORRT_USE_FP16", false) ? "_fp16" : "_fp32"); } - /*! \brief Get the batch size when in implicit_batch mode. */ - int GetBatchSize() { - if (!use_implicit_batch_) return -1; - for (size_t i = 0; i < input_nodes_.size(); ++i) { - auto nid = input_nodes_[i]; - if (nodes_[nid].GetOpType() == "input") { - // Get batch size from first input. - return nodes_[nid].GetOpShape()[0][0]; + /*! \brief Retreive a GPU buffer for input or output or allocate if needed. */ + NDArray GetOrAllocateDeviceBuffer(int entry_id, int binding_index) { + std::vector shape(data_entry_[entry_id]->shape, + data_entry_[entry_id]->shape + data_entry_[entry_id]->ndim); + if (device_buffers_.count(binding_index)) { + // Buffer is already initialized. + if (shape[0] > device_buffers_[binding_index]->shape[0]) { + // Buffer is too small. Need to allocate bigger buffer. + device_buffers_[binding_index] = + runtime::NDArray::Empty(shape, data_entry_[entry_id]->dtype, {kDLCUDA, 0}); + } else if (shape[0] < device_buffers_[binding_index]->shape[0]) { + // Buffer is too large. Create view. + return device_buffers_[binding_index].CreateView(shape, data_entry_[entry_id]->dtype); } + } else { + // Buffer not initialized yet. + device_buffers_[binding_index] = + runtime::NDArray::Empty(shape, data_entry_[entry_id]->dtype, {kDLCUDA, 0}); } - return -1; + return device_buffers_.at(binding_index); } - /*! \brief Map of function name to TRT engine if built already. */ + /*! \brief Map of function name and max batch size to TRT engine if built already. */ std::unordered_map, TensorRTEngineAndContext, PairHash> trt_engine_cache_; + /*! \brief Map of inding index to GPU buffers for inputs and outputs. Only used when target device + * is not "cuda". Since TensorRT execution can only read data from GPU, we need to copy data from + * the runtime device to these buffers first. These will be allocated for the highest batch size + * used by all engines. */ + std::unordered_map device_buffers_; + /*! \brief TensorRT logger. */ TensorRTLogger logger_; - /*! \brief Batch size that the engine is optimized for. */ - int batch_size_; - #else void Run() override { LOG(FATAL) << "TensorRT runtime is not enabled. " @@ -342,6 +395,17 @@ class TensorRTRuntime : public JSONRuntimeBase { bool use_implicit_batch_; size_t max_workspace_size_; + + /*! \brief Highest batch size that an engine has been built for, used in single-engine mode only + * (multi_engine_mode=false). */ + int max_batch_size_; + + /*! \brief The strategy to use for dynamic batching. With multi_engine_mode=true, a new TensorRT + * engine is created for each unique batch size encountered. With multi_engine_mode=false, only + * one TensorRT engine is alive at any given time. It is replaced if a higher batch size is + * encountered. Multi-engine mode should give better performance, at a cost of higher memory usage + * and more time spent building engines. */ + bool multi_engine_mode_; }; runtime::Module TensorRTRuntimeCreate(const String& symbol_name, const String& graph_json, diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index f9912c9674e5..b54da208b33d 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -1246,12 +1246,12 @@ def test_tensorrt_dynamic_batch(): def test_tensorrt_dynamic_batch_conv(): if skip_codegen_test(): return - batches_to_test = [1, 1, 0, 2, 3, 0, 1, 3, 2] + batches_to_test = [1, 5, 1, 0, 2, 3, 0, 1, 3, 2] x_shape = (relay.Any(), 32, 8, 8) x_data = np.ones([max(batches_to_test)] + list(x_shape)[1:]).astype("float32") k_shape = (16, 32, 3, 3) params = {"kernel": np.random.uniform(-1, 1, k_shape).astype("float32")} - result_arr = [{} for _ in range(len(batches_to_test))] + result_arr = [{"cuda": {}, "llvm": {}} for _ in range(len(batches_to_test))] for use_trt in [True, False]: x = relay.var("x", shape=x_shape, dtype="float32") kernel = relay.var("kernel", shape=k_shape, dtype="float32") @@ -1263,15 +1263,21 @@ def test_tensorrt_dynamic_batch_conv(): mod, _ = tensorrt.partition_for_tensorrt(mod, params) if not skip_runtime_test(): - with relay.build_config(opt_level=3): - relay_exec = relay.create_executor("vm", mod=mod, device=tvm.cpu(0), target="llvm") + for target in ["llvm", "cuda"]: + with relay.build_config(opt_level=3): + relay_exec = relay.create_executor( + "vm", mod=mod, device=tvm.cpu(0), target="llvm" + ) - for i, batch_size in enumerate(batches_to_test): - result_arr[i][use_trt] = relay_exec.evaluate()(x_data[:batch_size, ...], **params) + for i, batch_size in enumerate(batches_to_test): + result_arr[i][target][use_trt] = relay_exec.evaluate()( + x_data[:batch_size, ...], **params + ) if not skip_runtime_test(): for i in range(len(batches_to_test)): - assert_result_dict_holds(result_arr[i]) + for target in ["llvm", "cuda"]: + assert_result_dict_holds(result_arr[i][target]) def test_maskrcnn_resnet50() -> None: