From d008622788e7ba6c282471308ac5b4dcfcad87d3 Mon Sep 17 00:00:00 2001 From: Vincent Zhao Date: Mon, 22 Apr 2019 17:20:09 +0100 Subject: [PATCH 01/15] Fixed issue #3069 by adding in_channels --- include/tvm/relay/attrs/nn.h | 6 ++++++ python/tvm/relay/frontend/onnx.py | 4 +++- python/tvm/relay/op/nn/_nn.py | 11 ++++++++--- python/tvm/relay/op/nn/nn.py | 6 +++++- src/relay/op/nn/convolution.cc | 1 + 5 files changed, 23 insertions(+), 5 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 431b6032c8cd..cb769ca58a94 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -54,6 +54,7 @@ struct Conv2DAttrs : public tvm::AttrsNode { Array dilation; int groups; IndexExpr channels; + IndexExpr in_channels; Array kernel_size; std::string data_layout; std::string kernel_layout; @@ -78,6 +79,11 @@ struct Conv2DAttrs : public tvm::AttrsNode { .describe("The number of output channels in the convolution." " If it is not set, inferred by shape of the weight.") .set_default(NullValue()); + TVM_ATTR_FIELD(in_channels) + .describe("The number of input channels in the convolution." + " Its value won't affect the behaviour of standard conv2d and depthwise conv2d," + " but it is necessary for group conv2d.") + .set_default(NullValue()); TVM_ATTR_FIELD(kernel_size) .describe("Specifies the dimensions of the convolution window.") .set_default(NullValue >()); diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index ebedc20375e5..b6fc33bce5c1 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -169,7 +169,9 @@ class Conv(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - # get number of channels + # get number of input channels + attr['in_channels'] = infer_shape(inputs[0])[1] + out = AttrCvt(op_name=dimension_picker('conv'), transforms={ 'kernel_shape': 'kernel_size', diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index e60c01cfb3ff..aa06a92dff0e 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -125,6 +125,9 @@ def schedule_conv2d(attrs, outs, target): groups = attrs.groups layout = attrs.data_layout kernel_layout = attrs.kernel_layout + in_channels = attrs.in_channels + out_channels = outs[0].shape[1] + with target: if groups == 1 and layout == "NCHW": return topi.generic.schedule_conv2d_nchw(outs) @@ -133,12 +136,14 @@ def schedule_conv2d(attrs, outs, target): if groups == 1 and layout == "NHWC": return topi.generic.schedule_conv2d_nhwc(outs) if groups != 1: - if layout == "NCHW": + if layout == "NCHW" and in_channels == groups and \ + in_channels == out_channels: # TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d. return topi.generic.schedule_depthwise_conv2d_nchw(outs) - if layout == "NHWC" and kernel_layout == "HWOI": + if layout == "NHWC" and kernel_layout == "HWOI" and \ + in_channels == groups and in_channels == out_channels: return topi.generic.schedule_depthwise_conv2d_nhwc(outs) - if layout == "NCHW4c": + if layout in ["NCHW", "NCHW4c"]: return topi.generic.schedule_group_conv2d_nchw(outs) raise ValueError("No compatible schedule") diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 2d13f53f17fd..0c211383423c 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -28,6 +28,7 @@ def conv2d(data, dilation=(1, 1), groups=1, channels=None, + in_channels=None, kernel_size=None, data_layout="NCHW", kernel_layout="OIHW", @@ -81,6 +82,9 @@ def conv2d(data, channels : int, optional Number of output channels of this convolution. + in_channels : int, optional + Number of input channels of this convolution. + kernel_size : tuple of int, optional The spatial of the convolution kernel. @@ -102,7 +106,7 @@ def conv2d(data, The computed result. """ return _make.conv2d(data, weight, strides, padding, dilation, - groups, channels, kernel_size, data_layout, + groups, channels, in_channels, kernel_size, data_layout, kernel_layout, out_layout, out_dtype) diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 97cba7964000..996ee3e935af 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -148,6 +148,7 @@ Expr MakeConv2D(Expr data, Array dilation, int groups, IndexExpr channels, + IndexExpr in_channels, Array kernel_size, std::string data_layout, std::string kernel_layout, From c1d47dc556af2abc062e175fa00792703e44a3ca Mon Sep 17 00:00:00 2001 From: Ruizhe Date: Tue, 23 Apr 2019 16:03:39 +0100 Subject: [PATCH 02/15] Registerd group_conv2d_nchw as topi compute --- include/tvm/relay/attrs/nn.h | 6 --- python/tvm/relay/frontend/onnx.py | 3 -- python/tvm/relay/op/nn/_nn.py | 68 +++++++++++++++++++++++++------ python/tvm/relay/op/nn/nn.py | 6 +-- python/tvm/target.py | 2 +- src/relay/op/nn/convolution.cc | 1 - topi/python/topi/generic/nn.py | 2 +- topi/python/topi/nn/conv2d.py | 5 ++- topi/python/topi/x86/conv2d.py | 14 ++++++- 9 files changed, 75 insertions(+), 32 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index cb769ca58a94..431b6032c8cd 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -54,7 +54,6 @@ struct Conv2DAttrs : public tvm::AttrsNode { Array dilation; int groups; IndexExpr channels; - IndexExpr in_channels; Array kernel_size; std::string data_layout; std::string kernel_layout; @@ -79,11 +78,6 @@ struct Conv2DAttrs : public tvm::AttrsNode { .describe("The number of output channels in the convolution." " If it is not set, inferred by shape of the weight.") .set_default(NullValue()); - TVM_ATTR_FIELD(in_channels) - .describe("The number of input channels in the convolution." - " Its value won't affect the behaviour of standard conv2d and depthwise conv2d," - " but it is necessary for group conv2d.") - .set_default(NullValue()); TVM_ATTR_FIELD(kernel_size) .describe("Specifies the dimensions of the convolution window.") .set_default(NullValue >()); diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index b6fc33bce5c1..53f104ce48cf 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -169,9 +169,6 @@ class Conv(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - # get number of input channels - attr['in_channels'] = infer_shape(inputs[0])[1] - out = AttrCvt(op_name=dimension_picker('conv'), transforms={ 'kernel_shape': 'kernel_size', diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index aa06a92dff0e..017cdb2664de 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -#pylint: disable=invalid-name, unused-argument +# pylint: disable=invalid-name, unused-argument """Backend compiler related feature registration""" from __future__ import absolute_import @@ -34,16 +34,19 @@ def schedule_softmax(_, outputs, target): with target: return topi.generic.schedule_softmax(outputs) + reg.register_pattern("nn.softmax", OpPattern.OPAQUE) schedule_broadcast = schedule_injective + @reg.register_schedule("nn.log_softmax") def schedule_log_softmax(_, outputs, target): """Schedule definition of log_softmax""" with target: return topi.generic.schedule_softmax(outputs) + reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE) @@ -53,12 +56,14 @@ def compute_dense(attrs, inputs, out_type, target): """Compute definition of dense""" return [topi.nn.dense(inputs[0], inputs[1])] + @reg.register_schedule("nn.dense") def schedule_dense(attrs, outputs, target): """Schedule definition of dense""" with target: return topi.generic.schedule_dense(outputs) + reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) @@ -68,12 +73,14 @@ def compute_batch_matmul(attrs, inputs, out_type, target): """Compute definition of batch_matmul""" return [topi.nn.batch_matmul(inputs[0], inputs[1])] + @reg.register_schedule("nn.batch_matmul") def schedule_batch_matmul(attrs, outputs, target): """Schedule definition of batch_matmul""" with target: return topi.generic.schedule_batch_matmul(outputs) + reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE) @@ -101,14 +108,14 @@ def compute_conv2d(attrs, inputs, out_type, target): inputs[0], inputs[1], strides, padding, dilation, layout, out_dtype=out_dtype) elif layout == "NCHW" and \ - get_const_int(inputs[1].shape[0]) == groups and \ - get_const_int(inputs[1].shape[1]) == 1: + get_const_int(inputs[1].shape[0]) == groups and \ + get_const_int(inputs[1].shape[1]) == 1: out = topi.nn.depthwise_conv2d_nchw( inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype) elif layout == "NHWC" and \ - kernel_layout == "HWOI" and\ - get_const_int(inputs[1].shape[2]) == groups and \ - get_const_int(inputs[1].shape[3]) == 1: + kernel_layout == "HWOI" and\ + get_const_int(inputs[1].shape[2]) == groups and \ + get_const_int(inputs[1].shape[3]) == 1: out = topi.nn.depthwise_conv2d_nhwc( inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype) elif layout in ['NCHW', 'NCHW4c']: @@ -125,8 +132,6 @@ def schedule_conv2d(attrs, outs, target): groups = attrs.groups layout = attrs.data_layout kernel_layout = attrs.kernel_layout - in_channels = attrs.in_channels - out_channels = outs[0].shape[1] with target: if groups == 1 and layout == "NCHW": @@ -136,14 +141,18 @@ def schedule_conv2d(attrs, outs, target): if groups == 1 and layout == "NHWC": return topi.generic.schedule_conv2d_nhwc(outs) if groups != 1: - if layout == "NCHW" and in_channels == groups and \ - in_channels == out_channels: + # collect in_channels to distinguish depthwise and group conv2d + wkl = outs[0].op.attrs['workload'] + in_channels = wkl[1][1] + out_channels = outs[0].shape[1] + + if layout == "NCHW" and in_channels == groups and in_channels == out_channels: # TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d. return topi.generic.schedule_depthwise_conv2d_nchw(outs) if layout == "NHWC" and kernel_layout == "HWOI" and \ in_channels == groups and in_channels == out_channels: return topi.generic.schedule_depthwise_conv2d_nhwc(outs) - if layout in ["NCHW", "NCHW4c"]: + if layout == "NCHW": return topi.generic.schedule_group_conv2d_nchw(outs) raise ValueError("No compatible schedule") @@ -154,6 +163,7 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos): from ... import op return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op) + reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) @@ -172,18 +182,21 @@ def compute_conv2d_transpose(attrs, inputs, out_dtype, target): assert layout == "NCHW", "only support nchw for now" assert dilation == (1, 1), "not support dilate now" assert groups == 1, "only support groups == 1 for now" - out = topi.nn.conv2d_transpose_nchw(inputs[0], inputs[1], strides, padding, out_dtype) + out = topi.nn.conv2d_transpose_nchw( + inputs[0], inputs[1], strides, padding, out_dtype) output_padding = get_const_tuple(attrs.output_padding) out = topi.nn.pad(out, [0, 0, 0, 0], [0, 0, output_padding[0], output_padding[1]]) return [out] + @reg.register_schedule("nn.conv2d_transpose") def schedule_conv2d_transpose(attrs, outs, target): """Schedule definition of conv2d_transpose""" with target: return topi.generic.schedule_conv2d_transpose_nchw(outs) + reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE) # bias_add @@ -199,6 +212,7 @@ def schedule_max_pool2d(attrs, outs, target): with target: return topi.generic.schedule_pool(outs, layout) + reg.register_pattern("nn.max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) @@ -210,6 +224,7 @@ def schedule_avg_pool2d(attrs, outs, target): with target: return topi.generic.schedule_pool(outs, layout) + reg.register_pattern("nn.avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) @@ -220,6 +235,7 @@ def schedule_global_max_pool2d(_, outs, target): with target: return topi.generic.schedule_global_pool(outs) + reg.register_pattern("nn.global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) @@ -230,6 +246,7 @@ def schedule_global_avg_pool2d(_, outs, target): with target: return topi.generic.schedule_global_pool(outs) + reg.register_pattern("nn.global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) # leaky_relu @@ -253,12 +270,14 @@ def compute_lrn(attrs, inputs, out_dtype, target): return [topi.nn.lrn(inputs[0], attrs.size, attrs.axis, attrs.alpha, attrs.beta, attrs.bias)] + @reg.register_schedule("nn.lrn") def schedule_lrn(attrs, outs, target): """Schedule definition of lrn""" with target: return topi.generic.schedule_lrn(outs) + reg.register_pattern("nn.lrn", OpPattern.OPAQUE) @@ -268,20 +287,26 @@ def compute_l2_normalize(attrs, inputs, out_dtype, target): """Compute definition of l2 normalize""" return [topi.nn.l2_normalize(inputs[0], attrs.eps, attrs.axis)] + @reg.register_schedule("nn.l2_normalize") def schedule_l2_normalize(attrs, outs, target): """Schedule definition of l2 normalize""" with target: return topi.generic.schedule_l2_normalize(outs) + reg.register_pattern("nn.l2_normalize", OpPattern.OUT_ELEMWISE_FUSABLE) # upsampling reg.register_schedule("nn.upsampling", reg.schedule_injective) + + def schedule_upsampling(_, outs, target): """Schedule definition of upsampling""" with target: return topi.generic.schedule_injective(outs) + + # pad reg.register_schedule("nn.pad", schedule_broadcast) @@ -307,12 +332,14 @@ def compute_contrib_conv2d_winograd_without_weight_transform(attrs, inputs, out_ return [out] + @reg.register_schedule("nn.contrib_conv2d_winograd_without_weight_transform") def schedule_contrib_conv2d_winograd_without_weight_transform(attrs, outs, target): """Schedule definition of conv2d_winograd_without_weight_transform""" with target: return topi.generic.schedule_conv2d_winograd_without_weight_transform(outs) + reg.register_pattern("nn.contrib_conv2d_winograd_without_weight_transform", OpPattern.OUT_ELEMWISE_FUSABLE) @@ -320,15 +347,18 @@ def schedule_contrib_conv2d_winograd_without_weight_transform(attrs, outs, targe @reg.register_compute("nn.contrib_conv2d_winograd_weight_transform") def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, out_dtype, target): """Compute definition of contrib_conv2d_winograd_weight_transform""" - out = topi.nn.conv2d_winograd_weight_transform(inputs[0], attrs.get_int('tile_size')) + out = topi.nn.conv2d_winograd_weight_transform( + inputs[0], attrs.get_int('tile_size')) return [out] + @reg.register_schedule("nn.contrib_conv2d_winograd_weight_transform") def schedule_contrib_conv2d_winograd_weight_transform(attrs, outs, target): """Schedule definition of contrib_conv2d_winograd_weight_transform""" with target: return topi.generic.schedule_conv2d_winograd_weight_transform(outs) + reg.register_pattern("nn.contrib_conv2d_winograd_weight_transform", OpPattern.OUT_ELEMWISE_FUSABLE) @@ -356,12 +386,14 @@ def compute_contrib_conv2d_winograd_nnpack_without_weight_transform( return [out] + @reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_without_weight_transform") def schedule_contrib_conv2d_winograd_nnpack_without_weight_transform(attrs, outs, target): """Schedule definition of conv2d_winograd_nnpack_without_weight_transform""" with target: return topi.generic.schedule_conv2d_winograd_nnpack_without_weight_transform(outs) + reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_without_weight_transform", OpPattern.OPAQUE) @@ -374,12 +406,14 @@ def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, out_d inputs[0], convolution_algorithm, out_dtype) return [out] + @reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_weight_transform") def schedule_contrib_conv2d_winograd_nnpack_weight_transform(attrs, outs, target): """Schedule definition of contrib_conv2d_winograd_nnpack_weight_transform""" with target: return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs) + reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_weight_transform", OpPattern.OPAQUE) @@ -400,15 +434,18 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, out_dtype, target): data_layout, out_layout, out_dtype) return [out] + @reg.register_schedule("nn.contrib_conv2d_NCHWc") def schedule_contrib_conv2d_NCHWc(attrs, outs, target): """Schedule definition of contrib_conv2d_NCHWc""" with target: return topi.generic.schedule_conv2d_NCHWc(outs) + reg.register_pattern("nn.contrib_conv2d_NCHWc", OpPattern.OUT_ELEMWISE_FUSABLE) + @reg.register_compute("nn.contrib_depthwise_conv2d_NCHWc") def compute_contrib_depthwise_conv2d_NCHWc(attrs, inputs, out_dtype, target): """Compute definition of depthwise conv2d NCHWc""" @@ -425,15 +462,18 @@ def compute_contrib_depthwise_conv2d_NCHWc(attrs, inputs, out_dtype, target): data_layout, out_layout, out_dtype) return [out] + @reg.register_schedule("nn.contrib_depthwise_conv2d_NCHWc") def schedule_contrib_depthwise_conv2d_NCHWc(attrs, outs, target): """Schedule definition of contrib_conv2d_NCHWc""" with target: return topi.generic.schedule_depthwise_conv2d_NCHWc(outs) + reg.register_pattern("nn.contrib_depthwise_conv2d_NCHWc", OpPattern.OUT_ELEMWISE_FUSABLE) + @reg.register_compute("nn.deformable_conv2d") def compute_deformable_conv2d(attrs, inputs, out_dtype, target): """Compute definition of deformable_conv2d""" @@ -449,10 +489,12 @@ def compute_deformable_conv2d(attrs, inputs, out_dtype, target): dilation, deformable_groups, groups, out_dtype) return [out] + @reg.register_schedule("nn.deformable_conv2d") def schedule_deformable_conv2d(attrs, outs, target): """Schedule definition of deformable_conv2d""" with target: return topi.generic.schedule_deformable_conv2d_nchw(outs) + reg.register_pattern("nn.deformable_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 0c211383423c..2d13f53f17fd 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -28,7 +28,6 @@ def conv2d(data, dilation=(1, 1), groups=1, channels=None, - in_channels=None, kernel_size=None, data_layout="NCHW", kernel_layout="OIHW", @@ -82,9 +81,6 @@ def conv2d(data, channels : int, optional Number of output channels of this convolution. - in_channels : int, optional - Number of input channels of this convolution. - kernel_size : tuple of int, optional The spatial of the convolution kernel. @@ -106,7 +102,7 @@ def conv2d(data, The computed result. """ return _make.conv2d(data, weight, strides, padding, dilation, - groups, channels, in_channels, kernel_size, data_layout, + groups, channels, kernel_size, data_layout, kernel_layout, out_layout, out_dtype) diff --git a/python/tvm/target.py b/python/tvm/target.py index d3df3d705cb8..eff0088b37ce 100644 --- a/python/tvm/target.py +++ b/python/tvm/target.py @@ -296,7 +296,7 @@ def dispatch_func(func, *args, **kwargs): def generic_func(fdefault): """Wrap a target generic function. - Generic function allows registeration of further functions + Generic function allows registration of further functions that can be dispatched on current target context. If no registered dispatch is matched, the fdefault will be called. diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 996ee3e935af..97cba7964000 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -148,7 +148,6 @@ Expr MakeConv2D(Expr data, Array dilation, int groups, IndexExpr channels, - IndexExpr in_channels, Array kernel_size, std::string data_layout, std::string kernel_layout, diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index 70ce5791d905..7bd95688b75d 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -242,7 +242,7 @@ def schedule_depthwise_conv2d_NCHWc(outs): @tvm.target.generic_func def schedule_group_conv2d_nchw(outs): - """Schedule for conv2d_nchw + """Schedule for group_conv2d_nchw Parameters ---------- diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 49c0bd79eacc..8741441e7c50 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -561,6 +561,9 @@ def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtyp Output : tvm.Tensor 4-D with shape [batch, out_channel, out_height, out_width] """ + return _group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtype=out_dtype) + +def _group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtype=None): if out_dtype is None: out_dtype = Input.dtype assert isinstance(stride, int) or len(stride) == 2 @@ -603,4 +606,4 @@ def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtyp yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w].astype(out_dtype) * Filter[ff, rc, ry, rx].astype(out_dtype), - axis=[rc, ry, rx]), tag="conv2d_nchw") + axis=[rc, ry, rx]), tag="group_conv2d_nchw") diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 02f78f8007f9..84570f593df2 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -27,7 +27,7 @@ from .. import generic, tag from .. import nn from ..util import get_const_tuple -from ..nn.conv2d import conv2d, conv2d_NCHWc, \ +from ..nn.conv2d import conv2d, conv2d_NCHWc, _group_conv2d_nchw, \ conv2d_alter_layout, _get_workload as _get_conv2d_workload from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw @@ -523,3 +523,15 @@ def traverse(op): traverse(outs[0].op) return s + +@autotvm.register_topi_compute(nn.group_conv2d_nchw, 'cpu', 'direct') +def _declaration_group_conv2d_nchw(cfg, data, kernel, strides, padding, dilation, groups, + out_dtype): + """ A wrapper for generic group_conv2d_nchw """ + out_dtype = data.dtype if out_dtype is None else out_dtype + padding = padding if isinstance(padding, (tuple, list)) else (padding, padding) + strides = strides if isinstance(strides, (tuple, list)) else (strides, strides) + dilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) + + return _group_conv2d_nchw(data, kernel, strides, padding, dilation, groups, + out_dtype=out_dtype) From e524d6230cfb93a0f3bd94b2bde8a42085af0fe2 Mon Sep 17 00:00:00 2001 From: Ruizhe Date: Tue, 23 Apr 2019 16:42:54 +0100 Subject: [PATCH 03/15] Improved by checking tag value --- python/tvm/relay/op/nn/_nn.py | 36 +++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 017cdb2664de..8ceae6561b69 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -85,6 +85,17 @@ def schedule_batch_matmul(attrs, outputs, target): # conv2d +def _find_conv2d_op(op): + """Find the op with conv2d in its tag by traversing.""" + if 'conv2d' in op.tag: + return op + for tensor in op.input_tensors: + op_ = _find_conv2d_op(tensor.op) + if op_ is not None: + return op_ + return None + + @reg.register_compute("nn.conv2d") def compute_conv2d(attrs, inputs, out_type, target): """Compute definition of conv2d""" @@ -142,18 +153,19 @@ def schedule_conv2d(attrs, outs, target): return topi.generic.schedule_conv2d_nhwc(outs) if groups != 1: # collect in_channels to distinguish depthwise and group conv2d - wkl = outs[0].op.attrs['workload'] - in_channels = wkl[1][1] - out_channels = outs[0].shape[1] - - if layout == "NCHW" and in_channels == groups and in_channels == out_channels: - # TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d. - return topi.generic.schedule_depthwise_conv2d_nchw(outs) - if layout == "NHWC" and kernel_layout == "HWOI" and \ - in_channels == groups and in_channels == out_channels: - return topi.generic.schedule_depthwise_conv2d_nhwc(outs) - if layout == "NCHW": - return topi.generic.schedule_group_conv2d_nchw(outs) + op = _find_conv2d_op(outs[0].op) + assert op is not None + + is_depthwise = 'depthwise' in op.tag + if is_depthwise: + if layout == "NCHW": + # TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d. + return topi.generic.schedule_depthwise_conv2d_nchw(outs) + if layout == "NHWC" and kernel_layout == "HWOI": + return topi.generic.schedule_depthwise_conv2d_nhwc(outs) + else: + if layout == "NCHW": + return topi.generic.schedule_group_conv2d_nchw(outs) raise ValueError("No compatible schedule") From a56fb77832f6de1887923bb77623f9c54bee885d Mon Sep 17 00:00:00 2001 From: Ruizhe Date: Wed, 24 Apr 2019 11:06:58 +0100 Subject: [PATCH 04/15] Removed group_conv2d_nchw topi registration --- python/tvm/relay/op/nn/_nn.py | 6 +++--- topi/python/topi/x86/conv2d.py | 14 +------------- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 8ceae6561b69..b7b645e9a8e0 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -163,9 +163,9 @@ def schedule_conv2d(attrs, outs, target): return topi.generic.schedule_depthwise_conv2d_nchw(outs) if layout == "NHWC" and kernel_layout == "HWOI": return topi.generic.schedule_depthwise_conv2d_nhwc(outs) - else: - if layout == "NCHW": - return topi.generic.schedule_group_conv2d_nchw(outs) + + if layout in ["NCHW", "NCHW4c"]: + return topi.generic.schedule_group_conv2d_nchw(outs) raise ValueError("No compatible schedule") diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 84570f593df2..02f78f8007f9 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -27,7 +27,7 @@ from .. import generic, tag from .. import nn from ..util import get_const_tuple -from ..nn.conv2d import conv2d, conv2d_NCHWc, _group_conv2d_nchw, \ +from ..nn.conv2d import conv2d, conv2d_NCHWc, \ conv2d_alter_layout, _get_workload as _get_conv2d_workload from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw @@ -523,15 +523,3 @@ def traverse(op): traverse(outs[0].op) return s - -@autotvm.register_topi_compute(nn.group_conv2d_nchw, 'cpu', 'direct') -def _declaration_group_conv2d_nchw(cfg, data, kernel, strides, padding, dilation, groups, - out_dtype): - """ A wrapper for generic group_conv2d_nchw """ - out_dtype = data.dtype if out_dtype is None else out_dtype - padding = padding if isinstance(padding, (tuple, list)) else (padding, padding) - strides = strides if isinstance(strides, (tuple, list)) else (strides, strides) - dilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) - - return _group_conv2d_nchw(data, kernel, strides, padding, dilation, groups, - out_dtype=out_dtype) From a25dfcd7067fef640cac3b120f8bae2f34f4cc0b Mon Sep 17 00:00:00 2001 From: Ruizhe Date: Wed, 24 Apr 2019 16:13:33 +0100 Subject: [PATCH 05/15] Added test for relay group_conv2d_nchw --- python/tvm/relay/op/nn/_nn.py | 6 +++--- tests/python/relay/test_op_level2.py | 13 ++++++++++++- topi/python/topi/nn/conv2d.py | 3 --- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index b7b645e9a8e0..35a3ed2f265a 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -163,9 +163,9 @@ def schedule_conv2d(attrs, outs, target): return topi.generic.schedule_depthwise_conv2d_nchw(outs) if layout == "NHWC" and kernel_layout == "HWOI": return topi.generic.schedule_depthwise_conv2d_nhwc(outs) - - if layout in ["NCHW", "NCHW4c"]: - return topi.generic.schedule_group_conv2d_nchw(outs) + else: + if layout in ["NCHW", "NCHW4c"]: + return topi.generic.schedule_group_conv2d_nchw(outs) raise ValueError("No compatible schedule") diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index a6efd8cf0971..635417e2f853 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -100,7 +100,8 @@ def run_test_conv2d(dtype, out_dtype, scale, dshape, kshape, dkernel = topi.testing.dilate_python(kernel, (1, 1) + dilation) if fref is None: ref_res = topi.testing.conv2d_nchw_python( - data.astype(out_dtype), dkernel.astype(out_dtype), 1, padding) + data.astype(out_dtype), dkernel.astype(out_dtype), 1, padding, + groups=groups) else: ref_res = fref(data.astype(out_dtype), dkernel.astype(out_dtype)) @@ -116,6 +117,16 @@ def run_test_conv2d(dtype, out_dtype, scale, dshape, kshape, padding=(1, 1), channels=32, groups=32, kernel_size=(3 ,3), fref=lambda x, w: topi.testing.depthwise_conv2d_python_nchw( x, w, (1, 1), "SAME")) + # group conv2d + dshape = (1, 32, 18, 18) + kshape = (32, 2, 3, 3) + run_test_conv2d("float32", "float32", 1, dshape, kshape, + padding=(1, 1), channels=32, groups=16, kernel_size=(3 ,3)) + # also group conv2d + dshape = (1, 32, 18, 18) + kshape = (64, 1, 3, 3) + run_test_conv2d("float32", "float32", 1, dshape, kshape, + padding=(1, 1), channels=64, groups=32, kernel_size=(3 ,3)) # normal conv2d dshape = (1, 3, 224, 224) diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 8741441e7c50..1ae42a976a48 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -561,9 +561,6 @@ def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtyp Output : tvm.Tensor 4-D with shape [batch, out_channel, out_height, out_width] """ - return _group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtype=out_dtype) - -def _group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtype=None): if out_dtype is None: out_dtype = Input.dtype assert isinstance(stride, int) or len(stride) == 2 From fd24e4b2f5563cf1901ee2a1074c206fc7ed88ef Mon Sep 17 00:00:00 2001 From: Vincent Zhao Date: Thu, 25 Apr 2019 10:07:27 +0100 Subject: [PATCH 06/15] Added assertions to forbid small group size --- tests/python/relay/test_op_level2.py | 4 ++-- topi/python/topi/cuda/group_conv2d_nchw.py | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 635417e2f853..27a15b68c2ad 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -119,9 +119,9 @@ def run_test_conv2d(dtype, out_dtype, scale, dshape, kshape, x, w, (1, 1), "SAME")) # group conv2d dshape = (1, 32, 18, 18) - kshape = (32, 2, 3, 3) + kshape = (32, 4, 3, 3) run_test_conv2d("float32", "float32", 1, dshape, kshape, - padding=(1, 1), channels=32, groups=16, kernel_size=(3 ,3)) + padding=(1, 1), channels=32, groups=8, kernel_size=(3 ,3)) # also group conv2d dshape = (1, 32, 18, 18) kshape = (64, 1, 3, 3) diff --git a/topi/python/topi/cuda/group_conv2d_nchw.py b/topi/python/topi/cuda/group_conv2d_nchw.py index 601b9b6e062c..8f9bfe985fbc 100644 --- a/topi/python/topi/cuda/group_conv2d_nchw.py +++ b/topi/python/topi/cuda/group_conv2d_nchw.py @@ -99,6 +99,13 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups, oc_chunk, _, kernel_h, kernel_w, oc_block, ic_block = get_const_tuple( packed_kernel.shape) + assert groups >= oc_chunk, \ + ('Number of groups {} should not be less than ' + 'output channel chunk size {}'.format(groups, oc_chunk)) + assert groups >= ic_chunk, \ + ('Number of groups {} should not be less than ' + 'input channel chunk size {}'.format(groups, ic_chunk)) + if isinstance(stride, int): stride_h = stride_w = stride else: From 4774f28d4aa0f3f536191eeac8ae10281300e7a6 Mon Sep 17 00:00:00 2001 From: Vincent Zhao Date: Thu, 25 Apr 2019 10:25:12 +0100 Subject: [PATCH 07/15] Removed hard-coded oc_block_factor --- topi/python/topi/cuda/group_conv2d_nchw.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/topi/python/topi/cuda/group_conv2d_nchw.py b/topi/python/topi/cuda/group_conv2d_nchw.py index 8f9bfe985fbc..e5a37c2821b4 100644 --- a/topi/python/topi/cuda/group_conv2d_nchw.py +++ b/topi/python/topi/cuda/group_conv2d_nchw.py @@ -76,7 +76,7 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups, assert out_channels % groups == 0, "output channels must divide group size" assert channels % ic_block_factor == 0, \ "Number of input channels per group must divide {}".format(ic_block_factor) - assert out_channels % 4 == 0, \ + assert out_channels % oc_block_factor == 0, \ "Number of output channels per group must divide {}".format(oc_block_factor) packed_data = tvm.compute((batch, channels // ic_block_factor, height, width, From 626a5e6d9f783c6563fa4d1e10e53c8a89975493 Mon Sep 17 00:00:00 2001 From: Vincent Zhao Date: Thu, 25 Apr 2019 10:50:58 +0100 Subject: [PATCH 08/15] Added explanatory comments to group_conv2d_nchw_cuda --- topi/python/topi/cuda/group_conv2d_nchw.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/topi/python/topi/cuda/group_conv2d_nchw.py b/topi/python/topi/cuda/group_conv2d_nchw.py index e5a37c2821b4..7a2007b7c772 100644 --- a/topi/python/topi/cuda/group_conv2d_nchw.py +++ b/topi/python/topi/cuda/group_conv2d_nchw.py @@ -99,6 +99,10 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups, oc_chunk, _, kernel_h, kernel_w, oc_block, ic_block = get_const_tuple( packed_kernel.shape) + # TODO(kumasento): these assertions ensure that the number of groups + # should be larger or equal to the number of blocks, so that each + # group will have at least one block. + # Shall we pad the channels to avoid raising assertions? assert groups >= oc_chunk, \ ('Number of groups {} should not be less than ' 'output channel chunk size {}'.format(groups, oc_chunk)) @@ -116,9 +120,9 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups, else: dilation_h, dilation_w = dilation + # pad the input data pad_top, pad_left, pad_down, pad_right = get_pad_tuple( padding, (kernel_h, kernel_w)) - # compute graph pad_before = [0, 0, pad_top, pad_left, 0] pad_after = [0, 0, pad_down, pad_right, 0] pad_data = pad(packed_data, pad_before, pad_after, name="pad_data") @@ -136,6 +140,17 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups, kh = tvm.reduce_axis((0, kernel_h), name='kh') kw = tvm.reduce_axis((0, kernel_w), name='kw') + # NOTE(kumasento): explanation of this snippet - + # oc_chunk//groups and ic_chunk//groups give you the number of blocks, + # i.e., chunk, per group. + # occ is the ID of the output channel block, so that occ//(oc_chunk//groups) + # produces the ID of the group. + # Multiplying that result with ic_chunk//groups resulting in the ID + # of the beginning block of the corresponding input group. + # Adding the block offset (icc) will give you the exact block ID. + # + # Compared with a normal convolution, group convolution only sums + # input channels from the group that an output channel resides in. conv = tvm.compute(oshape, lambda n, occ, oh, ow, ocb: tvm.sum(pad_data[n, occ//(oc_chunk//groups)*(ic_chunk//groups)+icc, oh*stride_h+kh*dilation_h, ow*stride_w+kw*dilation_w, icb] @@ -145,8 +160,10 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups, .astype('int32'), axis=[icc, kh, kw, icb])) + # Type conversion output = tvm.compute(oshape, lambda *index: conv(*index).astype(out_dtype), tag='group_conv2d_NCHWc_int8') + num_flop = batch * oc_chunk * oc_block * out_height * out_width * \ ic_chunk * ic_block * kernel_h * kernel_w * 2 // groups cfg.add_flop(num_flop) From 8de44b26b2f05c3ab1ec2e3785a18bdbe044716c Mon Sep 17 00:00:00 2001 From: Vincent Zhao Date: Thu, 25 Apr 2019 11:37:03 +0100 Subject: [PATCH 09/15] Updated group_conv2d_nchw_cuda schedule Removed 'direct' CUDA tests --- tests/python/relay/test_op_level2.py | 28 ++++++++++++++++++---- topi/python/topi/cuda/group_conv2d_nchw.py | 8 ++++--- topi/python/topi/nn/conv2d.py | 7 +++++- 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 27a15b68c2ad..b112ca81587a 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -86,9 +86,13 @@ def run_test_conv2d(dtype, out_dtype, scale, dshape, kshape, fref=None, groups=1, dilation=(1, 1), + except_targets=None, **attrs): - x = relay.var("x", shape=dshape) - w = relay.var("w") + if except_targets is None: + except_targets = [] + + x = relay.var("x", shape=dshape, dtype=dtype) + w = relay.var("w", dtype=dtype) y = relay.nn.conv2d(x, w, padding=padding, dilation=dilation, @@ -105,7 +109,10 @@ def run_test_conv2d(dtype, out_dtype, scale, dshape, kshape, else: ref_res = fref(data.astype(out_dtype), dkernel.astype(out_dtype)) + for target, ctx in ctx_list(): + if target in except_targets: + continue intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(data, kernel) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) @@ -117,16 +124,21 @@ def run_test_conv2d(dtype, out_dtype, scale, dshape, kshape, padding=(1, 1), channels=32, groups=32, kernel_size=(3 ,3), fref=lambda x, w: topi.testing.depthwise_conv2d_python_nchw( x, w, (1, 1), "SAME")) + + # CUDA is disabled for 'direct' schedule: + # https://github.com/dmlc/tvm/pull/3070#issuecomment-486597553 # group conv2d dshape = (1, 32, 18, 18) kshape = (32, 4, 3, 3) run_test_conv2d("float32", "float32", 1, dshape, kshape, - padding=(1, 1), channels=32, groups=8, kernel_size=(3 ,3)) + padding=(1, 1), channels=32, groups=8, kernel_size=(3 ,3), + except_targets=['cuda']) # also group conv2d dshape = (1, 32, 18, 18) kshape = (64, 1, 3, 3) run_test_conv2d("float32", "float32", 1, dshape, kshape, - padding=(1, 1), channels=64, groups=32, kernel_size=(3 ,3)) + padding=(1, 1), channels=64, groups=32, kernel_size=(3 ,3), + except_targets=['cuda']) # normal conv2d dshape = (1, 3, 224, 224) @@ -138,8 +150,14 @@ def run_test_conv2d(dtype, out_dtype, scale, dshape, kshape, padding=(1, 1), channels=10, kernel_size=(3 ,3)) kshape = (10, 3, 1, 3) # mixed precision. - run_test_conv2d("int8", "int32", 1, dshape, kshape, + run_test_conv2d("int8", "int8", 1, dshape, kshape, padding=(0, 1), channels=10, kernel_size=(1 ,3)) + # mixed precision group conv2d + # NOTE(kumasento): This test cannot pass + # dshape = (1, 32, 18, 18) + # kshape = (32, 4, 3, 3) + # run_test_conv2d("int8", "int32", 1, dshape, kshape, + # padding=(1, 1), channels=32, groups=8, kernel_size=(3, 3)) # dilated conv2d dshape = (1, 3, 18, 18) kshape = (10, 3, 3, 3) diff --git a/topi/python/topi/cuda/group_conv2d_nchw.py b/topi/python/topi/cuda/group_conv2d_nchw.py index 7a2007b7c772..9b8e4e3484df 100644 --- a/topi/python/topi/cuda/group_conv2d_nchw.py +++ b/topi/python/topi/cuda/group_conv2d_nchw.py @@ -27,10 +27,12 @@ from .. import nn, generic -@autotvm.register_topi_compute(nn.group_conv2d_nchw, ['cuda', 'gpu'], ['direct', 'int8']) +# autotvm.register_topi_compute(nn.group_conv2d_nchw, ['cuda', 'gpu'], 'direct', nn.group_conv2d_nchw.fdefault) + +@autotvm.register_topi_compute(nn.group_conv2d_nchw, ['cuda', 'gpu'], ['int8']) def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups, out_dtype='float32'): - """Group convolution operator in NCHW layout. + """Group convolution operator for 'group_conv2d_NCHWc_int8'. Parameters ---------- @@ -319,7 +321,7 @@ def schedule_group_conv2d_NCHWc_int8(cfg, s, output): @autotvm.register_topi_schedule(generic.schedule_group_conv2d_nchw, - ["cuda", "gpu"], ["direct", "int8"]) + ["cuda", "gpu"], ["int8"]) def schedule_conv2d_nchw_cuda(cfg, outs): """TOPI schedule callback of group conv2d for cuda gpu diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 1ae42a976a48..315adbe18d08 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -575,6 +575,11 @@ def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtyp else: dilation_h, dilation_w = dilation + if Input.dtype == 'int8': + tag = 'group_conv2d_NCHWc_int8' + else: + tag = 'group_conv2d_nchw' + batch, in_channel, in_height, in_width = get_const_tuple(Input.shape) num_filter, _, kernel_h, kernel_w = get_const_tuple(Filter.shape) @@ -603,4 +608,4 @@ def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtyp yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w].astype(out_dtype) * Filter[ff, rc, ry, rx].astype(out_dtype), - axis=[rc, ry, rx]), tag="group_conv2d_nchw") + axis=[rc, ry, rx]), tag=tag) From 4866a1756ee88c0ba9b3d6b08d260b6809f4d2e6 Mon Sep 17 00:00:00 2001 From: Vincent Zhao Date: Thu, 25 Apr 2019 11:39:52 +0100 Subject: [PATCH 10/15] Reverted an accidental change in a conv2d test --- tests/python/relay/test_op_level2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index b112ca81587a..89b9477d4013 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -150,7 +150,7 @@ def run_test_conv2d(dtype, out_dtype, scale, dshape, kshape, padding=(1, 1), channels=10, kernel_size=(3 ,3)) kshape = (10, 3, 1, 3) # mixed precision. - run_test_conv2d("int8", "int8", 1, dshape, kshape, + run_test_conv2d("int8", "int32", 1, dshape, kshape, padding=(0, 1), channels=10, kernel_size=(1 ,3)) # mixed precision group conv2d # NOTE(kumasento): This test cannot pass From c42b3bf00fa6a595d07c4dcc1ec0c7940b01c0bb Mon Sep 17 00:00:00 2001 From: Vincent Zhao Date: Thu, 25 Apr 2019 11:50:30 +0100 Subject: [PATCH 11/15] Fixed indentation problems --- topi/python/topi/nn/conv2d.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 315adbe18d08..3f7bc285d3eb 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -576,9 +576,9 @@ def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtyp dilation_h, dilation_w = dilation if Input.dtype == 'int8': - tag = 'group_conv2d_NCHWc_int8' + tag = 'group_conv2d_NCHWc_int8' else: - tag = 'group_conv2d_nchw' + tag = 'group_conv2d_nchw' batch, in_channel, in_height, in_width = get_const_tuple(Input.shape) num_filter, _, kernel_h, kernel_w = get_const_tuple(Filter.shape) From 2c3f7c580ad099c310b5d60acbc0b188d1b5d9c4 Mon Sep 17 00:00:00 2001 From: Vincent Zhao Date: Thu, 25 Apr 2019 11:52:39 +0100 Subject: [PATCH 12/15] Fixed a mis-commented line --- topi/python/topi/cuda/group_conv2d_nchw.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/topi/python/topi/cuda/group_conv2d_nchw.py b/topi/python/topi/cuda/group_conv2d_nchw.py index 9b8e4e3484df..46904c3bf535 100644 --- a/topi/python/topi/cuda/group_conv2d_nchw.py +++ b/topi/python/topi/cuda/group_conv2d_nchw.py @@ -27,7 +27,8 @@ from .. import nn, generic -# autotvm.register_topi_compute(nn.group_conv2d_nchw, ['cuda', 'gpu'], 'direct', nn.group_conv2d_nchw.fdefault) +autotvm.register_topi_compute(nn.group_conv2d_nchw, ['cuda', 'gpu'], 'direct', + nn.group_conv2d_nchw.fdefault) @autotvm.register_topi_compute(nn.group_conv2d_nchw, ['cuda', 'gpu'], ['int8']) def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups, From db8b2a90db88c962c644e289714db8cf1f32d657 Mon Sep 17 00:00:00 2001 From: Vincent Zhao Date: Thu, 25 Apr 2019 16:08:18 +0100 Subject: [PATCH 13/15] Reverted change in group_conv2d_nchw tag --- topi/python/topi/nn/conv2d.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 3f7bc285d3eb..06d4074147c1 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -575,11 +575,6 @@ def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtyp else: dilation_h, dilation_w = dilation - if Input.dtype == 'int8': - tag = 'group_conv2d_NCHWc_int8' - else: - tag = 'group_conv2d_nchw' - batch, in_channel, in_height, in_width = get_const_tuple(Input.shape) num_filter, _, kernel_h, kernel_w = get_const_tuple(Filter.shape) @@ -608,4 +603,4 @@ def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtyp yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w].astype(out_dtype) * Filter[ff, rc, ry, rx].astype(out_dtype), - axis=[rc, ry, rx]), tag=tag) + axis=[rc, ry, rx]), tag='group_conv2d_nchw') From b7ed2473ecf479587062c079a172c865ce1759a0 Mon Sep 17 00:00:00 2001 From: Vincent Zhao Date: Fri, 26 Apr 2019 08:59:45 +0100 Subject: [PATCH 14/15] Removed commented int8 group_conv2d test --- tests/python/relay/test_op_level2.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 89b9477d4013..88963a63c770 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -152,12 +152,6 @@ def run_test_conv2d(dtype, out_dtype, scale, dshape, kshape, # mixed precision. run_test_conv2d("int8", "int32", 1, dshape, kshape, padding=(0, 1), channels=10, kernel_size=(1 ,3)) - # mixed precision group conv2d - # NOTE(kumasento): This test cannot pass - # dshape = (1, 32, 18, 18) - # kshape = (32, 4, 3, 3) - # run_test_conv2d("int8", "int32", 1, dshape, kshape, - # padding=(1, 1), channels=32, groups=8, kernel_size=(3, 3)) # dilated conv2d dshape = (1, 3, 18, 18) kshape = (10, 3, 3, 3) From e901c7cd20987d406a4b14d9d8850665892086fc Mon Sep 17 00:00:00 2001 From: Vincent Zhao Date: Fri, 26 Apr 2019 10:33:46 +0100 Subject: [PATCH 15/15] Fixed group size assertions in group_conv2d_nchw_cuda --- topi/python/topi/cuda/group_conv2d_nchw.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/topi/python/topi/cuda/group_conv2d_nchw.py b/topi/python/topi/cuda/group_conv2d_nchw.py index 46904c3bf535..be4ae3554e33 100644 --- a/topi/python/topi/cuda/group_conv2d_nchw.py +++ b/topi/python/topi/cuda/group_conv2d_nchw.py @@ -103,14 +103,14 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups, packed_kernel.shape) # TODO(kumasento): these assertions ensure that the number of groups - # should be larger or equal to the number of blocks, so that each + # should be smaller or equal to the number of blocks, so that each # group will have at least one block. # Shall we pad the channels to avoid raising assertions? - assert groups >= oc_chunk, \ - ('Number of groups {} should not be less than ' + assert groups <= oc_chunk, \ + ('Number of groups {} should be less than ' 'output channel chunk size {}'.format(groups, oc_chunk)) - assert groups >= ic_chunk, \ - ('Number of groups {} should not be less than ' + assert groups <= ic_chunk, \ + ('Number of groups {} should be less than ' 'input channel chunk size {}'.format(groups, ic_chunk)) if isinstance(stride, int):