From 3706e8a6e5e8a81b46088ebd30800ef34b667262 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 21 Mar 2019 14:26:38 +0800 Subject: [PATCH 1/2] [Relay][Op] Add group conv2d dispatch to topi function --- python/tvm/relay/op/nn/_nn.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 0c2733ecae92..08b5aa3ce8e4 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -96,6 +96,9 @@ def compute_conv2d(attrs, inputs, out_type, target): get_const_int(inputs[1].shape[3]) == 1: out = topi.nn.depthwise_conv2d_nhwc( inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype) + elif layout in ['NCHW', 'NCHW4c']: + out = topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides, padding, dilation, groups, + out_dtype=out_dtype) else: raise ValueError("not support arbitrary group number for now") return [out] @@ -120,6 +123,8 @@ def schedule_conv2d(attrs, outs, target): return topi.generic.schedule_depthwise_conv2d_nchw(outs) if layout == "NHWC" and kernel_layout == "HWOI": return topi.generic.schedule_depthwise_conv2d_nhwc(outs) + if layout == "NCHW4c": + return topi.generic.schedule_group_conv2d_nchw(outs) raise ValueError("No compatible schedule") From d0e913e243a9a96864b5cd90563dea9044e24173 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Sun, 24 Mar 2019 23:29:42 +0800 Subject: [PATCH 2/2] Rerun tests