From afb6909c2872383922d0fb142304541dca70fcbf Mon Sep 17 00:00:00 2001 From: LiangW <732811423@qq.com> Date: Mon, 9 Jan 2023 09:17:23 +0000 Subject: [PATCH 1/2] [TOPI][OP] Support grouped conv2d_NCHWc --- python/tvm/relay/op/strategy/x86.py | 3 ++ python/tvm/topi/nn/conv2d.py | 30 +++++++++++++++---- .../topi/python/test_topi_conv2d_NCHWc.py | 19 ++++++++++-- 3 files changed, 43 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 4585809f63e1..18bfd173e90e 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -256,6 +256,9 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): wrap_topi_schedule(topi.generic.schedule_group_conv2d_nhwc), name="group_conv2d_nhwc.generic", ) + elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc + assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio + return conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target) else: raise RuntimeError("Unsupported group_conv2d layout {}".format(layout)) return strategy diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 92b5a90e5b11..cd5de9538d8b 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -415,26 +415,44 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou else: data_pad = data - ic = te.reduce_axis((0, in_channel), name="ic") kh = te.reduce_axis((0, kernel_height), name="kh") kw = te.reduce_axis((0, kernel_width), name="kw") idxdiv = tvm.tir.indexdiv idxmod = tvm.tir.indexmod + if groups == 1: + ic = te.reduce_axis((0, in_channel), name="ic") + return te.compute( + oshape, + lambda n, oc_chunk, oh, ow, oc_block: te.sum( + data_pad[ + n, + idxdiv(ic, ic_bn), + oh * HSTR + kh * dilation_h, + ow * WSTR + kw * dilation_w, + idxmod(ic, ic_bn), + ].astype(out_dtype) + * kernel[oc_chunk, idxdiv(ic, ic_bn), kh, kw, idxmod(ic, ic_bn), oc_block].astype( + out_dtype + ), + axis=[ic, kh, kw], + ), + name="conv2d_NCHWc", + tag="conv2d_NCHWc", + ) + ic = te.reduce_axis((0, in_channel // groups), name="ic") return te.compute( oshape, - lambda n, oc_chunk, oh, ow, oc_block: te.sum( + lambda n, occ, oh, ow, oc_block: te.sum( data_pad[ n, - idxdiv(ic, ic_bn), + (occ // (oc_chunk // groups)) * (ic_chunk // groups) + idxdiv(ic, ic_bn), oh * HSTR + kh * dilation_h, ow * WSTR + kw * dilation_w, idxmod(ic, ic_bn), ].astype(out_dtype) - * kernel[oc_chunk, idxdiv(ic, ic_bn), kh, kw, idxmod(ic, ic_bn), oc_block].astype( - out_dtype - ), + * kernel[occ, idxdiv(ic, ic_bn), kh, kw, idxmod(ic, ic_bn), oc_block].astype(out_dtype), axis=[ic, kh, kw], ), name="conv2d_NCHWc", diff --git a/tests/python/topi/python/test_topi_conv2d_NCHWc.py b/tests/python/topi/python/test_topi_conv2d_NCHWc.py index 2298816d373a..007f2a5c6a16 100644 --- a/tests/python/topi/python/test_topi_conv2d_NCHWc.py +++ b/tests/python/topi/python/test_topi_conv2d_NCHWc.py @@ -63,6 +63,7 @@ def verify_conv2d_NCHWc( dilation=1, add_bias=False, add_relu=False, + groups=1, dtype="float32", ): pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) @@ -90,7 +91,14 @@ def verify_conv2d_NCHWc( A = te.placeholder((batch, in_channel // ic_block, in_height, in_width, ic_block), name="A") W = te.placeholder( - (num_filter // oc_block, in_channel // ic_block, kernel, kernel, ic_block, oc_block), + ( + num_filter // oc_block, + in_channel // ic_block // groups, + kernel, + kernel, + ic_block, + oc_block, + ), name="W", ) bias = te.placeholder((num_filter // oc_block, 1, 1, oc_block), name="bias") @@ -98,10 +106,12 @@ def verify_conv2d_NCHWc( @memoize("topi.tests.test_topi_conv2d_NCHWc.verify_conv2d_NCHWc") def get_ref_data(): a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype) - w_np = np.random.uniform(size=(num_filter, in_channel, kernel, kernel)).astype(dtype) + w_np = np.random.uniform(size=(num_filter, in_channel // groups, kernel, kernel)).astype( + dtype + ) b_np = np.random.uniform(size=(num_filter, 1, 1)).astype(dtype) dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) - c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding) + c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding, groups) if add_bias: c_np += b_np if add_relu: @@ -195,6 +205,9 @@ def test_conv2d_NCHWc(): verify_conv2d_NCHWc(4, 64, 56, 64, 3, 1, 1) verify_conv2d_NCHWc(9, 64, 56, 64, 3, 1, 1) + # groups + verify_conv2d_NCHWc(1, 2048, 10, 2048, 3, 1, 1, groups=128) + # weird workloads verify_conv2d_NCHWc(2, 2, 2, 2, 2, 2, 2) verify_conv2d_NCHWc(3, 3, 3, 3, 3, 3, 3) From fffa91de14476ddc1a3b06ddd79808edb6df9dbb Mon Sep 17 00:00:00 2001 From: LiangW <732811423@qq.com> Date: Wed, 11 Jan 2023 03:47:31 +0000 Subject: [PATCH 2/2] Fix CI tests --- python/tvm/topi/nn/conv2d.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index cd5de9538d8b..0485a17e98f5 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -388,9 +388,11 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) in_channel = ic_chunk * ic_bn target = tvm.target.Target.current(allow_none=False) - oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape) + oc_chunk, ic_chunk_group, kernel_height, kernel_width, kernel_ic_bn, oc_bn = get_const_tuple( + kernel.shape + ) num_filter = oc_chunk * oc_bn - groups = ic_chunk // ic_chunk_group + groups = in_channel // (ic_chunk_group * kernel_ic_bn) dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 dilated_kernel_w = (kernel_width - 1) * dilation_w + 1