diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index ba28b6c7c31c..2d9ef99ba8a6 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -235,6 +235,28 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nhwc), name="depthwise_conv2d_nhwc.arm_cpu", ) + + # Optimized special case depthwiseConv2D operation. Requires a 3x3 kernel, a + # NHWC layout, a HWOI kernel layout (which we rearrange), no dilation, int8 inputs, + # int32 output, the same number of input and output channels, and for that channel + # count to be divisible by 4. Additional work could remove these restrictions. + + elif ( + target.features.has_dsp + and kernel.shape[0] == kernel.shape[1] == 3 + and dilation_w == dilation_h == 1 + and kernel.shape[3] == 1 # channel_multiplier == 1 + and data.dtype == "int8" + and out_type.dtype == "int32" + and data.shape[3] % 4 == 0 + and (padding != "SAME" or data.shape[1] % stride_h == data.shape[2] % stride_w == 0) + ): + strategy.add_implementation( + wrap_compute_conv2d(topi.arm_cpu.depthwise_conv2d_nhwc_dsp), + wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nhwc_dsp), + name="depthwise_conv2d_nhwc_dsp.arm_cpu", + ) + else: logger.warning("depthwise_conv2d with layout NHWC is not optimized for arm cpu.") strategy.add_implementation( diff --git a/python/tvm/topi/arm_cpu/depthwise_conv2d.py b/python/tvm/topi/arm_cpu/depthwise_conv2d.py index c21480724ae4..333db3d5e014 100644 --- a/python/tvm/topi/arm_cpu/depthwise_conv2d.py +++ b/python/tvm/topi/arm_cpu/depthwise_conv2d.py @@ -28,6 +28,11 @@ from .tensor_intrin import smlal_int16_int32 from .arm_utils import is_aarch64_arm +from .mprofile.dsp.depthwise_conv2d import ( + depthwise_conv2d_nhwc_dsp_compute, + depthwise_conv2d_nhwc_dsp_schedule, +) + @autotvm.register_topi_compute("depthwise_conv2d_nchw.arm_cpu") def depthwise_conv2d_nchw(_, data, kernel, strides, padding, dilation, out_dtype): @@ -699,3 +704,17 @@ def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, last): s[kernel_vec].parallel(co) return s + + +@autotvm.register_topi_compute("depthwise_conv2d_nhwc_dsp.arm_cpu") +def depthwise_conv2d_nhwc_dsp(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Compute conv2d_nhwc with v7e-m DSP instructions.""" + return depthwise_conv2d_nhwc_dsp_compute( + cfg, data, kernel, strides, padding, dilation, out_dtype + ) + + +@autotvm.register_topi_schedule("depthwise_conv2d_nhwc_dsp.arm_cpu") +def schedule_depthwise_conv2d_nhwc_dsp(cfg, outs): + """Create schedule for conv2d_nhwc_dsp""" + return depthwise_conv2d_nhwc_dsp_schedule(cfg, outs) diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py b/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py new file mode 100644 index 000000000000..162bf65a21f9 --- /dev/null +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py @@ -0,0 +1,245 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""ARM Cortex-M DSP schedule for depthwise_conv2d""" + +import random +import string + +from tvm import te +from tvm.topi.utils import traverse_inline, get_const_tuple +from tvm.topi.nn.pad import pad +from tvm import tir + +from .micro_kernel.quad_channel_convolve import ( + intrin_quad_channel_convolve, + quad_channel_convolve_impl, +) + +# For depthwise_conv2d, kernels are normally given in HWOI format, +# which when input_channels = output channels, we will call HWC. +# This is bad, as we want "related" parts of the kernel to be next +# to each other, so we can use __SMLAD later. +# +# Consider a 3x3 int8 kernel with no bias vector, with eight +# channels. Let us specify entries in the kernel as H_W_C - i.e. +# where 0_2_3 represents the rightmost position in the first row +# of channel 4/8 (4 because of zero indexing). Each [ ] represents +# a 32-bit integer. We currently store the kernel as: +# +# 0 ................................31 +# [ 0_0_0 || 0_0_1 || 0_0_2 || 0_0_3 ] [ 0_0_4 || 0_0_5 || 0_0_6 || 0_0_7 ] +# [ 0_1_0 || 0_1_1 || 0_1_2 || 0_1_3 ] [ 0_1_4 || 0_1_5 || 0_1_6 || 0_1_7 ] +# [ 0_2_0 || 0_2_1 || 0_2_2 || 0_2_3 ] [ 0_2_4 || 0_2_5 || 0_2_6 || 0_2_7 ] +# [ 1_0_0 || 1_0_1 || 1_0_2 || 1_0_3 ] [ 1_0_4 || 1_0_5 || 1_0_6 || 1_0_7 ] +# [ 1_1_0 || 1_1_1 || 1_1_2 || 1_1_3 ] [ 1_1_4 || 1_1_5 || 1_1_6 || 1_1_7 ] +# [ 1_2_0 || 1_2_1 || 1_2_2 || 1_2_3 ] [ 1_2_4 || 1_2_5 || 1_2_6 || 1_2_7 ] +# [ 2_0_0 || 2_0_1 || 2_0_2 || 2_0_3 ] [ 2_0_4 || 2_0_5 || 2_0_6 || 2_0_7 ] +# [ 2_1_0 || 2_1_1 || 2_1_2 || 2_1_3 ] [ 2_1_4 || 2_1_5 || 2_1_6 || 2_1_7 ] +# [ 2_2_0 || 2_2_1 || 2_2_2 || 2_2_3 ] [ 2_2_4 || 2_2_5 || 2_2_6 || 2_2_7 ] +# +# Let 0x00 be all zeros. We rearrange into: +# +# 0 ................................31 +# [ 0_0_0 || 0_0_1 || 0_1_0 || 0_1_1 ] [ 0_0_2 || 0_0_3 || 0_1_2 || 0_1_3 ] +# [ 0_2_0 || 0_2_1 || 1_0_0 || 1_0_1 ] [ 0_2_2 || 0_2_3 || 1_0_2 || 1_0_3 ] +# [ 1_1_0 || 1_1_1 || 1_2_0 || 1_2_1 ] [ 1_1_2 || 1_1_3 || 1_2_2 || 1_2_3 ] +# [ 2_0_0 || 2_0_1 || 2_1_0 || 2_1_1 ] [ 2_0_2 || 2_0_3 || 2_1_2 || 2_1_3 ] +# [ 2_2_0 || 2_2_1 || 0x000 || 0x000 ] [ 2_2_2 || 2_2_3 || 0x000 || 0x000 ] +# [ 0_0_4 || 0_0_5 || 0_1_4 || 0_1_5 ] [ 0_0_6 || 0_0_7 || 0_1_6 || 0_1_7 ] +# [ 0_2_4 || 0_2_5 || 1_0_4 || 1_0_5 ] [ 0_2_6 || 0_2_7 || 1_0_6 || 1_0_7 ] +# [ 1_1_4 || 1_1_5 || 1_2_4 || 1_2_5 ] [ 1_1_6 || 1_1_7 || 1_2_6 || 1_2_7 ] +# [ 2_0_4 || 2_0_5 || 2_1_4 || 2_1_5 ] [ 2_0_6 || 2_0_7 || 2_1_6 || 2_1_7 ] +# [ 2_2_4 || 2_2_5 || 0x000 || 0x000 ] [ 2_2_6 || 2_2_7 || 0x000 || 0x000 ] +# +# This saves us six operations comapred to the original ordering, as we +# do not need halfword packing instructions. +# +# This kernel re-arranging function will be used for 3x3 kernels (as that +# is all this DSP implementation currently supports) but would work with +# any M*N kernel such that M*N is odd. + + +def _rearrange_kernel(kernel): + # Kernel must be HWC format. + kernel_h, kernel_w, channels, _ = get_const_tuple(kernel.shape) + assert channels % 4 == 0 + + # This restriction could be removed by only using tir.if_then_else to add padding + # zeros if (kernel_w * kernel_h) % 2 == 1, and filling completely otherwise. + assert (kernel_w * kernel_h) % 2 == 1 + + def fcompute(c_o, pos, c_i): + channel = (2 * (pos % 2)) + (c_i % 2) + (4 * c_o) + true_pos_index = 2 * (pos // 2) + (c_i // 2) + + return tir.if_then_else( + true_pos_index < (kernel_h * kernel_w), + kernel[true_pos_index // kernel_w, true_pos_index % kernel_w, channel, 0], + tir.const(0, "int8"), + ) + + return te.compute( + (channels // 4, kernel_h * kernel_w + 1, 4), + fcompute, + name="packed_kernel", + ) + + +def depthwise_conv2d_nhwc_dsp_compute(_cfg, data, kernel, strides, padding, dilation, out_dtype): + """Compute function for v7e-m DSP instructions of DepthwiseConv2D. Has a lot of requirements + for use - if not all apply, the fallback implementation will be used instead.""" + assert isinstance(strides, int) or len(strides) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + + if isinstance(strides, int): + stride_h = stride_w = strides + else: + stride_h, stride_w = strides + + # We do not support dilation currently. It would be possible, but it would require + # modifying the way the kernel is packed. Gnarly. + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + assert dilation_h == dilation_w == 1 + + batch_size, height, width, channels = data.shape + kernel_h, kernel_w, _, _ = kernel.shape + + # We require that the number of channels be divisible by 4. This restriction could + # be removed with strip mining if people cared. + assert channels % 4 == 0 + + # We don't support different numbers of input and output channels. + assert channels == kernel.shape[2] + assert kernel.shape[3] == 1 + + # We take in int8 as our dtype, but we spit out int32. This is because we cannot + # round until we compute activations. + assert out_dtype == "int32" + + # This can pretty easily be generalized in the future. Likely worth doing, and this + # function was written to make doing so easy. Should only require adding more calls + # to QUAD_CHANNEL_REARRANGE_SUM. + assert kernel_w == kernel_h == 3 + + # Padding the data requires COPYING THE ENTIRE INPUT TENSOR, which + # is slow and bad. We should really implement a strip mining + # routine to avoid this, but TVM has terrible support for that. + + if padding == "SAME": + # This assumption makes the logic easier. Could be removed with work. + assert height % stride_h == width % stride_w == 0 + + output_h = height // stride_h + output_w = width // stride_w + + # This padding behavior is consistent with other TVM depthwise_conv2d schedules. However it + # differs from the TensorFlow, which only pads the bottom right if stride > 1. This probably + # brings down accuracy slightly for models imported from TFLite. + pad_down = 1 if stride_h == 1 else 0 + pad_right = 1 if stride_w == 1 else 0 + + padded_data = pad( + data, + [0, kernel_h // 2, kernel_w // 2, 0], + [0, pad_down, pad_right, 0], + name="padded_data", + ) + + elif padding == "VALID": + assert height > kernel_h and width > kernel_w + output_h = (height - kernel_h) // stride_h + 1 + output_w = (width - kernel_w) // stride_w + 1 + padded_data = data + + elif isinstance(padding, tuple): + if len(padding) == 2: + pad_up, pad_down = padding[0] + pad_left, pad_right = padding[1] + else: + pad_up, pad_left, pad_down, pad_right = padding + + output_h = (height - kernel_h + pad_up + pad_down) // stride_h + 1 + output_w = (width - kernel_w + pad_left + pad_right) // stride_w + 1 + padded_data = pad( + data, + [0, pad_up, pad_left, 0], + [0, pad_down, pad_right, 0], + name="padded_data", + ) + + else: + raise RuntimeError() + _, padded_h, padded_w, _ = padded_data.shape + + packed_kernel = _rearrange_kernel(kernel) + kh_i = te.reduce_axis((0, kernel_h), name="kh_i") + kw_i = te.reduce_axis((0, kernel_w), name="kw_i") + return te.compute( + (batch_size, output_h, output_w, channels), + lambda h, i, j, k: te.sum( + padded_data[h, (i * stride_h) + kh_i, (j * stride_w) + kw_i, k].astype("int32") + * packed_kernel[ + k // 4, + (2 * ((3 * kh_i + kw_i) // 2)) + ((k % 4) // 2), + (2 * ((kh_i + kw_i) % 2)) + (k % 2), + ].astype("int32"), + axis=(kh_i, kw_i), + ), + name="depthwise_conv2d", + tag=f"depthwise_conv2d_nhwc_{padded_h}_{padded_w}_dsp", + ) + + +def depthwise_conv2d_nhwc_dsp_schedule(_cfg, outs): + + """Schedule function for v7e-m DSP instructions of conv2d.""" + schedule = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if "depthwise_conv2d_nhwc" not in op.tag: + return + + # extract tensors + output = op.output(0) + padded_data = output.op.input_tensors[0] + packed_kernel = output.op.input_tensors[1] + kernel = packed_kernel.op.input_tensors[0] + + _, _, padded_w, channels = padded_data.shape + kernel_h, kernel_w, _, _ = kernel.shape + suffix = "".join(random.choices(string.ascii_uppercase, k=8)) + + b_ax, y_ax, x_ax, c_ax = schedule[output].op.axis + ky_ax, kx_ax = schedule[output].op.reduce_axis + c_ax_o, c_ax_i = schedule[output].split(c_ax, factor=4) + schedule[output].reorder(b_ax, c_ax_o, y_ax, x_ax, ky_ax, kx_ax, c_ax_i) + + quad_channel_convolve = intrin_quad_channel_convolve( + padded_w, channels, kernel_h, kernel_w, suffix + ) + schedule[output].tensorize(ky_ax, quad_channel_convolve) + schedule[output].pragma( + b_ax, + "import_c", + quad_channel_convolve_impl(padded_w, channels, kernel_h, kernel_w, suffix), + ) + + traverse_inline(schedule, outs[-1].op, _callback) + return schedule diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py new file mode 100644 index 000000000000..960ef8fadc0e --- /dev/null +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py @@ -0,0 +1,180 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This is a special intrinsic used for depthwise convolution using Cortex-M DSP instructions +(v7e-m). It takes as inputs an int8 HWC data tensor and an int8 CHWc kernel. This intrinsic "lays" +the kernel on top of the data tensors starting from a given pointer, performs signed sixteen-bit +multiplies on each pair of values, and sums all the products in an int32 accumlator. This process is +repeated four times giving four int32 outputs - one per channel.""" + +import textwrap + +from tvm import te, tir + + +def intrin_quad_channel_convolve(tensor_w, channels, kernel_h, kernel_w, suffix): + """Defines a v7e-m DSP-accelerated four-channel convolution.""" + data_slice = te.placeholder((kernel_h, kernel_w, 4), name="a", dtype="int8") + + if kernel_h * kernel_w % 2 == 1: + kernel_length = kernel_h * kernel_w + 1 + else: + kernel_length = kernel_h * kernel_w + kernel_slice = te.placeholder((kernel_length, 4), name="b", dtype="int8") + + kh_i = te.reduce_axis((0, kernel_h), name="kh_i") + kw_i = te.reduce_axis((0, kernel_w), name="kw_i") + + output_slice = te.compute( + (4,), + lambda k: te.sum( + data_slice[kh_i, kw_i, k].astype("int32") + * kernel_slice[ + (2 * ((3 * kh_i + kw_i) // 2)) + ((k % 4) // 2), + (2 * ((kh_i + kw_i) % 2)) + (k % 2), + ].astype("int32"), + axis=(kh_i, kw_i), + ), + name="c", + ) + + data_buf = tir.decl_buffer( + data_slice.shape, + data_slice.dtype, + name="data", + offset_factor=1, + strides=[tensor_w * channels, channels, 1], + ) + kernel_buf = tir.decl_buffer( + kernel_slice.shape, kernel_slice.dtype, name="kernel", offset_factor=1, strides=[4, 1] + ) + output_buf = tir.decl_buffer( + output_slice.shape, output_slice.dtype, name="output", offset_factor=1, strides=[1] + ) + + def intrin_func(ins, outs): + builder = tir.ir_builder.create() + builder.emit( + tir.call_extern( + "int32", + f"kernel_convolve_w{tensor_w}_c{channels}_kh{kernel_h}_kw{kernel_w}_{suffix}", + outs[0].access_ptr("w"), + ins[0].access_ptr("r"), + ins[1].access_ptr("r"), + ) + ) + return builder.get() + + return te.decl_tensor_intrin( + output_slice.op, + intrin_func, + binds={data_slice: data_buf, kernel_slice: kernel_buf, output_slice: output_buf}, + ) + + +def quad_channel_convolve_impl(tensor_w, channels, kernel_h, kernel_w, suffix): + """Emits C code for quad_channel_convolve. Note that while intrin_quad_channel_convolve supports + any kernel size, this function only supports 3x3 kernels (this could be fixed with work).""" + assert kernel_h == kernel_w == 3 + + return textwrap.dedent( + ( + f""" + #include + #include + + // __SXTB16(_ROR(X, Y)) is combined into one assembly instruction + + #define TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP( \ + arranged_kernel, \ + tensor_v0_c3210, tensor_v1_c3210, \ + sum0, sum1, sum2, sum3) {{ \ + \ + uint32_t tensor_v0_c20 = __SXTB16(tensor_v0_c3210); \ + uint32_t tensor_v0_c31 = __SXTB16(__ROR(tensor_v0_c3210, 8)); \ + uint32_t tensor_v1_c20 = __SXTB16(tensor_v1_c3210); \ + uint32_t tensor_v1_c31 = __SXTB16(__ROR(tensor_v1_c3210, 8)); \ + \ + uint32_t kernel_v1c1_v1c0_v0c1_v0c0 = *arranged_kernel++; \ + uint32_t kernel_v1c3_v1c2_v0c3_v0c2 = *arranged_kernel++; \ + \ + uint32_t kernel_v10_c0 = __SXTB16(kernel_v1c1_v1c0_v0c1_v0c0); \ + uint32_t kernel_v10_c1 = __SXTB16(__ROR(kernel_v1c1_v1c0_v0c1_v0c0, 8)); \ + uint32_t kernel_v10_c2 = __SXTB16(kernel_v1c3_v1c2_v0c3_v0c2); \ + uint32_t kernel_v10_c3 = __SXTB16(__ROR(kernel_v1c3_v1c2_v0c3_v0c2, 8)); \ + \ + uint32_t tensor_v10_c0 = __PKHBT(tensor_v0_c20, tensor_v1_c20, 16); \ + uint32_t tensor_v10_c1 = __PKHBT(tensor_v0_c31, tensor_v1_c31, 16); \ + uint32_t tensor_v10_c2 = __PKHTB(tensor_v1_c20, tensor_v0_c20, 16); \ + uint32_t tensor_v10_c3 = __PKHTB(tensor_v1_c31, tensor_v0_c31, 16); \ + \ + sum_c0 = __SMLAD(tensor_v10_c0, kernel_v10_c0, sum_c0); \ + sum_c1 = __SMLAD(tensor_v10_c1, kernel_v10_c1, sum_c1); \ + sum_c2 = __SMLAD(tensor_v10_c2, kernel_v10_c2, sum_c2); \ + sum_c3 = __SMLAD(tensor_v10_c3, kernel_v10_c3, sum_c3); \ + }} + + /* We do four channels at once to get this speed boost. */ + #ifdef __cplusplus + extern "C" + #endif + int32_t kernel_convolve_w{tensor_w}_c{channels}_kh{kernel_h}_kw{kernel_w}_{suffix}( + uint32_t *out, + uint32_t *tensor, + uint32_t *packed_kernel) {{ + + uint32_t sum_c0 = 0; + uint32_t sum_c1 = 0; + uint32_t sum_c2 = 0; + uint32_t sum_c3 = 0; + + TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP( + packed_kernel, + *tensor, + *(tensor + {channels // 4}), + sum_c0, sum_c1, sum_c2, sum_c3) + TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP( + packed_kernel, + *(tensor + {(2) * channels // 4}), + *(tensor + {tensor_w * (channels // 4)}), + sum_c0, sum_c1, sum_c2, sum_c3) + TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP( + packed_kernel, + *(tensor + {(tensor_w + 1) * (channels // 4)}), + *(tensor + {(tensor_w + 2) * (channels // 4)}), + sum_c0, sum_c1, sum_c2, sum_c3) + TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP( + packed_kernel, + *(tensor + {(2 * tensor_w) * (channels // 4)}), + *(tensor + {(2 * tensor_w + 1) * (channels // 4)}), + sum_c0, sum_c1, sum_c2, sum_c3) + TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP( + packed_kernel, + *(tensor + {(2 * tensor_w + 2) * (channels // 4)}), + 0, + sum_c0, sum_c1, sum_c2, sum_c3) + + out[0] = sum_c0; + out[1] = sum_c1; + out[2] = sum_c2; + out[3] = sum_c3; + return 0; + }} + + #undef TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP + """ + ) + ) diff --git a/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py index ee0d51c321f7..18c5082f2a0c 100644 --- a/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py +++ b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py @@ -147,5 +147,30 @@ class TestDepthwiseConv2d_NHWC_HWOI(BasicDepthwiseConv2dTests): schedule_name = tvm.testing.parameter("depthwise_conv2d_nhwc.generic") +class TestDepthwiseConv2d_NHWC_HWOI_DSP(BasicDepthwiseConv2dTests): + """This test is for depthwise_conv2d_nhwc_dsp.arm_cpu schedule.""" + + data_shape, kernel_size, num_filter, strides, padding, dilation = tvm.testing.parameters( + # The LLVM implementation doesn't support "SAME" and "VALID" padding, + # so padding must be explicitly specified. + # Depthwise_conv2d parameters from MobileNetV1 0.25x + ((1, 48, 48, 8), (3, 3), 8, (1, 1), 1, 1), + ((1, 48, 48, 16), (3, 3), 16, (2, 2), (1, 1, 0, 0), 1), + ((1, 24, 24, 32), (3, 3), 32, (1, 1), 1, 1), + ((1, 24, 24, 32), (3, 3), 32, (2, 2), (1, 1, 0, 0), 1), + ((1, 12, 12, 64), (3, 3), 64, (1, 1), 1, 1), + ((1, 12, 12, 64), (3, 3), 64, (2, 2), (1, 1, 0, 0), 1), + ((1, 6, 6, 128), (3, 3), 128, (1, 1), 1, 1), + ((1, 6, 6, 128), (3, 3), 128, (2, 2), (1, 1, 0, 0), 1), + ((1, 3, 3, 256), (3, 3), 256, (1, 1), 1, 1), + # Asymmetric height and width + ((1, 25, 5, 64), (3, 3), 64, (1, 1), 1, 1), + ) + data_layout = tvm.testing.parameter("NHWC") + dtype = tvm.testing.parameter("int8") + kernel_layout = tvm.testing.parameter("HWOI") + schedule_name = tvm.testing.parameter("depthwise_conv2d_nhwc_dsp.arm_cpu") + + if __name__ == "__main__": tvm.testing.main()