From 9b9a6ed1ab08861921ebc3c2f9c81baf98079205 Mon Sep 17 00:00:00 2001 From: Mikael Sevenier Date: Wed, 18 Mar 2020 19:30:05 -0700 Subject: [PATCH 1/9] fix conv transpose import from TF --- python/tvm/relay/frontend/tensorflow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index bed32b7274af..24a1af19a40e 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -270,7 +270,7 @@ def _impl(inputs, attr, params): attr['strides'][3], attr['strides'][1], attr['strides'][2] attr['data_format'] = 'NCHW' - if opname == 'conv_transpose' and len(attr['_output_shapes']) > 0: + if opname == 'conv_transpose' and len(attr['_output_shapes']) > 0 and attr['_output_shapes'][0] is not None: tmp_shape = attr['_output_shapes'][0] tmp_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)] attr['_output_shapes'][0] = tmp_shape @@ -355,7 +355,7 @@ def _impl(inputs, attr, params): kernel_h, kernel_w = attr['kernel_shape'] pdata_shape = input_shape - if opname == 'conv_transpose' and len(attr['_output_shapes']) > 0: + if opname == 'conv_transpose' and len(attr['_output_shapes']) > 0 and attr['_output_shapes'][0] is not None: pdata_shape = attr['_output_shapes'][0] if attr['data_format'] == 'NHWC': From 1372b43fbfcc1b840091cdbcad80f398b8dbf1e4 Mon Sep 17 00:00:00 2001 From: Mikael Sevenier Date: Tue, 3 Nov 2020 20:53:15 -0800 Subject: [PATCH 2/9] fix String::fromwe() to String::from() --- rust/tvm/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/tvm/src/lib.rs b/rust/tvm/src/lib.rs index 7e0682b86b33..27e794984094 100644 --- a/rust/tvm/src/lib.rs +++ b/rust/tvm/src/lib.rs @@ -53,7 +53,7 @@ macro_rules! export { ($($fn_name:expr),*) => { pub fn tvm_export(ns: &str) -> Result<(), tvm::Error> { $( - let name = String::fromwe(ns) + ::std::stringify!($fn_name); + let name = String::from(ns) + ::std::stringify!($fn_name); tvm::runtime::function::register_override($fn_name, name, true)?; )* Ok(()) From 7f0452ede256366f3b7840e85b4aea5a8faa0420 Mon Sep 17 00:00:00 2001 From: Jeffrey Spitz Date: Wed, 24 Mar 2021 21:26:05 -0700 Subject: [PATCH 3/9] * fixing pytorch converter to take into account the output_padding parameter for conv transpose operations * updating pytorch converter to correctly convert conv1d to conv1d in tvm inestead of a flattened conv2d unless under circumstances of grouped convolution * updating pytorch converter to correctly convert conv1d transpose to conv1d transpose in tvm instead of a flattened conv2d transpose * added tests to cover these latest additions --- python/tvm/relay/frontend/pytorch.py | 45 +++++++++++---- tests/python/frontend/pytorch/test_forward.py | 57 ++++++++++++++++--- 2 files changed, 82 insertions(+), 20 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index cb9ea6a043f4..7a0dc9bf5b75 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -970,32 +970,52 @@ def convolution(self, inputs, input_types): kernel_size = weight_shape[2:] use_bias = isinstance(bias, _expr.Expr) - if len(kernel_size) == 1: - strides = (1,) + strides - padding = (0,) + padding - dilation = (1,) + dilation + # We are trying to invoke various relay operations through a single conv_op variable. However the function + # signatures for some operations have additional attributes so we pass these in along with the standard ones. + additional_arguments = dict() if use_transpose: if len(kernel_size) == 3: conv_op = _op.nn.conv3d_transpose - else: + elif len(kernel_size) == 2: conv_op = _op.nn.conv2d_transpose + else: + conv_op = _op.nn.conv1d_transpose + output_padding = tuple(inputs[7]) + additional_arguments['output_padding'] = output_padding + else: if len(kernel_size) == 3: conv_op = _op.nn.conv3d - else: + elif len(kernel_size) == 2: conv_op = _op.nn.conv2d + else: + conv_op = _op.nn.conv1d if len(kernel_size) == 3: data_layout = "NCDHW" kernel_layout = "OIDHW" - else: + elif len(kernel_size) == 2: data_layout = "NCHW" kernel_layout = "OIHW" - - if len(kernel_size) == 1: + else: + data_layout = "NCW" + kernel_layout = "OIW" + + # Conv1d does not currently support grouped convolution so we convert it to conv2d + is_grouped_conv1d = False + if groups > 1 and len(kernel_size) == 1 and not use_transpose: + is_grouped_conv1d = True + conv_op = _op.nn.conv2d + kernel_size = [1] + kernel_size + strides = (1,) + strides + padding = (0,) + padding + dilation = (1,) + dilation data = _op.expand_dims(data, axis=2) weight = _op.expand_dims(weight, axis=2) + data_layout = "NCHW" + kernel_layout = "OIHW" + conv_out = conv_op( data, @@ -1005,17 +1025,20 @@ def convolution(self, inputs, input_types): dilation=dilation, groups=groups, channels=channels, - kernel_size=[1] + kernel_size if len(kernel_size) == 1 else kernel_size, + kernel_size=kernel_size, data_layout=data_layout, kernel_layout=kernel_layout, out_layout="", out_dtype="", + **additional_arguments ) if use_bias: res = _op.nn.bias_add(conv_out, bias) else: res = conv_out - if len(kernel_size) == 1: + if is_grouped_conv1d: + # Because we conducted grouped conv1d convolution through conv2d we must squeeze the output to get the + # correct result. res = _op.squeeze(res, axis=[2]) return res diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 572aa472c540..9fd9a817a641 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -31,6 +31,7 @@ from tvm.contrib.nvcc import have_fp16 import tvm.testing from packaging import version as package_version +import pytest sys.setrecursionlimit(10000) @@ -207,6 +208,8 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at with tvm.transform.PassContext(opt_level=3): for target, ctx in tvm.testing.enabled_targets(): + print(target, ctx) + print(tvm.testing.enabled_targets()) relay_graph, relay_lib, relay_params = relay.build(mod, target=target, params=params) relay_model = graph_runtime.create(relay_graph, relay_lib, ctx) relay_model.set_input(**relay_params) @@ -945,17 +948,53 @@ def forward(self, *args): @tvm.testing.uses_gpu -def test_forward_conv_transpose(): - torch.set_grad_enabled(False) - conv2d_input_shape = [1, 3, 10, 10] +@pytest.mark.parametrize("in_channels", [3], ids=lambda x: 'in_channels=' + str(x)) +@pytest.mark.parametrize("out_channels", [5], ids=lambda x: 'out_channels=' + str(x)) +@pytest.mark.parametrize("kernel_size", [3], ids=lambda x: 'kernel_size=' + str(x)) +@pytest.mark.parametrize("output_padding", [0, 1, 2], ids=lambda x: 'output_padding=' + str(x)) +@pytest.mark.parametrize("groups", [1], ids=lambda x: 'groups=' + str(x)) +@pytest.mark.parametrize("bias", [True, False], ids=lambda x: 'bias=' + str(x)) +def test_forward_conv_transpose(in_channels, out_channels, kernel_size, output_padding, bias, groups): + # Note we do not test withg roups > 1 because that is not supported in tvm for conv transpose operations + + # output padding must be smaller than either stride or dilation so we opt to make the stride 1 + output padding + stride = output_padding + 1 + + #Conv 3D Transpose Tests + conv3d_input_shape = [1, in_channels, 16, 16, 16] + conv3d_input_data = torch.rand(conv3d_input_shape).float() + conv3d_transpose = torch.nn.ConvTranspose3d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + output_padding=output_padding, + groups=groups, + bias=bias).eval() + verify_model(conv3d_transpose, conv3d_input_data) + + # Conv 2D Transpose Tests + conv2d_input_shape = [1, in_channels, 128, 256] conv2d_input_data = torch.rand(conv2d_input_shape).float() - verify_model(torch.nn.ConvTranspose2d(3, 6, 7, bias=True), input_data=conv2d_input_data) - verify_model(torch.nn.ConvTranspose2d(3, 12, 3, bias=False), input_data=conv2d_input_data) - - conv1d_input_shape = [1, 3, 10] + conv2d_transpose = torch.nn.ConvTranspose2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + output_padding=output_padding, + groups=groups, + bias=bias).eval() + verify_model(conv2d_transpose, conv2d_input_data) + + # # Conv 1D Transpose Tests + conv1d_input_shape = [1, in_channels, 10] conv1d_input_data = torch.rand(conv1d_input_shape).float() - verify_model(torch.nn.ConvTranspose1d(3, 6, 7, bias=True), input_data=conv1d_input_data) - verify_model(torch.nn.ConvTranspose1d(3, 12, 3, bias=False), input_data=conv1d_input_data) + conv1d_transpose = torch.nn.ConvTranspose1d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + output_padding=output_padding, + groups=groups, + bias=bias).eval() + verify_model(conv1d_transpose, conv1d_input_data) def test_forward_deform_conv(): From 8b47c034cfbbf0fd942f3278798c0455a04ef998 Mon Sep 17 00:00:00 2001 From: Jeffrey Spitz Date: Thu, 25 Mar 2021 03:25:18 -0700 Subject: [PATCH 4/9] * removing print statements used for debugging --- tests/python/frontend/pytorch/test_forward.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 9fd9a817a641..2d72808d4075 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -208,8 +208,6 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at with tvm.transform.PassContext(opt_level=3): for target, ctx in tvm.testing.enabled_targets(): - print(target, ctx) - print(tvm.testing.enabled_targets()) relay_graph, relay_lib, relay_params = relay.build(mod, target=target, params=params) relay_model = graph_runtime.create(relay_graph, relay_lib, ctx) relay_model.set_input(**relay_params) From 3e68e6aa78b798100a7a359ada6daa4cc845eb0a Mon Sep 17 00:00:00 2001 From: Jeffrey Spitz Date: Thu, 25 Mar 2021 13:37:42 -0700 Subject: [PATCH 5/9] * fixing typos and formatting --- tests/python/frontend/pytorch/test_forward.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 2d72808d4075..3f0029165250 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -953,7 +953,7 @@ def forward(self, *args): @pytest.mark.parametrize("groups", [1], ids=lambda x: 'groups=' + str(x)) @pytest.mark.parametrize("bias", [True, False], ids=lambda x: 'bias=' + str(x)) def test_forward_conv_transpose(in_channels, out_channels, kernel_size, output_padding, bias, groups): - # Note we do not test withg roups > 1 because that is not supported in tvm for conv transpose operations + # Note we do not test within groups > 1 because that is not supported in tvm for conv transpose operations # output padding must be smaller than either stride or dilation so we opt to make the stride 1 + output padding stride = output_padding + 1 @@ -967,7 +967,8 @@ def test_forward_conv_transpose(in_channels, out_channels, kernel_size, output_p stride=stride, output_padding=output_padding, groups=groups, - bias=bias).eval() + bias=bias, + ).eval() verify_model(conv3d_transpose, conv3d_input_data) # Conv 2D Transpose Tests @@ -979,7 +980,8 @@ def test_forward_conv_transpose(in_channels, out_channels, kernel_size, output_p stride=stride, output_padding=output_padding, groups=groups, - bias=bias).eval() + bias=bias, + ).eval() verify_model(conv2d_transpose, conv2d_input_data) # # Conv 1D Transpose Tests @@ -991,7 +993,8 @@ def test_forward_conv_transpose(in_channels, out_channels, kernel_size, output_p stride=stride, output_padding=output_padding, groups=groups, - bias=bias).eval() + bias=bias, + ).eval() verify_model(conv1d_transpose, conv1d_input_data) From 60c4d757e2a6a44fb2ccbefcba89b346b39c1fb5 Mon Sep 17 00:00:00 2001 From: Jeffrey Spitz Date: Thu, 25 Mar 2021 13:41:55 -0700 Subject: [PATCH 6/9] * fixing formatting --- python/tvm/relay/frontend/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 7a0dc9bf5b75..44e677775ca3 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1030,7 +1030,7 @@ def convolution(self, inputs, input_types): kernel_layout=kernel_layout, out_layout="", out_dtype="", - **additional_arguments + **additional_arguments, ) if use_bias: res = _op.nn.bias_add(conv_out, bias) From b09aff420de2badf0e09811ad993146eed70458c Mon Sep 17 00:00:00 2001 From: Jeffrey Spitz Date: Thu, 25 Mar 2021 13:44:13 -0700 Subject: [PATCH 7/9] * fixing grammar --- tests/python/frontend/pytorch/test_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 3f0029165250..cce656cc895e 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -953,7 +953,7 @@ def forward(self, *args): @pytest.mark.parametrize("groups", [1], ids=lambda x: 'groups=' + str(x)) @pytest.mark.parametrize("bias", [True, False], ids=lambda x: 'bias=' + str(x)) def test_forward_conv_transpose(in_channels, out_channels, kernel_size, output_padding, bias, groups): - # Note we do not test within groups > 1 because that is not supported in tvm for conv transpose operations + # Note we do not test with groups > 1 because that is not supported in tvm for conv transpose operations # output padding must be smaller than either stride or dilation so we opt to make the stride 1 + output padding stride = output_padding + 1 From 17ab1c2b550b26b9e645375c82e11353f1eb0b81 Mon Sep 17 00:00:00 2001 From: Jeffrey Spitz Date: Tue, 13 Apr 2021 17:26:55 -0700 Subject: [PATCH 8/9] * formatting fixes --- python/tvm/relay/frontend/pytorch.py | 10 +++++----- tests/python/frontend/pytorch/test_forward.py | 15 +++++++++++---- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 44e677775ca3..8834770f7cf1 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -970,8 +970,9 @@ def convolution(self, inputs, input_types): kernel_size = weight_shape[2:] use_bias = isinstance(bias, _expr.Expr) - # We are trying to invoke various relay operations through a single conv_op variable. However the function - # signatures for some operations have additional attributes so we pass these in along with the standard ones. + # We are trying to invoke various relay operations through a single conv_op variable. + # However the function signatures for some operations have additional attributes so we + # pass these in along with the standard ones. additional_arguments = dict() if use_transpose: @@ -1016,7 +1017,6 @@ def convolution(self, inputs, input_types): data_layout = "NCHW" kernel_layout = "OIHW" - conv_out = conv_op( data, weight, @@ -1037,8 +1037,8 @@ def convolution(self, inputs, input_types): else: res = conv_out if is_grouped_conv1d: - # Because we conducted grouped conv1d convolution through conv2d we must squeeze the output to get the - # correct result. + # Because we conducted grouped conv1d convolution through conv2d we must + # squeeze the output to get the correct result. res = _op.squeeze(res, axis=[2]) return res diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index cce656cc895e..2e0afaa73a30 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -952,10 +952,17 @@ def forward(self, *args): @pytest.mark.parametrize("output_padding", [0, 1, 2], ids=lambda x: 'output_padding=' + str(x)) @pytest.mark.parametrize("groups", [1], ids=lambda x: 'groups=' + str(x)) @pytest.mark.parametrize("bias", [True, False], ids=lambda x: 'bias=' + str(x)) -def test_forward_conv_transpose(in_channels, out_channels, kernel_size, output_padding, bias, groups): - # Note we do not test with groups > 1 because that is not supported in tvm for conv transpose operations - - # output padding must be smaller than either stride or dilation so we opt to make the stride 1 + output padding +def test_forward_conv_transpose(in_channels, + out_channels, + kernel_size, + output_padding, + bias, + groups): + # Note we do not test with groups > 1 because that is not supported + # in tvm for conv transpose operations + + # Output padding must be smaller than either stride or dilation so we + # opt to make the stride 1 + output padding stride = output_padding + 1 #Conv 3D Transpose Tests From b298cad8264700e5afd69258615326b0e3eeaa32 Mon Sep 17 00:00:00 2001 From: Jeffrey Spitz Date: Tue, 27 Apr 2021 09:12:02 -0700 Subject: [PATCH 9/9] * updated formatting after running pylint and python_format checks --- python/tvm/relay/frontend/pytorch.py | 2 +- tests/python/frontend/pytorch/test_forward.py | 74 +++++++++---------- 2 files changed, 38 insertions(+), 38 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 11fcd5d00e9f..1ce0f78c6419 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -983,7 +983,7 @@ def convolution(self, inputs, input_types): else: conv_op = _op.nn.conv1d_transpose output_padding = tuple(inputs[7]) - additional_arguments['output_padding'] = output_padding + additional_arguments["output_padding"] = output_padding else: if len(kernel_size) == 3: diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 71fcb0a2cfda..971539564b9c 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -962,18 +962,15 @@ def forward(self, *args): @tvm.testing.uses_gpu -@pytest.mark.parametrize("in_channels", [3], ids=lambda x: 'in_channels=' + str(x)) -@pytest.mark.parametrize("out_channels", [5], ids=lambda x: 'out_channels=' + str(x)) -@pytest.mark.parametrize("kernel_size", [3], ids=lambda x: 'kernel_size=' + str(x)) -@pytest.mark.parametrize("output_padding", [0, 1, 2], ids=lambda x: 'output_padding=' + str(x)) -@pytest.mark.parametrize("groups", [1], ids=lambda x: 'groups=' + str(x)) -@pytest.mark.parametrize("bias", [True, False], ids=lambda x: 'bias=' + str(x)) -def test_forward_conv_transpose(in_channels, - out_channels, - kernel_size, - output_padding, - bias, - groups): +@pytest.mark.parametrize("in_channels", [3], ids=lambda x: "in_channels=" + str(x)) +@pytest.mark.parametrize("out_channels", [5], ids=lambda x: "out_channels=" + str(x)) +@pytest.mark.parametrize("kernel_size", [3], ids=lambda x: "kernel_size=" + str(x)) +@pytest.mark.parametrize("output_padding", [0, 1, 2], ids=lambda x: "output_padding=" + str(x)) +@pytest.mark.parametrize("groups", [1], ids=lambda x: "groups=" + str(x)) +@pytest.mark.parametrize("bias", [True, False], ids=lambda x: "bias=" + str(x)) +def test_forward_conv_transpose( + in_channels, out_channels, kernel_size, output_padding, bias, groups +): # Note we do not test with groups > 1 because that is not supported # in tvm for conv transpose operations @@ -981,43 +978,46 @@ def test_forward_conv_transpose(in_channels, # opt to make the stride 1 + output padding stride = output_padding + 1 - #Conv 3D Transpose Tests + # Conv 3D Transpose Tests conv3d_input_shape = [1, in_channels, 16, 16, 16] conv3d_input_data = torch.rand(conv3d_input_shape).float() - conv3d_transpose = torch.nn.ConvTranspose3d(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - output_padding=output_padding, - groups=groups, - bias=bias, - ).eval() + conv3d_transpose = torch.nn.ConvTranspose3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + output_padding=output_padding, + groups=groups, + bias=bias, + ).eval() verify_model(conv3d_transpose, conv3d_input_data) # Conv 2D Transpose Tests conv2d_input_shape = [1, in_channels, 128, 256] conv2d_input_data = torch.rand(conv2d_input_shape).float() - conv2d_transpose = torch.nn.ConvTranspose2d(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - output_padding=output_padding, - groups=groups, - bias=bias, - ).eval() + conv2d_transpose = torch.nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + output_padding=output_padding, + groups=groups, + bias=bias, + ).eval() verify_model(conv2d_transpose, conv2d_input_data) # # Conv 1D Transpose Tests conv1d_input_shape = [1, in_channels, 10] conv1d_input_data = torch.rand(conv1d_input_shape).float() - conv1d_transpose = torch.nn.ConvTranspose1d(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - output_padding=output_padding, - groups=groups, - bias=bias, - ).eval() + conv1d_transpose = torch.nn.ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + output_padding=output_padding, + groups=groups, + bias=bias, + ).eval() verify_model(conv1d_transpose, conv1d_input_data)