From b9433ad7734166213e31fd35ffe2dac291cdfd66 Mon Sep 17 00:00:00 2001 From: Wheest Date: Sat, 13 Apr 2019 11:48:56 +0100 Subject: [PATCH] Fixed relay grouped convolution selection --- python/tvm/relay/op/nn/_nn.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 5a47b1d42ed3..5bdd289a9cba 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -125,6 +125,7 @@ def schedule_conv2d(attrs, outs, target): groups = attrs.groups layout = attrs.data_layout kernel_layout = attrs.kernel_layout + channels = attrs.channels #TODO should be input channels with target: if groups == 1 and layout == "NCHW": return topi.generic.schedule_conv2d_nchw(outs) @@ -132,13 +133,14 @@ def schedule_conv2d(attrs, outs, target): return topi.generic.schedule_conv2d_nchw(outs) if groups == 1 and layout == "NHWC": return topi.generic.schedule_conv2d_nhwc(outs) - if groups != 1: + if groups == channels: if layout == "NCHW": # TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d. return topi.generic.schedule_depthwise_conv2d_nchw(outs) if layout == "NHWC" and kernel_layout == "HWOI": return topi.generic.schedule_depthwise_conv2d_nhwc(outs) - if layout == "NCHW4c": + if groups > 1: + if layout == "NCHW": return topi.generic.schedule_group_conv2d_nchw(outs) raise ValueError("No compatible schedule")