From c5ef5dda33f3c43a8b53ec76533a87c0f51cd8f4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 13 Feb 2022 15:37:53 +0900 Subject: [PATCH 1/4] [Torch] Fix conv2d transpose with group --- python/tvm/relay/frontend/pytorch.py | 14 ++++++---- tests/python/frontend/pytorch/test_forward.py | 28 ++++++++++++++++++- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 61478219908c..d534166481e5 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -972,19 +972,21 @@ def convolution(self, inputs, input_types): msg = "Data type %s could not be parsed in conv op" % (type(weight)) raise AssertionError(msg) - # Transposed convolutions have IOHW layout. - if use_transpose: - weight_shape[0], weight_shape[1] = weight_shape[1], weight_shape[0] - - channels = weight_shape[0] groups = int(inputs[8]) + if use_transpose: + channels = weight_shape[1] * groups + in_channels = weight_shape[0] + else: + channels = weight_shape[0] + in_channels = weight_shape[1] + # Check if this is depth wise convolution # We need to reshape weight so that Relay could recognize this is depth wise # weight_shape[1] is always in_channels // groups # For depthwise, in_channels == groups, so weight_shape[1] == 1 # If groups > 1 but weight_shape[1] != 1, this is group convolution - if groups > 1 and weight_shape[1] == 1: + if groups > 1 and in_channels == 1: channel_multiplier = channels // groups new_weight_shape = (groups, channel_multiplier) + tuple(weight_shape[2:]) weight = _op.transform.reshape(weight, new_weight_shape) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 97ef08f7b8a9..7b489a8ff534 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1067,6 +1067,32 @@ def test_forward_conv_transpose( verify_model(conv1d_transpose, conv1d_input_data) +@tvm.testing.uses_gpu +def test_forward_conv2d_transpose_group(): + # https://github.com/apache/tvm/pull/9465 + + class ModulatedConvTranspose2D(torch.nn.Module): + def forward(self, x, w, s): + B, C, H, W = x.shape + I, O, KH, KW = w.shape + + # weight is different for each input in batch (this is why we want grouped conv transpose) + w = w.unsqueeze(0) * s.reshape(B, 1, 1, 1, 1) + w = w.reshape(B * I, O, KH, KW) + x = x.reshape(1, B * C, H, W) + x = torch.nn.functional.conv_transpose2d( + x, w, stride=(2, 2), padding=(1, 1), output_padding=(1, 1), groups=B + ) + return x.reshape(B, O, H * 2, W * 2) + + b, c, h, w, k = 4, 512, 8, 16, 3 + inputs = torch.rand(b, c, h, w) + weights = torch.rand(c, c // 2, k, k) + styles = torch.rand(b) + + verify_model(ModulatedConvTranspose2D().eval(), [inputs, weights, styles]) + + def test_forward_deform_conv(): torch.set_grad_enabled(False) @@ -2148,7 +2174,7 @@ def test_vgg11_bn(): def test_custom_conversion_map(): def get_roi_align(): pool_size = 5 - n_channels = 2 * (pool_size ** 2) + n_channels = 2 * (pool_size**2) x = torch.rand(2, n_channels, 10, 10) rois = torch.tensor( [ From a7e08fb6bcb5dcdee49a271eef78c9364392f5bf Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 13 Feb 2022 16:07:32 +0900 Subject: [PATCH 2/4] lint --- tests/python/frontend/pytorch/test_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 7b489a8ff534..bc7c224a6cf7 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -2174,7 +2174,7 @@ def test_vgg11_bn(): def test_custom_conversion_map(): def get_roi_align(): pool_size = 5 - n_channels = 2 * (pool_size**2) + n_channels = 2 * (pool_size ** 2) x = torch.rand(2, n_channels, 10, 10) rois = torch.tensor( [ From 671b7bdfd96f365ca666fd2f425a4859251b5c3c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 13 Feb 2022 16:19:55 +0900 Subject: [PATCH 3/4] wrong issue number --- tests/python/frontend/pytorch/test_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index bc7c224a6cf7..65dcb774f832 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1069,7 +1069,7 @@ def test_forward_conv_transpose( @tvm.testing.uses_gpu def test_forward_conv2d_transpose_group(): - # https://github.com/apache/tvm/pull/9465 + # https://github.com/apache/tvm/issues/10223 class ModulatedConvTranspose2D(torch.nn.Module): def forward(self, x, w, s): From 15ec776958a61838d850abb57b487b6548599af9 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 13 Feb 2022 20:35:35 +0900 Subject: [PATCH 4/4] do not run test on cuda --- tests/python/frontend/pytorch/test_forward.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 65dcb774f832..c240a19c9730 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1067,7 +1067,6 @@ def test_forward_conv_transpose( verify_model(conv1d_transpose, conv1d_input_data) -@tvm.testing.uses_gpu def test_forward_conv2d_transpose_group(): # https://github.com/apache/tvm/issues/10223 @@ -1090,7 +1089,8 @@ def forward(self, x, w, s): weights = torch.rand(c, c // 2, k, k) styles = torch.rand(b) - verify_model(ModulatedConvTranspose2D().eval(), [inputs, weights, styles]) + # cuda not supported for group > 1 conv2d_transpose + verify_trace_model(ModulatedConvTranspose2D().eval(), [inputs, weights, styles], ["llvm"]) def test_forward_deform_conv(): @@ -4141,7 +4141,7 @@ def forward(self, x): x = torch.rand([4, 4, 16, 32]).float() script_module = torch.jit.trace(List_tuple(), x, strict=False).eval() - mod, params = relay.frontend.from_pytorch(script_module, [("x", x.shape)]) + relay.frontend.from_pytorch(script_module, [("x", x.shape)]) if __name__ == "__main__":