From c2c487191b2bc048f6880efd2469363297747f9b Mon Sep 17 00:00:00 2001 From: Yuwei-EdgeCortix Date: Thu, 11 Jul 2024 20:41:07 +0800 Subject: [PATCH 1/2] GraphExecutor: Fix wild pointer assign when input and output are reshape --- src/runtime/graph_executor/graph_executor.cc | 22 +++++++++ src/runtime/graph_executor/graph_executor.h | 2 + .../test_runtime_module_based_interface.py | 48 +++++++++++++++++++ 3 files changed, 72 insertions(+) diff --git a/src/runtime/graph_executor/graph_executor.cc b/src/runtime/graph_executor/graph_executor.cc index 5bd7967cab37..107613e5a28c 100644 --- a/src/runtime/graph_executor/graph_executor.cc +++ b/src/runtime/graph_executor/graph_executor.cc @@ -230,6 +230,16 @@ void GraphExecutor::SetOutputZeroCopy(int index, DLTensor* data_ref) { // check the consistency of output CheckExternalDLTensor(data_ref, output_node_eid); + if (nodes_[output_node.node_id].op_type == "tvm_op" && + nodes_[output_node.node_id].param.func_name == "__nop") { + const NodeEntry& input_node = nodes_[output_node.node_id].inputs[0]; + output_node_eid = this->entry_id(input_node); + ICHECK_NE(node_output_dltensors_[output_node_eid].size(), 0); + for (DLTensor* t : node_output_dltensors_[output_node_eid]) { + t->data = static_cast(data_ref->data) + data_ref->byte_offset; + } + } + // Update the data pointer for output op for (DLTensor* t : output_dltensors_[output_node_eid]) { t->data = static_cast(data_ref->data) + data_ref->byte_offset; @@ -540,6 +550,13 @@ void GraphExecutor::SetupOpExecs() { input_dltensors_[input_eid].push_back( const_cast(data_entry_[eid].operator->())); } + } else { + const auto& arg_node = nodes_[inode.inputs[i].node_id]; + if (arg_node.op_type == "tvm_op" && arg_node.param.func_name == "__nop") { + uint32_t arg_input_eid = this->entry_id(arg_node.inputs[0]); + input_dltensors_[arg_input_eid].push_back( + static_cast(op_args->arg_values[i].v_handle)); + } } // check if any model output is the input of the op if (output_node_eids.count(input_eid) > 0) { @@ -554,6 +571,11 @@ void GraphExecutor::SetupOpExecs() { if (output_node_eids.count(output_eid) > 0) { output_dltensors_[output_eid].push_back( static_cast(op_args->arg_values[i].v_handle)); + } else { + // If the node is not an output, keep its output for record and support set_output_zero_copy + // of reshape __nop nodes. + node_output_dltensors_[output_eid].push_back( + static_cast(op_args->arg_values[i].v_handle)); } } } diff --git a/src/runtime/graph_executor/graph_executor.h b/src/runtime/graph_executor/graph_executor.h index 08e06f4e6bf3..53e2801d574e 100644 --- a/src/runtime/graph_executor/graph_executor.h +++ b/src/runtime/graph_executor/graph_executor.h @@ -464,6 +464,8 @@ class TVM_DLL GraphExecutor : public ModuleNode { std::vector> output_dltensors_; /*! \brief Used for quick node(both model output and op input) DLTensor* lookup given an eid. */ std::vector> both_output_opinput_dltensors_; + /*! \brief Used for quick node output DLTensor* lookup given a nop's input eid. */ + std::unordered_map> node_output_dltensors_; /*! \brief Used for quick entry_id lookup given an storage_id. */ std::vector> sid_to_eid_; /*! \brief Used for quick entry indexing. */ diff --git a/tests/python/runtime/test_runtime_module_based_interface.py b/tests/python/runtime/test_runtime_module_based_interface.py index 0751e2ea3d42..bc8d49a9e424 100644 --- a/tests/python/runtime/test_runtime_module_based_interface.py +++ b/tests/python/runtime/test_runtime_module_based_interface.py @@ -735,6 +735,53 @@ def test_graph_module_zero_copy(): tvm.testing.assert_allclose(gm.get_output(0).numpy(), z_torch.numpy()) +@tvm.testing.requires_llvm +def test_reshape_zero_copy(): + shape0 = (56, 224) + shape1 = (112, 112) + in_name0 = "infeats0" + in_name1 = "infeats1" + x0 = relay.var(in_name0, shape=shape0, dtype="float32") + x0 = relay.reshape(x0, shape1) + + x1 = relay.var(in_name1, shape=shape1, dtype="float32") + mat = relay.nn.matmul(x0, x1) + _y = relay.reshape(mat, (-1)) + func = relay.Function(relay.analysis.free_vars(_y), _y) + mod = tvm.IRModule.from_expr(func) + + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target="llvm") + m = graph_executor.GraphModule(lib["default"](tvm.cpu(0))) + + data_ndarray0 = tvm.nd.array( + np.random.random(shape0).astype(np.float32), device=tvm.device("llvm", 0) + ) + data_ndarray1 = tvm.nd.array( + np.random.random(shape1).astype(np.float32), device=tvm.device("llvm", 0) + ) + + def expected(): + m.set_input(in_name0, data_ndarray0) + m.set_input(in_name1, data_ndarray1) + m.run() + return m.get_output(0).numpy() + + def zero_copy(): + from tvm.relay.frontend.common import infer_shape + outshape = infer_shape(_y) + output_view = tvm.nd.empty(outshape, device=tvm.device("llvm", 0)) + m.set_input_zero_copy(in_name0, data_ndarray0) + m.set_input_zero_copy(in_name1, data_ndarray1) + m.set_output_zero_copy(0, output_view) + m.run() + return output_view.numpy() + + golden_out = expected() + out = zero_copy() + np.testing.assert_equal(golden_out, out) + + if __name__ == "__main__": test_legacy_compatibility() test_cpu() @@ -747,3 +794,4 @@ def test_graph_module_zero_copy(): test_cpu_get_graph_params_run() test_cpu_get_graph_params_compare() test_graph_module_zero_copy() + test_reshape_zero_copy() From 9c2d0d406a3f3bb4d44b933cdda39d5f368fd6b2 Mon Sep 17 00:00:00 2001 From: Yuwei Hu Date: Fri, 12 Jul 2024 13:30:37 +0800 Subject: [PATCH 2/2] lint fix --- tests/python/runtime/test_runtime_module_based_interface.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/runtime/test_runtime_module_based_interface.py b/tests/python/runtime/test_runtime_module_based_interface.py index bc8d49a9e424..3f712587684d 100644 --- a/tests/python/runtime/test_runtime_module_based_interface.py +++ b/tests/python/runtime/test_runtime_module_based_interface.py @@ -769,6 +769,7 @@ def expected(): def zero_copy(): from tvm.relay.frontend.common import infer_shape + outshape = infer_shape(_y) output_view = tvm.nd.empty(outshape, device=tvm.device("llvm", 0)) m.set_input_zero_copy(in_name0, data_ndarray0) @@ -779,7 +780,7 @@ def zero_copy(): golden_out = expected() out = zero_copy() - np.testing.assert_equal(golden_out, out) + tvm.testing.assert_allclose(golden_out, out) if __name__ == "__main__":