From e76980815383d459990f0452cf4929e34295216e Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Fri, 5 Mar 2021 18:34:45 -0800 Subject: [PATCH 1/6] properly return and unflatten outputs from GraphExecutor --- python/tvm/relay/build_module.py | 37 +++++++++++++++---- .../relay/test_backend_graph_runtime.py | 21 +++++++++++ 2 files changed, 51 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 79eb7e4f19ff..a5ce7bdaeda6 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -19,6 +19,7 @@ from a Relay expression. """ import warnings +import copy import numpy as np from tvm.ir import IRModule @@ -391,10 +392,35 @@ def _make_executor(self, expr=None): ret_type = self.mod["main"].checked_type.ret_type if _ty.is_dynamic(ret_type): raise ValueError("Graph Runtime only supports static graphs, got output type", ret_type) - num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1 mod = build(self.mod, target=self.target) gmodule = _graph_rt.GraphModule(mod["default"](self.ctx)) + def _write_prefix(data, prefix, value): + while len(prefix) > 1: + data = data[prefix[0]] + prefix = prefix[1:] + data[prefix[0]] = value + + def _build_index(ty, prefix, structure, index_map, cur_index=0): + if isinstance(ty, _ty.TensorType): + index_map[cur_index] = prefix + _write_prefix(structure, prefix, None) + return structure, index_map, cur_index + 1 + elif isinstance(ty, _ty.TupleType): + _write_prefix(structure, prefix, [None] * len(ty.fields)) + for i, field_ty in enumerate(ty.fields): + structure, index_map, cur_index = _build_index( + field_ty, prefix + [i], structure, index_map, cur_index=cur_index + ) + return structure, index_map, cur_index + else: + raise ValueError("Return type", ret_type, "contains unsupported type", ty) + + # output_structure has the unflattened structure of outputs according to ret_type + # index_map takes the flattened index to a list of indices indexing into output_structure + output_structure, index_map, num_outputs = _build_index(ret_type, [0], [None], {}) + assert num_outputs == gmodule.get_num_outputs() + def _graph_wrapper(*args, **kwargs): args = self._convert_args(self.mod["main"], args, kwargs) # Create map of inputs. @@ -402,13 +428,10 @@ def _graph_wrapper(*args, **kwargs): gmodule.set_input(i, arg) # Run the module, and fetch the output. gmodule.run() - # make a copy so multiple invocation won't hurt perf. - if num_outputs == 1: - return gmodule.get_output(0).copyto(_nd.cpu(0)) - outputs = [] + outputs = copy.deepcopy(output_structure) for i in range(num_outputs): - outputs.append(gmodule.get_output(i).copyto(_nd.cpu(0))) - return outputs + _write_prefix(outputs, index_map[i], gmodule.get_output(i).copyto(_nd.cpu(0))) + return outputs[0] return _graph_wrapper diff --git a/tests/python/relay/test_backend_graph_runtime.py b/tests/python/relay/test_backend_graph_runtime.py index 3c42b7b4196f..68708aaeb413 100644 --- a/tests/python/relay/test_backend_graph_runtime.py +++ b/tests/python/relay/test_backend_graph_runtime.py @@ -209,6 +209,27 @@ def test_compile_nested_tuples(): ref = ref + 1 +def test_graph_executor_nested_tuples(): + x, y, z, w = [relay.var(c, shape=(2, 3), dtype="float32") for c in "xyzw"] + out = relay.Tuple([x, relay.Tuple([y, relay.Tuple([z, w])])]) + func = relay.Function([x, y, z, w], out) + + exe = relay.create_executor( + kind="graph", mod=tvm.IRModule.from_expr(func), ctx=tvm.cpu(0), target="llvm" + ) + f = exe.evaluate() + + data = [np.random.uniform(size=(2, 3)).astype("float32") for _ in "xyzw"] + out = f(*data) + assert len(out) == 2 + tvm.testing.assert_allclose(out[0].asnumpy(), data[0]) + assert len(out[1]) == 2 + tvm.testing.assert_allclose(out[1][0].asnumpy(), data[1]) + assert len(out[1][1]) == 2 + tvm.testing.assert_allclose(out[1][1][0].asnumpy(), data[2]) + tvm.testing.assert_allclose(out[1][1][1].asnumpy(), data[3]) + + if __name__ == "__main__": test_plan_memory() test_with_params() From edee0f99e7c88f5332b5b0796019b90f37c69025 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Fri, 5 Mar 2021 18:51:31 -0800 Subject: [PATCH 2/6] lint --- python/tvm/relay/build_module.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index a5ce7bdaeda6..a8fcbdac2d9d 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -401,20 +401,19 @@ def _write_prefix(data, prefix, value): prefix = prefix[1:] data[prefix[0]] = value - def _build_index(ty, prefix, structure, index_map, cur_index=0): - if isinstance(ty, _ty.TensorType): + def _build_index(cur_type, prefix, structure, index_map, cur_index=0): + if isinstance(cur_type, _ty.TensorType): index_map[cur_index] = prefix _write_prefix(structure, prefix, None) return structure, index_map, cur_index + 1 - elif isinstance(ty, _ty.TupleType): - _write_prefix(structure, prefix, [None] * len(ty.fields)) - for i, field_ty in enumerate(ty.fields): + if isinstance(cur_type, _ty.TupleType): + _write_prefix(structure, prefix, [None] * len(cur_type.fields)) + for i, field_type in enumerate(cur_type.fields): structure, index_map, cur_index = _build_index( - field_ty, prefix + [i], structure, index_map, cur_index=cur_index + field_type, prefix + [i], structure, index_map, cur_index=cur_index ) return structure, index_map, cur_index - else: - raise ValueError("Return type", ret_type, "contains unsupported type", ty) + raise ValueError("Return type", ret_type, "contains unsupported type", cur_type) # output_structure has the unflattened structure of outputs according to ret_type # index_map takes the flattened index to a list of indices indexing into output_structure From fba14e411978dfd13103ebdfc57589e7b68337f0 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Fri, 5 Mar 2021 19:08:13 -0800 Subject: [PATCH 3/6] cleaner approach, not sure what I was thinking before --- python/tvm/relay/build_module.py | 37 +++++++++++--------------------- 1 file changed, 12 insertions(+), 25 deletions(-) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index a8fcbdac2d9d..d81ed10e8cf0 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -395,31 +395,17 @@ def _make_executor(self, expr=None): mod = build(self.mod, target=self.target) gmodule = _graph_rt.GraphModule(mod["default"](self.ctx)) - def _write_prefix(data, prefix, value): - while len(prefix) > 1: - data = data[prefix[0]] - prefix = prefix[1:] - data[prefix[0]] = value - - def _build_index(cur_type, prefix, structure, index_map, cur_index=0): + def _unflatten(flattened, cur_type, cur_index): if isinstance(cur_type, _ty.TensorType): - index_map[cur_index] = prefix - _write_prefix(structure, prefix, None) - return structure, index_map, cur_index + 1 + return flattened[cur_index], cur_index + 1 if isinstance(cur_type, _ty.TupleType): - _write_prefix(structure, prefix, [None] * len(cur_type.fields)) - for i, field_type in enumerate(cur_type.fields): - structure, index_map, cur_index = _build_index( - field_type, prefix + [i], structure, index_map, cur_index=cur_index - ) - return structure, index_map, cur_index + fields = [] + for field_type in cur_type.fields: + field, cur_index = _unflatten(flattened, field_type, cur_index) + fields.append(field) + return fields, cur_index raise ValueError("Return type", ret_type, "contains unsupported type", cur_type) - # output_structure has the unflattened structure of outputs according to ret_type - # index_map takes the flattened index to a list of indices indexing into output_structure - output_structure, index_map, num_outputs = _build_index(ret_type, [0], [None], {}) - assert num_outputs == gmodule.get_num_outputs() - def _graph_wrapper(*args, **kwargs): args = self._convert_args(self.mod["main"], args, kwargs) # Create map of inputs. @@ -427,10 +413,11 @@ def _graph_wrapper(*args, **kwargs): gmodule.set_input(i, arg) # Run the module, and fetch the output. gmodule.run() - outputs = copy.deepcopy(output_structure) - for i in range(num_outputs): - _write_prefix(outputs, index_map[i], gmodule.get_output(i).copyto(_nd.cpu(0))) - return outputs[0] + flattened = [] + for i in range(gmodule.get_num_outputs()): + flattened.append(gmodule.get_output(i)) + unflattened, _ = _unflatten(flattened, ret_type, 0) + return unflattened return _graph_wrapper From bbd5f2be77ac3c230624760954093da142e56084 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Fri, 5 Mar 2021 19:09:04 -0800 Subject: [PATCH 4/6] remove unused import --- python/tvm/relay/build_module.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index d81ed10e8cf0..3e6f215f598d 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -19,7 +19,6 @@ from a Relay expression. """ import warnings -import copy import numpy as np from tvm.ir import IRModule From 10e89ce8cf9f0e3cf4b15fde9b2df5e4b2fb4ac4 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Fri, 5 Mar 2021 19:11:24 -0800 Subject: [PATCH 5/6] forgot copyto cpu --- python/tvm/relay/build_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 3e6f215f598d..85bc53df942d 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -414,7 +414,7 @@ def _graph_wrapper(*args, **kwargs): gmodule.run() flattened = [] for i in range(gmodule.get_num_outputs()): - flattened.append(gmodule.get_output(i)) + flattened.append(gmodule.get_output(i).copyto(_nd.cpu(0))) unflattened, _ = _unflatten(flattened, ret_type, 0) return unflattened From 4ba2f4a87ff586560633df314644c7cd6bbe1a6c Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Fri, 5 Mar 2021 20:56:07 -0800 Subject: [PATCH 6/6] make solution even cleaner using iterator --- python/tvm/relay/build_module.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 85bc53df942d..4c9a898f2374 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -394,15 +394,15 @@ def _make_executor(self, expr=None): mod = build(self.mod, target=self.target) gmodule = _graph_rt.GraphModule(mod["default"](self.ctx)) - def _unflatten(flattened, cur_type, cur_index): + def _unflatten(flat_iter, cur_type): if isinstance(cur_type, _ty.TensorType): - return flattened[cur_index], cur_index + 1 + return next(flat_iter) if isinstance(cur_type, _ty.TupleType): fields = [] for field_type in cur_type.fields: - field, cur_index = _unflatten(flattened, field_type, cur_index) + field = _unflatten(flat_iter, field_type) fields.append(field) - return fields, cur_index + return fields raise ValueError("Return type", ret_type, "contains unsupported type", cur_type) def _graph_wrapper(*args, **kwargs): @@ -415,7 +415,7 @@ def _graph_wrapper(*args, **kwargs): flattened = [] for i in range(gmodule.get_num_outputs()): flattened.append(gmodule.get_output(i).copyto(_nd.cpu(0))) - unflattened, _ = _unflatten(flattened, ret_type, 0) + unflattened = _unflatten(iter(flattened), ret_type) return unflattened return _graph_wrapper