diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 0c2733ecae92..08b5aa3ce8e4 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -96,6 +96,9 @@ def compute_conv2d(attrs, inputs, out_type, target): get_const_int(inputs[1].shape[3]) == 1: out = topi.nn.depthwise_conv2d_nhwc( inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype) + elif layout in ['NCHW', 'NCHW4c']: + out = topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides, padding, dilation, groups, + out_dtype=out_dtype) else: raise ValueError("not support arbitrary group number for now") return [out] @@ -120,6 +123,8 @@ def schedule_conv2d(attrs, outs, target): return topi.generic.schedule_depthwise_conv2d_nchw(outs) if layout == "NHWC" and kernel_layout == "HWOI": return topi.generic.schedule_depthwise_conv2d_nhwc(outs) + if layout == "NCHW4c": + return topi.generic.schedule_group_conv2d_nchw(outs) raise ValueError("No compatible schedule")