diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 5a47b1d42ed3..5bdd289a9cba 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -125,6 +125,7 @@ def schedule_conv2d(attrs, outs, target): groups = attrs.groups layout = attrs.data_layout kernel_layout = attrs.kernel_layout + channels = attrs.channels #TODO should be input channels with target: if groups == 1 and layout == "NCHW": return topi.generic.schedule_conv2d_nchw(outs) @@ -132,13 +133,14 @@ def schedule_conv2d(attrs, outs, target): return topi.generic.schedule_conv2d_nchw(outs) if groups == 1 and layout == "NHWC": return topi.generic.schedule_conv2d_nhwc(outs) - if groups != 1: + if groups == channels: if layout == "NCHW": # TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d. return topi.generic.schedule_depthwise_conv2d_nchw(outs) if layout == "NHWC" and kernel_layout == "HWOI": return topi.generic.schedule_depthwise_conv2d_nhwc(outs) - if layout == "NCHW4c": + if groups > 1: + if layout == "NCHW": return topi.generic.schedule_group_conv2d_nchw(outs) raise ValueError("No compatible schedule")