diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index ce82c6bd6fd2..617b00a87355 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -518,7 +518,7 @@ def conv2d_transpose_strategy(attrs, inputs, out_type, target): # conv3d_transpose -def wrap_compute_conv3d_transpose(topi_compute): +def wrap_compute_conv3d_transpose(topi_compute, has_groups=False): """wrap conv3d_transpose topi compute""" def compute_conv3d_transpose(attrs, inputs, out_dtype): @@ -528,7 +528,10 @@ def compute_conv3d_transpose(attrs, inputs, out_dtype): output_padding = get_const_tuple(attrs.output_padding) out_dtype = attrs.out_dtype out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype - out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype, output_padding) + args = [inputs[0], inputs[1], strides, padding, out_dtype, output_padding] + if has_groups: + args.append(attrs.group) + out = topi_compute(*args) return [out] return compute_conv3d_transpose @@ -543,13 +546,20 @@ def conv3d_transpose_strategy(attrs, inputs, out_type, target): groups = attrs.groups assert layout == "NCDHW", "only support ncdhw for now" assert dilation == (1, 1, 1), "not support dilate now" - assert groups == 1, "only support groups == 1 for now" + strategy = _op.OpStrategy() - strategy.add_implementation( - wrap_compute_conv3d_transpose(topi.nn.conv3d_transpose_ncdhw), - wrap_topi_schedule(topi.generic.schedule_conv3d_transpose_ncdhw), - name="conv3d_transpose_ncdhw.generic", - ) + if groups == 1: + strategy.add_implementation( + wrap_compute_conv3d_transpose(topi.nn.conv3d_transpose_ncdhw), + wrap_topi_schedule(topi.generic.schedule_conv3d_transpose_ncdhw), + name="conv3d_transpose_ncdhw.generic", + ) + else: + strategy.add_implementation( + wrap_compute_conv3d_transpose(topi.nn.group_conv3d_transpose_ncdhw, has_groups=True), + wrap_topi_schedule(topi.generic.schedule_group_conv3d_transpose_ncdhw), + name="group_conv3d_transpose_ncdhw.generic", + ) return strategy diff --git a/python/tvm/topi/generic/nn.py b/python/tvm/topi/generic/nn.py index 80ea00ab0153..a21d5592e447 100644 --- a/python/tvm/topi/generic/nn.py +++ b/python/tvm/topi/generic/nn.py @@ -479,6 +479,23 @@ def schedule_group_conv2d_transpose_nchw(outs): return _default_schedule(outs, False) +def schedule_group_conv3d_transpose_ncdhw(outs): + """Schedule for schedule_group_conv3d_transpose_ncdhw + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of schedule_group_conv3d_transpose_ncdhw + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + def schedule_group_conv2d_nhwc(outs): """Schedule for group_conv2d_nhwc diff --git a/python/tvm/topi/nn/conv3d_transpose.py b/python/tvm/topi/nn/conv3d_transpose.py index 2d048f432f1b..602da272980c 100644 --- a/python/tvm/topi/nn/conv3d_transpose.py +++ b/python/tvm/topi/nn/conv3d_transpose.py @@ -125,6 +125,81 @@ def declaration_conv3d_transpose_impl(data, kernel, strides, padding, out_dtype, return Output +def group_conv3d_transpose_ncdhw(data, kernel, strides, padding, out_dtype, output_padding, groups): + """Transposed group 3D convolution ncdhw forward operator. + + Parameters + ---------- + data : tvm.te.Tensor + 5-D with shape [batch, in_channel, in_depth, in_height, in_width] + + kernel : tvm.te.Tensor + 5-D with shape [in_channel, num_filter, filter_depth, filter_height, filter_width] + + strides : int or a list/tuple of three ints + The spatial stride along depth,height and width + + padding : int or str + Padding size, or ['VALID', 'SAME'] + + out_dtype : str + The output data type. This is used for mixed precision. + + output_padding : tuple of ints + Used to get the right output shape for gradients + + groups : int + number of groups + + Returns + ------- + Output : tvm.te.Tensor + 5-D with shape [batch, out_channel, out_depth, out_height, out_width] + """ + if not isinstance(strides, (tuple, list)): + strides = (strides, strides, strides) + + if groups == 1: + return conv3d_transpose_ncdhw(data, kernel, strides, padding, out_dtype, output_padding) + + data_pad, kernel_transform = conv3d_transpose_ncdhw_preprocess( + data, kernel, strides, padding, out_dtype, output_padding + ) + batch, in_c, in_d, in_h, in_w = data_pad.shape + out_c, _, filter_d, filter_h, filter_w = kernel_transform.shape + assert in_c % groups == 0, f"input channels {in_c} must divide group size {groups}" + + # convolution stage + out_c = simplify(out_c * groups) + out_d = simplify(in_d - filter_d + 1) + out_h = simplify(in_h - filter_h + 1) + out_w = simplify(in_w - filter_w + 1) + dc = te.reduce_axis((0, in_c // groups), name="dc") + dd = te.reduce_axis((0, filter_d), name="dd") + dh = te.reduce_axis((0, filter_h), name="dh") + dw = te.reduce_axis((0, filter_w), name="dw") + + # data: batch, in_channels, out_d, out_h, out_w + # weight: out_channels // G, in_channels, out_d, out_h, out_w + return te.compute( + (batch, out_c, out_d, out_h, out_w), + lambda b, c, d, h, w: te.sum( + data_pad[ + b, c // (out_c // groups) * (in_c // groups) + dc, d + dd, h + dh, w + dw + ].astype(out_dtype) + * kernel_transform[ + c % (out_c // groups), + c // (out_c // groups) * (in_c // groups) + dc, + dd, + dh, + dw, + ].astype(out_dtype), + axis=[dc, dd, dh, dw], + ), + tag="group_conv3d_transpose_ncdhw", + ) + + @tvm.target.generic_func def conv3d_transpose_legalize(attrs, inputs, types): """Legalizes Transposed 3D convolution op. diff --git a/python/tvm/topi/testing/conv3d_transpose_ncdhw_python.py b/python/tvm/topi/testing/conv3d_transpose_ncdhw_python.py index 38b8bc51bc70..949e513c027c 100644 --- a/python/tvm/topi/testing/conv3d_transpose_ncdhw_python.py +++ b/python/tvm/topi/testing/conv3d_transpose_ncdhw_python.py @@ -21,7 +21,7 @@ from tvm.topi.nn.utils import get_pad_tuple3d -def conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding, output_padding): +def _conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding, output_padding): """Transposed 3d convolution operator in NCDHW layout. Parameters @@ -102,3 +102,41 @@ def conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding, output_padding): ) return b_np + + +def conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding, output_padding, groups=1): + """Transposed 3d convolution operator in NCDHW layout. + + Parameters + ---------- + a_np : numpy.ndarray + 5-D with shape [batch, in_channel, in_depth, in_height, in_width] + + w_np : numpy.ndarray + 5-D with shape [in_channel, num_filter, filter_depth, filter_height, filter_width] + + stride : int or a list/tuple of two ints + Stride size, or [stride_depth, stride_height, stride_width] + + padding : int or str + Padding size + + output_padding : int or list/tuple of three ints + Used to disambiguate output shape. + + groups : int + Number of groups + + Returns + ------- + b_np : np.ndarray + 5-D with shape [batch, out_channel, out_depth, out_height, out_width] + """ + a_slices = np.array_split(a_np, groups, axis=1) + w_slices = np.array_split(w_np, groups, axis=0) + b_slices = [ + _conv3d_transpose_ncdhw_python(a_slice, w_slice, stride, padding, output_padding) + for a_slice, w_slice in zip(a_slices, w_slices) + ] + b_np = np.concatenate(b_slices, axis=1) + return b_np diff --git a/tests/python/topi/test_topi_group_conv3d_transpose_ncdhw.py b/tests/python/topi/test_topi_group_conv3d_transpose_ncdhw.py new file mode 100644 index 000000000000..14de236e3d28 --- /dev/null +++ b/tests/python/topi/test_topi_group_conv3d_transpose_ncdhw.py @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for group transposed 3d convolution ncdhw.""" + +import numpy as np + +import tvm +import tvm.testing +import tvm.topi.testing + +from tvm import te, topi +from tvm.topi.utils import get_const_tuple + +_group_conv3d_transpose_ncdhw_implement = { + "generic": ( + topi.nn.group_conv3d_transpose_ncdhw, + topi.generic.schedule_group_conv3d_transpose_ncdhw, + ), +} + + +( + batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding, + output_padding, + groups, +) = tvm.testing.parameters( + (1, 4, (32, 32, 32), 32, (5, 5, 5), 1, 0, (0, 0, 0), 4), + (1, 8, (32, 32, 32), 32, (7, 7, 7), 1, 2, (0, 0, 0), 4), + (1, 8, (32, 32, 32), 32, (5, 5, 5), 2, 1, (0, 0, 0), 2), + (1, 4, (32, 32, 32), 4, (5, 5, 5), 2, 1, (1, 1, 1), 4), + (1, 3, (64, 64, 64), 15, (5, 5, 5), 2, 0, (0, 0, 0), 3), + (1, 32, (16, 16, 16), 128, (5, 5, 5), 1, 0, (0, 0, 0), 32), + (1, 32, (16, 16, 16), 128, (5, 5, 5), 2, 1, (0, 0, 0), 16), +) + +dtype = tvm.testing.parameter("float32") + + +@tvm.testing.fixture(cache_return_value=True) +def ref_data( + batch, in_channel, in_size, num_filter, kernel, stride, padding, output_padding, groups +): + dtype = "float32" + in_d, in_h, in_w = in_size + k_d, k_h, k_w = kernel + a_shape = (batch, in_channel, in_d, in_h, in_w) + w_shape = (in_channel, num_filter // groups, k_d, k_h, k_w) + + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(dtype) + b_np = tvm.topi.testing.conv3d_transpose_ncdhw_python( + a_np, w_np, stride, padding, output_padding, groups + ) + c_np = np.maximum(b_np, 0) + return a_np, w_np, b_np, c_np + + +@tvm.testing.parametrize_targets("llvm") +def test_group_conv3d_transpose_ncdhw( + target, dev, ref_data, dtype, stride, padding, output_padding, groups +): + a_np, w_np, b_np, c_np = ref_data + print("shapes : ", a_np.shape, w_np.shape, b_np.shape, c_np.shape) + A = te.placeholder(a_np.shape, name="A", dtype=dtype) + W = te.placeholder(w_np.shape, name="W", dtype=dtype) + + with tvm.target.Target(target): + fcompute, fschedule = tvm.topi.testing.dispatch( + target, _group_conv3d_transpose_ncdhw_implement + ) + B = fcompute(A, W, stride, padding, A.dtype, output_padding, groups) + C = topi.nn.relu(B) + s1 = fschedule([B]) + s2 = fschedule([C]) + a = tvm.nd.array(a_np, dev) + w = tvm.nd.array(w_np, dev) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev) + + func1 = tvm.build(s1, [A, W, B], target) + func2 = tvm.build(s2, [A, W, C], target) + func1(a, w, b) + func2(a, w, c) + tvm.testing.assert_allclose(b.numpy(), b_np, atol=1e-5, rtol=1e-5) + tvm.testing.assert_allclose(c.numpy(), c_np, atol=1e-5, rtol=1e-5) + + +if __name__ == "__main__": + tvm.testing.main()