Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,9 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.generic.schedule_group_conv2d_nhwc),
name="group_conv2d_nhwc.generic",
)
elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc
assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio
return conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target)
else:
raise RuntimeError("Unsupported group_conv2d layout {}".format(layout))
return strategy
Expand Down
36 changes: 28 additions & 8 deletions python/tvm/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,9 +388,11 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou
n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
in_channel = ic_chunk * ic_bn
target = tvm.target.Target.current(allow_none=False)
oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape)
oc_chunk, ic_chunk_group, kernel_height, kernel_width, kernel_ic_bn, oc_bn = get_const_tuple(
kernel.shape
)
num_filter = oc_chunk * oc_bn
groups = ic_chunk // ic_chunk_group
groups = in_channel // (ic_chunk_group * kernel_ic_bn)

dilated_kernel_h = (kernel_height - 1) * dilation_h + 1
dilated_kernel_w = (kernel_width - 1) * dilation_w + 1
Expand All @@ -415,26 +417,44 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou
else:
data_pad = data

ic = te.reduce_axis((0, in_channel), name="ic")
kh = te.reduce_axis((0, kernel_height), name="kh")
kw = te.reduce_axis((0, kernel_width), name="kw")

idxdiv = tvm.tir.indexdiv
idxmod = tvm.tir.indexmod

if groups == 1:
ic = te.reduce_axis((0, in_channel), name="ic")
return te.compute(
oshape,
lambda n, oc_chunk, oh, ow, oc_block: te.sum(
data_pad[
n,
idxdiv(ic, ic_bn),
oh * HSTR + kh * dilation_h,
ow * WSTR + kw * dilation_w,
idxmod(ic, ic_bn),
].astype(out_dtype)
* kernel[oc_chunk, idxdiv(ic, ic_bn), kh, kw, idxmod(ic, ic_bn), oc_block].astype(
out_dtype
),
axis=[ic, kh, kw],
),
name="conv2d_NCHWc",
tag="conv2d_NCHWc",
)
ic = te.reduce_axis((0, in_channel // groups), name="ic")
return te.compute(
oshape,
lambda n, oc_chunk, oh, ow, oc_block: te.sum(
lambda n, occ, oh, ow, oc_block: te.sum(
data_pad[
n,
idxdiv(ic, ic_bn),
(occ // (oc_chunk // groups)) * (ic_chunk // groups) + idxdiv(ic, ic_bn),
oh * HSTR + kh * dilation_h,
ow * WSTR + kw * dilation_w,
idxmod(ic, ic_bn),
].astype(out_dtype)
* kernel[oc_chunk, idxdiv(ic, ic_bn), kh, kw, idxmod(ic, ic_bn), oc_block].astype(
out_dtype
),
* kernel[occ, idxdiv(ic, ic_bn), kh, kw, idxmod(ic, ic_bn), oc_block].astype(out_dtype),
axis=[ic, kh, kw],
),
name="conv2d_NCHWc",
Expand Down
19 changes: 16 additions & 3 deletions tests/python/topi/python/test_topi_conv2d_NCHWc.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def verify_conv2d_NCHWc(
dilation=1,
add_bias=False,
add_relu=False,
groups=1,
dtype="float32",
):
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
Expand Down Expand Up @@ -90,18 +91,27 @@ def verify_conv2d_NCHWc(

A = te.placeholder((batch, in_channel // ic_block, in_height, in_width, ic_block), name="A")
W = te.placeholder(
(num_filter // oc_block, in_channel // ic_block, kernel, kernel, ic_block, oc_block),
(
num_filter // oc_block,
in_channel // ic_block // groups,
kernel,
kernel,
ic_block,
oc_block,
),
name="W",
)
bias = te.placeholder((num_filter // oc_block, 1, 1, oc_block), name="bias")

@memoize("topi.tests.test_topi_conv2d_NCHWc.verify_conv2d_NCHWc")
def get_ref_data():
a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype)
w_np = np.random.uniform(size=(num_filter, in_channel, kernel, kernel)).astype(dtype)
w_np = np.random.uniform(size=(num_filter, in_channel // groups, kernel, kernel)).astype(
dtype
)
b_np = np.random.uniform(size=(num_filter, 1, 1)).astype(dtype)
dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding)
c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding, groups)
if add_bias:
c_np += b_np
if add_relu:
Expand Down Expand Up @@ -195,6 +205,9 @@ def test_conv2d_NCHWc():
verify_conv2d_NCHWc(4, 64, 56, 64, 3, 1, 1)
verify_conv2d_NCHWc(9, 64, 56, 64, 3, 1, 1)

# groups
verify_conv2d_NCHWc(1, 2048, 10, 2048, 3, 1, 1, groups=128)

# weird workloads
verify_conv2d_NCHWc(2, 2, 2, 2, 2, 2, 2)
verify_conv2d_NCHWc(3, 3, 3, 3, 3, 3, 3)
Expand Down