Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def conv2d_transpose_strategy(attrs, inputs, out_type, target):


# conv3d_transpose
def wrap_compute_conv3d_transpose(topi_compute):
def wrap_compute_conv3d_transpose(topi_compute, has_groups=False):
"""wrap conv3d_transpose topi compute"""

def compute_conv3d_transpose(attrs, inputs, out_dtype):
Expand All @@ -528,7 +528,10 @@ def compute_conv3d_transpose(attrs, inputs, out_dtype):
output_padding = get_const_tuple(attrs.output_padding)
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype, output_padding)
args = [inputs[0], inputs[1], strides, padding, out_dtype, output_padding]
if has_groups:
args.append(attrs.group)
out = topi_compute(*args)
return [out]

return compute_conv3d_transpose
Expand All @@ -543,13 +546,20 @@ def conv3d_transpose_strategy(attrs, inputs, out_type, target):
groups = attrs.groups
assert layout == "NCDHW", "only support ncdhw for now"
assert dilation == (1, 1, 1), "not support dilate now"
assert groups == 1, "only support groups == 1 for now"

strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv3d_transpose(topi.nn.conv3d_transpose_ncdhw),
wrap_topi_schedule(topi.generic.schedule_conv3d_transpose_ncdhw),
name="conv3d_transpose_ncdhw.generic",
)
if groups == 1:
strategy.add_implementation(
wrap_compute_conv3d_transpose(topi.nn.conv3d_transpose_ncdhw),
wrap_topi_schedule(topi.generic.schedule_conv3d_transpose_ncdhw),
name="conv3d_transpose_ncdhw.generic",
)
else:
strategy.add_implementation(
wrap_compute_conv3d_transpose(topi.nn.group_conv3d_transpose_ncdhw, has_groups=True),
wrap_topi_schedule(topi.generic.schedule_group_conv3d_transpose_ncdhw),
name="group_conv3d_transpose_ncdhw.generic",
)
return strategy


Expand Down
17 changes: 17 additions & 0 deletions python/tvm/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,23 @@ def schedule_group_conv2d_transpose_nchw(outs):
return _default_schedule(outs, False)


def schedule_group_conv3d_transpose_ncdhw(outs):
"""Schedule for schedule_group_conv3d_transpose_ncdhw

Parameters
----------
outs: Array of Tensor
The computation graph description of schedule_group_conv3d_transpose_ncdhw
in the format of an array of tensors.

Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)


def schedule_group_conv2d_nhwc(outs):
"""Schedule for group_conv2d_nhwc

