diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 2d9ef99ba8a6..947beb396ae2 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -236,20 +236,24 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): name="depthwise_conv2d_nhwc.arm_cpu", ) - # Optimized special case depthwiseConv2D operation. Requires a 3x3 kernel, a - # NHWC layout, a HWOI kernel layout (which we rearrange), no dilation, int8 inputs, - # int32 output, the same number of input and output channels, and for that channel - # count to be divisible by 4. Additional work could remove these restrictions. + # Optimized special case depthwiseConv2D operation. Requires NHWC layout, + # a HWOI kernel layout (which we rearrange to a custom layout) no dilation, + # int8/16 inputs, int32 output, and the same number of input and output channels. + # The int8 implementation DOES need the DSP unit (for SXTB16), but it is not + # possible to use the DSP unit to speed up a NHWC depthwise convolution (though + # an NCHW convolution would benefit). elif ( - target.features.has_dsp - and kernel.shape[0] == kernel.shape[1] == 3 - and dilation_w == dilation_h == 1 + dilation_w == dilation_h == 1 and kernel.shape[3] == 1 # channel_multiplier == 1 - and data.dtype == "int8" and out_type.dtype == "int32" - and data.shape[3] % 4 == 0 + and ( + (data.shape[3] % 4 == 0 and data.dtype == "int8" and target.features.has_dsp) + or (data.shape[3] % 2 == 0 and data.dtype == "int16") + ) and (padding != "SAME" or data.shape[1] % stride_h == data.shape[2] % stride_w == 0) + # Ideally we should check that kernel is a Relay constant, but strategy functions + # don't have access to the data needed to check this. ): strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.depthwise_conv2d_nhwc_dsp), diff --git a/python/tvm/topi/arm_cpu/conv2d_alter_op.py b/python/tvm/topi/arm_cpu/conv2d_alter_op.py index 90461f0c1c99..d4878f4b6908 100644 --- a/python/tvm/topi/arm_cpu/conv2d_alter_op.py +++ b/python/tvm/topi/arm_cpu/conv2d_alter_op.py @@ -19,6 +19,8 @@ import logging +import numpy as np + import tvm from tvm import te from tvm import relay @@ -31,6 +33,7 @@ from .conv2d_int8 import is_int8_hw_support from .arm_utils import get_tiling_B_interleaved_t from ..generic.conv2d import conv2d_alter_int8_common +from .mprofile.dsp.micro_kernel.common import num_simd_lanes_per_word logger = logging.getLogger("topi") @@ -121,7 +124,40 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): idxd = tvm.tir.indexdiv - # We don't perform layout alteration for NHWC layout with real data types + if topi_tmpl == "depthwise_conv2d_nhwc_dsp.arm_cpu": + assert data_layout == "NHWC" and kernel_layout == "HWOI" + + # We are not able to check if inputs[1] (the kernel) is a constant in the + # strategy function, so as a stopgap solution we use an assert here. + assert isinstance( + inputs[1], relay.Constant + ), "depthwise_conv2d_nhwc_dsp.arm_cpu requires kernel be a relay Constant" + + channels = get_const_tuple(data.shape)[3] + KH, KW, _, _ = get_const_tuple(kernel.shape) + simd_lanes = num_simd_lanes_per_word(data.dtype) + + HWOI_kernel_np = inputs[1].data.numpy() + CHWc_kernel_np = np.zeros((channels // simd_lanes, KH, KW, simd_lanes), dtype=kernel.dtype) + for i in range(channels // simd_lanes): + CHWc_kernel_np[i] = HWOI_kernel_np[:, :, simd_lanes * i : simd_lanes * (i + 1), 0] + reshaped_new_kernel = CHWc_kernel_np.reshape((KH, KW, channels, 1)) + + # Store the same config for the altered operator (workload) + new_data = data + new_kernel = te.placeholder((KH, KW, channels, 1), dtype=kernel.dtype) + new_workload = autotvm.task.args_to_workload( + [new_data, new_kernel, strides, padding, dilation, out_dtype], + "depthwise_conv2d_nhwc_dsp.arm_cpu", + ) + dispatch_ctx.update(target, new_workload, cfg) + return relay.nn.conv2d( + inputs[0], + relay.Constant(tvm.nd.array(reshaped_new_kernel)), + **new_attrs, + ) + + # Only microTVM does layout alteration for NHWC layout with real data types if data_layout == "NHWC" and data_dtype not in ["uint8", "int8"]: return None diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py b/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py index 162bf65a21f9..b8da15dadf13 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py @@ -19,84 +19,15 @@ import random import string -from tvm import te -from tvm.topi.utils import traverse_inline, get_const_tuple +from tvm import te, topi +from tvm.topi.utils import traverse_inline from tvm.topi.nn.pad import pad -from tvm import tir -from .micro_kernel.quad_channel_convolve import ( - intrin_quad_channel_convolve, - quad_channel_convolve_impl, +from .micro_kernel.multi_channel_convolve import ( + intrin_multi_channel_convolve, + multi_channel_convolve_impl, ) - -# For depthwise_conv2d, kernels are normally given in HWOI format, -# which when input_channels = output channels, we will call HWC. -# This is bad, as we want "related" parts of the kernel to be next -# to each other, so we can use __SMLAD later. -# -# Consider a 3x3 int8 kernel with no bias vector, with eight -# channels. Let us specify entries in the kernel as H_W_C - i.e. -# where 0_2_3 represents the rightmost position in the first row -# of channel 4/8 (4 because of zero indexing). Each [ ] represents -# a 32-bit integer. We currently store the kernel as: -# -# 0 ................................31 -# [ 0_0_0 || 0_0_1 || 0_0_2 || 0_0_3 ] [ 0_0_4 || 0_0_5 || 0_0_6 || 0_0_7 ] -# [ 0_1_0 || 0_1_1 || 0_1_2 || 0_1_3 ] [ 0_1_4 || 0_1_5 || 0_1_6 || 0_1_7 ] -# [ 0_2_0 || 0_2_1 || 0_2_2 || 0_2_3 ] [ 0_2_4 || 0_2_5 || 0_2_6 || 0_2_7 ] -# [ 1_0_0 || 1_0_1 || 1_0_2 || 1_0_3 ] [ 1_0_4 || 1_0_5 || 1_0_6 || 1_0_7 ] -# [ 1_1_0 || 1_1_1 || 1_1_2 || 1_1_3 ] [ 1_1_4 || 1_1_5 || 1_1_6 || 1_1_7 ] -# [ 1_2_0 || 1_2_1 || 1_2_2 || 1_2_3 ] [ 1_2_4 || 1_2_5 || 1_2_6 || 1_2_7 ] -# [ 2_0_0 || 2_0_1 || 2_0_2 || 2_0_3 ] [ 2_0_4 || 2_0_5 || 2_0_6 || 2_0_7 ] -# [ 2_1_0 || 2_1_1 || 2_1_2 || 2_1_3 ] [ 2_1_4 || 2_1_5 || 2_1_6 || 2_1_7 ] -# [ 2_2_0 || 2_2_1 || 2_2_2 || 2_2_3 ] [ 2_2_4 || 2_2_5 || 2_2_6 || 2_2_7 ] -# -# Let 0x00 be all zeros. We rearrange into: -# -# 0 ................................31 -# [ 0_0_0 || 0_0_1 || 0_1_0 || 0_1_1 ] [ 0_0_2 || 0_0_3 || 0_1_2 || 0_1_3 ] -# [ 0_2_0 || 0_2_1 || 1_0_0 || 1_0_1 ] [ 0_2_2 || 0_2_3 || 1_0_2 || 1_0_3 ] -# [ 1_1_0 || 1_1_1 || 1_2_0 || 1_2_1 ] [ 1_1_2 || 1_1_3 || 1_2_2 || 1_2_3 ] -# [ 2_0_0 || 2_0_1 || 2_1_0 || 2_1_1 ] [ 2_0_2 || 2_0_3 || 2_1_2 || 2_1_3 ] -# [ 2_2_0 || 2_2_1 || 0x000 || 0x000 ] [ 2_2_2 || 2_2_3 || 0x000 || 0x000 ] -# [ 0_0_4 || 0_0_5 || 0_1_4 || 0_1_5 ] [ 0_0_6 || 0_0_7 || 0_1_6 || 0_1_7 ] -# [ 0_2_4 || 0_2_5 || 1_0_4 || 1_0_5 ] [ 0_2_6 || 0_2_7 || 1_0_6 || 1_0_7 ] -# [ 1_1_4 || 1_1_5 || 1_2_4 || 1_2_5 ] [ 1_1_6 || 1_1_7 || 1_2_6 || 1_2_7 ] -# [ 2_0_4 || 2_0_5 || 2_1_4 || 2_1_5 ] [ 2_0_6 || 2_0_7 || 2_1_6 || 2_1_7 ] -# [ 2_2_4 || 2_2_5 || 0x000 || 0x000 ] [ 2_2_6 || 2_2_7 || 0x000 || 0x000 ] -# -# This saves us six operations comapred to the original ordering, as we -# do not need halfword packing instructions. -# -# This kernel re-arranging function will be used for 3x3 kernels (as that -# is all this DSP implementation currently supports) but would work with -# any M*N kernel such that M*N is odd. - - -def _rearrange_kernel(kernel): - # Kernel must be HWC format. - kernel_h, kernel_w, channels, _ = get_const_tuple(kernel.shape) - assert channels % 4 == 0 - - # This restriction could be removed by only using tir.if_then_else to add padding - # zeros if (kernel_w * kernel_h) % 2 == 1, and filling completely otherwise. - assert (kernel_w * kernel_h) % 2 == 1 - - def fcompute(c_o, pos, c_i): - channel = (2 * (pos % 2)) + (c_i % 2) + (4 * c_o) - true_pos_index = 2 * (pos // 2) + (c_i // 2) - - return tir.if_then_else( - true_pos_index < (kernel_h * kernel_w), - kernel[true_pos_index // kernel_w, true_pos_index % kernel_w, channel, 0], - tir.const(0, "int8"), - ) - - return te.compute( - (channels // 4, kernel_h * kernel_w + 1, 4), - fcompute, - name="packed_kernel", - ) +from .micro_kernel.common import num_simd_lanes_per_word def depthwise_conv2d_nhwc_dsp_compute(_cfg, data, kernel, strides, padding, dilation, out_dtype): @@ -120,10 +51,7 @@ def depthwise_conv2d_nhwc_dsp_compute(_cfg, data, kernel, strides, padding, dila batch_size, height, width, channels = data.shape kernel_h, kernel_w, _, _ = kernel.shape - - # We require that the number of channels be divisible by 4. This restriction could - # be removed with strip mining if people cared. - assert channels % 4 == 0 + simd_lanes = num_simd_lanes_per_word(data.dtype) # We don't support different numbers of input and output channels. assert channels == kernel.shape[2] @@ -133,11 +61,6 @@ def depthwise_conv2d_nhwc_dsp_compute(_cfg, data, kernel, strides, padding, dila # round until we compute activations. assert out_dtype == "int32" - # This can pretty easily be generalized in the future. Likely worth doing, and this - # function was written to make doing so easy. Should only require adding more calls - # to QUAD_CHANNEL_REARRANGE_SUM. - assert kernel_w == kernel_h == 3 - # Padding the data requires COPYING THE ENTIRE INPUT TENSOR, which # is slow and bad. We should really implement a strip mining # routine to avoid this, but TVM has terrible support for that. @@ -188,18 +111,14 @@ def depthwise_conv2d_nhwc_dsp_compute(_cfg, data, kernel, strides, padding, dila raise RuntimeError() _, padded_h, padded_w, _ = padded_data.shape - packed_kernel = _rearrange_kernel(kernel) kh_i = te.reduce_axis((0, kernel_h), name="kh_i") kw_i = te.reduce_axis((0, kernel_w), name="kw_i") + reshaped_kernel = topi.reshape(kernel, (channels // simd_lanes, kernel_h, kernel_w, simd_lanes)) return te.compute( (batch_size, output_h, output_w, channels), lambda h, i, j, k: te.sum( padded_data[h, (i * stride_h) + kh_i, (j * stride_w) + kw_i, k].astype("int32") - * packed_kernel[ - k // 4, - (2 * ((3 * kh_i + kw_i) // 2)) + ((k % 4) // 2), - (2 * ((kh_i + kw_i) % 2)) + (k % 2), - ].astype("int32"), + * reshaped_kernel[k // simd_lanes, kh_i, kw_i, k % simd_lanes].astype("int32"), axis=(kh_i, kw_i), ), name="depthwise_conv2d", @@ -212,33 +131,36 @@ def depthwise_conv2d_nhwc_dsp_schedule(_cfg, outs): """Schedule function for v7e-m DSP instructions of conv2d.""" schedule = te.create_schedule([x.op for x in outs]) - def _callback(op): - if "depthwise_conv2d_nhwc" not in op.tag: + def _callback(operator): + if "depthwise_conv2d_nhwc" not in operator.tag: return # extract tensors - output = op.output(0) + output = operator.output(0) padded_data = output.op.input_tensors[0] - packed_kernel = output.op.input_tensors[1] - kernel = packed_kernel.op.input_tensors[0] + reshaped_kernel = output.op.input_tensors[1] + in_dtype = padded_data.dtype - _, _, padded_w, channels = padded_data.shape - kernel_h, kernel_w, _, _ = kernel.shape + _, padded_h, padded_w, channels = padded_data.shape + _, kernel_h, kernel_w, _ = reshaped_kernel.shape suffix = "".join(random.choices(string.ascii_uppercase, k=8)) b_ax, y_ax, x_ax, c_ax = schedule[output].op.axis ky_ax, kx_ax = schedule[output].op.reduce_axis - c_ax_o, c_ax_i = schedule[output].split(c_ax, factor=4) + simd_lanes = num_simd_lanes_per_word(in_dtype) + c_ax_o, c_ax_i = schedule[output].split(c_ax, factor=simd_lanes) schedule[output].reorder(b_ax, c_ax_o, y_ax, x_ax, ky_ax, kx_ax, c_ax_i) - quad_channel_convolve = intrin_quad_channel_convolve( - padded_w, channels, kernel_h, kernel_w, suffix + multi_channel_convolve = intrin_multi_channel_convolve( + in_dtype, padded_h, padded_w, channels, kernel_h, kernel_w, suffix ) - schedule[output].tensorize(ky_ax, quad_channel_convolve) + schedule[output].tensorize(ky_ax, multi_channel_convolve) schedule[output].pragma( b_ax, "import_c", - quad_channel_convolve_impl(padded_w, channels, kernel_h, kernel_w, suffix), + multi_channel_convolve_impl( + in_dtype, padded_h, padded_w, channels, kernel_h, kernel_w, suffix + ), ) traverse_inline(schedule, outs[-1].op, _callback) diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/common.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/common.py index df54c101773e..0398844315a7 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/common.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/common.py @@ -29,3 +29,18 @@ #include """ + +MICRO_WORD_LENGTH_BITS = 32 + + +def num_simd_lanes_per_word(dtype: str) -> int: + """Takes a dtype, and returns how many of that dtype fit into a single microcontroller word. + + >>> num_simd_lanes_per_word("int8") + 4 + >>> num_simd_lanes_per_word("int16") + 2 + """ + assert dtype.startswith("int") + dtype_width = int(dtype[3:]) + return MICRO_WORD_LENGTH_BITS // dtype_width diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/multi_channel_convolve.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/multi_channel_convolve.py new file mode 100644 index 000000000000..992d90578046 --- /dev/null +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/multi_channel_convolve.py @@ -0,0 +1,210 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This is a special intrinsic used for depthwise convolution using Cortex-M DSP instructions +(v7e-m). It takes as inputs an int8 HWC data tensor and an int8 CHWc kernel. This intrinsic "lays" +the kernel on top of the data tensors starting from a given pointer, performs signed sixteen-bit +multiplies on each pair of values, and sums all the products in an int32 accumlator. This process is +repeated four times giving four int32 outputs - one per channel.""" + +import textwrap + +from tvm import te, tir +from .common import num_simd_lanes_per_word + + +def _get_func_name(in_dtype, tensor_w, channels, kernel_h, kernel_w, suffix): + """Gets the C function name of the tensorized function.""" + return f"kernel_convolve_{in_dtype}_w{tensor_w}_c{channels}_kh{kernel_h}_kw{kernel_w}_{suffix}" + + +def intrin_multi_channel_convolve( + in_dtype, _tensor_h, tensor_w, channels, kernel_h, kernel_w, suffix +): + """Defines a v7e-m DSP-accelerated multi-channel convolution. Works on two + channels if in_dtype==int16, and four channels if in_dtype==int8.""" + simd_lanes = num_simd_lanes_per_word(in_dtype) + + overlap_dims = (kernel_h, kernel_w, simd_lanes) + data_slice = te.placeholder(overlap_dims, name="data_slice", dtype=in_dtype) + kernel_slice = te.placeholder(overlap_dims, name="kernel_slice", dtype=in_dtype) + + kh_i = te.reduce_axis((0, kernel_h), name="kh_i") + kw_i = te.reduce_axis((0, kernel_w), name="kw_i") + + output_slice = te.compute( + (simd_lanes,), + lambda k: te.sum( + data_slice[kh_i, kw_i, k].astype("int32") * kernel_slice[kh_i, kw_i, k].astype("int32"), + axis=(kh_i, kw_i), + ), + name="c", + ) + + data_buf = tir.decl_buffer( + data_slice.shape, + data_slice.dtype, + name="data", + offset_factor=1, + strides=[tensor_w * channels, channels, 1], + ) + kernel_buf = tir.decl_buffer( + kernel_slice.shape, + kernel_slice.dtype, + name="kernel", + offset_factor=1, + strides=[kernel_w * simd_lanes, simd_lanes, 1], + ) + output_buf = tir.decl_buffer( + output_slice.shape, output_slice.dtype, name="output", offset_factor=1, strides=[1] + ) + + def intrin_func(ins, outs): + builder = tir.ir_builder.create() + builder.emit( + tir.call_extern( + "int32", + _get_func_name(in_dtype, tensor_w, channels, kernel_h, kernel_w, suffix), + outs[0].access_ptr("w"), + ins[0].access_ptr("r"), + ins[1].access_ptr("r"), + ) + ) + return builder.get() + + return te.decl_tensor_intrin( + output_slice.op, + intrin_func, + binds={data_slice: data_buf, kernel_slice: kernel_buf, output_slice: output_buf}, + ) + + +def multi_channel_convolve_impl(in_dtype, *args) -> str: + """Generates C code for a fast multi-channel convolution function for ARM Cortex-M. This is done + by calling a sub-function depending on the input data type, as since v7e-m has no quad multiply + accumulate instruction, the int8 and int16 cases work differently.""" + if in_dtype == "int8": + return _quad_int8_channel_convolve_impl(*args) + if in_dtype == "int16": + return _dual_int16_channel_convolve_impl(*args) + + raise NotImplementedError(f"No Cortex-M {in_dtype} depthwise_conv2d implementation exists!") + + +def _quad_int8_channel_convolve_impl(_tensor_h, tensor_w, channels, kernel_h, kernel_w, suffix): + return textwrap.dedent( + ( + f""" + #include + #include + + // __SXTB16(_ROR(X, Y)) is combined into one assembly instruction + + #define TVMGEN_QUAD_INT8_CHANNEL_REARRANGE_SUM_DSP( \ + arranged_kernel, \ + tensor_c3210, \ + sum_c0, sum_c1, sum_c2, sum_c3) {{ \ + \ + uint32_t kernel_c3210 = *arranged_kernel++; \ + \ + uint32_t tensor_c20 = __SXTB16(tensor_c3210); \ + uint32_t kernel_c20 = __SXTB16(kernel_c3210); \ + sum_c0 = __builtin_arm_smlabb(tensor_c20, kernel_c20, sum_c0); \ + sum_c2 = __builtin_arm_smlatt(tensor_c20, kernel_c20, sum_c2); \ + \ + uint32_t tensor_c31 = __SXTB16(__ROR(tensor_c3210, 8)); \ + uint32_t kernel_c31 = __SXTB16(__ROR(kernel_c3210, 8)); \ + sum_c1 = __builtin_arm_smlabb(tensor_c31, kernel_c31, sum_c1); \ + sum_c3 = __builtin_arm_smlatt(tensor_c31, kernel_c31, sum_c3); \ + }} + + /* We do four channels at once to get this speed boost. */ + #ifdef __cplusplus + extern "C" + #endif + int32_t {_get_func_name("int8", tensor_w, channels, kernel_h, kernel_w, suffix)}( + uint32_t *out, + uint32_t *tensor, + uint32_t *kernel) {{ + + uint32_t sum_c0 = 0; + uint32_t sum_c1 = 0; + uint32_t sum_c2 = 0; + uint32_t sum_c3 = 0; + + #pragma GCC unroll 3 + for (int i = 0; i < {kernel_h}; i++) {{ + #pragma GCC unroll 3 + for (int j = 0; j < {kernel_w}; j++) {{ + TVMGEN_QUAD_INT8_CHANNEL_REARRANGE_SUM_DSP( + kernel, + *(tensor + j * {channels // 4} + i * {tensor_w * (channels // 4)}), + sum_c0, sum_c1, sum_c2, sum_c3) + }} + }} + + out[0] = sum_c0; + out[1] = sum_c1; + out[2] = sum_c2; + out[3] = sum_c3; + return 0; + }} + + #undef TVMGEN_QUAD_INT8_CHANNEL_REARRANGE_SUM_DSP + """ + ) + ) + + +def _dual_int16_channel_convolve_impl(_tensor_h, tensor_w, channels, kernel_h, kernel_w, suffix): + return textwrap.dedent( + ( + f""" + #include + + /* We do four channels at once to get this speed boost. */ + #ifdef __cplusplus + extern "C" + #endif + int32_t {_get_func_name("int16", tensor_w, channels, kernel_h, kernel_w, suffix)}( + uint32_t *out, + uint32_t *tensor, + uint32_t *kernel) {{ + + uint32_t sum_c0 = 0; + uint32_t sum_c1 = 0; + + #pragma GCC unroll 3 + for (int i = 0; i < {kernel_h}; i++) {{ + #pragma GCC unroll 3 + for (int j = 0; j < {kernel_w}; j++) {{ + uint32_t tensor_c10 = *(tensor + j * {channels // 2} + + i * {tensor_w * (channels // 2)}); + uint32_t kernel_c10 = *kernel++; + sum_c0 = __builtin_arm_smlabb(tensor_c10, kernel_c10, sum_c0); + sum_c1 = __builtin_arm_smlatt(tensor_c10, kernel_c10, sum_c1); + }} + }} + + out[0] = sum_c0; + out[1] = sum_c1; + return 0; + }} + + #undef TVMGEN_DUAL_INT16_CHANNEL_REARRANGE_SUM + """ + ) + ) diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py deleted file mode 100644 index 960ef8fadc0e..000000000000 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py +++ /dev/null @@ -1,180 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This is a special intrinsic used for depthwise convolution using Cortex-M DSP instructions -(v7e-m). It takes as inputs an int8 HWC data tensor and an int8 CHWc kernel. This intrinsic "lays" -the kernel on top of the data tensors starting from a given pointer, performs signed sixteen-bit -multiplies on each pair of values, and sums all the products in an int32 accumlator. This process is -repeated four times giving four int32 outputs - one per channel.""" - -import textwrap - -from tvm import te, tir - - -def intrin_quad_channel_convolve(tensor_w, channels, kernel_h, kernel_w, suffix): - """Defines a v7e-m DSP-accelerated four-channel convolution.""" - data_slice = te.placeholder((kernel_h, kernel_w, 4), name="a", dtype="int8") - - if kernel_h * kernel_w % 2 == 1: - kernel_length = kernel_h * kernel_w + 1 - else: - kernel_length = kernel_h * kernel_w - kernel_slice = te.placeholder((kernel_length, 4), name="b", dtype="int8") - - kh_i = te.reduce_axis((0, kernel_h), name="kh_i") - kw_i = te.reduce_axis((0, kernel_w), name="kw_i") - - output_slice = te.compute( - (4,), - lambda k: te.sum( - data_slice[kh_i, kw_i, k].astype("int32") - * kernel_slice[ - (2 * ((3 * kh_i + kw_i) // 2)) + ((k % 4) // 2), - (2 * ((kh_i + kw_i) % 2)) + (k % 2), - ].astype("int32"), - axis=(kh_i, kw_i), - ), - name="c", - ) - - data_buf = tir.decl_buffer( - data_slice.shape, - data_slice.dtype, - name="data", - offset_factor=1, - strides=[tensor_w * channels, channels, 1], - ) - kernel_buf = tir.decl_buffer( - kernel_slice.shape, kernel_slice.dtype, name="kernel", offset_factor=1, strides=[4, 1] - ) - output_buf = tir.decl_buffer( - output_slice.shape, output_slice.dtype, name="output", offset_factor=1, strides=[1] - ) - - def intrin_func(ins, outs): - builder = tir.ir_builder.create() - builder.emit( - tir.call_extern( - "int32", - f"kernel_convolve_w{tensor_w}_c{channels}_kh{kernel_h}_kw{kernel_w}_{suffix}", - outs[0].access_ptr("w"), - ins[0].access_ptr("r"), - ins[1].access_ptr("r"), - ) - ) - return builder.get() - - return te.decl_tensor_intrin( - output_slice.op, - intrin_func, - binds={data_slice: data_buf, kernel_slice: kernel_buf, output_slice: output_buf}, - ) - - -def quad_channel_convolve_impl(tensor_w, channels, kernel_h, kernel_w, suffix): - """Emits C code for quad_channel_convolve. Note that while intrin_quad_channel_convolve supports - any kernel size, this function only supports 3x3 kernels (this could be fixed with work).""" - assert kernel_h == kernel_w == 3 - - return textwrap.dedent( - ( - f""" - #include - #include - - // __SXTB16(_ROR(X, Y)) is combined into one assembly instruction - - #define TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP( \ - arranged_kernel, \ - tensor_v0_c3210, tensor_v1_c3210, \ - sum0, sum1, sum2, sum3) {{ \ - \ - uint32_t tensor_v0_c20 = __SXTB16(tensor_v0_c3210); \ - uint32_t tensor_v0_c31 = __SXTB16(__ROR(tensor_v0_c3210, 8)); \ - uint32_t tensor_v1_c20 = __SXTB16(tensor_v1_c3210); \ - uint32_t tensor_v1_c31 = __SXTB16(__ROR(tensor_v1_c3210, 8)); \ - \ - uint32_t kernel_v1c1_v1c0_v0c1_v0c0 = *arranged_kernel++; \ - uint32_t kernel_v1c3_v1c2_v0c3_v0c2 = *arranged_kernel++; \ - \ - uint32_t kernel_v10_c0 = __SXTB16(kernel_v1c1_v1c0_v0c1_v0c0); \ - uint32_t kernel_v10_c1 = __SXTB16(__ROR(kernel_v1c1_v1c0_v0c1_v0c0, 8)); \ - uint32_t kernel_v10_c2 = __SXTB16(kernel_v1c3_v1c2_v0c3_v0c2); \ - uint32_t kernel_v10_c3 = __SXTB16(__ROR(kernel_v1c3_v1c2_v0c3_v0c2, 8)); \ - \ - uint32_t tensor_v10_c0 = __PKHBT(tensor_v0_c20, tensor_v1_c20, 16); \ - uint32_t tensor_v10_c1 = __PKHBT(tensor_v0_c31, tensor_v1_c31, 16); \ - uint32_t tensor_v10_c2 = __PKHTB(tensor_v1_c20, tensor_v0_c20, 16); \ - uint32_t tensor_v10_c3 = __PKHTB(tensor_v1_c31, tensor_v0_c31, 16); \ - \ - sum_c0 = __SMLAD(tensor_v10_c0, kernel_v10_c0, sum_c0); \ - sum_c1 = __SMLAD(tensor_v10_c1, kernel_v10_c1, sum_c1); \ - sum_c2 = __SMLAD(tensor_v10_c2, kernel_v10_c2, sum_c2); \ - sum_c3 = __SMLAD(tensor_v10_c3, kernel_v10_c3, sum_c3); \ - }} - - /* We do four channels at once to get this speed boost. */ - #ifdef __cplusplus - extern "C" - #endif - int32_t kernel_convolve_w{tensor_w}_c{channels}_kh{kernel_h}_kw{kernel_w}_{suffix}( - uint32_t *out, - uint32_t *tensor, - uint32_t *packed_kernel) {{ - - uint32_t sum_c0 = 0; - uint32_t sum_c1 = 0; - uint32_t sum_c2 = 0; - uint32_t sum_c3 = 0; - - TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP( - packed_kernel, - *tensor, - *(tensor + {channels // 4}), - sum_c0, sum_c1, sum_c2, sum_c3) - TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP( - packed_kernel, - *(tensor + {(2) * channels // 4}), - *(tensor + {tensor_w * (channels // 4)}), - sum_c0, sum_c1, sum_c2, sum_c3) - TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP( - packed_kernel, - *(tensor + {(tensor_w + 1) * (channels // 4)}), - *(tensor + {(tensor_w + 2) * (channels // 4)}), - sum_c0, sum_c1, sum_c2, sum_c3) - TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP( - packed_kernel, - *(tensor + {(2 * tensor_w) * (channels // 4)}), - *(tensor + {(2 * tensor_w + 1) * (channels // 4)}), - sum_c0, sum_c1, sum_c2, sum_c3) - TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP( - packed_kernel, - *(tensor + {(2 * tensor_w + 2) * (channels // 4)}), - 0, - sum_c0, sum_c1, sum_c2, sum_c3) - - out[0] = sum_c0; - out[1] = sum_c1; - out[2] = sum_c2; - out[3] = sum_c3; - return 0; - }} - - #undef TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP - """ - ) - ) diff --git a/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py index 18c5082f2a0c..15ea2a31d864 100644 --- a/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py +++ b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py @@ -150,24 +150,37 @@ class TestDepthwiseConv2d_NHWC_HWOI(BasicDepthwiseConv2dTests): class TestDepthwiseConv2d_NHWC_HWOI_DSP(BasicDepthwiseConv2dTests): """This test is for depthwise_conv2d_nhwc_dsp.arm_cpu schedule.""" - data_shape, kernel_size, num_filter, strides, padding, dilation = tvm.testing.parameters( - # The LLVM implementation doesn't support "SAME" and "VALID" padding, - # so padding must be explicitly specified. - # Depthwise_conv2d parameters from MobileNetV1 0.25x - ((1, 48, 48, 8), (3, 3), 8, (1, 1), 1, 1), - ((1, 48, 48, 16), (3, 3), 16, (2, 2), (1, 1, 0, 0), 1), - ((1, 24, 24, 32), (3, 3), 32, (1, 1), 1, 1), - ((1, 24, 24, 32), (3, 3), 32, (2, 2), (1, 1, 0, 0), 1), - ((1, 12, 12, 64), (3, 3), 64, (1, 1), 1, 1), - ((1, 12, 12, 64), (3, 3), 64, (2, 2), (1, 1, 0, 0), 1), - ((1, 6, 6, 128), (3, 3), 128, (1, 1), 1, 1), - ((1, 6, 6, 128), (3, 3), 128, (2, 2), (1, 1, 0, 0), 1), - ((1, 3, 3, 256), (3, 3), 256, (1, 1), 1, 1), + # Tests that work with both int8 and int16 data types. Tuple elements are: + # data_shape, kernel_size, num_filter, strides, padding + dtype_parameterized_tests = [ + # Depthwise_conv2d parameters from MobileNetV1 0.25x. The LLVM implementation doesn't + # support "SAME" and "VALID" padding, so padding must be explicitly specified. + ((1, 48, 48, 8), (3, 3), 8, (1, 1), 1), + ((1, 48, 48, 16), (3, 3), 16, (2, 2), (1, 1, 0, 0)), + ((1, 24, 24, 32), (3, 3), 32, (1, 1), 1), + ((1, 24, 24, 32), (3, 3), 32, (2, 2), (1, 1, 0, 0)), + ((1, 12, 12, 64), (3, 3), 64, (1, 1), 1), + ((1, 12, 12, 64), (3, 3), 64, (2, 2), (1, 1, 0, 0)), + ((1, 6, 6, 128), (3, 3), 128, (1, 1), 1), + ((1, 6, 6, 128), (3, 3), 128, (2, 2), (1, 1, 0, 0)), + ((1, 3, 3, 256), (3, 3), 256, (1, 1), 1), # Asymmetric height and width - ((1, 25, 5, 64), (3, 3), 64, (1, 1), 1, 1), + ((1, 25, 5, 64), (3, 3), 64, (1, 1), 1), + # Larger kernel + ((1, 24, 24, 8), (5, 5), 8, (1, 1), 1), + # Asymmetric kernel + ((1, 24, 24, 8), (3, 5), 8, (1, 1), 1), + ] + + data_shape, kernel_size, num_filter, strides, padding, dtype = tvm.testing.parameters( + # Make a copy of each parameterized test for int8 and one for int16 + *map(lambda t: t + ("int8",), dtype_parameterized_tests), + *map(lambda t: t + ("int16",), dtype_parameterized_tests), + # Test the int16 implementation with channel numbers not divisible by four + ((1, 48, 48, 6), (3, 3), 6, (1, 1), 1, "int16"), ) + dilation = tvm.testing.parameter(1) data_layout = tvm.testing.parameter("NHWC") - dtype = tvm.testing.parameter("int8") kernel_layout = tvm.testing.parameter("HWOI") schedule_name = tvm.testing.parameter("depthwise_conv2d_nhwc_dsp.arm_cpu")