diff --git a/python/tvm/contrib/graph_executor.py b/python/tvm/contrib/graph_executor.py index 08dae307a89e..161ca5ffd08c 100644 --- a/python/tvm/contrib/graph_executor.py +++ b/python/tvm/contrib/graph_executor.py @@ -153,6 +153,21 @@ class GraphModule(object): def __init__(self, module): self.module = module self._set_input = module["set_input"] + + # TODO(shingjan): The graph_executor in C doesn't have + # set_input/output_zero_copy implemented. + try: + self._set_input_zero_copy = module["set_input_zero_copy"] + except AttributeError: + self._set_input_zero_copy = lambda *_: (_ for _ in ()).throw( + Exception("set_input_zero_copy is not implemented for C graph executor") + ) + try: + self._set_output_zero_copy = module["set_output_zero_copy"] + except AttributeError: + self._set_output_zero_copy = lambda *_: (_ for _ in ()).throw( + Exception("set_output_zero_copy is not implemented for C graph executor") + ) self._run = module["run"] self._get_output = module["get_output"] self._get_input = module["get_input"] @@ -172,7 +187,7 @@ def set_input(self, key=None, value=None, **params): The input key value : the input value. - The input key + The input value params : dict of str to NDArray Additional arguments @@ -195,6 +210,47 @@ def set_input(self, key=None, value=None, **params): if val: self._get_input(k).copyfrom(params[k]) + def set_input_zero_copy(self, key=None, value=None, **params): + """Set inputs to the module via kwargs with zero memory copy + + Parameters + ---------- + key : int or str + The input key + + value : the input value in DLPack + The input value + + params : dict of str to NDArray + Additional arguments + """ + if key is not None: + self._set_input_zero_copy(key, value) + + if params: + keys = list(params.keys()) + + for k in keys: + # TODO(zhiics) Skip the weights for submodule in a better way. + # We should use ConstLoaderModule for initialization and remove + # params from set_input + val = self._get_input(k) + if val: + self._set_input_zero_copy(k, params[k]) + + def set_output_zero_copy(self, key, value): + """Set outputs to the module with zero memory copy + + Parameters + ---------- + key : int or str + The output key + + value : the output value in DLPack + The output value + """ + self._set_output_zero_copy(key, value) + def run(self, **input_dict): """Run forward execution of the graph diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py b/tests/python/unittest/test_meta_schedule_relay_integration.py index 604f337099b0..76d6323f309a 100644 --- a/tests/python/unittest/test_meta_schedule_relay_integration.py +++ b/tests/python/unittest/test_meta_schedule_relay_integration.py @@ -54,16 +54,6 @@ def main(a: T.handle, b: T.handle) -> None: # type: ignore # pylint: enable=no-member,line-too-long,too-many-nested-blocks,unbalanced-tuple-unpacking,no-self-argument -def _has_torch(): - import importlib.util # pylint: disable=unused-import,import-outside-toplevel - - spec = importlib.util.find_spec("torch") - return spec is not None - - -requires_torch = pytest.mark.skipif(not _has_torch(), reason="torch is not installed") - - def test_meta_schedule_dynamic_loop_extent(): a = relay.var("a", shape=(1, 8, 8, 512), dtype="float32") b = relay.nn.adaptive_avg_pool2d(a, (7, 7), "NHWC") @@ -72,7 +62,7 @@ def test_meta_schedule_dynamic_loop_extent(): assert not extracted_tasks -@requires_torch +@tvm.testing.requires_package("torch") def test_meta_schedule_integration_extract_from_resnet(): mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) extracted_tasks = ms.relay_integration.extract_tasks(mod, target="llvm", params=params) @@ -108,7 +98,7 @@ def test_meta_schedule_integration_extract_from_resnet(): assert t.task_name in expected_task_names, t.task_name -@requires_torch +@tvm.testing.requires_package("torch") def test_task_extraction_winograd_tensorcore(): mod, params, _ = get_network(name="resnet_50", input_shape=[16, 3, 224, 224]) seq = tvm.transform.Sequential( @@ -126,7 +116,7 @@ def test_task_extraction_winograd_tensorcore(): assert len([t for t in extracted_tasks if "winograd" in t.task_name]) == 4 -@requires_torch +@tvm.testing.requires_package("torch") def test_task_extraction_anchor_block(): mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) extracted_tasks = ms.relay_integration.extract_tasks( @@ -161,7 +151,7 @@ def test_task_extraction_anchor_block(): assert t.task_name in expected_task_names, t.task_name -@requires_torch +@tvm.testing.requires_package("torch") def test_meta_schedule_integration_extract_from_bert_base(): pytest.importorskip( "transformers", reason="transformers package is required to import bert_base" @@ -259,7 +249,7 @@ def test_meta_schedule_integration_extract_from_bert_base(): assert expected_shape == shape, t.task_name -@requires_torch +@tvm.testing.requires_package("torch") def test_meta_schedule_integration_extract_from_resnet_with_filter_func(): @register_func("relay.backend.tir_converter.remove_purely_spatial", override=True) def filter_func(args, _) -> bool: diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index c7ce5abfbd92..0ed097ddf563 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -688,6 +688,44 @@ def test_num_threads(): assert reported == hardware_threads or reported == hardware_threads // 2 +@tvm.testing.requires_llvm +@tvm.testing.requires_package("torch") +def test_graph_module_zero_copy(): + mod = tvm.IRModule() + params = {} + dev = tvm.cpu() + x = relay.var("x", shape=(1, 10)) + y = relay.var("y", shape=(1, 10)) + z = relay.add(x, y) + mod["main"] = relay.Function([x, y], z) + + # need torch to do the from_dlpack trick + import torch + + compiled_graph_lib = relay.build(mod, target="llvm", params=params) + gm = graph_executor.GraphModule(compiled_graph_lib["default"](dev)) + x_data = torch.rand((1, 10)) + y_data = torch.rand((1, 10)) + z_data = torch.rand((1, 10)) + z_torch = x_data + y_data + + # zero copy run + assert not np.allclose(z_data.numpy(), z_torch.numpy()) + gm.set_input_zero_copy("x", tvm.nd.from_dlpack(x_data)) + gm.set_input_zero_copy("y", tvm.nd.from_dlpack(y_data)) + gm.set_output_zero_copy(0, tvm.nd.from_dlpack(z_data)) + gm.run() + + tvm.testing.assert_allclose(z_data.numpy(), z_torch.numpy()) + + # zero input copy with params + gm = graph_executor.GraphModule(compiled_graph_lib["default"](dev)) + gm.set_input_zero_copy(x=tvm.nd.from_dlpack(x_data), y=tvm.nd.from_dlpack(y_data)) + gm.run() + + tvm.testing.assert_allclose(gm.get_output(0).numpy(), z_torch.numpy()) + + if __name__ == "__main__": test_legacy_compatibility() test_cpu() @@ -699,3 +737,4 @@ def test_num_threads(): test_cpu_get_graph_json() test_cpu_get_graph_params_run() test_cpu_get_graph_params_compare() + test_graph_module_zero_copy()