Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions python/tvm/topi/nn/conv2d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,81 @@ def declaration_conv2d_transpose_impl(data, kernel, strides, padding, out_dtype,
return Output


def group_conv2d_transpose_nchw(data, kernel, stride=1, padding=0, output_padding=0, groups=1, dilation=1, out_dtype=None):
# some pre-processing and prelimnary checks
if out_dtype is None:
out_dtype = data.dtype

strides = _pair(stride)
padding = _pair(padding)
output_padding = _pair(output_padding)
dilation = _pair(dilation)

batch, in_channels, in_height, in_width = get_const_tuple(data.shape)
_, out_c, filter_h, filter_w = kernel.shape
assert in_channels % groups == 0, "input channels must divide group size"
# assert out_c % groups == 0, "output channels must divide group size"

batch, in_c, in_h, in_w = data.shape
_, out_c, filter_h, filter_w = kernel.shape
stride_h, stride_w = strides
opad_h, opad_w = output_padding
assert opad_h < stride_h and opad_w < stride_w
# dilate data
data_dilate = dilate(data, [1, 1, stride_h, stride_w], name="data_dilate")
# pad data
fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (filter_h, filter_w))
bpad_top = filter_h - 1 - fpad_top
bpad_bottom = filter_h - 1 - fpad_bottom + opad_h
bpad_left = filter_w - 1 - fpad_left
bpad_right = filter_w - 1 - fpad_right + opad_w
data_pad = pad(
data_dilate, [0, 0, bpad_top, bpad_left], [0, 0, bpad_bottom, bpad_right], name="data_pad"
)
# transform kernel layout from IOHW to OIHW, and rotate kernel by 180 degrees
kernel_transform = te.compute(
(out_c, in_c, filter_h, filter_w),
lambda i, o, h, w: kernel[o][i][filter_h - 1 - h][filter_w - 1 - w],
name="kernel_transform",
)

batch, in_c, in_h, in_w = data_pad.shape
out_c, _, filter_h, filter_w = kernel_transform.shape

# convolution stage
out_c = simplify(out_c)
out_channels = simplify(out_c * groups)

out_h = simplify(in_h - filter_h + 1)
out_w = simplify(in_w - filter_w + 1)
dc = te.reduce_axis((0, in_c // groups), name="dc")
dh = te.reduce_axis((0, filter_h), name="dh")
dw = te.reduce_axis((0, filter_w), name="dw")

# data: batch, in_channels, out_h, out_w
# weight: out_channels // G, in_channels, out_h, out_w
return te.compute(
(batch, out_channels, out_h, out_w),
lambda b, c, h, w: te.sum(
data_pad[
b,
c // (out_channels // groups) * (in_channels // groups) + dc,
h + dh,
w + dw
].astype(out_dtype)
* kernel_transform[
c % (out_channels // groups),
c // (out_channels // groups) * (in_channels // groups) + dc,
dh,
dw
].astype(out_dtype),
axis=[dc, dh, dw],
),
tag="conv2d_transpose_nchw",
)



@tvm.target.generic_func
def conv2d_transpose_legalize(attrs, inputs, types):
"""Legalizes Transposed 2D convolution op.
Expand Down