diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 947beb396ae2..e56e7ba12e94 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Definition of ARM CPU operator strategy.""" +from functools import reduce import logging # pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import @@ -71,6 +72,32 @@ def schedule_pool_arm_cpu(attrs, outs, target): return topi.generic.schedule_pool(outs, layout) +def _get_padding_width(padding): + assert isinstance(padding, tuple) + if len(padding) == 2: + _, (pad_left, pad_right) = padding + else: + _, pad_left, _, pad_right = padding + return pad_left + pad_right + + +def _is_simd_aligned(dtype, dimensions, padding=None): + if padding: + assert len(dimensions) == len(padding) + padded_dims = (sum(x) for x in zip(dimensions, padding)) + else: + padded_dims = dimensions + + # Multiply all elements of padded_dims together. We can't use math.prod, as it + # does not exist in Python 3.7. + size = reduce(lambda x, y: x * y, padded_dims) + return ( + (dtype == "int8" and size % 4 == 0) + or (dtype == "int16" and size % 2 == 0) + or (dtype == "int32") + ) + + @conv2d_strategy.register("arm_cpu") def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): """conv2d arm cpu strategy""" @@ -159,7 +186,21 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): name="conv2d_hwcn.generic", ) elif layout == "NHWC": - if target.features.has_dsp and kernel_layout == "HWOI": + data_width_padding = _get_padding_width(padding) + if ( + target.features.has_dsp + and dilation_w == dilation_h == 1 + and kernel_layout == "OHWI" + # Check SIMD alignment + and _is_simd_aligned(data.dtype, data.shape[2:], padding=(data_width_padding, 0)) + and _is_simd_aligned(kernel.dtype, kernel.shape[2:]) + ): + strategy.add_implementation( + wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_ohwi_dsp), + wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_ohwi_dsp), + name="conv2d_nhwc_ohwi_dsp.arm_cpu", + ) + elif target.features.has_dsp and kernel_layout == "HWOI": strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_dsp), wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_dsp), @@ -199,13 +240,25 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups): if layout == "NCHW": assert kernel_layout == "OIHW" or re.match(r"OIHW\d*o", kernel_layout) - # ARM conv2d depthwise schedule if kernel_layout == "OIHW": - strategy.add_implementation( - wrap_compute_conv2d(topi.arm_cpu.depthwise_conv2d_nchw), - wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nchw), - name="depthwise_conv2d_nchw.arm_cpu", - ) + data_width_padding = _get_padding_width(padding) + if ( + target.features.has_dsp + and dilation_w == dilation_h == 1 + and _is_simd_aligned(data.dtype, data.shape[3:], padding=(data_width_padding,)) + and _is_simd_aligned(kernel.dtype, kernel.shape[3:]) + ): + strategy.add_implementation( + wrap_compute_conv2d(topi.arm_cpu.depthwise_conv2d_nchw_oihw_dsp), + wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nchw_oihw_dsp), + name="depthwise_conv2d_nchw_oihw_dsp.arm_cpu", + ) + else: + strategy.add_implementation( + wrap_compute_conv2d(topi.arm_cpu.depthwise_conv2d_nchw), + wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nchw), + name="depthwise_conv2d_nchw.arm_cpu", + ) # TODO: # This schedule has incorrect result on some hardware platforms (like NV Jetson TX2) diff --git a/python/tvm/topi/arm_cpu/conv2d.py b/python/tvm/topi/arm_cpu/conv2d.py index ab489161a8fa..bb29de8fa27b 100644 --- a/python/tvm/topi/arm_cpu/conv2d.py +++ b/python/tvm/topi/arm_cpu/conv2d.py @@ -37,6 +37,10 @@ conv2d_nhwc_dsp_compute, conv2d_nhwc_dsp_schedule, ) +from .mprofile.dsp.tensordot_conv2ds import ( + conv2d_nhwc_ohwi_dsp_compute, + tensordot_conv2ds_schedule, +) @autotvm.register_topi_compute("conv2d_nchw_spatial_pack.arm_cpu") @@ -518,3 +522,15 @@ def conv2d_nhwc_dsp(cfg, data, kernel, strides, padding, dilation, out_dtype): def schedule_conv2d_nhwc_dsp(cfg, outs): """Create schedule for conv2d_nhwc_dsp""" return conv2d_nhwc_dsp_schedule(cfg, outs) + + +@autotvm.register_topi_compute("conv2d_nhwc_ohwi_dsp.arm_cpu") +def conv2d_nhwc_ohwi_dsp(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Compute conv2d_nhwc_ohwi with v7e-m DSP instructions and the tensordot kernel.""" + return conv2d_nhwc_ohwi_dsp_compute(cfg, data, kernel, strides, padding, dilation, out_dtype) + + +@autotvm.register_topi_schedule("conv2d_nhwc_ohwi_dsp.arm_cpu") +def schedule_conv2d_nhwc_ohwi_dsp(cfg, outs): + """Create schedule for conv2d_nhwc_ohwi.""" + return tensordot_conv2ds_schedule(cfg, outs) diff --git a/python/tvm/topi/arm_cpu/depthwise_conv2d.py b/python/tvm/topi/arm_cpu/depthwise_conv2d.py index 333db3d5e014..58cd11e8cc09 100644 --- a/python/tvm/topi/arm_cpu/depthwise_conv2d.py +++ b/python/tvm/topi/arm_cpu/depthwise_conv2d.py @@ -27,11 +27,14 @@ from ..nn.utils import get_pad_tuple from .tensor_intrin import smlal_int16_int32 from .arm_utils import is_aarch64_arm - from .mprofile.dsp.depthwise_conv2d import ( depthwise_conv2d_nhwc_dsp_compute, depthwise_conv2d_nhwc_dsp_schedule, ) +from .mprofile.dsp.tensordot_conv2ds import ( + depthwise_conv2d_nchw_oihw_dsp_compute, + tensordot_conv2ds_schedule, +) @autotvm.register_topi_compute("depthwise_conv2d_nchw.arm_cpu") @@ -718,3 +721,17 @@ def depthwise_conv2d_nhwc_dsp(cfg, data, kernel, strides, padding, dilation, out def schedule_depthwise_conv2d_nhwc_dsp(cfg, outs): """Create schedule for conv2d_nhwc_dsp""" return depthwise_conv2d_nhwc_dsp_schedule(cfg, outs) + + +@autotvm.register_topi_compute("depthwise_conv2d_nchw_oihw_dsp.arm_cpu") +def depthwise_conv2d_nchw_oihw_dsp(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Compute depthwise_conv2d_nchw_oihw with v7e-m DSP instructions and the tensordot kernel.""" + return depthwise_conv2d_nchw_oihw_dsp_compute( + cfg, data, kernel, strides, padding, dilation, out_dtype + ) + + +@autotvm.register_topi_schedule("depthwise_conv2d_nchw_oihw_dsp.arm_cpu") +def schedule_depthwise_conv2d_nchw_oihw_dsp(cfg, outs): + """Create schedule for depthwise_conv2d_nchw_oihw.""" + return tensordot_conv2ds_schedule(cfg, outs) diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py new file mode 100644 index 000000000000..0fdffc06cf4f --- /dev/null +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py @@ -0,0 +1,155 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Computes a "jumpy tensordot" operator, which can be used to tensorize many common operators +including regular conv2d, depthwise conv2d, and grouped conv2d provided the data and kernel layouts +are the optimal ones. When groups=1, the optimal data layout is NHWC and kernel layout is OHWI. When +this is a depthwise convolution, the optimal data layout is NCHW and kernel layout is OIHW.""" + +import textwrap + +from tvm import te, tir + +from .common import num_simd_lanes_per_word + + +def _get_func_name(in_dtype, tensor_h, jump, tensor_w, suffix): + """Gets the C function name of the tensordot function.""" + return f"tensordot_{in_dtype}_h{tensor_h}_j{jump}_w{tensor_w}_{suffix}" + + +def make_intrin_tensordot(slices, strides, tensordot_params): + """Helper function for constructing tensordot intrinsic. We can't construct the whole thing here + (as multiple schedules use tensordot and each must build the intrinstic differently) but we can + build part here to simplify the code.""" + + # in_dtype, tensor_h, jump, tensor_w, suffix = tensordot_params + data, kernel, output = slices + data_strides, kernel_strides = strides + + data_buf = tir.decl_buffer( + data.shape, data.dtype, name="data", offset_factor=1, strides=data_strides + ) + kernel_buf = tir.decl_buffer( + kernel.shape, + kernel.dtype, + name="kernel", + offset_factor=1, + strides=kernel_strides, + ) + output_buf = tir.decl_buffer( + output.shape, output.dtype, name="output", offset_factor=1, strides=[1] + ) + + def intrin_func(ins, outs): + builder = tir.ir_builder.create() + builder.emit( + tir.call_extern( + "int32", + _get_func_name(*tensordot_params), + outs[0].access_ptr("w"), + ins[0].access_ptr("r"), + ins[1].access_ptr("r"), + ) + ) + return builder.get() + + return te.decl_tensor_intrin( + output.op, + intrin_func, + binds={data: data_buf, kernel: kernel_buf, output: output_buf}, + ) + + +def tensordot_impl(in_dtype: str, tensor_h: int, jump: int, tensor_w: int, suffix: str) -> str: + """Generates C code for taking the dot products of two `tensor_h` * `tensor_w` tensors. Also has + a `jump` argument that advances the pointer of one tensor by that many words after each row. The + `jump` and `tensor_w` values must be word-aligned for the input data type, as non-word-aligned + memory access is slow on the Cortex-M series. Depending on the input datatype, the code may + contain DSP instructions for Arm v7e-m. C code contains DSP instructions for Arm v7e-m. See + the below pseudocode for reference: + + tensordot(out_ptr, dat_ptr, ker_ptr) { + sum = 0; + for (i = 0; i < tensor_h; i++) { + for (j = 0; j < tensor_w; j++) { + sum += (*dat_ptr++) * (*ker_ptr++); + } + dat_ptr += jump; + } + *out_ptr = sum; + } + """ + + simd_lanes = num_simd_lanes_per_word(in_dtype) + assert tensor_w % simd_lanes == 0 + assert jump % simd_lanes == 0 + + if in_dtype == "int8": + inner_loop = """ + uint32_t tensor_c20 = __SXTB16(tensor_batch); + uint32_t kernel_c20 = __SXTB16(kernel_batch); + sum = __SMLAD(tensor_c20, kernel_c20, sum); + + uint32_t tensor_c31 = __SXTB16(__ROR(tensor_batch, 8)); + uint32_t kernel_c31 = __SXTB16(__ROR(kernel_batch, 8)); + sum = __SMLAD(tensor_c31, kernel_c31, sum);""" + + elif in_dtype == "int16": + inner_loop = """ + sum = __SMLAD(tensor_batch, kernel_batch, sum);""" + + elif in_dtype == "int32": + inner_loop = """ + // Compiles to a single MAC instruction + sum += tensor_batch * kernel_batch;""" + + else: + raise ValueError(f"No tensordot implementation exists for dtype '{in_dtype}'!") + + function_name = _get_func_name(in_dtype, tensor_h, jump, tensor_w, suffix) + return textwrap.dedent( + ( + f""" + #include + #include + + #ifdef __cplusplus + extern "C" + #endif + __STATIC_FORCEINLINE int32_t {function_name}( + uint32_t *out, + uint32_t *tensor, + uint32_t *kernel) {{ + + uint32_t sum = 0; + + #pragma GCC unroll {tensor_h} + for (int i = 0; i < {tensor_h}; i++) {{ + #pragma GCC unroll {tensor_w // simd_lanes} + for (int j = 0; j < {tensor_w // simd_lanes}; j++) {{ + uint32_t tensor_batch = *tensor++; + uint32_t kernel_batch = *kernel++; + {inner_loop.strip()} + }} + tensor += {jump // simd_lanes}; + }} + out[0] = sum; + return 0; + }} + """ + ) + ) diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py b/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py new file mode 100644 index 000000000000..ccd0c8e3ef32 --- /dev/null +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py @@ -0,0 +1,271 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Implementations of several conv2d variations, all tensorized using tensordot and optimized for +Cortex-M DSP. Currently contains a standard conv2d and depthwise conv2d implementation, but could be +extended to add a grouped conv2d operator. Due to the way we tensorize, this schedule ONLY works +when the data and kernel layouts are NCHWxc and OIHWxi respectively, where x is the number of +input channels divided by the number of groups.""" + +import random +import string +from typing import Union, Tuple + +from tvm import te +from tvm.tir import indexdiv, indexmod +from tvm.topi.utils import traverse_inline +from tvm.topi.nn.pad import pad + +from .micro_kernel.tensordot import ( + make_intrin_tensordot, + tensordot_impl, +) + + +def _unpack_2d_argument(argument: Union[int, Tuple]) -> Tuple: + if isinstance(argument, int): + return (argument, argument) + assert len(argument) == 2 + return argument + + +def _check_no_dilation(dilation: Union[int, Tuple]) -> None: + """Takes a dilation argument as an integer or tuple, and makes sure both dimensions are 1. + Dilation prevents us from using DSP instructions, so this schedule can't work (aside from the + niche case where dilation_h == stride_h and dilation_w == stride_w, which is rare enough we + probably don't need to support it).""" + + dilation_h, dilation_w = _unpack_2d_argument(dilation) + assert dilation_h == dilation_w == 1 + + +def _unpack_padding(padding: Tuple) -> Tuple: + assert isinstance(padding, tuple) + if len(padding) == 2: + (pad_up, pad_down), (pad_left, pad_right) = padding + else: + pad_up, pad_left, pad_down, pad_right = padding + return pad_up, pad_left, pad_down, pad_right + + +def _pad_if_needed(data: te.tensor.Tensor, layout: str, padding: Tuple) -> te.tensor.Tensor: + """Performs padding on a te.tensor.Tensor object if necessary. If padding = (0, 0, 0, 0), the + input tensor is returned unmodified. We only care about tuples here - "VALID" and "SAME" padding + will be converted by the importer TFLite importer if present.""" + + pad_up, pad_left, pad_down, pad_right = padding + if not any(padding): + return data + + # We want to pad the "H" and "W" columns, and their position depends on the layout + pad_before, pad_after = [0, 0, 0, 0], [0, 0, 0, 0] + pad_before[layout.index("H")] = pad_up + pad_before[layout.index("W")] = pad_left + pad_after[layout.index("H")] = pad_down + pad_after[layout.index("W")] = pad_right + return pad(data, pad_before, pad_after, name="padded_data") + + +def _compute_output_dim( + data_dim: int, kernel_dim: int, pad_before: int, pad_after: int, stride: int +) -> int: + """Computes an output dimension of a convolution, given the data dimension, kernel dimension, + padding, and stride along that axis. Note that when stride > 1, this division will often not + be perfectly even.""" + return (data_dim + pad_before + pad_after - kernel_dim) // stride + 1 + + +def _get_suffix() -> str: + """Returns a random eight-character string to append to C function names. Prevents accidental + re-definition of functions if the same operator appears twice in a Relay graph.""" + return "".join(random.choices(string.ascii_uppercase, k=8)) + + +def conv2d_nhwc_ohwi_dsp_compute(_cfg, data, kernel, strides, padding, dilation, out_dtype): + """Standard conv2d schedule that can be tensorized using tensordot.""" + + stride_h, stride_w = _unpack_2d_argument(strides) + pad_up, pad_left, pad_down, pad_right = _unpack_padding(padding) + _check_no_dilation(dilation) + + batch_size, data_h, data_w, in_channels = data.shape + output_channels, kernel_h, kernel_w, _ = kernel.shape + assert kernel.shape[3] == in_channels + + output_h = _compute_output_dim(data_h, kernel_h, pad_up, pad_down, stride_h) + output_w = _compute_output_dim(data_w, kernel_w, pad_left, pad_right, stride_w) + + kh_i = te.reduce_axis((0, kernel_h), name="kh_i") + kw_i = te.reduce_axis((0, kernel_w), name="kw_i") + kc_i = te.reduce_axis((0, in_channels), name="rc") + + padded_data = _pad_if_needed(data, "NHWC", (pad_up, pad_left, pad_down, pad_right)) + return te.compute( + (batch_size, output_h, output_w, output_channels), + lambda n, y, x, c: te.sum( + padded_data[n, y * stride_h + kh_i, x * stride_w + kw_i, kc_i].astype(out_dtype) + * kernel[c, kh_i, kw_i, kc_i].astype(out_dtype), + axis=(kh_i, kw_i, kc_i), + ), + name="conv2d", + tag="conv2d_nhwc_ohwi_dsp", + ) + + +def _make_conv2d_tensorization(padded_data, kernel): + _, _, padded_w, in_channels = padded_data.shape + _, kernel_h, kernel_w, _ = kernel.shape + in_dtype = padded_data.dtype + suffix = _get_suffix() + assert in_dtype == kernel.dtype + + data_slice = te.placeholder((kernel_h, kernel_w, in_channels), name="a", dtype=in_dtype) + kernel_slice = te.placeholder((kernel_h, kernel_w, in_channels), name="b", dtype=in_dtype) + + kh_i = te.reduce_axis((0, kernel_h), name="kh_i") + kw_i = te.reduce_axis((0, kernel_w), name="kw_i") + kc_i = te.reduce_axis((0, in_channels), name="kc_i") + + output_slice = te.compute( + (1,), + lambda k: te.sum( + data_slice[kh_i, kw_i, kc_i].astype("int32") + * kernel_slice[kh_i, kw_i, kc_i].astype("int32"), + axis=[kh_i, kw_i, kc_i], + ), + name="c", + ) + + # TVM has a really strange bug where the outer reduction axis (kh_i) having length 1 causes the + # decl_buffer strides check to fail. height_stride is a dark magic workaround for this. + height_stride = in_channels * padded_w if kernel_h > 1 else in_channels + jump = (padded_w - kernel_w) * in_channels + tensordot_params = (in_dtype, kernel_h, jump, kernel_w * in_channels, suffix) + intrin_tensordot = make_intrin_tensordot( + (data_slice, kernel_slice, output_slice), + ([height_stride, in_channels, 1], [kernel_w * in_channels, in_channels, 1]), + tensordot_params, + ) + + tensordot_code = tensordot_impl(*tensordot_params) + return (intrin_tensordot, tensordot_code) + + +def depthwise_conv2d_nchw_oihw_dsp_compute( + _cfg, data, kernel, strides, padding, dilation, out_dtype +): + """Depthwise conv2d schedule that can be tensorized using tensordot.""" + + stride_h, stride_w = _unpack_2d_argument(strides) + pad_up, pad_left, pad_down, pad_right = _unpack_padding(padding) + _check_no_dilation(dilation) + + batch_size, in_channels, data_h, data_w = data.shape + _, c_mul, kernel_h, kernel_w = kernel.shape + output_channels = in_channels * c_mul + assert kernel.shape[0] == in_channels + + output_h = _compute_output_dim(data_h, kernel_h, pad_up, pad_down, stride_h) + output_w = _compute_output_dim(data_w, kernel_w, pad_left, pad_right, stride_w) + + kh_i = te.reduce_axis((0, kernel_h), name="kh_i") + kw_i = te.reduce_axis((0, kernel_w), name="kw_i") + + padded_data = _pad_if_needed(data, "NCHW", (pad_up, pad_left, pad_down, pad_right)) + return te.compute( + (batch_size, output_channels, output_h, output_w), + lambda n, c, y, x: te.sum( + padded_data[ + n, + indexdiv(c, c_mul), + y * stride_h + kh_i, + x * stride_w + kw_i, + ].astype(out_dtype) + * kernel[indexdiv(c, c_mul), indexmod(c, c_mul), kh_i, kw_i].astype(out_dtype), + axis=(kh_i, kw_i), + ), + name="depthwise_conv2d", + tag="depthwise_conv2d_nchw_oihw_dsp", + ) + + +def _make_depthwise_conv2d_tensorization(padded_data, kernel): + _, _, _, padded_w = padded_data.shape + _, _, kernel_h, kernel_w = kernel.shape + + in_dtype = padded_data.dtype + suffix = _get_suffix() + assert in_dtype == kernel.dtype + + data_slice = te.placeholder((kernel_h, kernel_w), name="a", dtype=in_dtype) + kernel_slice = te.placeholder((kernel_h, kernel_w), name="b", dtype=in_dtype) + + kh_i = te.reduce_axis((0, kernel_h), name="kh_i") + kw_i = te.reduce_axis((0, kernel_w), name="kw_i") + + output_slice = te.compute( + (1,), + lambda k: te.sum( + data_slice[kh_i, kw_i].astype("int32") * kernel_slice[kh_i, kw_i].astype("int32"), + axis=[kh_i, kw_i], + ), + name="c", + ) + + jump = padded_w - kernel_w + tensordot_params = (in_dtype, kernel_h, jump, kernel_w, suffix) + intrin_tensordot = make_intrin_tensordot( + (data_slice, kernel_slice, output_slice), + ([padded_w, 1], [kernel_w, 1]), + tensordot_params, + ) + + tensordot_code = tensordot_impl(*tensordot_params) + return (intrin_tensordot, tensordot_code) + + +def tensordot_conv2ds_schedule(_cfg, outs): + """Schedule function using v7e-m DSP instructions for all the conv2d operators in this file. We + use one schedule function for them all, because they are tensorized with the same kernel.""" + + schedule = te.create_schedule([x.op for x in outs]) + + def _callback(operator): + if "conv2d" in operator.tag: + output = operator.output(0) + padded_data = output.op.input_tensors[0] + kernel = output.op.input_tensors[1] + + if operator.tag == "conv2d_nhwc_ohwi_dsp": + b_ax, y_ax, x_ax, co_ax = schedule[output].op.axis + kh_ax, kw_ax, ci_ax = schedule[output].op.reduce_axis + schedule[output].reorder(b_ax, y_ax, x_ax, co_ax, kh_ax, kw_ax, ci_ax) + intrin, code = _make_conv2d_tensorization(padded_data, kernel) + + elif operator.tag == "depthwise_conv2d_nchw_oihw_dsp": + b_ax, y_ax, x_ax, co_ax = schedule[output].op.axis + kh_ax, kw_ax = schedule[output].op.reduce_axis + schedule[output].reorder(b_ax, co_ax, y_ax, x_ax, kh_ax, kw_ax) + intrin, code = _make_depthwise_conv2d_tensorization(padded_data, kernel) + + else: + raise ValueError(f"Cannot tensorize {operator.tag} with tensordot!") + + schedule[output].tensorize(kh_ax, intrin) + schedule[output].pragma(b_ax, "import_c", code) + + traverse_inline(schedule, outs[-1].op, _callback) + return schedule diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index f1c6fb5aa4f4..f6ca03d32742 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -22,9 +22,8 @@ import numpy as np import tvm -from tvm import te +from tvm import relay, te from tvm.tir import bijective_layout, layout - from . import cpp, tag @@ -432,6 +431,33 @@ def get_shape(src_shape, src_layout, dst_layout): return get_const_tuple(tuple([src_shape[i.value] for i in dst_indices])) +def change_constant_shape(src, src_layout, dst_layout): + """Makes a copy of a Relay constant, reshaping it to a new data layout. + + Parameter + --------- + src : relay.Constant + The Constant to be reformatted. + + src_layout : str + The current layout of the Relay constant. Must be alphabetic (e.g. NHWC + or OIHW, but not NCHW2c). + + dst_layout : str + The desired layout of new the Relay constant. Must be alphabetic (e.g. NHWC + or OIHW, but not NCHW2c). + + Returns + ------- + dst_shape : relay.Constant + A copy of the Constant with the new layout. + """ + assert src_layout.isalpha() and dst_layout.isalpha() + axis_order = [src_layout.index(c) for c in dst_layout] + reshaped = np.transpose(src.data.numpy(), axis_order) + return relay.Constant(tvm.nd.array(reshaped)) + + def within_index(b, e, s, i): """Return a boolean value that indicates if i is within the given index. diff --git a/tests/python/relay/strategy/arm_cpu/test_conv2d_nhwc.py b/tests/python/relay/strategy/arm_cpu/test_conv2d_nhwc.py index f5ae6f51dbd7..f5de3b51b67d 100644 --- a/tests/python/relay/strategy/arm_cpu/test_conv2d_nhwc.py +++ b/tests/python/relay/strategy/arm_cpu/test_conv2d_nhwc.py @@ -22,6 +22,7 @@ from tvm import relay from tvm.testing.aot import AOTTestModel, compile_and_run, generate_ref_data from tvm.micro.testing.aot_test_utils import AOT_CORSTONE300_RUNNER +from tvm.topi.utils import change_constant_shape class BasicConv2dTests: @@ -61,11 +62,7 @@ def test_conv2d( ref_mod = tvm.IRModule.from_expr(relay.Function([input0], out0)) input1 = relay.var("input", relay.TensorType(ishape, dtype)) - - if kernel_layout == "HWOI": - weight1 = relay.const(np.moveaxis(weight_data, 2, -1)) - elif kernel_layout == "HWIO": - weight1 = relay.const(weight_data) + weight1 = change_constant_shape(weight0, "HWIO", kernel_layout) out1 = relay.op.nn.conv2d( input1, @@ -150,5 +147,34 @@ class TestConv2d_HWIO(BasicConv2dTests): schedule_name = tvm.testing.parameter("conv2d_nhwc_spatial_pack.arm_cpu") +class TestConv2d_Tensordot(BasicConv2dTests): + data_shape, kernel_size, num_filter, strides, padding = tvm.testing.parameters( + # Disabled because these kernels are not an integral number of words + # ((1, 32, 32, 1), (3, 3), 12, 1, 0), + # ((1, 32, 10, 3), (3, 3), 16, 1, 0), + # ((1, 96, 96, 3), (3, 3), 8, (2, 2), (0, 0, 1, 1)), + ((4, 16, 16, 8), (5, 5), 8, 2, (0, 3, 3, 0)), + ((4, 16, 16, 8), (5, 5), 16, 2, (0, 3, 3, 0)), + ((4, 16, 16, 8), (5, 5), 8, 2, 0), + ((4, 16, 16, 8), (5, 5), 16, 2, 0), + ((1, 16, 16, 32), (1, 1), 64, (2, 2), 0), + ((1, 16, 16, 32), (1, 1), 64, (2, 2), 0), + ((1, 49, 10, 1), (10, 4), 64, (2, 1), (4, 1, 5, 1)), + ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0)), + ((1, 32, 32, 16), (3, 3), 16, 1, 0), + ((1, 32, 32, 16), (3, 3), 16, 1, 0), + ((1, 49, 10, 1), (10, 4), 64, (2, 2), (4, 1, 5, 1)), + ((1, 16, 16, 8), (3, 3), 16, 2, (0, 0, 1, 1)), + ((1, 16, 16, 8), (3, 3), 16, 2, (1, 1, 2, 2)), + ((1, 16, 16, 8), (5, 5), 16, 2, (3, 3, 2, 2)), + ((1, 32, 32, 16), (3, 3), 16, 1, 0), + ((1, 16, 16, 32), (1, 1), 64, 1, 0), + ) + dilation = tvm.testing.parameter(1) + dtype = tvm.testing.parameter("int8", "int16", "int32") + kernel_layout = tvm.testing.parameter("OHWI") + schedule_name = tvm.testing.parameter("conv2d_nhwc_ohwi_dsp.arm_cpu") + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py index 15ea2a31d864..36059c798cbb 100644 --- a/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py +++ b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py @@ -185,5 +185,32 @@ class TestDepthwiseConv2d_NHWC_HWOI_DSP(BasicDepthwiseConv2dTests): schedule_name = tvm.testing.parameter("depthwise_conv2d_nhwc_dsp.arm_cpu") +class TestDepthwiseConv2d_Tensordot(BasicDepthwiseConv2dTests): + data_shape, kernel_size, num_filter, strides, padding, dtype = tvm.testing.parameters( + # Currently, our schedule requires kernel_w be divisible by the number of simd lanes given + # its dtype. This means 3x3 and 5x5 kernels do not work on int16 or int8 for now. If you had + # to, you could hack around this by padding the data and kernel. + ((1, 8, 48, 48), (3, 3), 8, (1, 1), 1, "int32"), + ((1, 16, 48, 48), (3, 3), 16, (2, 2), (1, 1, 0, 0), "int32"), + ((1, 32, 24, 24), (3, 3), 32, (1, 1), 1, "int32"), + ((1, 32, 24, 24), (3, 3), 32, (2, 2), (1, 1, 0, 0), "int32"), + ((1, 64, 12, 12), (3, 3), 64, (1, 1), 1, "int32"), + ((1, 64, 12, 12), (3, 3), 64, (2, 2), (1, 1, 0, 0), "int32"), + ((1, 128, 6, 6), (3, 3), 128, (1, 1), 1, "int32"), + ((1, 128, 6, 6), (3, 3), 128, (2, 2), (1, 1, 0, 0), "int32"), + ((1, 256, 3, 3), (3, 3), 256, (1, 1), 1, "int32"), + ((1, 64, 25, 5), (3, 3), 64, (1, 1), 1, "int32"), + ((1, 8, 24, 24), (5, 5), 8, (1, 1), 1, "int32"), + ((1, 8, 24, 24), (3, 5), 8, (1, 1), 1, "int32"), + # These "evenly divisible" kernels work on smaller dtypes. + ((1, 8, 48, 48), (3, 2), 8, 1, 0, "int16"), + ((1, 8, 48, 48), (4, 4), 8, 1, 0, "int8"), + ) + dilation = tvm.testing.parameter(1) + data_layout = tvm.testing.parameter("NCHW") + kernel_layout = tvm.testing.parameter("OIHW") + schedule_name = tvm.testing.parameter("depthwise_conv2d_nchw_oihw_dsp.arm_cpu") + + if __name__ == "__main__": tvm.testing.main()