From 7827eb127cfdcce434b8de73a7ad13f02b922ba9 Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Wed, 14 Dec 2022 22:02:54 -0800 Subject: [PATCH 01/16] add set_output and test for set_output_zero_copy in python --- python/tvm/contrib/graph_executor.py | 48 ++++++++++++++++++- src/runtime/graph_executor/graph_executor.cc | 19 ++++++++ src/runtime/graph_executor/graph_executor.h | 6 +++ .../test_runtime_module_based_interface.py | 39 +++++++++++++++ 4 files changed, 111 insertions(+), 1 deletion(-) diff --git a/python/tvm/contrib/graph_executor.py b/python/tvm/contrib/graph_executor.py index 08dae307a89e..9b4d714f4bf8 100644 --- a/python/tvm/contrib/graph_executor.py +++ b/python/tvm/contrib/graph_executor.py @@ -153,6 +153,9 @@ class GraphModule(object): def __init__(self, module): self.module = module self._set_input = module["set_input"] + self._set_input_zero_copy = module["set_input_zero_copy"] + self._set_output = module["set_output"] + self._set_output_zero_copy = module["set_output_zero_copy"] self._run = module["run"] self._get_output = module["get_output"] self._get_input = module["get_input"] @@ -172,7 +175,7 @@ def set_input(self, key=None, value=None, **params): The input key value : the input value. - The input key + The input value params : dict of str to NDArray Additional arguments @@ -195,6 +198,49 @@ def set_input(self, key=None, value=None, **params): if val: self._get_input(k).copyfrom(params[k]) + def set_output(self, key, value): + """Set outputs to the module + + Parameters + ---------- + key : int or str + The output key + + value : the output value + The output value + """ + self._set_output(key, value) + + def set_input_zero_copy(self, key, value, **params): + """Set inputs to the module via kwargs with zero memory copy + + Parameters + ---------- + key : int or str + The input key + + value : the input value in DLPack + The input key + + params : dict of str to NDArray + Additional arguments + """ + self._set_input_zero_copy(key, value) + self.set_input(None, None, **params) + + def set_output_zero_copy(self, key, value): + """Set outputs to the module with zero memory copy + + Parameters + ---------- + key : int or str + The output key + + value : the output value in DLPack + The output value + """ + self._set_output_zero_copy(key, value) + def run(self, **input_dict): """Run forward execution of the graph diff --git a/src/runtime/graph_executor/graph_executor.cc b/src/runtime/graph_executor/graph_executor.cc index d805abfc658a..007c99038657 100644 --- a/src/runtime/graph_executor/graph_executor.cc +++ b/src/runtime/graph_executor/graph_executor.cc @@ -160,6 +160,16 @@ void GraphExecutor::SetInput(int index, DLTensor* data_in) { uint32_t eid = this->entry_id(input_nodes_[index], 0); data_entry_[eid].CopyFrom(data_in); } +/*! + * \brief set index-th output to the graph. + * \param index The output index. + * \param data_in The output data. + */ +void GraphExecutor::SetOutput(int index, DLTensor* data_in) { + ICHECK_LT(static_cast(index), outputs_.size()); + uint32_t output_node_eid = this->entry_id(outputs_[index]); + data_entry_[output_node_eid].CopyFrom(data_in); +} /*! * \brief Check the legality of external DLTensor*. * \param external The external DLTensor*. @@ -583,6 +593,15 @@ PackedFunc GraphExecutor::GetFunction(const std::string& name, this->SetInput(args[0], args[1]); } }); + } else if (name == "set_output") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (String::CanConvertFrom(args[0])) { + int in_idx = this->GetInputIndex(args[0].operator String()); + if (in_idx >= 0) this->SetOutput(in_idx, args[1]); + } else { + this->SetOutput(args[0], args[1]); + } + }); } else if (name == "set_input_zero_copy") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { if (String::CanConvertFrom(args[0])) { diff --git a/src/runtime/graph_executor/graph_executor.h b/src/runtime/graph_executor/graph_executor.h index bbe94636b3a1..9022b50b83b9 100644 --- a/src/runtime/graph_executor/graph_executor.h +++ b/src/runtime/graph_executor/graph_executor.h @@ -130,6 +130,12 @@ class TVM_DLL GraphExecutor : public ModuleNode { * \param data_in The input data. */ void SetInput(int index, DLTensor* data_in); + /*! + * \brief set index-th output to the graph. + * \param index The input index. + * \param data_in The input data. + */ + void SetOutput(int index, DLTensor* data_in); /*! * \brief set index-th input to the graph without copying the data * \param index The input index. diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index c7ce5abfbd92..43dc24613eae 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -29,6 +29,10 @@ def input_shape(mod): return [int(x) for x in mod["main"].checked_type.arg_types[0].shape] +def output_shape(mod): + return [int(x) for x in mod["main"].checked_type.arg_types[0].shape] + + def verify(data): if not tvm.runtime.enabled("llvm"): print("Skip because llvm is not enabled") @@ -688,6 +692,41 @@ def test_num_threads(): assert reported == hardware_threads or reported == hardware_threads // 2 +@tvm.testing.requires_llvm +def test_graph_module_zero_copy(): + mod = tvm.IRModule() + params = {} + dev = tvm.cpu() + x = relay.var("x", shape=(1, 10)) + y = relay.var("y", shape=(1, 10)) + z = relay.add(x, y) + mod["main"] = relay.Function([x, y], z) + + import torch + + compiled_graph_lib = relay.build(mod, target="llvm", params=params) + gm = graph_executor.GraphModule(compiled_graph_lib["default"](dev)) + x_data = torch.rand((1, 10)) + y_data = torch.rand((1, 10)) + z_data = torch.zeros((1, 10)) + z_torch = x_data + y_data + # regular run + gm.set_input("x", tvm.nd.array(x_data.numpy())) + gm.set_input("y", tvm.nd.array(y_data.numpy())) + gm.set_output("z", tvm.nd.array(z_data.numpy())) + gm.run() + + tvm.testing.assert_allclose(gm.get_output(0).numpy(), z_torch.numpy()) + + # zero copy run + gm.set_input_zero_copy("x", tvm.nd.from_dlpack(x_data)) + gm.set_input_zero_copy("y", tvm.nd.from_dlpack(y_data)) + gm.set_output_zero_copy("z", tvm.nd.from_dlpack(z_data)) + gm.run() + + tvm.testing.assert_allclose(z_data.numpy(), z_torch.numpy()) + + if __name__ == "__main__": test_legacy_compatibility() test_cpu() From b00e801c9059fe8b5a562986307ceb7d95785436 Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Wed, 14 Dec 2022 22:10:53 -0800 Subject: [PATCH 02/16] clean up --- tests/python/unittest/test_runtime_module_based_interface.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index 43dc24613eae..26b9cbfdd5ce 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -29,10 +29,6 @@ def input_shape(mod): return [int(x) for x in mod["main"].checked_type.arg_types[0].shape] -def output_shape(mod): - return [int(x) for x in mod["main"].checked_type.arg_types[0].shape] - - def verify(data): if not tvm.runtime.enabled("llvm"): print("Skip because llvm is not enabled") @@ -723,7 +719,6 @@ def test_graph_module_zero_copy(): gm.set_input_zero_copy("y", tvm.nd.from_dlpack(y_data)) gm.set_output_zero_copy("z", tvm.nd.from_dlpack(z_data)) gm.run() - tvm.testing.assert_allclose(z_data.numpy(), z_torch.numpy()) From 2c4ece9a2a4055ee1522605b57235b6c7dfc51c4 Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Wed, 14 Dec 2022 22:22:29 -0800 Subject: [PATCH 03/16] clean up test --- tests/python/unittest/test_runtime_module_based_interface.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index 26b9cbfdd5ce..d86c27f2df87 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -719,6 +719,7 @@ def test_graph_module_zero_copy(): gm.set_input_zero_copy("y", tvm.nd.from_dlpack(y_data)) gm.set_output_zero_copy("z", tvm.nd.from_dlpack(z_data)) gm.run() + tvm.testing.assert_allclose(z_data.numpy(), z_torch.numpy()) @@ -733,3 +734,4 @@ def test_graph_module_zero_copy(): test_cpu_get_graph_json() test_cpu_get_graph_params_run() test_cpu_get_graph_params_compare() + test_graph_module_zero_copy() From 42829389168a0e942a894a8ce5b7237d961ec44d Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Wed, 14 Dec 2022 23:03:10 -0800 Subject: [PATCH 04/16] test finished --- .../python/unittest/test_runtime_module_based_interface.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index d86c27f2df87..80ff29f0487a 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -698,6 +698,7 @@ def test_graph_module_zero_copy(): z = relay.add(x, y) mod["main"] = relay.Function([x, y], z) + # need torch to do the from_dlpack trick import torch compiled_graph_lib = relay.build(mod, target="llvm", params=params) @@ -706,18 +707,20 @@ def test_graph_module_zero_copy(): y_data = torch.rand((1, 10)) z_data = torch.zeros((1, 10)) z_torch = x_data + y_data + # regular run gm.set_input("x", tvm.nd.array(x_data.numpy())) gm.set_input("y", tvm.nd.array(y_data.numpy())) - gm.set_output("z", tvm.nd.array(z_data.numpy())) + gm.set_output(0, tvm.nd.array(z_data.numpy())) gm.run() tvm.testing.assert_allclose(gm.get_output(0).numpy(), z_torch.numpy()) # zero copy run + assert not np.allclose(z_data.numpy(), z_torch.numpy()) gm.set_input_zero_copy("x", tvm.nd.from_dlpack(x_data)) gm.set_input_zero_copy("y", tvm.nd.from_dlpack(y_data)) - gm.set_output_zero_copy("z", tvm.nd.from_dlpack(z_data)) + gm.set_output_zero_copy(0, tvm.nd.from_dlpack(z_data)) gm.run() tvm.testing.assert_allclose(z_data.numpy(), z_torch.numpy()) From e0868ddd8438c107985df8b838b68a03c21774a8 Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Thu, 15 Dec 2022 10:19:19 -0800 Subject: [PATCH 05/16] remove set output --- python/tvm/contrib/graph_executor.py | 14 -------------- src/runtime/graph_executor/graph_executor.cc | 19 ------------------- .../test_runtime_module_based_interface.py | 8 -------- 3 files changed, 41 deletions(-) diff --git a/python/tvm/contrib/graph_executor.py b/python/tvm/contrib/graph_executor.py index 9b4d714f4bf8..91dab94149f1 100644 --- a/python/tvm/contrib/graph_executor.py +++ b/python/tvm/contrib/graph_executor.py @@ -154,7 +154,6 @@ def __init__(self, module): self.module = module self._set_input = module["set_input"] self._set_input_zero_copy = module["set_input_zero_copy"] - self._set_output = module["set_output"] self._set_output_zero_copy = module["set_output_zero_copy"] self._run = module["run"] self._get_output = module["get_output"] @@ -198,19 +197,6 @@ def set_input(self, key=None, value=None, **params): if val: self._get_input(k).copyfrom(params[k]) - def set_output(self, key, value): - """Set outputs to the module - - Parameters - ---------- - key : int or str - The output key - - value : the output value - The output value - """ - self._set_output(key, value) - def set_input_zero_copy(self, key, value, **params): """Set inputs to the module via kwargs with zero memory copy diff --git a/src/runtime/graph_executor/graph_executor.cc b/src/runtime/graph_executor/graph_executor.cc index 007c99038657..d805abfc658a 100644 --- a/src/runtime/graph_executor/graph_executor.cc +++ b/src/runtime/graph_executor/graph_executor.cc @@ -160,16 +160,6 @@ void GraphExecutor::SetInput(int index, DLTensor* data_in) { uint32_t eid = this->entry_id(input_nodes_[index], 0); data_entry_[eid].CopyFrom(data_in); } -/*! - * \brief set index-th output to the graph. - * \param index The output index. - * \param data_in The output data. - */ -void GraphExecutor::SetOutput(int index, DLTensor* data_in) { - ICHECK_LT(static_cast(index), outputs_.size()); - uint32_t output_node_eid = this->entry_id(outputs_[index]); - data_entry_[output_node_eid].CopyFrom(data_in); -} /*! * \brief Check the legality of external DLTensor*. * \param external The external DLTensor*. @@ -593,15 +583,6 @@ PackedFunc GraphExecutor::GetFunction(const std::string& name, this->SetInput(args[0], args[1]); } }); - } else if (name == "set_output") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - if (String::CanConvertFrom(args[0])) { - int in_idx = this->GetInputIndex(args[0].operator String()); - if (in_idx >= 0) this->SetOutput(in_idx, args[1]); - } else { - this->SetOutput(args[0], args[1]); - } - }); } else if (name == "set_input_zero_copy") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { if (String::CanConvertFrom(args[0])) { diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index 80ff29f0487a..c3a0d3286071 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -708,14 +708,6 @@ def test_graph_module_zero_copy(): z_data = torch.zeros((1, 10)) z_torch = x_data + y_data - # regular run - gm.set_input("x", tvm.nd.array(x_data.numpy())) - gm.set_input("y", tvm.nd.array(y_data.numpy())) - gm.set_output(0, tvm.nd.array(z_data.numpy())) - gm.run() - - tvm.testing.assert_allclose(gm.get_output(0).numpy(), z_torch.numpy()) - # zero copy run assert not np.allclose(z_data.numpy(), z_torch.numpy()) gm.set_input_zero_copy("x", tvm.nd.from_dlpack(x_data)) From 393958bd88d3767fa842ef30b117db18b7a66097 Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Thu, 15 Dec 2022 10:22:30 -0800 Subject: [PATCH 06/16] remove setoutput from header --- src/runtime/graph_executor/graph_executor.h | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/runtime/graph_executor/graph_executor.h b/src/runtime/graph_executor/graph_executor.h index 9022b50b83b9..bbe94636b3a1 100644 --- a/src/runtime/graph_executor/graph_executor.h +++ b/src/runtime/graph_executor/graph_executor.h @@ -130,12 +130,6 @@ class TVM_DLL GraphExecutor : public ModuleNode { * \param data_in The input data. */ void SetInput(int index, DLTensor* data_in); - /*! - * \brief set index-th output to the graph. - * \param index The input index. - * \param data_in The input data. - */ - void SetOutput(int index, DLTensor* data_in); /*! * \brief set index-th input to the graph without copying the data * \param index The input index. From 880f960a604f88ee4b56832f3d0a1e58c6c56f9b Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Thu, 15 Dec 2022 10:32:26 -0800 Subject: [PATCH 07/16] use zero copy for params --- python/tvm/contrib/graph_executor.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/python/tvm/contrib/graph_executor.py b/python/tvm/contrib/graph_executor.py index 91dab94149f1..0318836c0a54 100644 --- a/python/tvm/contrib/graph_executor.py +++ b/python/tvm/contrib/graph_executor.py @@ -212,7 +212,17 @@ def set_input_zero_copy(self, key, value, **params): Additional arguments """ self._set_input_zero_copy(key, value) - self.set_input(None, None, **params) + if params: + # upload big arrays first to avoid memory issue in rpc mode + keys = list(params.keys()) + keys.sort(key=lambda x: -np.prod(params[x].shape)) + for k in keys: + # TODO(zhiics) Skip the weights for submodule in a better way. + # We should use ConstLoaderModule for initialization and remove + # params from set_input + val = self._get_input(k) + if val: + self._set_input_zero_copy(k, params[k]) def set_output_zero_copy(self, key, value): """Set outputs to the module with zero memory copy From 11375151af7cfaa18b08c6dd112570e6dc9653b2 Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Thu, 15 Dec 2022 10:34:31 -0800 Subject: [PATCH 08/16] fix typo --- python/tvm/contrib/graph_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/contrib/graph_executor.py b/python/tvm/contrib/graph_executor.py index 0318836c0a54..49e946ed8c58 100644 --- a/python/tvm/contrib/graph_executor.py +++ b/python/tvm/contrib/graph_executor.py @@ -206,7 +206,7 @@ def set_input_zero_copy(self, key, value, **params): The input key value : the input value in DLPack - The input key + The input value params : dict of str to NDArray Additional arguments From 1735d8a94ffa9800ae9fe47dcdd6d471654d33c9 Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Thu, 15 Dec 2022 10:49:52 -0800 Subject: [PATCH 09/16] address comments --- python/tvm/contrib/graph_executor.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/tvm/contrib/graph_executor.py b/python/tvm/contrib/graph_executor.py index 49e946ed8c58..5d0f101eb94a 100644 --- a/python/tvm/contrib/graph_executor.py +++ b/python/tvm/contrib/graph_executor.py @@ -211,11 +211,12 @@ def set_input_zero_copy(self, key, value, **params): params : dict of str to NDArray Additional arguments """ - self._set_input_zero_copy(key, value) + if key is not None: + self._set_input_zero_copy(key, value) + if params: - # upload big arrays first to avoid memory issue in rpc mode keys = list(params.keys()) - keys.sort(key=lambda x: -np.prod(params[x].shape)) + for k in keys: # TODO(zhiics) Skip the weights for submodule in a better way. # We should use ConstLoaderModule for initialization and remove From 31f147da7ed626e491a63b17f0eb3c3b02699b2b Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Thu, 15 Dec 2022 11:28:44 -0800 Subject: [PATCH 10/16] address comments --- python/tvm/contrib/graph_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/contrib/graph_executor.py b/python/tvm/contrib/graph_executor.py index 5d0f101eb94a..ec2920b88fd2 100644 --- a/python/tvm/contrib/graph_executor.py +++ b/python/tvm/contrib/graph_executor.py @@ -197,7 +197,7 @@ def set_input(self, key=None, value=None, **params): if val: self._get_input(k).copyfrom(params[k]) - def set_input_zero_copy(self, key, value, **params): + def set_input_zero_copy(self, key=None, value=None, **params): """Set inputs to the module via kwargs with zero memory copy Parameters From a14a9375a5b30cfe361a0855a9fa8fe5f53c0e7b Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Fri, 16 Dec 2022 12:22:13 -0800 Subject: [PATCH 11/16] add second test for set_input params --- .../unittest/test_runtime_module_based_interface.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index c3a0d3286071..83d04cef21ce 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -705,7 +705,7 @@ def test_graph_module_zero_copy(): gm = graph_executor.GraphModule(compiled_graph_lib["default"](dev)) x_data = torch.rand((1, 10)) y_data = torch.rand((1, 10)) - z_data = torch.zeros((1, 10)) + z_data = torch.rand((1, 10)) z_torch = x_data + y_data # zero copy run @@ -717,6 +717,13 @@ def test_graph_module_zero_copy(): tvm.testing.assert_allclose(z_data.numpy(), z_torch.numpy()) + # zero input copy with params + gm = graph_executor.GraphModule(compiled_graph_lib["default"](dev)) + gm.set_input_zero_copy(x=tvm.nd.from_dlpack(x_data), y=tvm.nd.from_dlpack(y_data)) + gm.run() + + tvm.testing.assert_allclose(gm.get_output(0).numpy(), z_torch.numpy()) + if __name__ == "__main__": test_legacy_compatibility() From 69b8612e59a4994b5ce360b649bfbb9f3d1c9e08 Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Fri, 16 Dec 2022 12:32:41 -0800 Subject: [PATCH 12/16] add requires_torch --- .../unittest/test_runtime_module_based_interface.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index 83d04cef21ce..106773f30631 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import numpy as np +import pytest import os from tvm import relay, runtime from tvm.relay import testing @@ -688,7 +689,19 @@ def test_num_threads(): assert reported == hardware_threads or reported == hardware_threads // 2 +def _has_torch(): + import importlib.util # pylint: disable=unused-import,import-outside-toplevel + + spec = importlib.util.find_spec("torch") + return spec is not None + + +# TODO(shingjan): put requires_torch in tvm.testing +requires_torch = pytest.mark.skipif(not _has_torch(), reason="torch is not installed") + + @tvm.testing.requires_llvm +@requires_torch def test_graph_module_zero_copy(): mod = tvm.IRModule() params = {} From ab0a6f2ff02f5546040e3c70e7e26d61dc17512c Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Fri, 16 Dec 2022 12:37:21 -0800 Subject: [PATCH 13/16] add requires torch --- .../test_meta_schedule_relay_integration.py | 20 +++++-------------- .../test_runtime_module_based_interface.py | 13 +----------- 2 files changed, 6 insertions(+), 27 deletions(-) diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py b/tests/python/unittest/test_meta_schedule_relay_integration.py index 604f337099b0..76d6323f309a 100644 --- a/tests/python/unittest/test_meta_schedule_relay_integration.py +++ b/tests/python/unittest/test_meta_schedule_relay_integration.py @@ -54,16 +54,6 @@ def main(a: T.handle, b: T.handle) -> None: # type: ignore # pylint: enable=no-member,line-too-long,too-many-nested-blocks,unbalanced-tuple-unpacking,no-self-argument -def _has_torch(): - import importlib.util # pylint: disable=unused-import,import-outside-toplevel - - spec = importlib.util.find_spec("torch") - return spec is not None - - -requires_torch = pytest.mark.skipif(not _has_torch(), reason="torch is not installed") - - def test_meta_schedule_dynamic_loop_extent(): a = relay.var("a", shape=(1, 8, 8, 512), dtype="float32") b = relay.nn.adaptive_avg_pool2d(a, (7, 7), "NHWC") @@ -72,7 +62,7 @@ def test_meta_schedule_dynamic_loop_extent(): assert not extracted_tasks -@requires_torch +@tvm.testing.requires_package("torch") def test_meta_schedule_integration_extract_from_resnet(): mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) extracted_tasks = ms.relay_integration.extract_tasks(mod, target="llvm", params=params) @@ -108,7 +98,7 @@ def test_meta_schedule_integration_extract_from_resnet(): assert t.task_name in expected_task_names, t.task_name -@requires_torch +@tvm.testing.requires_package("torch") def test_task_extraction_winograd_tensorcore(): mod, params, _ = get_network(name="resnet_50", input_shape=[16, 3, 224, 224]) seq = tvm.transform.Sequential( @@ -126,7 +116,7 @@ def test_task_extraction_winograd_tensorcore(): assert len([t for t in extracted_tasks if "winograd" in t.task_name]) == 4 -@requires_torch +@tvm.testing.requires_package("torch") def test_task_extraction_anchor_block(): mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) extracted_tasks = ms.relay_integration.extract_tasks( @@ -161,7 +151,7 @@ def test_task_extraction_anchor_block(): assert t.task_name in expected_task_names, t.task_name -@requires_torch +@tvm.testing.requires_package("torch") def test_meta_schedule_integration_extract_from_bert_base(): pytest.importorskip( "transformers", reason="transformers package is required to import bert_base" @@ -259,7 +249,7 @@ def test_meta_schedule_integration_extract_from_bert_base(): assert expected_shape == shape, t.task_name -@requires_torch +@tvm.testing.requires_package("torch") def test_meta_schedule_integration_extract_from_resnet_with_filter_func(): @register_func("relay.backend.tir_converter.remove_purely_spatial", override=True) def filter_func(args, _) -> bool: diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index 106773f30631..e12708727286 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -689,19 +689,8 @@ def test_num_threads(): assert reported == hardware_threads or reported == hardware_threads // 2 -def _has_torch(): - import importlib.util # pylint: disable=unused-import,import-outside-toplevel - - spec = importlib.util.find_spec("torch") - return spec is not None - - -# TODO(shingjan): put requires_torch in tvm.testing -requires_torch = pytest.mark.skipif(not _has_torch(), reason="torch is not installed") - - @tvm.testing.requires_llvm -@requires_torch +@tvm.testing.requires_package("torch") def test_graph_module_zero_copy(): mod = tvm.IRModule() params = {} From 710528cfe7215f3527265197bf707d8782167c84 Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Fri, 16 Dec 2022 12:39:55 -0800 Subject: [PATCH 14/16] remove pytest --- tests/python/unittest/test_runtime_module_based_interface.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index e12708727286..0ed097ddf563 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. import numpy as np -import pytest import os from tvm import relay, runtime from tvm.relay import testing From ea76cef3dff8eaf3324ba3b24ca413880e828998 Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Fri, 16 Dec 2022 14:41:06 -0800 Subject: [PATCH 15/16] add error handling for c graph executor --- python/tvm/contrib/graph_executor.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/python/tvm/contrib/graph_executor.py b/python/tvm/contrib/graph_executor.py index ec2920b88fd2..cf4301d36e0c 100644 --- a/python/tvm/contrib/graph_executor.py +++ b/python/tvm/contrib/graph_executor.py @@ -153,8 +153,21 @@ class GraphModule(object): def __init__(self, module): self.module = module self._set_input = module["set_input"] - self._set_input_zero_copy = module["set_input_zero_copy"] - self._set_output_zero_copy = module["set_output_zero_copy"] + + # TODO(shingjan): The graph_executor in C doesn't have + # set_input/output_zero_copy implemented. + try: + self._set_input_zero_copy = module["set_input_zero_copy"] + except AttributeError: + self._set_input_zero_copy = lambda: (_ for _ in ()).throw( + Exception("set_input_zero_copy is not implemented for C graph executor") + ) + try: + self._set_output_zero_copy = module["set_output_zero_copy"] + except AttributeError: + self._set_output_zero_copy = lambda: (_ for _ in ()).throw( + Exception("set_output_zero_copy is not implemented for C graph executor") + ) self._run = module["run"] self._get_output = module["get_output"] self._get_input = module["get_input"] From f9854d7fd7737729eb362c42b57cc30b6af2e4c0 Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Fri, 16 Dec 2022 14:44:06 -0800 Subject: [PATCH 16/16] better handling --- python/tvm/contrib/graph_executor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/contrib/graph_executor.py b/python/tvm/contrib/graph_executor.py index cf4301d36e0c..161ca5ffd08c 100644 --- a/python/tvm/contrib/graph_executor.py +++ b/python/tvm/contrib/graph_executor.py @@ -159,13 +159,13 @@ def __init__(self, module): try: self._set_input_zero_copy = module["set_input_zero_copy"] except AttributeError: - self._set_input_zero_copy = lambda: (_ for _ in ()).throw( + self._set_input_zero_copy = lambda *_: (_ for _ in ()).throw( Exception("set_input_zero_copy is not implemented for C graph executor") ) try: self._set_output_zero_copy = module["set_output_zero_copy"] except AttributeError: - self._set_output_zero_copy = lambda: (_ for _ in ()).throw( + self._set_output_zero_copy = lambda *_: (_ for _ in ()).throw( Exception("set_output_zero_copy is not implemented for C graph executor") ) self._run = module["run"]