diff --git a/src/runtime/graph_executor/graph_executor.cc b/src/runtime/graph_executor/graph_executor.cc index bc73a5988377..dbd072a68fb5 100644 --- a/src/runtime/graph_executor/graph_executor.cc +++ b/src/runtime/graph_executor/graph_executor.cc @@ -91,6 +91,11 @@ void GraphExecutor::Init(const std::string& graph_json, tvm::runtime::Module mod std::string& name = nodes_[nid].name; input_map_[name] = i; } + for (size_t i = 0; i < outputs_.size(); i++) { + const uint32_t nid = outputs_[i].node_id; + std::string& name = nodes_[nid].name; + output_map_[name] = i; + } } /*! * \brief Get the input index given the name of input. @@ -104,6 +109,18 @@ int GraphExecutor::GetInputIndex(const std::string& name) { } return -1; } +/*! + * \brief Get the output index given the name of output. + * \param name The name of the output. + * \return The index of output. + */ +int GraphExecutor::GetOutputIndex(const std::string& name) { + auto it = output_map_.find(name); + if (it != output_map_.end()) { + return it->second; + } + return -1; +} /*! * \brief set index-th input to the graph. * \param index The input index. @@ -114,6 +131,23 @@ void GraphExecutor::SetInput(int index, DLTensor* data_in) { uint32_t eid = this->entry_id(input_nodes_[index], 0); data_entry_[eid].CopyFrom(data_in); } +/*! + * \brief Check the legality of external DLTensor*. + * \param external The external DLTensor*. + * \param eid The data_enrty_ index. + */ +void GraphExecutor::CheckExternalDLTensor(const DLTensor* external, uint32_t eid) const { + const DLTensor* internal = data_entry_[eid].operator->(); + + ICHECK_EQ(data_alignment_[eid], details::GetDataAlignment(*external)); + ICHECK_EQ(reinterpret_cast(external->data) % kAllocAlignment, 0); + ICHECK_EQ(internal->ndim, static_cast(external->ndim)); + ICHECK_EQ(internal->device.device_type, external->device.device_type); + ICHECK_EQ(internal->device.device_id, external->device.device_id); + for (auto i = 0; i < external->ndim; ++i) { + ICHECK_EQ(internal->shape[i], external->shape[i]); + } +} /*! * \brief set index-th input to the graph without copying the data. * \param index The input index. @@ -122,23 +156,37 @@ void GraphExecutor::SetInput(int index, DLTensor* data_in) { void GraphExecutor::SetInputZeroCopy(int index, DLTensor* data_ref) { ICHECK_LT(static_cast(index), input_nodes_.size()); uint32_t eid = this->entry_id(input_nodes_[index], 0); - const DLTensor* old_t = data_entry_[eid].operator->(); - // check the consistency of input - ICHECK_EQ(data_alignment_[eid], details::GetDataAlignment(*data_ref)); - ICHECK_EQ(reinterpret_cast(data_ref->data) % kAllocAlignment, 0); - ICHECK_EQ(old_t->ndim, static_cast(data_ref->ndim)); - ICHECK_EQ(old_t->device.device_type, data_ref->device.device_type); - ICHECK_EQ(old_t->device.device_id, data_ref->device.device_id); - for (auto i = 0; i < data_ref->ndim; ++i) { - ICHECK_EQ(old_t->shape[i], data_ref->shape[i]); - } - + CheckExternalDLTensor(data_ref, eid); // Update the data pointer for each argument of each op for (DLTensor* t : input_dltensors_[eid]) { t->data = data_ref->data; } } +/*! + * \brief set index-th output to the graph without copying the data. + * \param index The output index. + * \param data_ref The output data that is referred. + */ +void GraphExecutor::SetOutputZeroCopy(int index, DLTensor* data_ref) { + ICHECK_LT(static_cast(index), outputs_.size()); + ICHECK_LT(static_cast(index), output_dltensors_.size()); + const NodeEntry& output_node = outputs_[index]; + uint32_t output_node_eid = this->entry_id(output_node); + + // check the consistency of output + CheckExternalDLTensor(data_ref, output_node_eid); + + // Update the data pointer for output op + for (DLTensor* t : output_dltensors_[output_node_eid]) { + t->data = data_ref->data; + } + + // Update the input of the op connected to the output + for (DLTensor* t : both_output_opinput_dltensors_[output_node_eid]) { + t->data = data_ref->data; + } +} /*! * \brief Get the number of outputs * @@ -358,11 +406,17 @@ void GraphExecutor::SetupStorage() { void GraphExecutor::SetupOpExecs() { op_execs_.resize(this->GetNumOfNodes()); input_dltensors_.resize(num_node_entries()); + output_dltensors_.resize(num_node_entries()); + both_output_opinput_dltensors_.resize(num_node_entries()); std::unordered_set input_node_eids; for (size_t i = 0; i < input_nodes_.size(); i++) { uint32_t nid = input_nodes_[i]; input_node_eids.insert(entry_id(nid, 0)); } + std::unordered_set output_node_eids; + for (size_t i = 0; i < outputs_.size(); i++) { + output_node_eids.insert(entry_id(outputs_[i])); + } // setup the array and requirements. for (uint32_t nid = 0; nid < this->GetNumOfNodes(); ++nid) { @@ -383,10 +437,25 @@ void GraphExecutor::SetupOpExecs() { std::tie(op_execs_[nid], op_args) = CreateTVMOp(inode.param, args); for (size_t i = 0; i < inode.inputs.size(); i++) { - uint32_t eid = this->entry_id(inode.inputs[i]); + uint32_t input_eid = this->entry_id(inode.inputs[i]); // check if op input is model input - if (input_node_eids.count(eid) > 0) { - input_dltensors_[eid].push_back(static_cast(op_args->arg_values[i].v_handle)); + if (input_node_eids.count(input_eid) > 0) { + input_dltensors_[input_eid].push_back( + static_cast(op_args->arg_values[i].v_handle)); + } + // check if any model output is the input of the op + if (output_node_eids.count(input_eid) > 0) { + both_output_opinput_dltensors_[input_eid].push_back( + static_cast(op_args->arg_values[i].v_handle)); + } + } + + for (uint32_t i = inode.inputs.size(); i < inode.inputs.size() + inode.param.num_outputs; ++i) { + uint32_t output_eid = this->entry_id(nid, i - inode.inputs.size()); + // check if op output is model output + if (output_node_eids.count(output_eid) > 0) { + output_dltensors_[output_eid].push_back( + static_cast(op_args->arg_values[i].v_handle)); } } } @@ -462,6 +531,15 @@ PackedFunc GraphExecutor::GetFunction(const std::string& name, this->SetInputZeroCopy(args[0], args[1]); } }); + } else if (name == "set_output_zero_copy") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (String::CanConvertFrom(args[0])) { + int out_idx = this->GetOutputIndex(args[0].operator String()); + if (out_idx >= 0) this->SetOutputZeroCopy(out_idx, args[1]); + } else { + this->SetOutputZeroCopy(args[0], args[1]); + } + }); } else if (name == "get_output") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { if (args.num_args == 2) { diff --git a/src/runtime/graph_executor/graph_executor.h b/src/runtime/graph_executor/graph_executor.h index 42b5c405b406..87e8aa3cee34 100644 --- a/src/runtime/graph_executor/graph_executor.h +++ b/src/runtime/graph_executor/graph_executor.h @@ -107,6 +107,13 @@ class TVM_DLL GraphExecutor : public ModuleNode { */ int GetInputIndex(const std::string& name); + /*! + * \brief Get the output index given the name of output. + * \param name The name of the output. + * \return The index of output. + */ + int GetOutputIndex(const std::string& name); + /*! * \brief set index-th input to the graph. * \param index The input index. @@ -119,6 +126,12 @@ class TVM_DLL GraphExecutor : public ModuleNode { * \param data_ref The input data that is referred. */ void SetInputZeroCopy(int index, DLTensor* data_ref); + /*! + * \brief set index-th output to the graph without copying the data. + * \param index The output index. + * \param data_ref The output data that is referred. + */ + void SetOutputZeroCopy(int index, DLTensor* data_ref); /*! * \brief Get the number of outputs * @@ -193,6 +206,9 @@ class TVM_DLL GraphExecutor : public ModuleNode { uint32_t node_id; uint32_t index; uint32_t version; + inline bool operator==(const NodeEntry& other) const { + return node_id == other.node_id && index == other.index && version == other.version; + } // JSON Loader void Load(dmlc::JSONReader* reader) { reader->BeginArray(); @@ -377,6 +393,12 @@ class TVM_DLL GraphExecutor : public ModuleNode { void SetupStorage(); /*! \brief Setup the executors. */ void SetupOpExecs(); + /*! + * \brief Check the legality of external DLTensor*. + * \param external The external DLTensor*. + * \param eid The data_enrty_ index. + */ + void CheckExternalDLTensor(const DLTensor* external, uint32_t eid) const; /*! * \brief Create an execution function given input. * \param attrs The node attributes. @@ -397,8 +419,14 @@ class TVM_DLL GraphExecutor : public ModuleNode { std::vector input_nodes_; /*! \brief Map of input names to input indices. */ std::unordered_map input_map_; + /*! \brief Map of output names to output indices. */ + std::unordered_map output_map_; /*! \brief Used for quick node input DLTensor* lookup given an input eid. */ std::vector> input_dltensors_; + /*! \brief Used for quick node output DLTensor* lookup given an output eid. */ + std::vector> output_dltensors_; + /*! \brief Used for quick node(both model output and op input) DLTensor* lookup given an eid. */ + std::vector> both_output_opinput_dltensors_; /*! \brief Used for quick entry indexing. */ std::vector node_row_ptr_; /*! \brief Output entries. */ diff --git a/tests/cpp/runtime_test.cc b/tests/cpp/runtime_test.cc new file mode 100644 index 000000000000..6dbcd61b8c37 --- /dev/null +++ b/tests/cpp/runtime_test.cc @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace tvm; +using namespace tvm::relay; + +TVM_REGISTER_GLOBAL("runtime_test.strategy") + .set_body_typed([](const Attrs& attrs, const Array& inputs, const Type& out_type, + const Target& target) { + FTVMCompute fcompute = [](const Attrs& attrs, const Array& inputs, + const Type& out_type) -> Array { + ICHECK_EQ(inputs.size(), 2U); + return {topi::add(inputs[0], inputs[1])}; + }; + FTVMSchedule fschedule = [](const Attrs& attrs, const Array& outs, + const Target& target) { + With target_scope(target); + return topi::generic::schedule_injective(target, outs); + }; + + auto n = make_object(); + auto strategy = tvm::relay::OpStrategy(std::move(n)); + strategy.AddImplementation(fcompute, fschedule, "runtime_test.strategy", 10); + return strategy; + }); + +TEST(Runtime, ZeroCopy) { + auto tensor_type = relay::TensorType({2, 3}, DataType::Float(32)); + auto a = relay::Var("a", tensor_type); + auto b = relay::Var("b", tensor_type); + auto add_op = relay::Op::Get("add"); + auto x = relay::Call(add_op, {a, b}, tvm::Attrs(), {}); + auto c = relay::Var("c", tensor_type); + auto y = relay::Call(add_op, {x, c}, tvm::Attrs(), {}); + auto func = relay::Function(relay::FreeVars(y), y, relay::Type(), {}); + auto A = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto Y = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + + auto pA = static_cast(A->data); + auto pB = static_cast(B->data); + auto pC = static_cast(C->data); + auto pY = static_cast(Y->data); + + for (int i = 0; i < 6; ++i) { + pA[i] = i; + pB[i] = i + 1; + pC[i] = i + 2; + } + // get schedule + auto reg = tvm::runtime::Registry::Get("ir.RegisterOpAttr"); + if (!reg) { + LOG(FATAL) << "no _Register"; + } + auto fs = tvm::runtime::Registry::Get("runtime_test.strategy"); + if (!fs) { + LOG(FATAL) << "No test_strategy registered."; + } + auto fgeneric = GenericFunc::Get("runtime_test.strategy_generic").set_default(*fs); + (*reg)("add", "FTVMStrategy", fgeneric, 10); + Array dep; + dep.push_back(0); + (*reg)("add", "TShapeDataDependent", dep, 10); + // build + auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule"); + tvm::runtime::Module build_mod = (*pfb)(); + auto build_f = build_mod.GetFunction("build", false); + auto json_f = build_mod.GetFunction("get_graph_json", false); + auto mod_f = build_mod.GetFunction("get_module", false); + Map targets; + Target llvm_tgt = Target("llvm"); + targets.Set(0, llvm_tgt); + auto relay_mod = tvm::IRModule::FromExpr(func); + ICHECK(relay_mod.defined()) << "Module must be defined"; + build_f(relay_mod, targets, llvm_tgt, runtime::kTvmExecutorGraph, ""); + // create graph executor + std::string json = json_f(); + tvm::runtime::Module mod = mod_f(); + auto dev = A->device; + auto pfr = tvm::runtime::Registry::Get("tvm.graph_executor.create"); + ICHECK(mod.defined()) << "Module must be defined"; + tvm::runtime::Module run_mod = + (*pfr)(json, mod, static_cast(dev.device_type), dev.device_id); + // get function + auto set_input_f = run_mod.GetFunction("set_input_zero_copy", false); + auto set_output_f = run_mod.GetFunction("set_output_zero_copy", false); + auto run_f = run_mod.GetFunction("run", false); + // set input zero copy + set_input_f("a", const_cast(A.operator->())); + set_input_f("b", const_cast(B.operator->())); + set_input_f("c", const_cast(C.operator->())); + // set output zero copy + set_output_f(0, const_cast(Y.operator->())); + run_f(); + // check correctness + for (int i = 0; i < 6; ++i) { + ICHECK_LT(fabs(pY[i] - (i + (i + 1) + (i + 2))), 1e-4); + } + // mutate the input a bit and run it again + for (int i = 0; i < 6; ++i) { + pB[i] = i + 3; + } + run_f(); + // check correctness + for (int i = 0; i < 6; ++i) { + ICHECK_LT(fabs(pY[i] - (i + (i + 3) + (i + 2))), 1e-4); + } + // attach a different input and run it again + auto C2 = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto pC2 = static_cast(C2->data); + for (int i = 0; i < 6; ++i) { + pC2[i] = i + 4; + } + set_input_f("c", const_cast(C2.operator->())); + run_f(); + // check correctness + for (int i = 0; i < 6; ++i) { + ICHECK_LT(fabs(pY[i] - (i + (i + 3) + (i + 4))), 1e-4); + } +}