diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index ebedc20375e5..53f104ce48cf 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -169,7 +169,6 @@ class Conv(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - # get number of channels out = AttrCvt(op_name=dimension_picker('conv'), transforms={ 'kernel_shape': 'kernel_size', diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index e60c01cfb3ff..35a3ed2f265a 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -#pylint: disable=invalid-name, unused-argument +# pylint: disable=invalid-name, unused-argument """Backend compiler related feature registration""" from __future__ import absolute_import @@ -34,16 +34,19 @@ def schedule_softmax(_, outputs, target): with target: return topi.generic.schedule_softmax(outputs) + reg.register_pattern("nn.softmax", OpPattern.OPAQUE) schedule_broadcast = schedule_injective + @reg.register_schedule("nn.log_softmax") def schedule_log_softmax(_, outputs, target): """Schedule definition of log_softmax""" with target: return topi.generic.schedule_softmax(outputs) + reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE) @@ -53,12 +56,14 @@ def compute_dense(attrs, inputs, out_type, target): """Compute definition of dense""" return [topi.nn.dense(inputs[0], inputs[1])] + @reg.register_schedule("nn.dense") def schedule_dense(attrs, outputs, target): """Schedule definition of dense""" with target: return topi.generic.schedule_dense(outputs) + reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) @@ -68,16 +73,29 @@ def compute_batch_matmul(attrs, inputs, out_type, target): """Compute definition of batch_matmul""" return [topi.nn.batch_matmul(inputs[0], inputs[1])] + @reg.register_schedule("nn.batch_matmul") def schedule_batch_matmul(attrs, outputs, target): """Schedule definition of batch_matmul""" with target: return topi.generic.schedule_batch_matmul(outputs) + reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE) # conv2d +def _find_conv2d_op(op): + """Find the op with conv2d in its tag by traversing.""" + if 'conv2d' in op.tag: + return op + for tensor in op.input_tensors: + op_ = _find_conv2d_op(tensor.op) + if op_ is not None: + return op_ + return None + + @reg.register_compute("nn.conv2d") def compute_conv2d(attrs, inputs, out_type, target): """Compute definition of conv2d""" @@ -101,14 +119,14 @@ def compute_conv2d(attrs, inputs, out_type, target): inputs[0], inputs[1], strides, padding, dilation, layout, out_dtype=out_dtype) elif layout == "NCHW" and \ - get_const_int(inputs[1].shape[0]) == groups and \ - get_const_int(inputs[1].shape[1]) == 1: + get_const_int(inputs[1].shape[0]) == groups and \ + get_const_int(inputs[1].shape[1]) == 1: out = topi.nn.depthwise_conv2d_nchw( inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype) elif layout == "NHWC" and \ - kernel_layout == "HWOI" and\ - get_const_int(inputs[1].shape[2]) == groups and \ - get_const_int(inputs[1].shape[3]) == 1: + kernel_layout == "HWOI" and\ + get_const_int(inputs[1].shape[2]) == groups and \ + get_const_int(inputs[1].shape[3]) == 1: out = topi.nn.depthwise_conv2d_nhwc( inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype) elif layout in ['NCHW', 'NCHW4c']: @@ -125,6 +143,7 @@ def schedule_conv2d(attrs, outs, target): groups = attrs.groups layout = attrs.data_layout kernel_layout = attrs.kernel_layout + with target: if groups == 1 and layout == "NCHW": return topi.generic.schedule_conv2d_nchw(outs) @@ -133,13 +152,20 @@ def schedule_conv2d(attrs, outs, target): if groups == 1 and layout == "NHWC": return topi.generic.schedule_conv2d_nhwc(outs) if groups != 1: - if layout == "NCHW": - # TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d. - return topi.generic.schedule_depthwise_conv2d_nchw(outs) - if layout == "NHWC" and kernel_layout == "HWOI": - return topi.generic.schedule_depthwise_conv2d_nhwc(outs) - if layout == "NCHW4c": - return topi.generic.schedule_group_conv2d_nchw(outs) + # collect in_channels to distinguish depthwise and group conv2d + op = _find_conv2d_op(outs[0].op) + assert op is not None + + is_depthwise = 'depthwise' in op.tag + if is_depthwise: + if layout == "NCHW": + # TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d. + return topi.generic.schedule_depthwise_conv2d_nchw(outs) + if layout == "NHWC" and kernel_layout == "HWOI": + return topi.generic.schedule_depthwise_conv2d_nhwc(outs) + else: + if layout in ["NCHW", "NCHW4c"]: + return topi.generic.schedule_group_conv2d_nchw(outs) raise ValueError("No compatible schedule") @@ -149,6 +175,7 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos): from ... import op return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op) + reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) @@ -167,18 +194,21 @@ def compute_conv2d_transpose(attrs, inputs, out_dtype, target): assert layout == "NCHW", "only support nchw for now" assert dilation == (1, 1), "not support dilate now" assert groups == 1, "only support groups == 1 for now" - out = topi.nn.conv2d_transpose_nchw(inputs[0], inputs[1], strides, padding, out_dtype) + out = topi.nn.conv2d_transpose_nchw( + inputs[0], inputs[1], strides, padding, out_dtype) output_padding = get_const_tuple(attrs.output_padding) out = topi.nn.pad(out, [0, 0, 0, 0], [0, 0, output_padding[0], output_padding[1]]) return [out] + @reg.register_schedule("nn.conv2d_transpose") def schedule_conv2d_transpose(attrs, outs, target): """Schedule definition of conv2d_transpose""" with target: return topi.generic.schedule_conv2d_transpose_nchw(outs) + reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE) # bias_add @@ -194,6 +224,7 @@ def schedule_max_pool2d(attrs, outs, target): with target: return topi.generic.schedule_pool(outs, layout) + reg.register_pattern("nn.max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) @@ -205,6 +236,7 @@ def schedule_avg_pool2d(attrs, outs, target): with target: return topi.generic.schedule_pool(outs, layout) + reg.register_pattern("nn.avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) @@ -215,6 +247,7 @@ def schedule_global_max_pool2d(_, outs, target): with target: return topi.generic.schedule_global_pool(outs) + reg.register_pattern("nn.global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) @@ -225,6 +258,7 @@ def schedule_global_avg_pool2d(_, outs, target): with target: return topi.generic.schedule_global_pool(outs) + reg.register_pattern("nn.global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) # leaky_relu @@ -248,12 +282,14 @@ def compute_lrn(attrs, inputs, out_dtype, target): return [topi.nn.lrn(inputs[0], attrs.size, attrs.axis, attrs.alpha, attrs.beta, attrs.bias)] + @reg.register_schedule("nn.lrn") def schedule_lrn(attrs, outs, target): """Schedule definition of lrn""" with target: return topi.generic.schedule_lrn(outs) + reg.register_pattern("nn.lrn", OpPattern.OPAQUE) @@ -263,20 +299,26 @@ def compute_l2_normalize(attrs, inputs, out_dtype, target): """Compute definition of l2 normalize""" return [topi.nn.l2_normalize(inputs[0], attrs.eps, attrs.axis)] + @reg.register_schedule("nn.l2_normalize") def schedule_l2_normalize(attrs, outs, target): """Schedule definition of l2 normalize""" with target: return topi.generic.schedule_l2_normalize(outs) + reg.register_pattern("nn.l2_normalize", OpPattern.OUT_ELEMWISE_FUSABLE) # upsampling reg.register_schedule("nn.upsampling", reg.schedule_injective) + + def schedule_upsampling(_, outs, target): """Schedule definition of upsampling""" with target: return topi.generic.schedule_injective(outs) + + # pad reg.register_schedule("nn.pad", schedule_broadcast) @@ -302,12 +344,14 @@ def compute_contrib_conv2d_winograd_without_weight_transform(attrs, inputs, out_ return [out] + @reg.register_schedule("nn.contrib_conv2d_winograd_without_weight_transform") def schedule_contrib_conv2d_winograd_without_weight_transform(attrs, outs, target): """Schedule definition of conv2d_winograd_without_weight_transform""" with target: return topi.generic.schedule_conv2d_winograd_without_weight_transform(outs) + reg.register_pattern("nn.contrib_conv2d_winograd_without_weight_transform", OpPattern.OUT_ELEMWISE_FUSABLE) @@ -315,15 +359,18 @@ def schedule_contrib_conv2d_winograd_without_weight_transform(attrs, outs, targe @reg.register_compute("nn.contrib_conv2d_winograd_weight_transform") def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, out_dtype, target): """Compute definition of contrib_conv2d_winograd_weight_transform""" - out = topi.nn.conv2d_winograd_weight_transform(inputs[0], attrs.get_int('tile_size')) + out = topi.nn.conv2d_winograd_weight_transform( + inputs[0], attrs.get_int('tile_size')) return [out] + @reg.register_schedule("nn.contrib_conv2d_winograd_weight_transform") def schedule_contrib_conv2d_winograd_weight_transform(attrs, outs, target): """Schedule definition of contrib_conv2d_winograd_weight_transform""" with target: return topi.generic.schedule_conv2d_winograd_weight_transform(outs) + reg.register_pattern("nn.contrib_conv2d_winograd_weight_transform", OpPattern.OUT_ELEMWISE_FUSABLE) @@ -351,12 +398,14 @@ def compute_contrib_conv2d_winograd_nnpack_without_weight_transform( return [out] + @reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_without_weight_transform") def schedule_contrib_conv2d_winograd_nnpack_without_weight_transform(attrs, outs, target): """Schedule definition of conv2d_winograd_nnpack_without_weight_transform""" with target: return topi.generic.schedule_conv2d_winograd_nnpack_without_weight_transform(outs) + reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_without_weight_transform", OpPattern.OPAQUE) @@ -369,12 +418,14 @@ def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, out_d inputs[0], convolution_algorithm, out_dtype) return [out] + @reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_weight_transform") def schedule_contrib_conv2d_winograd_nnpack_weight_transform(attrs, outs, target): """Schedule definition of contrib_conv2d_winograd_nnpack_weight_transform""" with target: return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs) + reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_weight_transform", OpPattern.OPAQUE) @@ -395,15 +446,18 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, out_dtype, target): data_layout, out_layout, out_dtype) return [out] + @reg.register_schedule("nn.contrib_conv2d_NCHWc") def schedule_contrib_conv2d_NCHWc(attrs, outs, target): """Schedule definition of contrib_conv2d_NCHWc""" with target: return topi.generic.schedule_conv2d_NCHWc(outs) + reg.register_pattern("nn.contrib_conv2d_NCHWc", OpPattern.OUT_ELEMWISE_FUSABLE) + @reg.register_compute("nn.contrib_depthwise_conv2d_NCHWc") def compute_contrib_depthwise_conv2d_NCHWc(attrs, inputs, out_dtype, target): """Compute definition of depthwise conv2d NCHWc""" @@ -420,15 +474,18 @@ def compute_contrib_depthwise_conv2d_NCHWc(attrs, inputs, out_dtype, target): data_layout, out_layout, out_dtype) return [out] + @reg.register_schedule("nn.contrib_depthwise_conv2d_NCHWc") def schedule_contrib_depthwise_conv2d_NCHWc(attrs, outs, target): """Schedule definition of contrib_conv2d_NCHWc""" with target: return topi.generic.schedule_depthwise_conv2d_NCHWc(outs) + reg.register_pattern("nn.contrib_depthwise_conv2d_NCHWc", OpPattern.OUT_ELEMWISE_FUSABLE) + @reg.register_compute("nn.deformable_conv2d") def compute_deformable_conv2d(attrs, inputs, out_dtype, target): """Compute definition of deformable_conv2d""" @@ -444,10 +501,12 @@ def compute_deformable_conv2d(attrs, inputs, out_dtype, target): dilation, deformable_groups, groups, out_dtype) return [out] + @reg.register_schedule("nn.deformable_conv2d") def schedule_deformable_conv2d(attrs, outs, target): """Schedule definition of deformable_conv2d""" with target: return topi.generic.schedule_deformable_conv2d_nchw(outs) + reg.register_pattern("nn.deformable_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/python/tvm/target.py b/python/tvm/target.py index d3df3d705cb8..eff0088b37ce 100644 --- a/python/tvm/target.py +++ b/python/tvm/target.py @@ -296,7 +296,7 @@ def dispatch_func(func, *args, **kwargs): def generic_func(fdefault): """Wrap a target generic function. - Generic function allows registeration of further functions + Generic function allows registration of further functions that can be dispatched on current target context. If no registered dispatch is matched, the fdefault will be called. diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index a6efd8cf0971..88963a63c770 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -86,9 +86,13 @@ def run_test_conv2d(dtype, out_dtype, scale, dshape, kshape, fref=None, groups=1, dilation=(1, 1), + except_targets=None, **attrs): - x = relay.var("x", shape=dshape) - w = relay.var("w") + if except_targets is None: + except_targets = [] + + x = relay.var("x", shape=dshape, dtype=dtype) + w = relay.var("w", dtype=dtype) y = relay.nn.conv2d(x, w, padding=padding, dilation=dilation, @@ -100,11 +104,15 @@ def run_test_conv2d(dtype, out_dtype, scale, dshape, kshape, dkernel = topi.testing.dilate_python(kernel, (1, 1) + dilation) if fref is None: ref_res = topi.testing.conv2d_nchw_python( - data.astype(out_dtype), dkernel.astype(out_dtype), 1, padding) + data.astype(out_dtype), dkernel.astype(out_dtype), 1, padding, + groups=groups) else: ref_res = fref(data.astype(out_dtype), dkernel.astype(out_dtype)) + for target, ctx in ctx_list(): + if target in except_targets: + continue intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(data, kernel) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) @@ -117,6 +125,21 @@ def run_test_conv2d(dtype, out_dtype, scale, dshape, kshape, fref=lambda x, w: topi.testing.depthwise_conv2d_python_nchw( x, w, (1, 1), "SAME")) + # CUDA is disabled for 'direct' schedule: + # https://github.com/dmlc/tvm/pull/3070#issuecomment-486597553 + # group conv2d + dshape = (1, 32, 18, 18) + kshape = (32, 4, 3, 3) + run_test_conv2d("float32", "float32", 1, dshape, kshape, + padding=(1, 1), channels=32, groups=8, kernel_size=(3 ,3), + except_targets=['cuda']) + # also group conv2d + dshape = (1, 32, 18, 18) + kshape = (64, 1, 3, 3) + run_test_conv2d("float32", "float32", 1, dshape, kshape, + padding=(1, 1), channels=64, groups=32, kernel_size=(3 ,3), + except_targets=['cuda']) + # normal conv2d dshape = (1, 3, 224, 224) kshape = (10, 3, 3, 3) diff --git a/topi/python/topi/cuda/group_conv2d_nchw.py b/topi/python/topi/cuda/group_conv2d_nchw.py index 601b9b6e062c..be4ae3554e33 100644 --- a/topi/python/topi/cuda/group_conv2d_nchw.py +++ b/topi/python/topi/cuda/group_conv2d_nchw.py @@ -27,10 +27,13 @@ from .. import nn, generic -@autotvm.register_topi_compute(nn.group_conv2d_nchw, ['cuda', 'gpu'], ['direct', 'int8']) +autotvm.register_topi_compute(nn.group_conv2d_nchw, ['cuda', 'gpu'], 'direct', + nn.group_conv2d_nchw.fdefault) + +@autotvm.register_topi_compute(nn.group_conv2d_nchw, ['cuda', 'gpu'], ['int8']) def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups, out_dtype='float32'): - """Group convolution operator in NCHW layout. + """Group convolution operator for 'group_conv2d_NCHWc_int8'. Parameters ---------- @@ -76,7 +79,7 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups, assert out_channels % groups == 0, "output channels must divide group size" assert channels % ic_block_factor == 0, \ "Number of input channels per group must divide {}".format(ic_block_factor) - assert out_channels % 4 == 0, \ + assert out_channels % oc_block_factor == 0, \ "Number of output channels per group must divide {}".format(oc_block_factor) packed_data = tvm.compute((batch, channels // ic_block_factor, height, width, @@ -99,6 +102,17 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups, oc_chunk, _, kernel_h, kernel_w, oc_block, ic_block = get_const_tuple( packed_kernel.shape) + # TODO(kumasento): these assertions ensure that the number of groups + # should be smaller or equal to the number of blocks, so that each + # group will have at least one block. + # Shall we pad the channels to avoid raising assertions? + assert groups <= oc_chunk, \ + ('Number of groups {} should be less than ' + 'output channel chunk size {}'.format(groups, oc_chunk)) + assert groups <= ic_chunk, \ + ('Number of groups {} should be less than ' + 'input channel chunk size {}'.format(groups, ic_chunk)) + if isinstance(stride, int): stride_h = stride_w = stride else: @@ -109,9 +123,9 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups, else: dilation_h, dilation_w = dilation + # pad the input data pad_top, pad_left, pad_down, pad_right = get_pad_tuple( padding, (kernel_h, kernel_w)) - # compute graph pad_before = [0, 0, pad_top, pad_left, 0] pad_after = [0, 0, pad_down, pad_right, 0] pad_data = pad(packed_data, pad_before, pad_after, name="pad_data") @@ -129,6 +143,17 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups, kh = tvm.reduce_axis((0, kernel_h), name='kh') kw = tvm.reduce_axis((0, kernel_w), name='kw') + # NOTE(kumasento): explanation of this snippet - + # oc_chunk//groups and ic_chunk//groups give you the number of blocks, + # i.e., chunk, per group. + # occ is the ID of the output channel block, so that occ//(oc_chunk//groups) + # produces the ID of the group. + # Multiplying that result with ic_chunk//groups resulting in the ID + # of the beginning block of the corresponding input group. + # Adding the block offset (icc) will give you the exact block ID. + # + # Compared with a normal convolution, group convolution only sums + # input channels from the group that an output channel resides in. conv = tvm.compute(oshape, lambda n, occ, oh, ow, ocb: tvm.sum(pad_data[n, occ//(oc_chunk//groups)*(ic_chunk//groups)+icc, oh*stride_h+kh*dilation_h, ow*stride_w+kw*dilation_w, icb] @@ -138,8 +163,10 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups, .astype('int32'), axis=[icc, kh, kw, icb])) + # Type conversion output = tvm.compute(oshape, lambda *index: conv(*index).astype(out_dtype), tag='group_conv2d_NCHWc_int8') + num_flop = batch * oc_chunk * oc_block * out_height * out_width * \ ic_chunk * ic_block * kernel_h * kernel_w * 2 // groups cfg.add_flop(num_flop) @@ -295,7 +322,7 @@ def schedule_group_conv2d_NCHWc_int8(cfg, s, output): @autotvm.register_topi_schedule(generic.schedule_group_conv2d_nchw, - ["cuda", "gpu"], ["direct", "int8"]) + ["cuda", "gpu"], ["int8"]) def schedule_conv2d_nchw_cuda(cfg, outs): """TOPI schedule callback of group conv2d for cuda gpu diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index 70ce5791d905..7bd95688b75d 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -242,7 +242,7 @@ def schedule_depthwise_conv2d_NCHWc(outs): @tvm.target.generic_func def schedule_group_conv2d_nchw(outs): - """Schedule for conv2d_nchw + """Schedule for group_conv2d_nchw Parameters ---------- diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 49c0bd79eacc..06d4074147c1 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -603,4 +603,4 @@ def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtyp yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w].astype(out_dtype) * Filter[ff, rc, ry, rx].astype(out_dtype), - axis=[rc, ry, rx]), tag="conv2d_nchw") + axis=[rc, ry, rx]), tag='group_conv2d_nchw')