Expand Down
75 changes: 75 additions & 0 deletions python/tvm/topi/nn/conv3d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,81 @@ def declaration_conv3d_transpose_impl(data, kernel, strides, padding, out_dtype,
return Output


def group_conv3d_transpose_ncdhw(data, kernel, strides, padding, out_dtype, output_padding, groups):
"""Transposed group 3D convolution ncdhw forward operator.

Parameters
----------
data : tvm.te.Tensor
5-D with shape [batch, in_channel, in_depth, in_height, in_width]

kernel : tvm.te.Tensor
5-D with shape [in_channel, num_filter, filter_depth, filter_height, filter_width]

strides : int or a list/tuple of three ints
The spatial stride along depth,height and width

padding : int or str
Padding size, or ['VALID', 'SAME']

out_dtype : str
The output data type. This is used for mixed precision.

output_padding : tuple of ints
Used to get the right output shape for gradients

groups : int
number of groups

Returns
-------
Output : tvm.te.Tensor
5-D with shape [batch, out_channel, out_depth, out_height, out_width]
"""
if not isinstance(strides, (tuple, list)):
strides = (strides, strides, strides)

if groups == 1:
return conv3d_transpose_ncdhw(data, kernel, strides, padding, out_dtype, output_padding)

data_pad, kernel_transform = conv3d_transpose_ncdhw_preprocess(
data, kernel, strides, padding, out_dtype, output_padding
)
batch, in_c, in_d, in_h, in_w = data_pad.shape
out_c, _, filter_d, filter_h, filter_w = kernel_transform.shape
assert in_c % groups == 0, f"input channels {in_c} must divide group size {groups}"

# convolution stage
out_c = simplify(out_c * groups)
out_d = simplify(in_d - filter_d + 1)
out_h = simplify(in_h - filter_h + 1)
out_w = simplify(in_w - filter_w + 1)
dc = te.reduce_axis((0, in_c // groups), name="dc")
dd = te.reduce_axis((0, filter_d), name="dd")
dh = te.reduce_axis((0, filter_h), name="dh")
dw = te.reduce_axis((0, filter_w), name="dw")

# data: batch, in_channels, out_d, out_h, out_w
# weight: out_channels // G, in_channels, out_d, out_h, out_w
return te.compute(
(batch, out_c, out_d, out_h, out_w),
lambda b, c, d, h, w: te.sum(
data_pad[
b, c // (out_c // groups) * (in_c // groups) + dc, d + dd, h + dh, w + dw
].astype(out_dtype)
* kernel_transform[
c % (out_c // groups),
c // (out_c // groups) * (in_c // groups) + dc,
dd,
dh,
dw,
].astype(out_dtype),
axis=[dc, dd, dh, dw],
),
tag="group_conv3d_transpose_ncdhw",
)


@tvm.target.generic_func
def conv3d_transpose_legalize(attrs, inputs, types):
"""Legalizes Transposed 3D convolution op.
Expand Down
40 changes: 39 additions & 1 deletion python/tvm/topi/testing/conv3d_transpose_ncdhw_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tvm.topi.nn.utils import get_pad_tuple3d


def conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding, output_padding):
def _conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding, output_padding):
"""Transposed 3d convolution operator in NCDHW layout.

Parameters
Expand Down Expand Up @@ -102,3 +102,41 @@ def conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding, output_padding):
)

return b_np


def conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding, output_padding, groups=1):
"""Transposed 3d convolution operator in NCDHW layout.

Parameters
----------
a_np : numpy.ndarray
5-D with shape [batch, in_channel, in_depth, in_height, in_width]

w_np : numpy.ndarray
5-D with shape [in_channel, num_filter, filter_depth, filter_height, filter_width]

stride : int or a list/tuple of two ints
Stride size, or [stride_depth, stride_height, stride_width]

padding : int or str
Padding size

output_padding : int or list/tuple of three ints
Used to disambiguate output shape.

groups : int
Number of groups

Returns
-------
b_np : np.ndarray
5-D with shape [batch, out_channel, out_depth, out_height, out_width]
"""
a_slices = np.array_split(a_np, groups, axis=1)
w_slices = np.array_split(w_np, groups, axis=0)
b_slices = [
_conv3d_transpose_ncdhw_python(a_slice, w_slice, stride, padding, output_padding)
for a_slice, w_slice in zip(a_slices, w_slices)
]
b_np = np.concatenate(b_slices, axis=1)
return b_np
109 changes: 109 additions & 0 deletions tests/python/topi/test_topi_group_conv3d_transpose_ncdhw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test code for group transposed 3d convolution ncdhw."""

import numpy as np

import tvm
import tvm.testing
import tvm.topi.testing

from tvm import te, topi
from tvm.topi.utils import get_const_tuple

_group_conv3d_transpose_ncdhw_implement = {
"generic": (
topi.nn.group_conv3d_transpose_ncdhw,
topi.generic.schedule_group_conv3d_transpose_ncdhw,
),
}


(
batch,
in_channel,
in_size,
num_filter,
kernel,
stride,
padding,
output_padding,
groups,
) = tvm.testing.parameters(
(1, 4, (32, 32, 32), 32, (5, 5, 5), 1, 0, (0, 0, 0), 4),
(1, 8, (32, 32, 32), 32, (7, 7, 7), 1, 2, (0, 0, 0), 4),
(1, 8, (32, 32, 32), 32, (5, 5, 5), 2, 1, (0, 0, 0), 2),
(1, 4, (32, 32, 32), 4, (5, 5, 5), 2, 1, (1, 1, 1), 4),
(1, 3, (64, 64, 64), 15, (5, 5, 5), 2, 0, (0, 0, 0), 3),
(1, 32, (16, 16, 16), 128, (5, 5, 5), 1, 0, (0, 0, 0), 32),
(1, 32, (16, 16, 16), 128, (5, 5, 5), 2, 1, (0, 0, 0), 16),
)

dtype = tvm.testing.parameter("float32")


@tvm.testing.fixture(cache_return_value=True)
def ref_data(
batch, in_channel, in_size, num_filter, kernel, stride, padding, output_padding, groups
):
dtype = "float32"
in_d, in_h, in_w = in_size
k_d, k_h, k_w = kernel
a_shape = (batch, in_channel, in_d, in_h, in_w)
w_shape = (in_channel, num_filter // groups, k_d, k_h, k_w)

a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = tvm.topi.testing.conv3d_transpose_ncdhw_python(
a_np, w_np, stride, padding, output_padding, groups
)
c_np = np.maximum(b_np, 0)
return a_np, w_np, b_np, c_np


@tvm.testing.parametrize_targets("llvm")
def test_group_conv3d_transpose_ncdhw(
target, dev, ref_data, dtype, stride, padding, output_padding, groups
):
a_np, w_np, b_np, c_np = ref_data
print("shapes : ", a_np.shape, w_np.shape, b_np.shape, c_np.shape)
A = te.placeholder(a_np.shape, name="A", dtype=dtype)
W = te.placeholder(w_np.shape, name="W", dtype=dtype)

with tvm.target.Target(target):
fcompute, fschedule = tvm.topi.testing.dispatch(
target, _group_conv3d_transpose_ncdhw_implement
)
B = fcompute(A, W, stride, padding, A.dtype, output_padding, groups)
C = topi.nn.relu(B)
s1 = fschedule([B])
s2 = fschedule([C])
a = tvm.nd.array(a_np, dev)
w = tvm.nd.array(w_np, dev)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)

func1 = tvm.build(s1, [A, W, B], target)
func2 = tvm.build(s2, [A, W, C], target)
func1(a, w, b)
func2(a, w, c)
tvm.testing.assert_allclose(b.numpy(), b_np, atol=1e-5, rtol=1e-5)
tvm.testing.assert_allclose(c.numpy(), c_np, atol=1e-5, rtol=1e-5)


if __name__ == "__main__":
tvm.testing.main()