From 81bf5d63e0b2f2c564c0726b1ee7a28b4c75c92e Mon Sep 17 00:00:00 2001 From: Ligeng Zhu Date: Fri, 20 Aug 2021 01:23:09 +0800 Subject: [PATCH] Update conv2d_transpose.py --- python/tvm/topi/nn/conv2d_transpose.py | 75 ++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/python/tvm/topi/nn/conv2d_transpose.py b/python/tvm/topi/nn/conv2d_transpose.py index 22188bcd45a4..7f61b9bf30dc 100644 --- a/python/tvm/topi/nn/conv2d_transpose.py +++ b/python/tvm/topi/nn/conv2d_transpose.py @@ -116,6 +116,81 @@ def declaration_conv2d_transpose_impl(data, kernel, strides, padding, out_dtype, return Output +def group_conv2d_transpose_nchw(data, kernel, stride=1, padding=0, output_padding=0, groups=1, dilation=1, out_dtype=None): + # some pre-processing and prelimnary checks + if out_dtype is None: + out_dtype = data.dtype + + strides = _pair(stride) + padding = _pair(padding) + output_padding = _pair(output_padding) + dilation = _pair(dilation) + + batch, in_channels, in_height, in_width = get_const_tuple(data.shape) + _, out_c, filter_h, filter_w = kernel.shape + assert in_channels % groups == 0, "input channels must divide group size" + # assert out_c % groups == 0, "output channels must divide group size" + + batch, in_c, in_h, in_w = data.shape + _, out_c, filter_h, filter_w = kernel.shape + stride_h, stride_w = strides + opad_h, opad_w = output_padding + assert opad_h < stride_h and opad_w < stride_w + # dilate data + data_dilate = dilate(data, [1, 1, stride_h, stride_w], name="data_dilate") + # pad data + fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (filter_h, filter_w)) + bpad_top = filter_h - 1 - fpad_top + bpad_bottom = filter_h - 1 - fpad_bottom + opad_h + bpad_left = filter_w - 1 - fpad_left + bpad_right = filter_w - 1 - fpad_right + opad_w + data_pad = pad( + data_dilate, [0, 0, bpad_top, bpad_left], [0, 0, bpad_bottom, bpad_right], name="data_pad" + ) + # transform kernel layout from IOHW to OIHW, and rotate kernel by 180 degrees + kernel_transform = te.compute( + (out_c, in_c, filter_h, filter_w), + lambda i, o, h, w: kernel[o][i][filter_h - 1 - h][filter_w - 1 - w], + name="kernel_transform", + ) + + batch, in_c, in_h, in_w = data_pad.shape + out_c, _, filter_h, filter_w = kernel_transform.shape + + # convolution stage + out_c = simplify(out_c) + out_channels = simplify(out_c * groups) + + out_h = simplify(in_h - filter_h + 1) + out_w = simplify(in_w - filter_w + 1) + dc = te.reduce_axis((0, in_c // groups), name="dc") + dh = te.reduce_axis((0, filter_h), name="dh") + dw = te.reduce_axis((0, filter_w), name="dw") + + # data: batch, in_channels, out_h, out_w + # weight: out_channels // G, in_channels, out_h, out_w + return te.compute( + (batch, out_channels, out_h, out_w), + lambda b, c, h, w: te.sum( + data_pad[ + b, + c // (out_channels // groups) * (in_channels // groups) + dc, + h + dh, + w + dw + ].astype(out_dtype) + * kernel_transform[ + c % (out_channels // groups), + c // (out_channels // groups) * (in_channels // groups) + dc, + dh, + dw + ].astype(out_dtype), + axis=[dc, dh, dw], + ), + tag="conv2d_transpose_nchw", + ) + + + @tvm.target.generic_func def conv2d_transpose_legalize(attrs, inputs, types): """Legalizes Transposed 2D convolution op.