From bab6a315eef04c737eed1f8c52536113281c04cb Mon Sep 17 00:00:00 2001 From: jonghewk Date: Fri, 15 Dec 2023 16:39:29 +0900 Subject: [PATCH 1/5] Merge commit '803c4ad0847f492491d8714c7ab6f52c679e6431' --- python/tvm/relay/op/strategy/generic.py | 26 +++-- python/tvm/topi/generic/nn.py | 16 +++ python/tvm/topi/nn/conv1d_transpose.py | 95 +++++++++++++++ python/tvm/topi/testing/__init__.py | 5 +- .../testing/conv1d_transpose_ncw_python.py | 12 ++ .../test_topi_group_conv1d_transpose_ncw.py | 109 ++++++++++++++++++ 6 files changed, 254 insertions(+), 9 deletions(-) create mode 100644 tests/python/topi/test_topi_group_conv1d_transpose_ncw.py diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index ce82c6bd6fd2..5020d135a7a9 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -713,7 +713,7 @@ def group_conv1d_strategy(attrs, inputs, out_type, target): # conv1d_transpose -def wrap_compute_conv1d_transpose(topi_compute): +def wrap_compute_conv1d_transpose(topi_compute, has_groups=False): """wrap conv1d_transpose topi compute""" def _compute_conv1d_tranpsoe(attrs, inputs, out_type): @@ -722,7 +722,11 @@ def _compute_conv1d_tranpsoe(attrs, inputs, out_type): out_dtype = attrs.out_dtype out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype output_padding = get_const_tuple(attrs.output_padding) - out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype, output_padding) + args = [inputs[0], inputs[1], strides, padding, out_dtype, output_padding] + if has_groups: + args.append(attrs.groups) + + out = topi_compute(*args) return [out] return _compute_conv1d_tranpsoe @@ -738,12 +742,18 @@ def conv1d_transpose_strategy(attrs, inputs, out_type, target): groups = attrs.groups assert layout == "NCW", "conv1d_transpose ncw only supported" assert dilation == (1,), "conv1d_transpose dilation is not supported" - assert groups == 1, "conv1d_transpose groups == 1 only supported" - strategy.add_implementation( - wrap_compute_conv1d_transpose(topi.nn.conv1d_transpose_ncw), - wrap_topi_schedule(topi.generic.schedule_conv1d_transpose_ncw), - name="conv1d_transpose_ncw.generic", - ) + if groups == 1: + strategy.add_implementation( + wrap_compute_conv1d_transpose(topi.nn.conv1d_transpose_ncw), + wrap_topi_schedule(topi.generic.schedule_conv1d_transpose_ncw), + name="conv1d_transpose_ncw.generic", + ) + else: # group_conv1d_transpose + strategy.add_implementation( + wrap_compute_conv1d_transpose(topi.nn.group_conv1d_transpose_ncw, has_groups=True), + wrap_topi_schedule(topi.generic.schedule_group_conv1d_transpose_ncw), + name="group_conv1d_transpose_ncw.generic", + ) return strategy diff --git a/python/tvm/topi/generic/nn.py b/python/tvm/topi/generic/nn.py index 80ea00ab0153..386545f27e82 100644 --- a/python/tvm/topi/generic/nn.py +++ b/python/tvm/topi/generic/nn.py @@ -395,6 +395,22 @@ def schedule_conv1d_transpose_ncw(outs): """ return _default_schedule(outs, False) +def schedule_group_conv1d_transpose_ncw(outs): + """Schedule for group_conv1d_transpose_ncw + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of group conv1d_transpose_ncw + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + def schedule_depthwise_conv2d_nchw(outs): """Schedule for depthwise_conv2d_nchw diff --git a/python/tvm/topi/nn/conv1d_transpose.py b/python/tvm/topi/nn/conv1d_transpose.py index 6f040409f47c..de7cd80e9f3a 100644 --- a/python/tvm/topi/nn/conv1d_transpose.py +++ b/python/tvm/topi/nn/conv1d_transpose.py @@ -91,3 +91,98 @@ def conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype, output_paddin ) return output + + +def group_conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype, output_padding, groups): + """Transposed 1D group convolution ncw forward operator. + + Parameters + ---------- + data : tvm.te.Tensor + 3-D with shape [batch, in_channel, in_width] + + kernel : tvm.te.Tensor + 3-D with shape [in_channel, num_filter, filter_width] + + stride : ints + The spatial stride along width + + padding : int or str + Padding size, or ['VALID', 'SAME'] + + out_dtype : str + The output data type. This is used for mixed precision. + + output_padding : ints + Used to recover the actual output shape in case there are more + than one possible shape. Must be smaller than stride. + + groups : int + number of groups + + Returns + ------- + output : tvm.te.Tensor + 3-D with shape [batch, out_channel, out_width] + + """ + if groups == 1: + return conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype, output_padding) + + # some pre-processing and prelimnary checks + if out_dtype is None: + out_dtype = data.dtype + + # dilate and pad + if isinstance(stride, (tuple, list)): + stride = stride[0] + if isinstance(output_padding, (tuple, list)): + output_padding = output_padding[0] + + batch, in_channels, in_w = data.shape + _, out_c, filter_w = kernel.shape + assert in_channels % groups == 0, f"input channels {in_channels} must divide group size {groups}" + + batch, channels_in, data_width = data.shape + _, channels_out, kernel_width = kernel.shape + assert output_padding < stride + channels_out = simplify(channels_out) + data_dilate = dilate(data, [1, 1, stride], name="data_dilate") + pad_left, pad_right = get_pad_tuple1d(padding, (kernel_width,)) + pad_left = kernel_width - 1 - pad_left + pad_right = kernel_width - 1 - pad_right + output_padding + data_pad = pad(data_dilate, [0, 0, pad_left], [0, 0, pad_right], name="data_pad") + + # transform kernel layout from IOHW to OIHW, and rotate kernel by 180 degrees + kernel = te.compute( + (channels_out, channels_in, kernel_width), + lambda o, i, w: kernel[i][o][kernel_width - 1 - w], + name="kernel", + ) + + batch, in_channels, in_w = data_pad.shape + out_c, _, filter_w = kernel.shape + + # convolution stage + out_channels = simplify(out_c * groups) + out_w = simplify(in_w - filter_w + 1) + dc = te.reduce_axis((0, in_channels // groups), name="dc") + dw = te.reduce_axis((0, filter_w), name="dw") + + # data: batch, in_channels, out_w + # weight: out_channels // G, in_channels, out_w + return te.compute( + (batch, out_channels, out_w), + lambda b, c, w: te.sum( + data_pad[b, c // (out_channels // groups) * (in_channels // groups) + dc, w + dw].astype( + out_dtype + ) + * kernel[ + c % (out_channels // groups), + c // (out_channels // groups) * (in_channels // groups) + dc, + dw, + ].astype(out_dtype), + axis=[dc, dw], + ), + tag="group_conv1d_transpose_ncw", + ) diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index 093f84d99bd3..72a7cedc491c 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -29,7 +29,10 @@ from .conv3d_ndhwc_python import conv3d_ndhwc_python from .conv3d_transpose_ncdhw_python import conv3d_transpose_ncdhw_python from .conv2d_transpose_python import conv2d_transpose_nchw_python, conv2d_transpose_nhwc_python -from .conv1d_transpose_ncw_python import conv1d_transpose_ncw_python +from .conv1d_transpose_ncw_python import ( + conv1d_transpose_ncw_python, + group_conv1d_transpose_ncw_python, +) from .correlation_nchw_python import correlation_nchw_python from .deformable_conv2d_python import deformable_conv2d_nchw_python, deformable_conv2d_nhwc_python from .depthwise_conv2d_python import ( diff --git a/python/tvm/topi/testing/conv1d_transpose_ncw_python.py b/python/tvm/topi/testing/conv1d_transpose_ncw_python.py index 85e1410c0cd8..b10c90aeff79 100644 --- a/python/tvm/topi/testing/conv1d_transpose_ncw_python.py +++ b/python/tvm/topi/testing/conv1d_transpose_ncw_python.py @@ -22,6 +22,18 @@ from tvm.topi.nn.utils import get_pad_tuple1d +def group_conv1d_transpose_ncw_python(a_np, w_np, stride, padding, output_padding, groups=1): + "Grouped version of `conv1d_transpose_ncw_python`, see that for documentation" + a_slices = np.array_split(a_np, groups, axis=1) + w_slices = np.array_split(w_np, groups, axis=0) + b_slices = [ + conv1d_transpose_ncw_python(a_slice, w_slice, stride, padding, output_padding) + for a_slice, w_slice in zip(a_slices, w_slices) + ] + b_np = np.concatenate(b_slices, axis=1) + return b_np + + def conv1d_transpose_ncw_python(a_np, w_np, stride, padding, output_padding): """Transposed 1D convolution operator in NCW layout. diff --git a/tests/python/topi/test_topi_group_conv1d_transpose_ncw.py b/tests/python/topi/test_topi_group_conv1d_transpose_ncw.py new file mode 100644 index 000000000000..7c42426d58cf --- /dev/null +++ b/tests/python/topi/test_topi_group_conv1d_transpose_ncw.py @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for group transposed 1d convolution.""" + +import itertools +import os + +import numpy as np + +import tvm +import tvm.testing +import tvm.topi.testing + +from tvm import te, topi +from tvm.topi.utils import get_const_tuple + +_group_conv1d_transpose_ncw_implement = { + "generic": ( + topi.nn.group_conv1d_transpose_ncw, + topi.generic.schedule_group_conv1d_transpose_ncw, + ), +} + + +( + batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding, + output_padding, + groups, +) = tvm.testing.parameters( + (1, 4, 224, 32, 5, 1, 0, (0,), 4), + (1, 8, 224, 32, 7, 1, 2, (0,), 4), + (1, 8, 224, 32, 5, 2, 1, (0,), 2), + (1, 4, 224, 4, 5, 2, 1, (1,), 4), + (1, 3, 224, 15, 5, 2, 0, (0,), 3), + (1, 32, 32, 128, 5, 1, 0, (0,), 32), + (1, 32, 32, 128, 5, 2, 1, (0,), 16), +) + +dtype = tvm.testing.parameter("float32") + + +@tvm.testing.fixture(cache_return_value=True) +def ref_data( + dtype, batch, in_channel, in_size, num_filter, kernel, stride, padding, output_padding, groups +): + dtype = "float32" + a_shape = (batch, in_channel, in_size) + w_shape = (in_channel, num_filter, kernel) + + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(dtype) + b_np = tvm.topi.testing.group_conv1d_transpose_ncw_python( + a_np, w_np, stride, padding, output_padding, groups + ) + c_np = np.maximum(b_np, 0) + return a_np, w_np, b_np, c_np + + +def test_group_conv1d_transpose_ncw( + target, dev, ref_data, dtype, stride, padding, output_padding, groups +): + a_np, w_np, b_np, c_np = ref_data + + A = te.placeholder(a_np.shape, name="A", dtype=dtype) + W = te.placeholder(w_np.shape, name="W", dtype=dtype) + + with tvm.target.Target(target): + fcompute, fschedule = tvm.topi.testing.dispatch( + target, _group_conv1d_transpose_ncw_implement + ) + B = fcompute(A, W, stride, padding, A.dtype, output_padding, groups) + C = topi.nn.relu(B) + s1 = fschedule([B]) + s2 = fschedule([C]) + a = tvm.nd.array(a_np, dev) + w = tvm.nd.array(w_np, dev) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev) + + func1 = tvm.build(s1, [A, W, B], target) + func2 = tvm.build(s2, [A, W, C], target) + func1(a, w, b) + func2(a, w, c) + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5) + + +if __name__ == "__main__": + tvm.testing.main() From c1f85626890248fa39a25ca16fa5d1b2b56538f5 Mon Sep 17 00:00:00 2001 From: jonghewk Date: Fri, 15 Dec 2023 16:52:46 +0900 Subject: [PATCH 2/5] apply black format --- python/tvm/topi/generic/nn.py | 1 + python/tvm/topi/nn/conv1d_transpose.py | 10 ++++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/tvm/topi/generic/nn.py b/python/tvm/topi/generic/nn.py index 386545f27e82..ca90d2b7175d 100644 --- a/python/tvm/topi/generic/nn.py +++ b/python/tvm/topi/generic/nn.py @@ -395,6 +395,7 @@ def schedule_conv1d_transpose_ncw(outs): """ return _default_schedule(outs, False) + def schedule_group_conv1d_transpose_ncw(outs): """Schedule for group_conv1d_transpose_ncw diff --git a/python/tvm/topi/nn/conv1d_transpose.py b/python/tvm/topi/nn/conv1d_transpose.py index de7cd80e9f3a..a14254de8fc7 100644 --- a/python/tvm/topi/nn/conv1d_transpose.py +++ b/python/tvm/topi/nn/conv1d_transpose.py @@ -141,7 +141,9 @@ def group_conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype, output_ batch, in_channels, in_w = data.shape _, out_c, filter_w = kernel.shape - assert in_channels % groups == 0, f"input channels {in_channels} must divide group size {groups}" + assert ( + in_channels % groups == 0 + ), f"input channels {in_channels} must divide group size {groups}" batch, channels_in, data_width = data.shape _, channels_out, kernel_width = kernel.shape @@ -174,9 +176,9 @@ def group_conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype, output_ return te.compute( (batch, out_channels, out_w), lambda b, c, w: te.sum( - data_pad[b, c // (out_channels // groups) * (in_channels // groups) + dc, w + dw].astype( - out_dtype - ) + data_pad[ + b, c // (out_channels // groups) * (in_channels // groups) + dc, w + dw + ].astype(out_dtype) * kernel[ c % (out_channels // groups), c // (out_channels // groups) * (in_channels // groups) + dc, From 91b02b2f809288f1b2003f30fdf1a880646a729e Mon Sep 17 00:00:00 2001 From: jonghewk Date: Sat, 16 Dec 2023 15:25:04 +0900 Subject: [PATCH 3/5] skip test for cuda (unimplemented) --- tests/python/topi/test_topi_group_conv1d_transpose_ncw.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/topi/test_topi_group_conv1d_transpose_ncw.py b/tests/python/topi/test_topi_group_conv1d_transpose_ncw.py index 7c42426d58cf..b612c13f9b59 100644 --- a/tests/python/topi/test_topi_group_conv1d_transpose_ncw.py +++ b/tests/python/topi/test_topi_group_conv1d_transpose_ncw.py @@ -76,6 +76,7 @@ def ref_data( return a_np, w_np, b_np, c_np +@tvm.testing.known_failing_targets("cuda", "vulkan") def test_group_conv1d_transpose_ncw( target, dev, ref_data, dtype, stride, padding, output_padding, groups ): From 16032c260f98924cc8115179943dfc35bee537d5 Mon Sep 17 00:00:00 2001 From: jonghewk Date: Mon, 18 Dec 2023 11:10:50 +0900 Subject: [PATCH 4/5] avoid code duplication for conv1d_transpose_ncw --- python/tvm/topi/nn/conv1d_transpose.py | 103 ++++++++++++++++--------- 1 file changed, 65 insertions(+), 38 deletions(-) diff --git a/python/tvm/topi/nn/conv1d_transpose.py b/python/tvm/topi/nn/conv1d_transpose.py index a14254de8fc7..1fe25610edd0 100644 --- a/python/tvm/topi/nn/conv1d_transpose.py +++ b/python/tvm/topi/nn/conv1d_transpose.py @@ -23,8 +23,8 @@ from .utils import get_pad_tuple1d -def conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype, output_padding): - """Transposed 1D convolution ncw forward operator. +def _conv1d_transpose_ncw_prepare(data, kernel, stride, padding, out_dtype, output_padding): + """Prepare for transposed 1D convolution ncw forward operator. Parameters ---------- @@ -49,42 +49,89 @@ def conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype, output_paddin Returns ------- - output : tvm.te.Tensor - 3-D with shape [batch, out_channel, out_width] + data_pad : tvm.te.Tensor + Padded input data. 3-D with shape [batch, in_channel, in_width] + kernel: tvm.te.Tensor + Transformed kernel. 3-D with shape [num_filter, in_channel, filter_width] """ + # some pre-processing and prelimnary checks + if out_dtype is None: + out_dtype = data.dtype # dilate and pad if isinstance(stride, (tuple, list)): stride = stride[0] if isinstance(output_padding, (tuple, list)): output_padding = output_padding[0] - batch, channels_in, data_width = data.shape + + _, channels_in, _ = data.shape _, channels_out, kernel_width = kernel.shape assert output_padding < stride channels_out = simplify(channels_out) - data = dilate(data, [1, 1, stride], name="data_dilate") + data_dilate = dilate(data, [1, 1, stride], name="data_dilate") pad_left, pad_right = get_pad_tuple1d(padding, (kernel_width,)) pad_left = kernel_width - 1 - pad_left pad_right = kernel_width - 1 - pad_right + output_padding - data = pad(data, [0, 0, pad_left], [0, 0, pad_right], name="data_pad") + data_pad = pad(data_dilate, [0, 0, pad_left], [0, 0, pad_right], name="data_pad") - # transpose kernel, switch kernel layout to IOW + # transform kernel layout from IOW to OIW, and rotate kernel by 180 degrees kernel = te.compute( (channels_out, channels_in, kernel_width), lambda o, i, w: kernel[i][o][kernel_width - 1 - w], name="kernel", ) + return data_pad, kernel + + +def conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype, output_padding): + """Transposed 1D convolution ncw forward operator. + + Parameters + ---------- + data : tvm.te.Tensor + 3-D with shape [batch, in_channel, in_width] + + kernel : tvm.te.Tensor + 3-D with shape [in_channel, num_filter, filter_width] + + stride : ints + The spatial stride along width + + padding : int or str + Padding size, or ['VALID', 'SAME'] + + out_dtype : str + The output data type. This is used for mixed precision. + + output_padding : ints + Used to recover the actual output shape in case there are more + than one possible shape. Must be smaller than stride. + + Returns + ------- + output : tvm.te.Tensor + 3-D with shape [batch, out_channel, out_width] + + """ + + batch, channels_in, _ = data.shape + _, channels_out, kernel_width = kernel.shape + + data_pad, transformed_kernel = _conv1d_transpose_ncw_prepare( + data, kernel, stride, padding, out_dtype, output_padding + ) # convolution - _, _, data_width = data.shape + _, _, data_width = data_pad.shape out_w = simplify(data_width - kernel_width + 1) dc = te.reduce_axis((0, channels_in), name="dc") dw = te.reduce_axis((0, kernel_width), name="dw") output = te.compute( (batch, channels_out, out_w), lambda b, c, w: te.sum( - data[b, dc, w + dw].astype(out_dtype) * kernel[c, dc, dw].astype(out_dtype), + data_pad[b, dc, w + dw].astype(out_dtype) + * transformed_kernel[c, dc, dw].astype(out_dtype), axis=[dc, dw], ), tag="conv1d_transpose_ncw", @@ -129,41 +176,21 @@ def group_conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype, output_ if groups == 1: return conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype, output_padding) - # some pre-processing and prelimnary checks - if out_dtype is None: - out_dtype = data.dtype - - # dilate and pad - if isinstance(stride, (tuple, list)): - stride = stride[0] - if isinstance(output_padding, (tuple, list)): - output_padding = output_padding[0] - - batch, in_channels, in_w = data.shape - _, out_c, filter_w = kernel.shape + _, in_channels, _ = data.shape assert ( in_channels % groups == 0 ), f"input channels {in_channels} must divide group size {groups}" - batch, channels_in, data_width = data.shape - _, channels_out, kernel_width = kernel.shape - assert output_padding < stride - channels_out = simplify(channels_out) - data_dilate = dilate(data, [1, 1, stride], name="data_dilate") - pad_left, pad_right = get_pad_tuple1d(padding, (kernel_width,)) - pad_left = kernel_width - 1 - pad_left - pad_right = kernel_width - 1 - pad_right + output_padding - data_pad = pad(data_dilate, [0, 0, pad_left], [0, 0, pad_right], name="data_pad") + assert ( + in_channels % groups == 0 + ), f"input channels {in_channels} must divide group size {groups}" - # transform kernel layout from IOHW to OIHW, and rotate kernel by 180 degrees - kernel = te.compute( - (channels_out, channels_in, kernel_width), - lambda o, i, w: kernel[i][o][kernel_width - 1 - w], - name="kernel", + data_pad, transformed_kernel = _conv1d_transpose_ncw_prepare( + data, kernel, stride, padding, out_dtype, output_padding ) batch, in_channels, in_w = data_pad.shape - out_c, _, filter_w = kernel.shape + out_c, _, filter_w = transformed_kernel.shape # convolution stage out_channels = simplify(out_c * groups) @@ -179,7 +206,7 @@ def group_conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype, output_ data_pad[ b, c // (out_channels // groups) * (in_channels // groups) + dc, w + dw ].astype(out_dtype) - * kernel[ + * transformed_kernel[ c % (out_channels // groups), c // (out_channels // groups) * (in_channels // groups) + dc, dw, From 58e77eefc828fde89a5cefdfc96bd7a356eda306 Mon Sep 17 00:00:00 2001 From: jonghewk Date: Mon, 18 Dec 2023 11:44:06 +0900 Subject: [PATCH 5/5] rename func & cleanup --- python/tvm/topi/nn/conv1d_transpose.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/python/tvm/topi/nn/conv1d_transpose.py b/python/tvm/topi/nn/conv1d_transpose.py index 1fe25610edd0..19872c544731 100644 --- a/python/tvm/topi/nn/conv1d_transpose.py +++ b/python/tvm/topi/nn/conv1d_transpose.py @@ -23,8 +23,9 @@ from .utils import get_pad_tuple1d -def _conv1d_transpose_ncw_prepare(data, kernel, stride, padding, out_dtype, output_padding): - """Prepare for transposed 1D convolution ncw forward operator. +def _conv1d_transpose_ncw_preprocess(data, kernel, stride, padding, out_dtype, output_padding): + """Preprocess data and kernel to make the compute pattern + of conv1d_transpose the same as conv1d. Parameters ---------- @@ -118,7 +119,7 @@ def conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype, output_paddin batch, channels_in, _ = data.shape _, channels_out, kernel_width = kernel.shape - data_pad, transformed_kernel = _conv1d_transpose_ncw_prepare( + data_pad, transformed_kernel = _conv1d_transpose_ncw_preprocess( data, kernel, stride, padding, out_dtype, output_padding ) @@ -177,15 +178,12 @@ def group_conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype, output_ return conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype, output_padding) _, in_channels, _ = data.shape - assert ( - in_channels % groups == 0 - ), f"input channels {in_channels} must divide group size {groups}" assert ( in_channels % groups == 0 ), f"input channels {in_channels} must divide group size {groups}" - data_pad, transformed_kernel = _conv1d_transpose_ncw_prepare( + data_pad, transformed_kernel = _conv1d_transpose_ncw_preprocess( data, kernel, stride, padding, out_dtype, output_padding )