Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 35 additions & 8 deletions src/runtime/contrib/tensorrt/tensorrt_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,12 @@ namespace tvm {
namespace runtime {
namespace contrib {

TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger, size_t max_workspace_size,
bool use_implicit_batch, bool use_fp16, int batch_size)
: max_workspace_size_(max_workspace_size),
TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger,
const std::vector<const DLTensor*>& data_entry,
size_t max_workspace_size, bool use_implicit_batch, bool use_fp16,
int batch_size)
: data_entry_(data_entry),
max_workspace_size_(max_workspace_size),
use_implicit_batch_(use_implicit_batch),
use_fp16_(use_fp16),
batch_size_(batch_size) {
Expand All @@ -63,7 +66,7 @@ TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger, size_t max_workspace_si
#endif
}

void TensorRTBuilder::AddInput(int nid, const JSONGraphNode& node) {
void TensorRTBuilder::AddInput(int nid, uint32_t entry_id, const JSONGraphNode& node) {
auto node_name = node.GetOpName();
auto shapes = node.GetOpShape();
auto dtypes = node.GetOpDataType();
Expand All @@ -80,7 +83,8 @@ void TensorRTBuilder::AddInput(int nid, const JSONGraphNode& node) {
ICHECK(TypeMatch(dtypes[i], kDLFloat, 32)) << "Only FP32 inputs are supported.";
auto input_tensor = network_->addInput(name.c_str(), nvinfer1::DataType::kFLOAT, dims);
node_output_map_[nid].push_back(TensorRTOpInput(input_tensor));
network_input_names_.push_back(input_tensor->getName());
network_input_names_.push_back(name);
entry_id_map_[name] = entry_id + i;
}
}

Expand All @@ -94,14 +98,15 @@ void TensorRTBuilder::AddConstant(int nid, const DLTensor* data) {
node_output_map_[nid] = {TensorRTOpInput(weight, shape)};
}

void TensorRTBuilder::AddOutput(const JSONGraphNodeEntry& node) {
void TensorRTBuilder::AddOutput(const JSONGraphNodeEntry& node, uint32_t entry_id) {
auto it = node_output_map_.find(node.id_);
ICHECK(it != node_output_map_.end()) << "Output was not found.";
auto out_tensor = it->second[node.index_].tensor;
std::string name = "tensorrt_output_" + std::to_string(network_output_names_.size());
out_tensor->setName(name.c_str());
network_->markOutput(*out_tensor);
network_output_names_.push_back(out_tensor->getName());
network_output_names_.push_back(name);
entry_id_map_[name] = entry_id;
}

void TensorRTBuilder::AddLayer(int nid, const JSONGraphNode& node) {
Expand Down Expand Up @@ -168,7 +173,16 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() {
ICHECK_EQ(engine->getNbBindings(), network_input_names_.size() + network_output_names_.size());
nvinfer1::IExecutionContext* context = engine->createExecutionContext();
CleanUp();
return {engine, context, network_input_names_, network_output_names_};

// Allocate I/O buffers on GPU for TVM inputs which are on a different context.
std::vector<runtime::NDArray> device_buffers(engine->getNbBindings());
for (size_t i = 0; i < network_input_names_.size(); ++i) {
AllocateDeviceBuffer(engine, network_input_names_[i], &device_buffers);
}
for (size_t i = 0; i < network_output_names_.size(); ++i) {
AllocateDeviceBuffer(engine, network_output_names_[i], &device_buffers);
}
return {engine, context, network_input_names_, network_output_names_, device_buffers};
}

nvinfer1::Weights TensorRTBuilder::GetDLTensorAsWeights(const DLTensor* dptr,
Expand Down Expand Up @@ -217,6 +231,19 @@ void TensorRTBuilder::CleanUp() {
}
}

void TensorRTBuilder::AllocateDeviceBuffer(nvinfer1::ICudaEngine* engine, const std::string& name,
std::vector<runtime::NDArray>* device_buffers) {
const uint32_t entry_id = entry_id_map_[name];
if (data_entry_[entry_id]->ctx.device_type != kDLGPU) {
const int binding_index = engine->getBindingIndex(name.c_str());
ICHECK_NE(binding_index, -1);
std::vector<int64_t> shape(data_entry_[entry_id]->shape,
data_entry_[entry_id]->shape + data_entry_[entry_id]->ndim);
device_buffers->at(binding_index) =
runtime::NDArray::Empty(shape, data_entry_[entry_id]->dtype, {kDLGPU, 0});
}
}

} // namespace contrib
} // namespace runtime
} // namespace tvm
27 changes: 23 additions & 4 deletions src/runtime/contrib/tensorrt/tensorrt_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#ifndef TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_BUILDER_H_
#define TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_BUILDER_H_

#include <tvm/runtime/ndarray.h>

#include <string>
#include <unordered_map>
#include <vector>
Expand All @@ -50,6 +52,8 @@ struct TensorRTEngineAndContext {
nvinfer1::IExecutionContext* context;
std::vector<std::string> inputs;
std::vector<std::string> outputs;
/*! \brief GPU buffers for inputs and outputs. */
std::vector<NDArray> device_buffers;
};

/*!
Expand All @@ -69,15 +73,17 @@ class TensorRTBuilder {
* \param use_fp16 Whether to use implicit batch mode (default)
* \param batch_size If use_implicit_batch,
*/
TensorRTBuilder(TensorRTLogger* logger, size_t max_workspace_size, bool use_implicit_batch,
bool use_fp16, int batch_size);
TensorRTBuilder(TensorRTLogger* logger, const std::vector<const DLTensor*>& data_entry,
size_t max_workspace_size, bool use_implicit_batch, bool use_fp16,
int batch_size);

/*!
* \brief Add TensorRT input(s) for input node in network definition.
* \param nid The input node id.
* \param entry_id The index into data_entry_ for first entry in node.
* \param node The input node.
*/
void AddInput(int nid, const JSONGraphNode& node);
void AddInput(int nid, uint32_t entry_id, const JSONGraphNode& node);

/*!
* \brief Add TensorRT weight for input constant in network definition.
Expand All @@ -96,8 +102,9 @@ class TensorRTBuilder {
/*!
* \brief Mark TensorRT output in network definition.
* \param entry The output node entry.
* \param entry_id The output node entry id.
*/
void AddOutput(const JSONGraphNodeEntry& entry);
void AddOutput(const JSONGraphNodeEntry& entry, uint32_t entry_id);

/*!
* \brief Takes network definition and "compiles" a TensorRT engine which can be used for
Expand All @@ -116,6 +123,12 @@ class TensorRTBuilder {
/*! \brief Clean up resources used to create engine. */
void CleanUp();

/*! \brief Allocate a GPU buffer for input or output DLTensor, only if the context is not GPU
* already. Inputs that are already on the GPU can be passed directly to TensorRT and will not
* need a buffer. */
void AllocateDeviceBuffer(nvinfer1::ICudaEngine* engine, const std::string& name,
std::vector<runtime::NDArray>* device_buffers);

/*! \brief Maps a node to its outputs. */
std::unordered_map<int, std::vector<TensorRTOpInput>> node_output_map_;

Expand All @@ -133,6 +146,12 @@ class TensorRTBuilder {
/*! \brief List of all weights held in memory. */
std::vector<nvinfer1::Weights> trt_weights_;

/*! \brief Input and output tensors from TVM. */
const std::vector<const DLTensor*>& data_entry_;

/*! \brief Map TensorRT binding name to index in data_entry_. */
std::unordered_map<std::string, uint32_t> entry_id_map_;

/*! \brief Max workspace size in bytes for TRT. */
size_t max_workspace_size_;

Expand Down
41 changes: 32 additions & 9 deletions src/runtime/contrib/tensorrt/tensorrt_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ class TensorRTRuntime : public JSONRuntimeBase {
LoadGlobalAttributes();
if (GetCachedEnginesFromDisk()) return;
SetupConstants(consts);
BuildEngine();
CacheEngineToDisk();
}

void LoadGlobalAttributes() {
Expand All @@ -106,9 +104,11 @@ class TensorRTRuntime : public JSONRuntimeBase {
#ifdef TVM_GRAPH_RUNTIME_TENSORRT
/*! \brief Run inference using built engine. */
void Run() override {
BuildEngine();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the reason of moving BuildEngine from Init to Run because you need subgraph specific information (e.g., I/O data entry IDs) to allocate device buffers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @comaniac for the review! Yes, to allocate the device buffers we need the DLTensor context and shape. data_entry_ in JSON runtime isn't initialized until Run() so I had to move BuildEngine.

In the future, we are planning to be able to dynamically build engines for different input shapes in order to handle subgraphs with dynamic input sizes, so moving it would be needed for that anyway.

auto& engine_and_context = trt_engine_cache_.at(symbol_name_);
auto engine = engine_and_context.engine;
auto context = engine_and_context.context;
auto& device_buffers = engine_and_context.device_buffers;
std::vector<void*> bindings(engine->getNbBindings(), nullptr);

for (size_t i = 0; i < input_nodes_.size(); ++i) {
Expand All @@ -119,7 +119,12 @@ class TensorRTRuntime : public JSONRuntimeBase {
const std::string name = nodes_[nid].GetOpName() + "_" + std::to_string(j);
int binding_index = engine->getBindingIndex(name.c_str());
ICHECK_NE(binding_index, -1);
bindings[binding_index] = data_entry_[eid]->data;
if (data_entry_[eid]->ctx.device_type == kDLGPU) {
bindings[binding_index] = data_entry_[eid]->data;
} else {
device_buffers[binding_index].CopyFrom(data_entry_[eid]);
bindings[binding_index] = device_buffers[binding_index]->data;
}
}
}
}
Expand All @@ -129,7 +134,11 @@ class TensorRTRuntime : public JSONRuntimeBase {
const std::string& name = engine_and_context.outputs[i];
int binding_index = engine->getBindingIndex(name.c_str());
ICHECK_NE(binding_index, -1);
bindings[binding_index] = data_entry_[eid]->data;
if (data_entry_[eid]->ctx.device_type == kDLGPU) {
bindings[binding_index] = data_entry_[eid]->data;
} else {
bindings[binding_index] = device_buffers[binding_index]->data;
}
}

#if TRT_VERSION_GE(6, 0, 1)
Expand All @@ -141,26 +150,39 @@ class TensorRTRuntime : public JSONRuntimeBase {
#else
ICHECK(context->execute(batch_size_, bindings.data())) << "Running TensorRT failed.";
#endif

// Copy outputs from GPU buffers if needed.
for (size_t i = 0; i < outputs_.size(); ++i) {
uint32_t eid = EntryID(outputs_[i]);
const std::string& name = engine_and_context.outputs[i];
int binding_index = engine->getBindingIndex(name.c_str());
ICHECK_NE(binding_index, -1);
if (data_entry_[eid]->ctx.device_type != kDLGPU) {
device_buffers[binding_index].CopyTo(const_cast<DLTensor*>(data_entry_[eid]));
}
}
}

private:
/*!
* \brief Build TensorRT engine from JSON representation.
* \brief Build TensorRT engine from JSON representation and cache it. If engine is already built,
* do nothing.
*/
void BuildEngine() {
if (trt_engine_cache_.count(symbol_name_)) return;
DLOG(INFO) << "Building new TensorRT engine for subgraph " << symbol_name_;
const bool use_fp16 = dmlc::GetEnv("TVM_TENSORRT_USE_FP16", false);
batch_size_ = GetBatchSize();
TensorRTBuilder builder(&logger_, max_workspace_size_, use_implicit_batch_, use_fp16,
batch_size_);
TensorRTBuilder builder(&logger_, data_entry_, max_workspace_size_, use_implicit_batch_,
use_fp16, batch_size_);

// Add inputs and constants.
for (size_t i = 0; i < input_nodes_.size(); ++i) {
auto nid = input_nodes_[i];
const auto& node = nodes_[nid];
std::string name = node.GetOpName();
if (node.GetOpType() == "input") {
builder.AddInput(nid, node);
builder.AddInput(nid, EntryID(nid, 0), node);
} else {
ICHECK_EQ(node.GetOpType(), "const");
uint32_t eid = EntryID(nid, 0);
Expand All @@ -177,12 +199,13 @@ class TensorRTRuntime : public JSONRuntimeBase {

// Add outputs.
for (size_t i = 0; i < outputs_.size(); ++i) {
builder.AddOutput(outputs_[i]);
builder.AddOutput(outputs_[i], EntryID(outputs_[i]));
}

// Build engine.
trt_engine_cache_[symbol_name_] = builder.BuildEngine();
DLOG(INFO) << "Finished building TensorRT engine for subgraph " << symbol_name_;
CacheEngineToDisk();
}

/*! \brief If TVM_TENSORRT_CACHE_DIR is set, will check that directory for
Expand Down
29 changes: 24 additions & 5 deletions tests/python/contrib/test_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def skip_runtime_test():
return False


def run_and_verify_func(config):
def run_and_verify_func(config, target="cuda"):
"""Test a Relay func by compiling, running, and comparing TVM and TRT outputs.

Parameters
Expand All @@ -70,10 +70,11 @@ def run_and_verify_func(config):
mod["main"] = f
mod, config = tensorrt.partition_for_tensorrt(mod, params)
with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
graph, lib, graph_params = relay.build(mod, "cuda", params=params)
graph, lib, graph_params = relay.build(mod, target, params=params)
if skip_runtime_test():
return
mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0))
ctx = tvm.context(target)
mod = graph_runtime.create(graph, lib, ctx=ctx)
mod.set_input(**graph_params)
mod.run(**input_dict)
results = [mod.get_output(i) for i in range(mod.get_num_outputs())]
Expand All @@ -82,8 +83,8 @@ def run_and_verify_func(config):
mod = tvm.IRModule()
mod["main"] = f
with tvm.transform.PassContext(opt_level=3):
graph, lib, graph_params = relay.build(mod, "cuda", params=params)
mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0))
graph, lib, graph_params = relay.build(mod, target, params=params)
mod = graph_runtime.create(graph, lib, ctx=ctx)
mod.set_input(**graph_params)
mod.run(**input_dict)
ref_results = [mod.get_output(i) for i in range(mod.get_num_outputs())]
Expand Down Expand Up @@ -188,6 +189,23 @@ def test_tensorrt_simple():
results = [mod.get_output(i).asnumpy() for i in range(mod.get_num_outputs())]


def test_tensorrt_simple_cpu_io():
def get_graph():
dtype = "float32"
x_shape = (1, 3, 2, 2)
y_shape = (1, 3, 1, 1)
z_shape = (1, 1, 1, 1)
x = relay.var("x", shape=(x_shape), dtype=dtype)
y = relay.var("y", shape=(y_shape), dtype=dtype)
z = relay.var("z", shape=(z_shape), dtype=dtype)
w = z * (x + y)
out = relay.nn.relu(w)
f = relay.Function([x, y, z], out)
return f, {"x": x_shape, "y": y_shape, "z": z_shape}, ["y"]

run_and_verify_func(get_graph(), target="llvm")


def test_tensorrt_not_compatible():
if skip_codegen_test():
return
Expand Down Expand Up @@ -859,6 +877,7 @@ def test_densenet121():
if __name__ == "__main__":
test_tensorrt_not_compatible()
test_tensorrt_simple()
test_tensorrt_simple_cpu_io()
test_tensorrt_serialize()

# Op tests
Expand Down