diff --git a/src/runtime/contrib/json/json_node.h b/src/runtime/contrib/json/json_node.h index 77c289b04c6d..d57eeb08df10 100644 --- a/src/runtime/contrib/json/json_node.h +++ b/src/runtime/contrib/json/json_node.h @@ -256,7 +256,7 @@ class JSONGraphNode { */ template T GetAttr(const std::string& key) const { - ICHECK_GT(attrs_.count(key), 0U) << "Key: " << key << "is not found"; + ICHECK_GT(attrs_.count(key), 0U) << "Key: " << key << " is not found"; return dmlc::get(attrs_.at(key)); } diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/contrib/tensorrt/tensorrt_builder.cc index ee47e67001f3..95c6bd18ef84 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc @@ -171,6 +171,11 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() { CleanUp(); // Allocate I/O buffers on GPU for TVM inputs which are on a different context. + std::vector device_buffers = CreateDeviceBuffers(engine); + return {engine, context, network_input_names_, network_output_names_, device_buffers}; +} + +std::vector TensorRTBuilder::CreateDeviceBuffers(nvinfer1::ICudaEngine* engine) { std::vector device_buffers(engine->getNbBindings()); for (size_t i = 0; i < network_input_names_.size(); ++i) { AllocateDeviceBuffer(engine, network_input_names_[i], &device_buffers); @@ -178,7 +183,7 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() { for (size_t i = 0; i < network_output_names_.size(); ++i) { AllocateDeviceBuffer(engine, network_output_names_[i], &device_buffers); } - return {engine, context, network_input_names_, network_output_names_, device_buffers}; + return device_buffers; } nvinfer1::Weights TensorRTBuilder::GetDLTensorAsWeights(const DLTensor* dptr, diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.h b/src/runtime/contrib/tensorrt/tensorrt_builder.h index 4926a4d02685..fcae52f7e994 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.h +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.h @@ -113,6 +113,12 @@ class TensorRTBuilder { */ TensorRTEngineAndContext BuildEngine(); + /*! + * \brief Create device buffers. + * \param engine_and_context The pointer pointing at TensorRTEngineAndContext. + */ + std::vector CreateDeviceBuffers(nvinfer1::ICudaEngine* engine); + private: /*! \brief Convert a DLTensor to a TensorRT weight. */ nvinfer1::Weights GetDLTensorAsWeights(const DLTensor* dptr, DLDeviceType src_device); diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index 3f87f8d00ee6..79c01e018d82 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -83,8 +83,8 @@ class TensorRTRuntime : public JSONRuntimeBase { ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required."; LoadGlobalAttributes(); - if (GetCachedEnginesFromDisk()) return; SetupConstants(consts); + GetCachedEnginesFromDisk(); } void LoadGlobalAttributes() { @@ -178,7 +178,13 @@ class TensorRTRuntime : public JSONRuntimeBase { */ void BuildEngine() { batch_size_ = data_entry_[input_var_eid_[0]]->shape[0]; - if (trt_engine_cache_.count(std::make_pair(symbol_name_, batch_size_))) return; + if (trt_engine_cache_.count(std::make_pair(symbol_name_, batch_size_))) { + TensorRTEngineAndContext& engine_and_context = + trt_engine_cache_.at(std::make_pair(symbol_name_, batch_size_)); + if (!engine_and_context.device_buffers.empty()) { + return; + } + } DLOG(INFO) << "Building new TensorRT engine for subgraph " << symbol_name_ << " with batch size " << batch_size_; const bool use_fp16 = dmlc::GetEnv("TVM_TENSORRT_USE_FP16", false); @@ -211,6 +217,16 @@ class TensorRTRuntime : public JSONRuntimeBase { builder.AddOutput(outputs_[i], EntryID(outputs_[i])); } + // Allocate Device Buffers + if (trt_engine_cache_.count(std::make_pair(symbol_name_, batch_size_))) { + TensorRTEngineAndContext& engine_and_context = + trt_engine_cache_.at(std::make_pair(symbol_name_, batch_size_)); + if (engine_and_context.device_buffers.empty()) { + engine_and_context.device_buffers = builder.CreateDeviceBuffers(engine_and_context.engine); + return; + } + } + // Build engine. trt_engine_cache_[std::make_pair(symbol_name_, batch_size_)] = builder.BuildEngine(); DLOG(INFO) << "Finished building TensorRT engine for subgraph " << symbol_name_