From f79145620135ad876c144b679ea3ca1820b4eef9 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 16 Mar 2021 20:34:11 +0000 Subject: [PATCH 1/5] [Torch] Remove unnecessary reshapes for batch_matmul --- python/tvm/relay/frontend/pytorch.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index c709e2b4e7bd..924be31ec172 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1094,8 +1094,7 @@ def instance_norm(self, inputs, input_types): data, gamma, beta, axis=1, epsilon=epsilon, center=center, scale=scale ) - @staticmethod - def get_dims(data): + def get_dims(self, data): import torch if isinstance(data, _expr.Expr): @@ -1576,14 +1575,28 @@ def matmul(self, inputs, input_types): # When performing a batch matmul, we need to properly handle N-dim shapes. if len(a_shape) > 2 or len(b_shape) > 2: # Convert a and b into 3 dimensional tensors. - a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]]) - b = _op.reshape(inputs_1, [-1, b_shape[-2], b_shape[-1]]) + need_reshape = False + if len(a_shape) != 3: + a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]]) + need_reshape = True + else: + a = inputs_0 + + if len(b_shape) != 3: + b = _op.reshape(inputs_1, [-1, b_shape[-2], b_shape[-1]]) + need_reshape = True + else: + b = inputs_1 + # Transpose matrix dimensions of b. b = _op.transpose(b, [0, 2, 1]) # Perform a batch matmul. output = _op.nn.batch_matmul(a, b) # Reshape output to original dimensions. - return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]]) + if need_reshape: + desired_shape = [*a_shape[:-2], a_shape[-2], b_shape[-1]] + return _op.reshape(output, desired_shape) + return output # Otherwise a simple dense op will get the job done. if len(b_shape) == 1: From f8402a1ab812ad22457e51b7e292247889060e6a Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 16 Mar 2021 20:34:56 +0000 Subject: [PATCH 2/5] lint --- python/tvm/relay/frontend/pytorch.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 924be31ec172..559e39249a85 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1594,8 +1594,7 @@ def matmul(self, inputs, input_types): output = _op.nn.batch_matmul(a, b) # Reshape output to original dimensions. if need_reshape: - desired_shape = [*a_shape[:-2], a_shape[-2], b_shape[-1]] - return _op.reshape(output, desired_shape) + return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]]) return output # Otherwise a simple dense op will get the job done. From 33f4bfefc277006d8810157a8bd29be75de676c5 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 16 Mar 2021 21:04:04 +0000 Subject: [PATCH 3/5] fix --- python/tvm/relay/frontend/pytorch.py | 1 - tests/python/frontend/pytorch/test_forward.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 559e39249a85..672f2138d728 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1584,7 +1584,6 @@ def matmul(self, inputs, input_types): if len(b_shape) != 3: b = _op.reshape(inputs_1, [-1, b_shape[-2], b_shape[-1]]) - need_reshape = True else: b = inputs_1 diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 83c1698799c7..b23cf672292e 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -201,6 +201,7 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)] input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input])) mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map) + print(relay.transform.InferType()(mod)["main"]) for arg in mod["main"].params[: len(input_names)]: assert arg.name_hint in input_names compiled_input = dict(zip(input_names, [inp.clone().cpu().numpy() for inp in baseline_input])) From 3104fd26dc4a9347c61c9f69833bcefb35820ef4 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 16 Mar 2021 21:30:04 +0000 Subject: [PATCH 4/5] reorder --- python/tvm/relay/frontend/pytorch.py | 22 +++++++++++-------- tests/python/frontend/pytorch/test_forward.py | 1 - 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 672f2138d728..fe5ad93ad7dc 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1574,25 +1574,29 @@ def matmul(self, inputs, input_types): # When performing a batch matmul, we need to properly handle N-dim shapes. if len(a_shape) > 2 or len(b_shape) > 2: - # Convert a and b into 3 dimensional tensors. - need_reshape = False + # Convert a into a 3 dimensional tensors. + need_reshape_output = False if len(a_shape) != 3: a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]]) - need_reshape = True + need_reshape_output = True else: a = inputs_0 + # Transpose matrix dimensions of b. + trans_axes = list(range(len(b_shape))) + trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2] + b = _op.transpose(inputs_1, trans_axes) + + # Convert b into a 3 dimensional tensor. Note that the last two dimensions + # are transposed. if len(b_shape) != 3: - b = _op.reshape(inputs_1, [-1, b_shape[-2], b_shape[-1]]) - else: - b = inputs_1 + b = _op.reshape(b, [-1, b_shape[-1], b_shape[-2]]) - # Transpose matrix dimensions of b. - b = _op.transpose(b, [0, 2, 1]) # Perform a batch matmul. output = _op.nn.batch_matmul(a, b) + # Reshape output to original dimensions. - if need_reshape: + if need_reshape_output: return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]]) return output diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index b23cf672292e..83c1698799c7 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -201,7 +201,6 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)] input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input])) mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map) - print(relay.transform.InferType()(mod)["main"]) for arg in mod["main"].params[: len(input_names)]: assert arg.name_hint in input_names compiled_input = dict(zip(input_names, [inp.clone().cpu().numpy() for inp in baseline_input])) From 219dd76fba1548642a260730e4c5742bc0131fce Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 16 Mar 2021 22:05:47 +0000 Subject: [PATCH 5/5] lint --- python/tvm/relay/frontend/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index fe5ad93ad7dc..fd0a07e35c15 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1581,7 +1581,7 @@ def matmul(self, inputs, input_types): need_reshape_output = True else: a = inputs_0 - + # Transpose matrix dimensions of b. trans_axes = list(range(len(b_shape))) trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]