From 262aefc76a6cc8a47325fbf5339e40d09258e6dd Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Thu, 22 Jul 2021 19:24:22 +0100 Subject: [PATCH 01/14] Arm(R) Ethos(TM)-U NPU Relay passes and Conv2D op This commit adds mainly the relay passes and ethosu_conv2d operator to relay. The relay passes include the legalizations and preprocessing of the relay graph in the paritioning. Co-authored-by: Matthew Barrett --- CMakeLists.txt | 1 + cmake/modules/contrib/EthosU.cmake | 21 + python/tvm/relay/backend/contrib/__init__.py | 1 + .../relay/backend/contrib/ethosu/__init__.py | 23 + .../relay/backend/contrib/ethosu/_ffi_api.py | 20 + .../relay/backend/contrib/ethosu/errors.py | 38 ++ .../relay/backend/contrib/ethosu/legalize.py | 200 ++++++++ .../backend/contrib/ethosu/op/__init__.py | 19 + .../backend/contrib/ethosu/op/convolution.py | 202 ++++++++ .../backend/contrib/ethosu/preprocess.py | 27 ++ .../backend/contrib/ethosu/te/__init__.py | 19 + .../backend/contrib/ethosu/te/convolution.py | 199 ++++++++ .../relay/backend/contrib/ethosu/te/dma.py | 299 ++++++++++++ .../tvm/relay/backend/contrib/ethosu/util.py | 198 ++++++++ .../relay/backend/contrib/ethosu/vela_api.py | 314 ++++++++++++ python/tvm/relay/op/contrib/ethosu.py | 251 ++++++++++ .../backend/contrib/ethosu/preprocess.cc | 268 +++++++++++ src/relay/op/contrib/ethosu/common.cc | 65 +++ src/relay/op/contrib/ethosu/common.h | 58 +++ src/relay/op/contrib/ethosu/convolution.cc | 212 ++++++++ .../contrib/test_ethosu/relay_ir_builder.py | 295 ++++++++++++ .../contrib/test_ethosu/test_legalize.py | 343 +++++++++++++ .../contrib/test_ethosu/test_preprocess.py | 343 +++++++++++++ .../contrib/test_ethosu/test_vela_api.py | 453 ++++++++++++++++++ tests/scripts/task_config_build_cpu.sh | 1 + 25 files changed, 3870 insertions(+) create mode 100644 cmake/modules/contrib/EthosU.cmake create mode 100644 python/tvm/relay/backend/contrib/ethosu/__init__.py create mode 100644 python/tvm/relay/backend/contrib/ethosu/_ffi_api.py create mode 100644 python/tvm/relay/backend/contrib/ethosu/errors.py create mode 100644 python/tvm/relay/backend/contrib/ethosu/legalize.py create mode 100644 python/tvm/relay/backend/contrib/ethosu/op/__init__.py create mode 100644 python/tvm/relay/backend/contrib/ethosu/op/convolution.py create mode 100644 python/tvm/relay/backend/contrib/ethosu/preprocess.py create mode 100644 python/tvm/relay/backend/contrib/ethosu/te/__init__.py create mode 100644 python/tvm/relay/backend/contrib/ethosu/te/convolution.py create mode 100644 python/tvm/relay/backend/contrib/ethosu/te/dma.py create mode 100644 python/tvm/relay/backend/contrib/ethosu/util.py create mode 100644 python/tvm/relay/backend/contrib/ethosu/vela_api.py create mode 100644 python/tvm/relay/op/contrib/ethosu.py create mode 100644 src/relay/backend/contrib/ethosu/preprocess.cc create mode 100644 src/relay/op/contrib/ethosu/common.cc create mode 100644 src/relay/op/contrib/ethosu/common.h create mode 100644 src/relay/op/contrib/ethosu/convolution.cc create mode 100644 tests/python/contrib/test_ethosu/relay_ir_builder.py create mode 100644 tests/python/contrib/test_ethosu/test_legalize.py create mode 100644 tests/python/contrib/test_ethosu/test_preprocess.py create mode 100644 tests/python/contrib/test_ethosu/test_vela_api.py diff --git a/CMakeLists.txt b/CMakeLists.txt index e02bb615aa41..62598cbdf4a7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -403,6 +403,7 @@ include(cmake/modules/LLVM.cmake) include(cmake/modules/Micro.cmake) include(cmake/modules/contrib/EthosN.cmake) include(cmake/modules/contrib/CMSISNN.cmake) +include(cmake/modules/contrib/EthosU.cmake) include(cmake/modules/contrib/BLAS.cmake) include(cmake/modules/contrib/CODEGENC.cmake) include(cmake/modules/contrib/DNNL.cmake) diff --git a/cmake/modules/contrib/EthosU.cmake b/cmake/modules/contrib/EthosU.cmake new file mode 100644 index 000000000000..8f3e09b8179b --- /dev/null +++ b/cmake/modules/contrib/EthosU.cmake @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(USE_ETHOSU) + file(GLOB ETHOSU_RELAY_CONTRIB_SRC src/relay/backend/contrib/ethosu/*) + list(APPEND COMPILER_SRCS ${ETHOSU_RELAY_CONTRIB_SRC}) +endif(USE_ETHOSU) \ No newline at end of file diff --git a/python/tvm/relay/backend/contrib/__init__.py b/python/tvm/relay/backend/contrib/__init__.py index 16b83612d797..9074e40af08b 100644 --- a/python/tvm/relay/backend/contrib/__init__.py +++ b/python/tvm/relay/backend/contrib/__init__.py @@ -16,3 +16,4 @@ # under the License. """external backend codegen modules for relay.""" from . import cmsisnn +from . import ethosu diff --git a/python/tvm/relay/backend/contrib/ethosu/__init__.py b/python/tvm/relay/backend/contrib/ethosu/__init__.py new file mode 100644 index 000000000000..3f315a74cbaa --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/__init__.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Arm(R) Ethos(TM)-U NPU codegen modules for relay.""" +from . import util +from . import legalize +from . import preprocess +from . import errors +from . import vela_api +from .util import partition_for_ethosu diff --git a/python/tvm/relay/backend/contrib/ethosu/_ffi_api.py b/python/tvm/relay/backend/contrib/ethosu/_ffi_api.py new file mode 100644 index 000000000000..a0175ba17a56 --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for relay transformation passes.""" +import tvm._ffi + +tvm._ffi._init_api("relay.ext.ethosu", __name__) diff --git a/python/tvm/relay/backend/contrib/ethosu/errors.py b/python/tvm/relay/backend/contrib/ethosu/errors.py new file mode 100644 index 000000000000..8625ddc880b7 --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/errors.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=super-init-not-called +"""This module is to hold all type of errors associated Arm(R) Ethos(TM)-U NPU Codegen""" + + +class EthosUCodegenError(Exception): + """Base class for all exceptions related to Codegen""" + + def __init__(self, data): + self.message = "EthosUCodegenError:" + data + + def __str__(self): + return self.message + + +class UnsupportedLayout(EthosUCodegenError): + """Raised when unsupported layout is encountered in the codegen""" + + def __init__(self, layout): + super().__init__(f"Unsupported Layout {layout}") + + def __str__(self): + return self.message diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py new file mode 100644 index 000000000000..2e54ffb25fc6 --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -0,0 +1,200 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument, import-outside-toplevel +""" A set of passes to legalize some of operations for the NPU""" +import numpy as np + +import tvm +from tvm import relay +from tvm.relay.dataflow_pattern import DFPatternCallback +from tvm.relay.dataflow_pattern import wildcard +from tvm.relay.dataflow_pattern import is_op +from tvm.relay.dataflow_pattern import rewrite +from tvm.relay.backend.contrib.ethosu import op as ethosu_ops +from tvm.relay.backend.contrib.ethosu.errors import UnsupportedLayout +from tvm.relay.backend.contrib.ethosu import vela_api +from tvm.relay.op.contrib import ethosu as ethosu_patterns + + +class SplitRewriter(DFPatternCallback): + """Convert split operations to bunch of strided_slice operations, + because codegen is going to be based on strided_slices that are + close to in/out boxes of Vela High-Level Command Stream (HLCS). + Moreover, Vela HLCS is a high-level description of the supported + hardware operator. + """ + + def __init__(self): + super().__init__(require_type=True) + self.split_in = wildcard() + self.pattern = is_op("split")(self.split_in) + + @staticmethod + def get_section_begin_coords(split): + """Currently, the split can take an array of indices or an integer + indicating the number of splits. This helper functions unifies + this by making it a array of section begins. + + Parameters + ---------- + split : relay.Expr + The relay expression for split operator + + Returns + ------- + section_begins : list + A list containing integers corresponding to section + begins + """ + indices_or_sections = split.attrs.indices_or_sections + input_shape = split.args[0].checked_type.shape + split_axis = split.attrs.axis + + if isinstance(indices_or_sections, tvm.ir.container.Array): + # 0 is the beginning of the first section. + return [0] + list(indices_or_sections) + split_axis_len = input_shape[split_axis].value + section_length = split_axis_len // indices_or_sections.value + section_begins = list(range(0, split_axis_len, section_length)) + return section_begins + + def callback(self, pre, post, node_map): + splits_types = dict() + split_input = post.args[0] + for idx, field_type in enumerate(post.checked_type.fields): + split = relay.TupleGetItem(post, idx) + splits_types[split] = field_type + + split_begins = list() + split_ends = list() + section_begins_in_split_axis = self.get_section_begin_coords(post) + for split_cord in section_begins_in_split_axis: + # first begin is [0, 0, ... , 0] + begin_shape = [0 for i in range(len(split_input.checked_type.shape))] + begin_shape[post.attrs.axis] = split_cord + split_begins.append(begin_shape) + + end_shape = list(split_input.checked_type.shape) + # Only the split axis coordinate changes + end_shape[post.attrs.axis] = split_cord + split_ends.append(end_shape) + + # Coordinates needs to be shifted left because beginning + # of the next section is the end of the previous + split_ends = split_ends[1:] + # Last section end is the shape of the tensor itself. + split_ends.append(list(split_input.checked_type.shape)) + + strided_slices = list() + for sb, se in zip(split_begins, split_ends): + strided_slices.append(relay.strided_slice(split_input, sb, se)) + + return relay.Tuple(strided_slices) + + +class EthosUConv2DRewriter(DFPatternCallback): + """Convert conv2d related composite functions to ethosu_conv2d operators""" + + def __init__(self): + super().__init__(require_type=True) + self.pattern = (wildcard().has_attr({"Composite": "ethosu.qnn_conv2d"}))(wildcard()) + + def callback(self, pre, post, node_map): + params = ethosu_patterns.QnnConv2DParams(post.op.body) + params.ifm.tensor = post.args[0] + channels_map = { + "NHWC": 3, + } + if str(params.ofm.layout) not in channels_map.keys(): + raise UnsupportedLayout(str(params.ofm.layout)) + kernel_size_map = { + "HWIO": params.weights.shape[0:2], + "OHWI": params.weights.shape[1:3], + "HWOI": params.weights.shape[0:2], + } + if str(params.weights.layout) not in kernel_size_map.keys(): + raise UnsupportedLayout(str(params.weights.layout)) + activation_map = {"clip": "CLIP"} + weight_to_ohwi_transform_map = {"HWIO": [3, 0, 1, 2]} + weights_values = params.weights.values + weights_values_ohwi = np.transpose( + weights_values, weight_to_ohwi_transform_map[str(params.weights.layout)] + ) + if params.activation: + activation = activation_map[params.activation.op.name] + clip_min = int(params.activation.attrs.a_min) + clip_max = int(params.activation.attrs.a_max) + else: + activation = "NONE" + clip_min = 0 + clip_max = 0 + scale_bias = vela_api.pack_biases( + biases=params.biases.tensor.data.asnumpy(), + ifm_scale=params.ifm.q_params.scale_f32, + ifm_dtype=np.dtype(params.ifm.dtype), + weight_scales=params.weights.q_params.scale_f32, + ofm_scale=params.ofm.q_params.scale_f32, + is_activation_tanh_or_sigmoid=activation in ["TANH", "SIGMOID"], + ) + ethosu_conv2d = ethosu_ops.ethosu_conv2d( + ifm=post.args[0], + weight=relay.const(weights_values_ohwi, params.weights.values.dtype), + scale_bias=relay.const(scale_bias, "uint8"), + lut=relay.const([], dtype="int8"), + ifm_scale=float(params.ifm.q_params.scale_f32), + ifm_zero_point=int(params.ifm.q_params.zero_point), + weight_zero_point=int(params.weights.q_params.zero_point), + ofm_scale=float(params.ofm.q_params.scale_f32), + ofm_zero_point=int(params.ofm.q_params.zero_point), + kernel_shape=kernel_size_map[str(params.weights.layout)], + ofm_channels=params.ofm.shape[channels_map[str(params.ofm.layout)]], + strides=params.strides, + padding=params.padding, + dilation=params.dilation, + activation=activation, + clip_min=clip_min, + clip_max=clip_max, + upscale="NONE", + ifm_layout=str(params.ifm.layout), + ofm_layout=str(params.ofm.layout), + ) + return ethosu_conv2d + + +class LegalizeEthosU: + """This is the wrapper class to call graph-rewrites to perform graph transformation + in a way such that the operations are replaced with hardware/codegen backend friendly + operations. + """ + + def __call__(self, func): + """The list of relay re-write passes need to be run to legalize + the external function for to be codegen'd. + + Parameters + ---------- + func : relay.function.Function + The external function + + Returns + ------- + func : relay.function.Function + The legalized external function + """ + func = rewrite(SplitRewriter(), func) + func = rewrite(EthosUConv2DRewriter(), func) + return func diff --git a/python/tvm/relay/backend/contrib/ethosu/op/__init__.py b/python/tvm/relay/backend/contrib/ethosu/op/__init__.py new file mode 100644 index 000000000000..0406298f23f4 --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/op/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"Relay operators for the Arm(R) Ethos(TM)-U NPU" + +from .convolution import ethosu_conv2d diff --git a/python/tvm/relay/backend/contrib/ethosu/op/convolution.py b/python/tvm/relay/backend/contrib/ethosu/op/convolution.py new file mode 100644 index 000000000000..790f0645af3f --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/op/convolution.py @@ -0,0 +1,202 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""Relay operators for convolutions for Arm(R) Ethos(TM)-U NPU""" +import tvm +from tvm.relay.op import _make +from tvm.topi.generic import schedule_injective +from tvm.relay.op.op import OpStrategy +from tvm.relay.op import strategy as _strategy + +from ..te import conv2d_compute + + +def _extract_ethosu_conv2d_params(attrs, args): + """Get the parameters necessary to construct a compute TE + from a ethosu_conv2d Relay call.""" + ifm = args[0] + weight = args[1] + scale_bias = args[2] + lut = args[3] + ifm_scale = attrs.ifm_scale + ifm_zero_point = attrs.ifm_zero_point + weight_zero_point = attrs.weight_zero_point + ofm_scale = attrs.ofm_scale + ofm_zero_point = attrs.ofm_zero_point + strides = attrs.strides + padding = attrs.padding + dilation = attrs.dilation + activation = attrs.activation + clip_min = attrs.clip_min + clip_max = attrs.clip_max + upscale = attrs.upscale + ifm_layout = attrs.ifm_layout + ofm_layout = attrs.ofm_layout + + return ( + ifm, + weight, + scale_bias, + lut, + ifm_scale, + ifm_zero_point, + weight_zero_point, + ofm_scale, + ofm_zero_point, + strides, + padding, + dilation, + activation, + clip_min, + clip_max, + upscale, + ifm_layout, + ofm_layout, + ) + + +@tvm.ir.register_op_attr("contrib.ethosu.conv2d", "FTVMCompute") +def create_ethosu_conv2d_compute(attrs, args, out_type): + """Create an ethosu_conv2d compute op.""" + params = _extract_ethosu_conv2d_params(attrs, args) + op = conv2d_compute(*params) + return [op] + + +@tvm.ir.register_op_attr("contrib.ethosu.conv2d", "FTVMStrategy") +def conv2d_strategy_ethosu(attrs, inputs, out_type, target): + strategy = OpStrategy() + strategy.add_implementation( + create_ethosu_conv2d_compute, + _strategy.wrap_topi_schedule(schedule_injective), + name="ethosu_conv2d", + ) + return strategy + + +def ethosu_conv2d( + ifm, + weight, + scale_bias, + lut, + ifm_scale, + ifm_zero_point, + weight_zero_point, + ofm_scale, + ofm_zero_point, + kernel_shape, + ofm_channels, + strides=(1, 1), + padding=(0, 0, 0, 0), + dilation=(1, 1), + activation="NONE", + clip_min=0, + clip_max=0, + upscale="NONE", + ifm_layout="NHWC", + ofm_layout="NHWC", +): + """This is a quantized 2D convolution operation as supported by the + the NPU. It accepts either NHWC or NHCWB16 format + for the input data and OHWI format for the kernel weights. + + Reference: https://developer.arm.com/documentation/102420/0200/ + + Note that the per-channel weight scale and bias tensor must be + packed together into a combined tensor of uint80s. This is represented + in TVM by a (channels, 10) tensor of type uint8. For more detail, + refer to the Technical Reference Manual linked above. + + Parameters + ---------- + ifm : tvm.relay.Expr + The Input Feature Map tensor (IFM). + weight : tvm.relay.Expr + The weight tensor. + scale_bias : tvm.relay.Expr + The packed per-channel weight scale and bias tensor. + lut : tvm.relay.Expr + The look-up table values to use if activation = "LUT". + ifm_scale : float + The quantization scale for the Input Feature Map tensor. + ifm_zero_point : int + The quantization zero point for the Input Feature Map tensor. + weight_zero_point : int + The quantization zero point for the weight tensor. + ofm_scale : int + The quantization scale for the Output Feature Map tensor. + ofm_zero_point : int + The quantization zero point for the Output Feature Map tensor. + kernel_shape : tuple of int + The 2 dimensional kernel shape as (kernel_height, kernel_width). + ofm_channels : int + The number of OFM channels. + strides : tuple of int, optional + The 2 dimensional strides as (stride_height, stride_width). + padding : tuple of int, optional + The 4 dimensional padding as (pad_top, pad_left, pad_bottom, pad_right). + dilation : tuple of int, optional + The 2 dimensional dilation as (dilation_height, dilation_width). + activation : str, optional + The activation function to use. + "NONE" - no activation function. + "CLIP" - clip the output between clip_min and clip_max. + "TANH" - tanh activation function. + "SIGMOID" - sigmoid activation function. + "LUT" - use a look-up table to perform the activation function. + clip_min : int, optional + The minimum clipping value if activation = "CLIP" + clip_max : int, optional, + The maximum clipping value if activation = "CLIP" + upscale : str, optional + The 2x2 upscaling mode to apply to the Input Feature Map tensor. + "NONE" - no upscaling. + "NEAREST" - upscale using nearest neighbour. + "ZEROS" - upscale using zeros. + ifm_layout : str, optional + The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16". + ofm_layout : str, optional + The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16". + + Returns + ------- + tvm.relay.Call + A call to the ethosu_conv2d op. + + """ + return _make.ethosu_conv2d( + ifm, + weight, + scale_bias, + lut, + ifm_scale, + ifm_zero_point, + weight_zero_point, + ofm_scale, + ofm_zero_point, + kernel_shape, + ofm_channels, + strides, + padding, + dilation, + activation, + clip_min, + clip_max, + upscale, + ifm_layout, + ofm_layout, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/preprocess.py b/python/tvm/relay/backend/contrib/ethosu/preprocess.py new file mode 100644 index 000000000000..77035b5b0826 --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/preprocess.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument, import-outside-toplevel +"""Set of passes to pre-process the IRModule prior to codegen""" +from . import _ffi_api + + +def preprocess_ext_io(): + """This function make the number of inputs going to / outputs coming out to/from + external function set to one. This is achieved via concatenation + of inputs and splitting of outputs in around the call to the external function. + """ + return _ffi_api.PreprocessExternalFuncIO() diff --git a/python/tvm/relay/backend/contrib/ethosu/te/__init__.py b/python/tvm/relay/backend/contrib/ethosu/te/__init__.py new file mode 100644 index 000000000000..7ca5de3c160c --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/te/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tensor Expressions for the NPU""" + +from .convolution import * diff --git a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py new file mode 100644 index 000000000000..a11974025f2e --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py @@ -0,0 +1,199 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-argument +"""Tensor Expressions for convolutions for the NPU""" +from tvm import te +from .dma import dma_ofm_compute, dma_ifm_compute + + +def process_stride(stride): + """Process the striding into a common format. + + Parameters + ---------- + stride : Union[int, tuple, list] + The 2D striding. + int -> striding is the same in the height and width axis. + 2D -> striding specified as (stride height, stride width). + + Returns + ------- + int + The stride in the height axis. + int + The stride in the width axis. + + """ + assert isinstance(stride, int) or len(stride) == 2 + if isinstance(stride, int): + return stride, stride + + return stride + + +def process_dilation(dilation): + """Process the dilation into a common format. + + Parameters + ---------- + dilation : Union[int, tuple, list] + The 2D dilation. + int -> dilation is the same in the height and width axis. + 2D -> dilation specified as (dilation height, dilation width). + + Returns + ------- + int + The dilation in the height axis. + int + The dilation in the width axis. + + """ + assert isinstance(dilation, int) or len(dilation) == 2 + if isinstance(dilation, int): + return dilation, dilation + + return dilation + + +def conv2d_compute( + ifm, + weight, + scale_bias, + lut, + ifm_scale, + ifm_zero_point, + weight_zero_point, + ofm_scale, + ofm_zero_point, + strides, + padding, + dilation, + activation, + clip_min, + clip_max, + upscale, + ifm_layout, + ofm_layout, +): + """A compute operator representing the capabilities of a 2D convolution for the NPU. + + Parameters + ---------- + ifm : te.Tensor + The Input Feature Map tensor (IFM). + weight : te.Tensor + The weight tensor. + scale_bias : te.Tensor + The packed per-channel weight scale and bias tensor. + lut : te.Tensor + The look-up table values to use if activation = "LUT". + ifm_scale : float + The quantization scale for the Input Feature Map tensor. + ifm_zero_point : int + The quantization zero point for the Input Feature Map tensor. + weight_zero_point : int + The quantization zero point for the weight tensor. + ofm_scale : float + The quantization scale for the Output Feature Map tensor. + ofm_zero_point : int + The quantization zero point for the Output Feature Map tensor. + strides : tuple + The 2 dimensional strides as (stride_height, stride_width). + padding : tuple + The 4 dimensional padding as (pad_top, pad_left, pad_bottom, pad_right). + dilation : Union[int, tuple, list] + The 2 dimensional dilation as (dilation_height, dilation_width). + activation : str + The activation function to use. + "NONE" - no activation function. + "CLIP" - clip the output between clip_min and clip_max. + "TANH" - tanh activation function. + "SIGMOID" - sigmoid activation function. + "LUT" - use a look-up table to perform the activation function. + clip_min : int + The minimum clipping value if activation = "CLIP". + clip_max : int + The maximum clipping value if activation = "CLIP". + upscale : str + The 2x2 upscaling mode to apply to the Input Feature Map tensor. + "NONE" - no upscaling. + "NEAREST" - upscale using nearest neighbour. + "ZEROS" - upscale using zeros. + ifm_layout : str + The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16". + ofm_layout : str + The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16". + + Returns + ------- + te.Tensor + The OFM tensor. + + """ + assert ifm.shape[0] == 1 + assert ifm_layout in {"NHWC", "NHCWB16"} + assert ofm_layout in {"NHWC", "NHCWB16"} + + stride_h, stride_w = strides + dilation_h, dilation_w = dilation + ofm_channels, kernel_h, kernel_w, ifm_channels = weight.shape + + # Compute operation for the IFM DMA pipeline + dmaed_ifm = dma_ifm_compute( + ifm, ifm_layout, ifm_zero_point, ifm_scale, weight.shape[3], padding + ) + + # 2D Convolution compute operation + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + ofm_height = (dmaed_ifm.shape[1] - dilated_kernel_h) // stride_h + 1 + ofm_width = (dmaed_ifm.shape[2] - dilated_kernel_w) // stride_w + 1 + rc = te.reduce_axis((0, ifm_channels), name="rc") + rh = te.reduce_axis((0, kernel_h), name="ry") + rw = te.reduce_axis((0, kernel_w), name="rx") + + conv2d_attrs = { + "op": "ethosu_conv2d", + "weight_zero_point": weight_zero_point, + "activation": activation, + "upscale": upscale, + "clip_min": clip_min, + "clip_max": clip_max, + "stride_h": stride_h, + "stride_w": stride_w, + "dilation_h": dilation_h, + "dilation_w": dilation_w, + } + + conv = te.compute( + (1, ofm_height, ofm_width, ofm_channels), + lambda nn, hh, ww, cc: te.sum( + dmaed_ifm( + nn, hh * stride_h + rh * dilation_h, ww * stride_w + rw * dilation_w, rc + ).astype(ifm.dtype) + * weight[cc, rh, rw, rc].astype(ifm.dtype) + # This is a trick to load 10 elements of the scale_bias at once, not accurate maths + + (scale_bias[cc, 0] * scale_bias[cc, 9]), + axis=[rh, rw, rc], + ), + name="ethosu_conv2d", + attrs=conv2d_attrs, + ) + + # Compute operation for the OFM DMA pipeline + return dma_ofm_compute(conv, ofm_layout, ofm_zero_point, ofm_scale, ofm_channels) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/dma.py b/python/tvm/relay/backend/contrib/ethosu/te/dma.py new file mode 100644 index 000000000000..3f8c8d1e7eef --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/te/dma.py @@ -0,0 +1,299 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unnecessary-lambda +"""Tensor Expressions for operations supported by the DMA engine""" +import tvm +from tvm import te +from tvm.topi.utils import equal_const_int + + +def _pad_tensor(tensor, pad_before, pad_after=None): + """Generate a padded tensor. + + Parameters + ---------- + tensor : te.Tensor + The tensor to pad. + pad_before : tuple of int + The 'before' padding on each axis. + pad_after : tuple of int + The 'after' padding on each axis. + Returns + ------- + _pad : callable + The padded tensor. + + """ + pad_after = pad_after or pad_before + dims = len(tensor.shape) + assert len(pad_before) == dims + assert len(pad_after) == dims + + def _pad(*indices): + not_zero = [] + index_tuple = [] + for i in range(dims): + if equal_const_int(pad_before[i], 0) and equal_const_int(pad_after[i], 0): + index_tuple.append(indices[i]) + else: + index_tuple.append(indices[i] - pad_before[i]) + not_zero.append(indices[i] >= pad_before[i]) + not_zero.append(indices[i] < tensor.shape[i] + pad_before[i]) + if not_zero: + not_zero = tvm.tir.all(*not_zero) + return tvm.tir.if_then_else(not_zero, tensor(*index_tuple), tvm.tir.const(0, "uint8")) + return tensor(*index_tuple) + + return _pad + + +def read_compute(tensor, layout, zero_point, scale): + """A TE compute operator to represent a read. + + Parameters + ---------- + tensor : te.Tensor + The tensor to read. + layout : str + The layout of the tensor, either NHWC or NHCWB16. + zero_point : int + The zero point of the tensor. + scale : float + The scale of the tensor. + + Returns + ------- + te.Tensor + The tensor having been read. + + """ + assert layout in {"NHWC", "NHCWB16"} + read_attrs = { + "op": "ethosu_read", + "layout": layout, + "zero_point": zero_point, + "scale": scale, + } + return te.compute(tensor.shape, lambda *i: tensor(*i), name="ethosu_read", attrs=read_attrs) + + +def write_compute(tensor, layout, zero_point, scale): + """A TE compute operator to represent a write. + + Parameters + ---------- + tensor : te.Tensor + The tensor to write. + layout : str + The layout of the tensor, either NHWC or NHCWB16. + zero_point : int + The zero point of the tensor. + scale : float + The scale of the tensor. + + Returns + ------- + te.Tensor + The tensor having been written. + + """ + assert layout in {"NHWC", "NHCWB16"} + write_attrs = { + "op": "ethosu_write", + "layout": layout, + "zero_point": zero_point, + "scale": scale, + } + return te.compute( + tensor.shape, + lambda *i: tensor(*i), + name="ethosu_write", + attrs=write_attrs, + ) + + +def convert_to_nhwc_compute(tensor, layout, channels): + """Converts a tensor into NHWC layout if it's in NHWCB16 layout. + + Parameters + ---------- + tensor : te.Tensor + The tensor to convert. + layout : str + The layout of the tensor, either NHWC or NHCWB16. + channels : int + The number of valid channels for the tensor. + + Returns + ------- + te.Tensor + The converted tensor in NHWC layout. + + """ + assert layout in {"NHWC", "NHCWB16"} + convert_to_nhwc_attrs = { + "op": "ethosu_convert_to_nhwc", + "layout": layout, + } + if layout == "NHCWB16": + return te.compute( + (tensor.shape[0], tensor.shape[1], tensor.shape[3], channels), + lambda nn, hh, ww, cc: tensor(nn, hh, te.indexdiv(cc, 16), ww, te.indexmod(cc, 16)), + name="ethosu_convert_to_nhwc", + attrs=convert_to_nhwc_attrs, + ) + + return te.compute( + tensor.shape, + lambda *i: tensor(*i), + name="ethosu_convert_to_nhwc", + attrs=convert_to_nhwc_attrs, + ) + + +def convert_to_nhcwb16_compute(tensor, layout, channels): + """Converts a tensor into NHCWB16 layout if it's in NHWC layout. + + Parameters + ---------- + tensor : te.Tensor + The tensor to convert. + layout : str + The layout of the tensor, either NHWC or NHCWB16. + channels : int + The number of valid channels for the tensor. + + Returns + ------- + te.Tensor + The converted tensor in NHCWB16 layout. + + """ + assert layout in {"NHWC", "NHCWB16"} + convert_to_nhcwb16_attrs = { + "op": "ethosu_convert_to_nhcwb16", + "layout": layout, + } + if layout == "NHCWB16": + out_channel_bricks = te.indexdiv(channels - 1, 16) + 1 + output_shape = (1, tensor.shape[1], out_channel_bricks, tensor.shape[2], 16) + return te.compute( + output_shape, + lambda nn, hh, cc, ww, cb: tvm.tir.if_then_else( + cc * 16 + cb < channels, + tensor(nn, hh, ww, cc * 16 + cb), + tvm.tir.IntImm(tensor.dtype, 0), + ), + name="ethosu_convert_to_nhcwb16", + attrs=convert_to_nhcwb16_attrs, + ) + + return te.compute( + tensor.shape, + lambda *i: tensor(*i), + name="ethosu_convert_to_nhcwb16", + attrs=convert_to_nhcwb16_attrs, + ) + + +def pad_compute(tensor, padding): + """Pad an NHWC tensor in the height and width axes. + + Parameters + ---------- + tensor : te.Tensor + The tensor to pad. + padding : tuple + The 4 dimensional padding as (pad_top, pad_left, pad_bottom, pad_right). + + Returns + ------- + te.Tensor + The padded tensor. + + """ + pad_top, pad_left, pad_down, pad_right = padding + pad_before = [0, pad_top, pad_left, 0] + pad_after = [0, pad_down, pad_right, 0] + pad_attrs = { + "op": "ethosu_pad", + } + shape = tensor.shape + return te.compute( + (shape[0], shape[1] + pad_top + pad_down, shape[2] + pad_left + pad_right, shape[3]), + lambda nn, hh, ww, cc: _pad_tensor(tensor, pad_before, pad_after)(nn, hh, ww, cc), + name="ethosu_pad", + attrs=pad_attrs, + ) + + +def dma_ifm_compute(ifm, layout, zero_point, scale, channels, padding): + """A sequence of compute operators representing the DMA capabilities for an IFM. + + Parameters + ---------- + ifm : te.Tensor + The Input Feature Map (IFM) tensor. + layout : str + The layout of the data, either NHWC or NHCWB16. + zero_point : int + The zero point of the data. + scale : float + The scale of the data. + channels : int + The number of valid channels for the data. + padding : Union[int, tuple, list] + The desired padding. + int -> padding applied to both height and width axes. + 2D -> padding applied equally on both sides of the (height, width) axes. + 4D -> padding applied as (top, left, bottom, right) + + Returns + ------- + te.Tensor + The dma-ed IFM tensor. + + """ + read_ifm = read_compute(ifm, layout, zero_point, scale) + convert_to_nhwc_ifm = convert_to_nhwc_compute(read_ifm, layout, channels) + return pad_compute(convert_to_nhwc_ifm, padding) + + +def dma_ofm_compute(ofm, layout, zero_point, scale, channels): + """A sequence of compute operators representing the DMA capabilities for an OFM. + + Parameters + ---------- + ofm : te.Tensor + The Output Feature Map (OFM) tensor. + layout : str + The layout of the data, either NHWC or NHCWB16. + zero_point : int + The zero point of the data. + scale : float + The scale of the data. + channels : int + The number of valid channels for the data. + + Returns + ------- + te.Tensor + The dma-ed OFM tensor. + + """ + convert_to_nhcwb16_ofm = convert_to_nhcwb16_compute(ofm, layout, channels) + return write_compute(convert_to_nhcwb16_ofm, layout, zero_point, scale) diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py new file mode 100644 index 000000000000..45b3c6731809 --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -0,0 +1,198 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Helper utility Enums and Functions used through out codegen + +The enums are there to indicate which argument of each relay operator +corresponds with which input. +e.g., input zero point of qnn.conv2d is 4th argument(3rd index) + +The rest of the utility functions are misc. +Refer to the description inside such functions +""" + +from enum import Enum +import numpy as np + +from tvm import relay +from tvm.relay.build_module import bind_params_by_name +from tvm.relay.backend.contrib.ethosu import preprocess + + +class QConv2DArgs(Enum): + """ + This is a helper enums to access the correct index + qnn conv2d arguments + """ + + ifm = 0 + weights = 1 + ifm_zero_point = 2 + weights_zero_point = 3 + ifm_scale = 4 + weights_scale = 5 + + +class RequantArgs(Enum): + """ + This is a helper enums to access the correct index + qnn requantize arguments + """ + + ifm_scale = 1 + ifm_zero_point = 2 + ofm_scale = 3 + ofm_zero_point = 4 + + +class BiasAddArgs(Enum): + """ + This is a helper enums to access the correct index + qnn bias_add arguments + """ + + biases = 1 + + +class ClipArgs(Enum): + """ + This is a helper enums to access the correct index + qnn bias_add arguments + """ + + a_min = 1 + a_max = 2 + + +class MaxPoolArgs(Enum): + """ + This is a helper enums to access the correct index + max pool arguments + """ + + ifm = 0 + + +class AddArgs(Enum): + """This is a helper enums to access the correct index + max pool arguments + """ + + ifm0 = 0 + ifm1 = 1 + ifm0_scale = 2 + ifm0_zero_point = 3 + ifm1_scale = 4 + ifm1_zero_point = 5 + ofm_scale = 6 + ofm_zero_point = 7 + + +def is_composite_func(func, name): + """ + This a method to check whether the call is to + a composite function of the "name". + """ + if not hasattr(func, "attrs"): + return False + if "Composite" not in func.attrs.keys(): + return False + composite_name = func.attrs["Composite"] + + if composite_name != name: + return False + return True + + +def get_range_for_dtype_str(dtype): + """ + Produce the min,max for a give data type. + + Parameters + ---------- + dtype : str + a type string (e.g., int8) + + Returns + ------- + type_info.min : int + the minimum of the range + type_info.max : int + the maximum of the range + """ + + try: + type_info = np.iinfo(dtype) + except ValueError: + type_info = np.finfo(dtype) + return type_info.min, type_info.max + + +def round_away_zero(f): + """round the number away from zero towards +inf / -inf""" + offset = -0.5 if (f < 0) else 0.5 + return np.trunc(f + offset) + + +def round_up(a, b): + """round up to a multiple of b""" + return ((a + b - 1) // b) * b + + +# pylint: disable=unused-argument +def partition_for_ethosu(mod, params=None, **opts): + """This helper function partition the relay graph as produced by the + relay frontend for a given model into external functions + to be presented to the codegen. + + Parameters + ---------- + mod : IRModule + The IRModule that gets generated from a relay frontend + params : Optional[Dict[str, NDArray]] + Constant input parameters. + + Returns + ------- + mod : IRModule + The partitioned IRModule with external global functions + """ + if params: + mod["main"] = bind_params_by_name(mod["main"], params) + + pattern = relay.op.contrib.get_pattern_table("ethosu") + mod = relay.transform.InferType()(mod) + mod = relay.transform.MergeComposite(pattern)(mod) + mod = relay.transform.AnnotateTarget("ethosu")(mod) + mod = relay.transform.MergeCompilerRegions()(mod) + mod = relay.transform.InferType()(mod) + mod = relay.transform.PartitionGraph()(mod) + mod = relay.transform.InferType()(mod) + mod = preprocess.preprocess_ext_io()(mod) + return mod + + +def get_dim_value(layout, dim): + """This is a helper function to retrieve the value + of the dimension given the shape and the layout + """ + assert isinstance(layout, str) + assert dim in list(layout) + for idx, dim_char in enumerate(layout): + if dim_char == dim: + return idx + return None diff --git a/python/tvm/relay/backend/contrib/ethosu/vela_api.py b/python/tvm/relay/backend/contrib/ethosu/vela_api.py new file mode 100644 index 000000000000..3e772a953c16 --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/vela_api.py @@ -0,0 +1,314 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +conversions between TVM and Vela. Therefore, all interactions with the +Vela API are supposed to go through this adapter, with the hope that +any changes to Vela API, TVM only needs to change this file. +The following conversion APIs are added : + *Obtaining the best block config + *Compressing weights + *Packing biases +""" +import logging +import math +import numpy as np +from ethosu.vela import api as vapi + +from tvm.relay.backend.contrib.ethosu import util + +# pylint: disable=invalid-name +logger = logging.getLogger("Ethos-U") + +VELA_TO_NP_DTYPES = { + vapi.NpuDataType.UINT8: np.uint8, + vapi.NpuDataType.UINT16: np.uint16, + vapi.NpuDataType.INT8: np.int8, + vapi.NpuDataType.INT16: np.int16, + vapi.NpuDataType.INT32: np.int32, +} + +SCALE_BIAS_LENGTH = 10 + + +def get_optimal_block_config(npu_op, accel_type): + """ + "The NPU's unit of work is known as a block. It will fetch block(s) from Input + Feature Map (IFM) and a compute block for Output Feature Map (OFM). + Therefore, we need to pick an optimal block configuration considering bandwidth + to bring IFM blocks and the number of OFM block computes need to happen + to cover the OFM as indicated by the npu op. + + Parameters + ---------- + npu_op : ethosu.vela.api.NpuOperation + The NPU operation and its params + accel_type : ethosu.vela.api.NpuAccelerator + The NPU accelerator variant + Returns + ------- + ethosu.vela.api.NpuShape3d : + The optimal block config for the operator + """ + all_valid_block_configs = vapi.npu_find_block_configs(npu_op, accel_type) + return _get_optimal_block_config(all_valid_block_configs) + + +def _get_optimal_block_config(all_valid_block_configs): + """An internal function to get block config with largest depth + and then highest volume/area""" + assert isinstance(all_valid_block_configs, list) + for block_cfg in all_valid_block_configs: + assert isinstance(block_cfg, vapi.NpuShape3D) + + # Getting the largest volume block for benchmarksing + all_valid_block_configs.sort( + key=lambda _cfg: _cfg.depth * _cfg.height * _cfg.width, reverse=True + ) + largest_volume_block_config = all_valid_block_configs[0] + largest_volume = ( + largest_volume_block_config.depth + * largest_volume_block_config.height + * largest_volume_block_config.width + ) + + all_valid_block_configs.sort(key=lambda _cfg: _cfg.depth, reverse=True) + max_d = all_valid_block_configs[0].depth + max_depth_block_configs = [_cfg for _cfg in all_valid_block_configs if _cfg.depth == max_d] + max_depth_block_configs.sort(key=lambda _cfg: _cfg.height * _cfg.width, reverse=True) + max_area = max_depth_block_configs[0].height * max_depth_block_configs[0].width + max_area_depth_block_configs = [ + _cfg for _cfg in max_depth_block_configs if _cfg.height * _cfg.width == max_area + ] + # This to get a deterministic anwser everytime + max_area_depth_block_configs.sort(key=lambda _cfg: _cfg.height, reverse=True) + assert len(max_area_depth_block_configs) > 0 + current_volume = ( + max_area_depth_block_configs[0].depth + * max_area_depth_block_configs[0].height + * max_area_depth_block_configs[0].width + ) + logger.info("Using block config=%s", max_area_depth_block_configs[0]) + logger.info( + "Quality of the block config w.r.t. max volume block config=%s", + 100.0 * (current_volume / largest_volume), + ) + return max_area_depth_block_configs[0] + + +def compress_weights( + weights, + weights_zp, + weights_layout, + ifm_bitdepth, + block_depth, + dilation, + accel_type, + is_depthwise=False, +): + """Obtain compressed weights from vela + + Parameters + ---------- + weights : numpy.ndarray + The raw weights + weights_zp : int + The zero point of the weights + weights_layout : str + A string literal indicating the layout + Supported values : HWIO, HWOI, OHWI + ifm_bitdepth : int + The bit depth of the ifm the weights are used with + block_depth : int + The depth of the optimal block config for the operator + dilation : tuple + A tuple of 2 elements indicating dilation in h and w + accel_type : ethosu.vela.api.NpuAccelerator + The NPU accelerator variant + is_depthwise : bool, Optional + This indicates whether the weights are compressed for depthwise convolution + + Returns + ------- + compressed_weights : bytearray + Compressed weights + """ + layout_transform_indices = {"HWIO": (3, 0, 1, 2), "HWOI": (2, 0, 1, 3), "OHWI": (0, 1, 2, 3)} + assert weights_layout in layout_transform_indices.keys() + assert isinstance(weights_zp, np.int64) + weights = weights.astype(np.int64) - weights_zp + # Vela needs the weights in OHWI layout + weights_ohwi = np.transpose(weights, layout_transform_indices[weights_layout]) + shape_ohwi = [ + weights.shape[layout_transform_indices[weights_layout][0]], + weights.shape[layout_transform_indices[weights_layout][1]], + weights.shape[layout_transform_indices[weights_layout][2]], + weights.shape[layout_transform_indices[weights_layout][3]], + ] + block_traversal = calculate_block_traversal_mode(is_depthwise, shape_ohwi, ifm_bitdepth) + compressed_weights = vapi.npu_encode_weights( + accelerator=accel_type, + weights_volume=weights_ohwi, + dilation_xy=dilation, + ifm_bitdepth=ifm_bitdepth, + ofm_block_depth=block_depth, + is_depthwise=is_depthwise, + block_traversal=block_traversal, + ) + return compressed_weights + + +def calculate_block_traversal_mode(is_depthwise, weights_shape_ohwi, ifm_bitdepth): + """Calculate a block traversal mode given whether the op is depthwise convolution, + shape of weights and bit-depth of the ifm. + """ + + if is_depthwise: + return vapi.NpuBlockTraversal.DEPTH_FIRST + # Determine which block traversal strategy has better DPU utilization + kernel_size = weights_shape_ohwi[1] * weights_shape_ohwi[2] + depth_utilization = weights_shape_ohwi[3] / util.round_up( + weights_shape_ohwi[3], 32 if ifm_bitdepth == 8 else 16 + ) + part_kernel_utilization = (weights_shape_ohwi[3] / util.round_up(weights_shape_ohwi[3], 8)) * ( + kernel_size / util.round_up(kernel_size, 4 if ifm_bitdepth == 8 else 2) + ) + if part_kernel_utilization >= depth_utilization or weights_shape_ohwi[3] <= 8: + # Part-kernel first is always better for ifm depths <= 8 + return vapi.NpuBlockTraversal.PART_KERNEL_FIRST + return vapi.NpuBlockTraversal.DEPTH_FIRST + + +def pack_biases( + biases, + ifm_scale, + ifm_dtype, + weight_scales, + ofm_scale, + is_activation_tanh_or_sigmoid=False, +): + """ + Obtain packed bias bytearray as the hardware requires from + Vela. + Parameters + ---------- + biases : numpy.ndarray + The values of biases + ifm_scale : float + The quantization scale parameter of input feature map + ifm_dtype : numpy.dtype + The data type of input feature map data. + weight_scales : numpy.ndarray + The quantization scale parameter of weight feature map + This could be a tuple if per-channel quantization is present. + ofm_scale : float + The quantization scale parameter of output feature map. + is_activation_tanh_or_sigmoid : bool + Indicates whether the fused activation function is tanh or sigmoid. + + Returns + ------- + scale_bias : numpy.ndarray + Packed scales/biases as the hardware requires them. + """ + # The BYOC infra should not partition anything else. + supported_ifm_dtypes = (np.uint8, np.int8, np.int16) + assert ifm_dtype in supported_ifm_dtypes + + if weight_scales.size == 1: + weight_scales = [weight_scales] * biases.size + + hw_bias_scales = _calculate_hw_bias_scales( + ifm_scale, weight_scales, ofm_scale, ifm_dtype, is_activation_tanh_or_sigmoid + ) + assert len(hw_bias_scales) == biases.size + biases = biases.astype("int64") + packed_biases = bytearray() + for idx, scale in enumerate(hw_bias_scales): + packed_biases.extend(vapi.npu_encode_bias(biases[idx], *scale)) + # Align to 16 + # remainder = (len(packed_biases)) % 16 + # if remainder > 0: + # packed_biases.extend(bytearray(16 - remainder)) + scale_bias = np.frombuffer(packed_biases, dtype=np.uint8) + scale_bias = np.reshape(scale_bias, (-1, 10)) + return scale_bias + + +def _quantize_scale(scale): + """Quantize floating point scale into 32-bit int scale with a 6-bit shift. + This is to be used with 8-bit data. + """ + mantissa, exponent = math.frexp(scale) + mantissa_scaled = mantissa * (1 << 31) + mantissa_scaled = int(util.round_away_zero(mantissa_scaled)) + required_shift = 31 - exponent + assert 0 <= required_shift < (1 << 6) + return mantissa_scaled, required_shift + + +def _reduced_quantize_scale(scale): + """A reduction of precision is required for 16 bit data.""" + mantissa_scaled, required_shift = _quantize_scale(scale) + # This is max a signed 16-bit number could represent + max_reduced_mantissa_scaled = (1 << 15) - 1 + # if the current value is larger than pre-scaled max_reduced_mantissa_scaled + # we need to saturate the anwser to max_reduced_mantissa_scaled + if mantissa_scaled >= max_reduced_mantissa_scaled << 16: + reduced_mantissa_scaled = max_reduced_mantissa_scaled + else: + reduced_mantissa_scaled = (mantissa_scaled + (1 << 15)) >> 16 + reduced_shift = required_shift - 16 + return reduced_mantissa_scaled, reduced_shift + + +def _calculate_hw_bias_scales( + ifm_scale, weight_scales, ofm_scale, ifm_dtype, is_faf_tanh_sigmoid=False +): + """This function will produce a scale that is calculated using scales of ifm, + weights and ofm. It is also important to note that if per-channel / per-value + quantization required they should go into hw bias scales""" + if is_faf_tanh_sigmoid: + ifm_scale = ifm_scale * 0x3000 + if ifm_dtype == np.uint8: + bias_scales = [np.double(ifm_scale * ws) / np.double(ofm_scale) for ws in weight_scales] + else: + assert ifm_dtype in (np.int8, np.int16) + ifm_scale_dbl = np.double(ifm_scale) + ofm_scale_dbl = np.double(ofm_scale) + bias_scales = [ifm_scale_dbl * np.double(ws) / ofm_scale_dbl for ws in weight_scales] + + if ifm_dtype == np.int16: + hw_bias_scales = [_reduced_quantize_scale(bs) for bs in bias_scales] + else: + assert ifm_dtype in (np.uint8, np.int8) + hw_bias_scales = [_quantize_scale(bs) for bs in bias_scales] + + return hw_bias_scales + + +def get_target_accel_type(): + """This is a helper function to convert cli accelerator type str argument + to NpuAccelerator""" + npu_accel_str_map = { + "ethos-u55-256": vapi.NpuAccelerator.Ethos_U55_256, + "ethos-u55-128": vapi.NpuAccelerator.Ethos_U55_128, + "ethos-u55-64": vapi.NpuAccelerator.Ethos_U55_64, + "ethos-u55-32": vapi.NpuAccelerator.Ethos_U55_32, + } + accel_type_str = util.get_accelerator_config() + assert accel_type_str in npu_accel_str_map.keys(), f"{accel_type_str} is not supported" + return npu_accel_str_map[accel_type_str] diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py new file mode 100644 index 000000000000..f7fee928e90a --- /dev/null +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -0,0 +1,251 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Arm(R) Ethos(TM)-U NPU supported operators.""" +import numpy as np + +from tvm.relay.expr import Constant +from tvm.relay.op.contrib.register import register_pattern_table +from tvm.relay.dataflow_pattern import wildcard, is_op, is_constant +from tvm.relay.backend.contrib.ethosu.util import QConv2DArgs +from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs +from tvm.relay.backend.contrib.ethosu.util import RequantArgs +from tvm.relay.backend.contrib.ethosu.util import get_dim_value +from ethosu.vela import api as vapi + + +def check_strides(strides): + """Checks whether strides are within the limits supported by the hardware""" + stride_range = (1, 3) + smin, smax = stride_range + if not smax >= strides[0] >= smin: + return False + if not smax >= strides[1] >= smin: + return False + return True + + +def check_valid_dtypes(tensor_params): + """Check whether dtypes are supported by the hardware""" + supported_dtypes = (np.uint8, np.int8) + for tep in tensor_params: + # Check for dtypes + if np.dtype(tep.dtype) not in supported_dtypes: + return False + # Check for shape sizes + if any(dimlen > 65536 for dimlen in tep.shape): + return False + return True + + +def check_weights(weights, dilation): + """Checks whether weight tensor is compatible with HW""" + dilated_height_range = (1, 64) + dilated_hxw_range = (1, 64 * 64) + weights_limit = 127 * 65536 + dilated_width = (weights.shape[get_dim_value(weights.layout, "W")] - 1) * dilation[0] + 1 + dilated_height = (weights.shape[get_dim_value(weights.layout, "H")] - 1) * dilation[1] + 1 + dh_min, dh_max = dilated_height_range + if not dh_min <= dilated_height <= dh_max: + return False + dilated_hxw = dilated_height * dilated_width + dhxw_min, dhxw_max = dilated_hxw_range + if not dhxw_min <= dilated_hxw <= dhxw_max: + return False + # A saturation upper bound check for accumulators + weights.values = weights.values - weights.q_params.zero_point + axis = ( + get_dim_value(weights.layout, "H"), + get_dim_value(weights.layout, "W"), + get_dim_value(weights.layout, "I"), + ) + sum_weights = np.amax(np.sum(np.absolute(weights.values), axis=axis)) + if not sum_weights <= weights_limit: + return False + return True + + +def check_bias(bias): + """Check whether the bias values fit in 40 bits""" + if bias and bias.dtype == np.dtype("int64"): + valid = all(len(bin(bias_value)[2:]) <= 40 for bias_value in bias.values) + return valid + return True + + +def check_batch_size(ifm): + """Checks for the number of batches vela currently supports""" + if ifm.shape[0] != 1: + return False + return True + + +def check_dilation(dilation): + """Checks whether dilation is within the limits supported by the hardware""" + dilation_range = (1, 2) + dmin, dmax = dilation_range + if not dmin <= dilation[0] <= dmax: + return False + if not dmin <= dilation[1] <= dmax: + return False + return True + + +def check_padding(padding, bounds): + """Checks whether padding is within the limits supported by the hardware""" + if len(padding) != 4 or len(bounds) != 4: + return False + top, left, bottom, right = padding + topb, leftb, bottomb, rightb = bounds + if top > topb or left > leftb or bottom > bottomb or right > rightb: + return False + return True + + +class TensorParams: + """ + This class will parse a tvm Expr along with quantization scale + and zero point to populate parameters that are required + for the creation of tensors in Vela. + """ + + def __init__(self, tensor, layout=None, scale=None, zero_point=None): + self.tensor = tensor + if isinstance(tensor, Constant): + self.values = tensor.data.asnumpy() + else: + self.values = None + self.dtype = tensor.checked_type.dtype + self.shape = [int(i) for i in tensor.checked_type.shape] + self.layout = layout + + if scale is not None and zero_point is not None: + self.q_params = vapi.NpuQuantization( + scale.data.asnumpy().astype("float32"), zero_point.data.asnumpy().astype(self.dtype) + ) + else: + # put default values + self.q_params = vapi.NpuQuantization(1.0, 0) + + +class QnnConv2DParams: + """ + This class will parse a Call to a ethosu.qnn_conv2d_clip composite function + and extract quantization information of all the associated tensors. + """ + + composite_name = "ethosu.qnn_conv2d" + # The hardware only supports padding upto the numbers as follows + padding_bounds = [31, 31, 32, 32] + activation_map = {"clip": "CLIP"} + + def __init__(self, func_body): + activation = None + if str(func_body.op) in self.activation_map.keys(): + activation = func_body + requantize_op = activation.args[0] + else: + requantize_op = func_body + bias_add = requantize_op.args[0] + qnn_conv2d = bias_add.args[0] + data_layout = qnn_conv2d.attrs.data_layout + kernel_layout = qnn_conv2d.attrs.kernel_layout + # We consider the weights & biases as params as it should be a Constant + self.weights = TensorParams( + qnn_conv2d.args[QConv2DArgs.weights.value], + kernel_layout, + qnn_conv2d.args[QConv2DArgs.weights_scale.value], + qnn_conv2d.args[QConv2DArgs.weights_zero_point.value], + ) + + self.biases = TensorParams( + bias_add.args[BiasAddArgs.biases.value], + data_layout, + requantize_op.args[RequantArgs.ifm_scale.value], + requantize_op.args[RequantArgs.ifm_zero_point.value], + ) + self.ifm = TensorParams( + qnn_conv2d.args[QConv2DArgs.ifm.value], + data_layout, + qnn_conv2d.args[QConv2DArgs.ifm_scale.value], + qnn_conv2d.args[QConv2DArgs.ifm_zero_point.value], + ) + self.ofm = TensorParams( + func_body, + data_layout, + requantize_op.args[RequantArgs.ofm_scale.value], + requantize_op.args[RequantArgs.ofm_zero_point.value], + ) + self.padding = qnn_conv2d.attrs.padding + self.strides = qnn_conv2d.attrs.strides + self.dilation = qnn_conv2d.attrs.dilation + self.activation = activation + + # If groups are equal to channel, its a depthwise_conv2d + self.groups = qnn_conv2d.attrs.groups + self.is_depthwise = False + channels_axis = {"HWIO": 3, "HWOI": 2} + if qnn_conv2d.attrs.groups == self.weights.shape[channels_axis[kernel_layout]]: + self.is_depthwise = True + + def is_valid(self): + """ + Checks whether QnnConv2D with Clip has compatible attributes with HW + """ + tensor_params = [self.weights, self.ifm, self.ofm] + if not check_valid_dtypes(tensor_params): + return False + if not check_weights(self.weights, self.dilation): + return False + if not check_bias(self.biases): + return False + if not check_strides(self.strides): + return False + if not check_batch_size(self.ifm): + return False + if not check_dilation(self.dilation): + return False + if not check_padding(self.padding, self.padding_bounds): + return False + legal_groups = [1, self.ofm.shape[3]] + if self.groups not in legal_groups: + return False + # This should be a valid QnnDepthwise2DParams, not QnnConv2DParams + if self.is_depthwise: + return False + return True + + +def qnn_conv2d_pattern(): + """ + Create pattern for qnn.conv2D with optional fused relu + """ + qnn_conv2d = is_op("qnn.conv2d")( + wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant() + ).has_attr({"kernel_layout": "HWIO"}) + bias_add = is_op("nn.bias_add")(qnn_conv2d, is_constant()) + req = is_op("qnn.requantize")( + qnn_conv2d | bias_add, is_constant(), is_constant(), is_constant(), is_constant() + ) + clip_or_req = req.optional(is_op("clip")) + return clip_or_req + + +@register_pattern_table("ethosu") +def pattern_table(): + return [ + ("ethosu.qnn_conv2d", qnn_conv2d_pattern(), lambda pat: QnnConv2DParams(pat).is_valid()) + ] diff --git a/src/relay/backend/contrib/ethosu/preprocess.cc b/src/relay/backend/contrib/ethosu/preprocess.cc new file mode 100644 index 000000000000..9a18868a7f86 --- /dev/null +++ b/src/relay/backend/contrib/ethosu/preprocess.cc @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../../../op/make_op.h" + +namespace tvm { +namespace relay { +namespace contrib { +namespace ethosu { + +/*! + * \brief This expression rewriter will traverse the graph to find calls + * to all external functions. If they have multiple inputs and/or + * multiple outputs, the following has to be done : + * 1) If multiple inputs are present, they needed to be concat before the call. + * 2) Inside the external function they need to be split again to their original inputs. + * 3) If there are multiple outputs, they need to be concat at the end of external function. + * 4) Then, the concat output again need to be split and made the original tuple output in the + * main. + */ +class ExternalFuncIOHandler : public ExprRewriter { + public: + explicit ExternalFuncIOHandler(IRModule& module) : module_(module) {} + int count = 0; + + Function InferType(const Function& expr, const IRModule& m) { + IRModule mod(m); + mod->Update(mod->GetGlobalVar("main"), expr); + mod = transform::InferType()(mod); + return Downcast(mod->Lookup("main")); + } + + /*! + * \brief This function will take shape and compute + * the scalar size value for it to be use to create + * flat single dimensional tensors. + */ + int64_t CalcSize(const Array& shape) { + int size = 1; + for (auto dim_sz : shape) { + size = size * Downcast(dim_sz)->value; + } + return size; + } + + /*! + * \brief This will take a tensor and create a flattened + * tensor to be used by the concat. + */ + Expr CreateFlattenTensor(const Expr& input) { + auto ishape = Downcast>(Downcast(input->checked_type())->shape); + int flatten_size = CalcSize(ishape); + Array oshape = {Integer(flatten_size)}; + return MakeReshape(input, oshape); + } + + /*! + * \brief This will take flattened tensors and create + * a single concat'd tensor. + */ + Expr CreateConcatTensor(const Array& inputs) { + auto tuple = Tuple(inputs); + return MakeConcatenate(tuple, 0); + } + + /*! + * \brief This will take a flattened concat'd tensor and use the original inputs shapes + * to recreate a Tuple of the original set of tensors. + */ + Expr CreateSplitReshapedTensors(const Expr& input, const Array& original_args) { + Array> shapes; + Array flatten_tensor_sizes; + Array split_indices; + Array rets; + + int total_size = 0; + for (auto orig_arg : original_args) { + auto shape = Downcast>(Downcast(orig_arg->checked_type())->shape); + shapes.push_back(shape); + flatten_tensor_sizes.push_back(CalcSize(shape)); + if (total_size != 0) { + split_indices.push_back(total_size); + } + total_size += CalcSize(shape); + } + auto split_outs = MakeSplit(input, split_indices, 0); + for (unsigned int i = 0; i < shapes.size(); i++) { + auto split_out = TupleGetItem(split_outs, i); + split_out->checked_type_ = original_args[i]->checked_type_; + rets.push_back(MakeReshape(split_out, shapes[i])); + } + return Tuple(rets); + } + + /*! + * \brief Modify the external function to split the input as the original compute + * as required originally. Moreover, the outputs will be flattened and concat'd + * to make a single output. Finaly, the external function should only have a single input + * and a single output. + */ + Function ModifyExternalFunction(const Function& func, GlobalVar gv, const DataType& dtype) { + Array inputs; + Var ifms; + if (func->params.size() > 1) { + Array> shapes; + Array flatten_tensor_sizes; + Array split_indices; + + auto func_name = gv->name_hint; + int total_size = 0; + for (auto input : func->params) { + auto shape = Downcast>(Downcast(input->checked_type())->shape); + shapes.push_back(shape); + auto flat_size = CalcSize(shape); + flatten_tensor_sizes.push_back(flat_size); + if (total_size != 0) { + split_indices.push_back(total_size); + } + total_size += flat_size; + } + Array ifms_shape = {total_size}; + ifms = Var(func_name + "_ifms", TensorType(ifms_shape, dtype)); + auto split_outs = MakeSplit(ifms, split_indices, 0); + for (unsigned int i = 0; i < shapes.size(); i++) { + auto split_out = TupleGetItem(split_outs, i); + split_out->checked_type_ = func->params[i]->checked_type(); + inputs.push_back(MakeReshape(split_out, shapes[i])); + } + } else { + CHECK_EQ(func->params.size(), 1); + inputs.push_back(func->params[0]); + ifms = func->params[0]; + } + Map bind_map; + CHECK_EQ(func->params.size(), inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + bind_map.Set(func->params[i], inputs[i]); + } + auto core_compute_expr = Bind(func->body, bind_map); + + // Creation of wrapper inside the external function + Array params = {ifms}; + if (func->body->IsInstance()) { + auto tuple_out = func->body.as(); + Array reshaped_outputs; + for (unsigned int i = 0; i < tuple_out->fields.size(); i++) { + auto out = Downcast(core_compute_expr)->fields[i]; + out->checked_type_ = tuple_out->fields[i]->checked_type_; + reshaped_outputs.push_back(CreateFlattenTensor(out)); + } + auto concat_out = CreateConcatTensor(reshaped_outputs); + auto f = Function(params, concat_out, concat_out->checked_type_, {}, func->attrs); + return InferType(f, this->module_); + } else { + auto f = + Function(params, core_compute_expr, core_compute_expr->checked_type_, {}, func->attrs); + return InferType(f, this->module_); + } + } + + Expr Rewrite_(const CallNode* call, const Expr& post) final { + auto post_call = Downcast(post); + + if (auto glb_var_node = post_call->op.as()) { + auto glb_var = GetRef(glb_var_node); + auto func = Downcast(module_->functions[glb_var]); + + // If the number of inputs and output are 1 --> no need to do anything + if (post_call->args.size() == 1 && !func->body->IsInstance()) { + return post; + } + if (auto compiler = func->GetAttr(attr::kCompiler)) { + if (compiler == "ethosu") { + auto ext_input = std::move(post_call->args[0]); + auto arg_dtype = Downcast(post_call->args[0]->checked_type())->dtype; + if (post_call->args.size() > 1) { + Array reshaped_inputs; + for (const auto& arg : post_call->args) { + // All arguments should be of same data type + CHECK_EQ(arg_dtype, Downcast(arg->checked_type())->dtype) + << "Currently NPU external functions require all inputs to be of same data " + "type"; + reshaped_inputs.push_back(CreateFlattenTensor(arg)); + } + ext_input = CreateConcatTensor(reshaped_inputs); + } + auto ext_func = ModifyExternalFunction(func, glb_var, arg_dtype); + Array new_args = {ext_input}; + module_->Add(glb_var, ext_func); + Expr new_call = Call(glb_var, new_args); + if (func->body->IsInstance()) { + auto orginal_tuple_out = Downcast(func->body); + new_call = CreateSplitReshapedTensors(new_call, orginal_tuple_out->fields); + } + return std::move(new_call); + } + } + } + return post; + } + + private: + IRModule module_; +}; + +IRModule PreprocessExternalFuncIO_(IRModule module) { + ExternalFuncIOHandler ex_func_io_handle(module); + auto func = GetRef(module->Lookup("main").as()); + auto preprocessed = PostOrderRewrite(func, &ex_func_io_handle); + module->Update(module->GetGlobalVar("main"), GetRef(preprocessed.as())); + return module; +} + +} // namespace ethosu +} // namespace contrib + +/*! + * \brief This is a pre-processing pass for all NPU external functions. + * Currently, the NPU runtime module expects a single input and a single output. + * Therefore, this pass will concat the inputs pre-call, split again inside ext. func, + * concat the output inside ext. func and re-split again after the call. + */ + +namespace transform { +Pass PreprocessExternalFuncIO() { + runtime::TypedPackedFunc pre_processed_ext_func = + [=](IRModule m, PassContext pc) { + auto _m = contrib::ethosu::PreprocessExternalFuncIO_(m); + return _m; + }; + auto preprocess_pass = + CreateModulePass(pre_processed_ext_func, 0, "PreprocessExternalFuncIO", {}); + return Sequential({preprocess_pass, InferType()}); +} + +TVM_REGISTER_GLOBAL("relay.ext.ethosu.PreprocessExternalFuncIO") + .set_body_typed(transform::PreprocessExternalFuncIO); + +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/contrib/ethosu/common.cc b/src/relay/op/contrib/ethosu/common.cc new file mode 100644 index 000000000000..a7109ade71a0 --- /dev/null +++ b/src/relay/op/contrib/ethosu/common.cc @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/op/contrib/ethosu/op_common.cc + * \brief A set of utilities and common functionality for Arm(R) Ethos(TM)-U NPU QNN ops. + */ + +#include "common.h" + +#include "../../op_common.h" + +namespace tvm { +namespace relay { +namespace op { +namespace contrib { +namespace ethosu { + +Array EthosuInferKernelOutput(Array ifm_shape, String ifm_layout, + String ofm_layout, Array kernel_shape, + IndexExpr ofm_channels, Array dilation, + Array strides, Array padding) { + // In the case of NHCWB16, convert the ifm shape to NHW (C not required for this function) + if (ifm_layout == "NHCWB16") { + ifm_shape = {ifm_shape[0], ifm_shape[1], ifm_shape[3]}; + } + Array oshape({ifm_shape[0], 0, 0, ofm_channels}); + + IndexExpr dilated_ksize_y = 1 + (kernel_shape[0] - 1) * dilation[0]; + IndexExpr dilated_ksize_x = 1 + (kernel_shape[1] - 1) * dilation[1]; + IndexExpr pad_h, pad_w; + GetPaddingHeightWidth(padding, &pad_h, &pad_w); + oshape.Set(1, indexdiv(ifm_shape[1] + pad_h - dilated_ksize_y, strides[0]) + 1); + oshape.Set(2, indexdiv(ifm_shape[2] + pad_w - dilated_ksize_x, strides[1]) + 1); + + // If the ofm is NHCWB16, convert the layout + if (ofm_layout == "NHCWB16") { + int channel_bricks = 1 + (oshape[3].as()->value - 1) / 16; + oshape = {oshape[0], oshape[1], channel_bricks, oshape[2], 16}; + } + + return oshape; +} + +} // namespace ethosu +} // namespace contrib +} // namespace op +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/contrib/ethosu/common.h b/src/relay/op/contrib/ethosu/common.h new file mode 100644 index 000000000000..b5377e6e8bdf --- /dev/null +++ b/src/relay/op/contrib/ethosu/common.h @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/op/contrib/ethosu/common.h + * \brief Functions for all Arm(R) Ethos(TM)-U NPU operators to use. + */ + +#ifndef TVM_RELAY_OP_CONTRIB_ETHOSU_COMMON_H_ +#define TVM_RELAY_OP_CONTRIB_ETHOSU_COMMON_H_ + +#include + +namespace tvm { +namespace relay { +namespace op { +namespace contrib { +namespace ethosu { + +/*! \brief Infer the output tensor shape for convolution and pooling operators. + * \param ifm_shape The shape of Input Feature Map. + * \param ifm_layout The layout of the IFM (NHWC or NHCWB16). + * \param ofm_layout The layout of the OFM (NHWC or NHCWB16). + * \param kernel_shape Kernel shape in format (height, width). + * \param ofm_channels The number of Output Feature Map channels. + * \param dilation The 2-dimensional dilation as (dilation_height, dilation_width). + * \param strides The 2 dimensional strides as (stride_height, stride_width). + * \param padding The 4 dimensional padding as (pad_top, pad_left, pad_bottom, pad_right). + * \return The shape of the output tensor. + */ +Array EthosuInferKernelOutput(Array ifm_shape, String ifm_layout, + String ofm_layout, Array kernel_shape, + IndexExpr ofm_channels, Array dilation, + Array strides, Array padding); + +} // namespace ethosu +} // namespace contrib +} // namespace op +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_OP_CONTRIB_ETHOSU_COMMON_H_ diff --git a/src/relay/op/contrib/ethosu/convolution.cc b/src/relay/op/contrib/ethosu/convolution.cc new file mode 100644 index 000000000000..ec5da6cd1c47 --- /dev/null +++ b/src/relay/op/contrib/ethosu/convolution.cc @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/op/contrib/ethosu/convolution.cc + * \brief Property def of the Arm(R) Ethos(TM)-U NPU convolution ops. + */ +#include "../../nn/convolution.h" + +#include +#include +#include +#include +#include + +#include "../../../qnn/utils.h" +#include "common.h" + +namespace tvm { +namespace relay { +namespace op { +namespace contrib { +namespace ethosu { + +/*! \brief Attributes used by the Ethos(TM)-U NPU convolution operator */ +struct EthosuConv2DAttrs : public tvm::AttrsNode { + double ifm_scale; + int ifm_zero_point; + int weight_zero_point; + double ofm_scale; + int ofm_zero_point; + Array kernel_shape; + IndexExpr ofm_channels; + Array strides; + Array padding; + Array dilation; + String activation; + int clip_min; + int clip_max; + String upscale; + tvm::String ifm_layout; + tvm::String ofm_layout; + + TVM_DECLARE_ATTRS(EthosuConv2DAttrs, "relay.attrs.EthosuConv2DAttrs") { + TVM_ATTR_FIELD(ifm_scale).describe("The quantization scale for the Input Feature Map tensor."); + TVM_ATTR_FIELD(ifm_zero_point) + .describe("The quantization zero point for the Input Feature Map tensor."); + TVM_ATTR_FIELD(weight_zero_point) + .describe("The quantization zero point for the weight tensor."); + TVM_ATTR_FIELD(ofm_scale).describe("The quantization scale for the Output Feature Map tensor."); + TVM_ATTR_FIELD(ofm_zero_point) + .describe("The quantization zero point for the Output Feature Map tensor."); + TVM_ATTR_FIELD(kernel_shape) + .describe("The 2 dimensional kernel shape as (kernel_height, kernel_width).") + .set_default(NullValue >()); + TVM_ATTR_FIELD(ofm_channels) + .describe("The number of OFM channels.") + .set_default(NullValue()); + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) + .describe("The 2 dimensional strides as (stride_height, stride_width)."); + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0, 0, 0})) + .describe("The 4 dimensional padding as (pad_top, pad_left, pad_bottom, pad_right)."); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1, 1})) + .describe("The 2 dimensional dilation as (dilation_height, dilation_width)."); + TVM_ATTR_FIELD(activation) + .describe( + "The activation function to use. " + "'NONE' - no activation function. " + "'CLIP' - clip the output between clip_min and clip_max. " + "'TANH' - tanh activation function. " + "'SIGMOID' - sigmoid activation function. " + "'LUT' - use a look-up table to perform the activation function.") + .set_default("NONE"); + TVM_ATTR_FIELD(clip_min) + .describe("The minimum clipping value if activation = 'CLIP'.") + .set_default(0); + TVM_ATTR_FIELD(clip_max) + .describe("The maximum clipping value if activation = 'CLIP'.") + .set_default(0); + TVM_ATTR_FIELD(upscale) + .describe( + "The 2x2 upscaling mode to apply to the Input Feature Map tensor. " + "'NONE' - no upscaling. " + "'NEAREST' - upscale using nearest neighbour. " + "'ZEROS' - upscale using zeros.") + .set_default("NONE"); + TVM_ATTR_FIELD(ifm_layout) + .set_default("NHWC") + .describe("The layout of the Input Feature Map tensor. Can be 'NHWC' or 'NHCWB16'."); + TVM_ATTR_FIELD(ofm_layout) + .set_default("NHWC") + .describe("The layout of the Output Feature Map tensor. Can be 'NHWC' or 'NHCWB16'."); + } +}; + +TVM_REGISTER_NODE_TYPE(EthosuConv2DAttrs); + +bool EthosuConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 5); + const auto* ifm = types[0].as(); + const auto* weight = types[1].as(); + const auto* scale_bias = types[2].as(); + if (ifm == nullptr || weight == nullptr) return false; + const auto* param = attrs.as(); + CHECK(param != nullptr) << "EthosuConv2DAttrs cannot be nullptr."; + CHECK(ifm->dtype == DataType::UInt(8) || ifm->dtype == DataType::Int(8)) + << "Expected ethosu_conv2d type(uint8) or type(int8) for ifm but was " << ifm->dtype; + CHECK(weight->dtype == DataType::UInt(8) || weight->dtype == DataType::Int(8)) + << "Expected ethosu_conv2d type(uint8) or type(int8) for weight but was " << weight->dtype; + CHECK(scale_bias->dtype == DataType::UInt(8)) + << "Expected ethosu_conv2d type(uint8) for scale_bias but was " << scale_bias->dtype; + + // The scale_bias should be provided as a tensor of size {ofm_channels, 10} + reporter->Assign(types[2], TensorType({weight->shape[0], 10}, DataType::UInt(8))); + + // Assign weight type {ofm_channels, kernel_height, kernel_width, ifm_channels} + reporter->Assign(types[1], TensorType({param->ofm_channels, param->kernel_shape[0], + param->kernel_shape[1], weight->shape[3]}, + weight->dtype)); + + // Assign ofm type + auto ofm_shape = + EthosuInferKernelOutput(ifm->shape, param->ifm_layout, param->ofm_layout, param->kernel_shape, + param->ofm_channels, param->dilation, param->strides, param->padding); + reporter->Assign(types[4], TensorType(ofm_shape, ifm->dtype)); + return true; +} + +Expr MakeEthosuConv2D(Expr ifm, Expr weight, Expr scale_bias, Expr lut, double ifm_scale, + int ifm_zero_point, int weight_zero_point, double ofm_scale, + int ofm_zero_point, Array kernel_shape, IndexExpr ofm_channels, + Array strides, Array padding, Array dilation, + String activation, int clip_min, int clip_max, String upscale, + String ifm_layout, String ofm_layout) { + auto attrs = make_object(); + attrs->ifm_scale = ifm_scale; + attrs->ifm_zero_point = ifm_zero_point; + attrs->weight_zero_point = weight_zero_point; + attrs->ofm_scale = ofm_scale; + attrs->ofm_zero_point = ofm_zero_point; + attrs->kernel_shape = std::move(kernel_shape); + attrs->ofm_channels = std::move(ofm_channels); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->dilation = std::move(dilation); + attrs->activation = std::move(activation); + attrs->clip_min = clip_min; + attrs->clip_max = clip_max; + attrs->upscale = std::move(upscale); + attrs->ifm_layout = std::move(ifm_layout); + attrs->ofm_layout = std::move(ofm_layout); + static const Op& op = Op::Get("contrib.ethosu.conv2d"); + return Call(op, {ifm, weight, scale_bias, lut}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.ethosu_conv2d").set_body_typed(MakeEthosuConv2D); + +RELAY_REGISTER_OP("contrib.ethosu.conv2d") + .describe(R"code(Arm(R) Ethos(TM)-U NPU 2D quantized convolution operator. + +This Relay operator corresponds to the hardware-implemented quantized +convolution operation found on Ethos(TM)-U NPUs. It accepts either NHWC +or NHCWB16 format for the input data (input feature map, or IFM) and +OHWI format for the kernel weights. + +Reference: https://developer.arm.com/documentation/102420/0200/ + +Note that the per-channel weight scale and bias tensor must be packed together into +a combined tensor of uint80s. This is represented in TVM by a (channels, 10) tensor +of type uint8. For more detail, refer to the Technical Reference Manual linked above. + +- **ifm**: NHWC - (1, ifm_height, ifm_width, ifm_channels) + NHCWB16 - (1, ifm_height, ifm_channels // 16, ifm_width, 16) +- **weight**: (ofm_channels, kernel_shape[0], kernel_shape[1], ifm_channels) +- **scale_bias**: (ofm_channels, 10) +- **ofm**: (1, ofm_height, ofm_width, ofm_channels) + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(4) + .add_argument("ifm", "Tensor", "The Input Feature Map tensor (IFM).") + .add_argument("weight", "Tensor", "The weight tensor.") + .add_argument("scale_bias", "Tensor", "The packed per-channel weight scale and bias tensor.") + .add_argument("lut", "Tensor", "The look-up table values to use if activation = 'LUT'.") + .set_support_level(11) + .add_type_rel("EthosuConv2D", EthosuConv2DRel); + +} // namespace ethosu +} // namespace contrib +} // namespace op +} // namespace relay +} // namespace tvm diff --git a/tests/python/contrib/test_ethosu/relay_ir_builder.py b/tests/python/contrib/test_ethosu/relay_ir_builder.py new file mode 100644 index 000000000000..6169a3e46520 --- /dev/null +++ b/tests/python/contrib/test_ethosu/relay_ir_builder.py @@ -0,0 +1,295 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Helper module to build relay operations for testing""" + +from pathlib import Path +import numpy as np +import math + +import tvm +from tvm import relay +from tvm.relay.op.contrib import get_pattern_table +from tvm.relay import qnn +from tvm.relay.backend.contrib.ethosu.util import get_range_for_dtype_str + + +class TensorType: + """A data structure to capture tensor parameters""" + + def __init__(self): + self.shape = None + self.dtype = None + self.zp = None + self.sc = None + self.layout = None + + def get_dim_size(self, dim): + for idx, char in enumerate(self.layout): + if dim == char: + return self.shape[idx] + return None + + def get_dim_index(self, dim): + for idx, char in enumerate(self.layout): + if dim == char: + return idx + return None + + +class QnnConv2DParams: + """A data structure to capture relay.qnn.op.conv2D parameters""" + + def __init__(self, dtype): + self.ifm = TensorType() + self.ofm = TensorType() + self.kernel = TensorType() + + # default values + self.ifm.dtype = dtype + self.ifm.layout = "NHWC" + ifm_min, ifm_max = get_range_for_dtype_str(self.ifm.dtype) + self.ifm.zp = relay.const(np.random.randint(ifm_min, ifm_max), "int32") + self.ifm.sc = relay.const(np.random.random() * 2, "float32") + self.kernel.dtype = dtype + self.kernel.layout = "HWIO" + kernel_min, kernel_max = get_range_for_dtype_str(self.kernel.dtype) + self.kernel.zp = relay.const(np.random.randint(kernel_min, kernel_max), "int32") + self.kernel.sc = relay.const(np.random.random() * 2, "float32") + self.ofm.layout = "NHWC" + self.ofm.dtype = dtype + ofm_min, ofm_max = get_range_for_dtype_str(self.ofm.dtype) + self.ofm.zp = relay.const(np.random.randint(ofm_min, ofm_max), "int32") + self.ofm.sc = relay.const(np.random.random() * 2, "float32") + self.dilation = (1, 1) + + self.strides = None + self.pad = None + self.activation = "NONE" + self.clip_min = 0 + self.clip_max = 0 + + def update_output_qnn_params( + self, input_dtype="uint8", kernel_dtype="uint8", output_dtype="uint8" + ): + _, dtype_max = get_range_for_dtype_str(input_dtype) + input_max = self.ifm.sc.data.asnumpy() * (dtype_max - self.ifm.zp.data.asnumpy()) + input_min = -self.ifm.sc.data.asnumpy() * self.ifm.zp.data.asnumpy() + _, dtype_max = get_range_for_dtype_str(kernel_dtype) + kernel_max = np.max( + self.kernel.sc.data.asnumpy() * (dtype_max - self.kernel.zp.data.asnumpy()) + ) + kernel_min = np.min(-self.kernel.sc.data.asnumpy() * self.kernel.zp.data.asnumpy()) + kernel_h = self.kernel.get_dim_size("H") + kernel_w = self.kernel.get_dim_size("W") + channels = self.kernel.get_dim_size("I") + output_limits = [ + kernel_max * kernel_h * kernel_w * channels * input_max, + kernel_min * kernel_h * kernel_w * channels * input_max, + kernel_min * kernel_h * kernel_w * channels * input_min, + kernel_max * kernel_h * kernel_w * channels * input_min, + ] + output_max = max(output_limits) + output_min = min(output_limits) + dtype_min, dtype_max = get_range_for_dtype_str(input_dtype) + self.ofm.sc = relay.const((output_max - output_min) / (dtype_max - dtype_min), "float32") + self.ofm.zp = relay.const(-int(output_min / self.ofm.sc.data.asnumpy()), "int32") + + +class PoolingParams: + """A data structure to capture relay.op.max_pool2d / + relay.op.avg_pool2d parameters + """ + + def __init__(self, dtype): + self.type = None + self.size = None + self.strides = None + self.pad = None + self.layout = None + self.ifm = TensorType() + self.ofm = TensorType() + + # default values + self.ifm.dtype = dtype + self.ifm.layout = "NHWC" + self.ifm.zp = relay.const(np.random.randint(0, 255), "int32") + self.ifm.sc = relay.const(np.random.random() * 2, "float32") + self.ofm.zp = relay.const(np.random.randint(0, 255), "int32") + self.ofm.sc = relay.const(np.random.random() * 2, "float32") + self.ofm.dtype = dtype + self.dilation = (1, 1) + + +class AddParams: + """A data structure to capture relay.qnn.op.add parameters""" + + def __init__(self, dtype): + self.ifm0 = TensorType() + self.ifm1 = TensorType() + self.ofm = TensorType() + + # default values + self.ifm0.dtype = dtype + self.ifm0.zp = relay.const(np.random.randint(0, 255), "int32") + self.ifm0.sc = relay.const(np.random.random() * 2, "float32") + self.ifm1.dtype = dtype + self.ifm1.zp = relay.const(np.random.randint(0, 255), "int32") + self.ifm1.sc = relay.const(np.random.random() * 2, "float32") + self.update_output_qnn_params() + self.ofm.dtype = dtype + + def update_output_qnn_params(self): + ti = np.iinfo(self.ifm0.dtype) + dtype_min, dtype_max = int(ti.min), int(ti.max) + input1_max = self.ifm0.sc.data.asnumpy() * (dtype_max - self.ifm0.zp.data.asnumpy()) + input1_min = (dtype_min - self.ifm0.sc.data.asnumpy()) * self.ifm0.zp.data.asnumpy() + input2_max = self.ifm1.sc.data.asnumpy() * (dtype_max - self.ifm1.zp.data.asnumpy()) + input2_min = (dtype_min - self.ifm1.sc.data.asnumpy()) * self.ifm1.zp.data.asnumpy() + output_max = input1_max + input2_max + output_min = input1_min + input2_min + self.ofm.sc = relay.const((output_max - output_min) / dtype_max, "float32") + self.ofm.zp = relay.const( + (dtype_min - int(output_min / self.ofm.sc.data.asnumpy())), "int32" + ) + + +def get_pad_value(data, kernel, stride): + """Get the pad tuple of value for SAME padding""" + + out = int(math.ceil(float(data) / float(stride))) + pad = max(0, (out - 1) * stride + kernel - data) + pad_before = pad // 2 + pad_after = pad - pad_before + return pad_before, pad_after + + +def create_qnn_conv2d(qnn_conv2d_params, ifm_expr): + """Create a relay.Expr of relay.qnn.conv2D given the parameters""" + v_params = list() + params = { + "kernel_size": [ + qnn_conv2d_params.kernel.get_dim_size("H"), + qnn_conv2d_params.kernel.get_dim_size("W"), + ], + "strides": [qnn_conv2d_params.strides[0], qnn_conv2d_params.strides[1]], + "dilation": [qnn_conv2d_params.dilation[0], qnn_conv2d_params.dilation[1]], + "padding": [0, 0, 0, 0], + "data_layout": qnn_conv2d_params.ifm.layout, + } + dilated_kernel_h = ( + qnn_conv2d_params.dilation[0] * (qnn_conv2d_params.kernel.get_dim_size("H") - 1) + 1 + ) + dilated_kernel_w = ( + qnn_conv2d_params.dilation[1] * (qnn_conv2d_params.kernel.get_dim_size("W") - 1) + 1 + ) + if qnn_conv2d_params.pad == "SAME": + pad_top, pad_bottom = get_pad_value( + qnn_conv2d_params.ifm.get_dim_size("H"), dilated_kernel_h, qnn_conv2d_params.strides[0] + ) + pad_left, pad_right = get_pad_value( + qnn_conv2d_params.ifm.get_dim_size("W"), dilated_kernel_w, qnn_conv2d_params.strides[1] + ) + do_pad = not (pad_top == 0 and pad_bottom == 0 and pad_left == 0 and pad_right == 0) + if do_pad: + params["padding"] = [pad_top, pad_left, pad_bottom, pad_right] + qnn_conv2d_params.pad = params["padding"] + params["input_zero_point"] = qnn_conv2d_params.ifm.zp + params["kernel_zero_point"] = qnn_conv2d_params.kernel.zp + params["out_dtype"] = "int32" + params["input_scale"] = qnn_conv2d_params.ifm.sc + params["kernel_scale"] = qnn_conv2d_params.kernel.sc + params["channels"] = int(qnn_conv2d_params.kernel.get_dim_size("O")) + params["kernel_layout"] = qnn_conv2d_params.kernel.layout + k_shape = qnn_conv2d_params.kernel.shape + k_dtype = qnn_conv2d_params.kernel.dtype + w = tvm.nd.array( + np.random.randint( + np.iinfo(k_dtype).min, high=np.iinfo(k_dtype).max, size=k_shape, dtype=k_dtype + ) + ) + weight_expr = relay.const(w, k_dtype) + v_params.append(w) + qnn_conv2d_expr = qnn.op.conv2d(ifm_expr, weight_expr, **params) + b = tvm.nd.array( + np.random.randint( + 0, high=10, size=(qnn_conv2d_params.kernel.get_dim_size("O")), dtype="int32" + ) + ) + v_params.append(b) + bias_expr = relay.const(b, "int32") + bias = relay.nn.bias_add( + qnn_conv2d_expr, bias_expr, axis=qnn_conv2d_params.ifm.get_dim_index("C") + ) + bias_scale = relay.const( + qnn_conv2d_params.ifm.sc.data.asnumpy() * qnn_conv2d_params.kernel.sc.data.asnumpy(), + "float32", + ) + req_expr = relay.qnn.op.requantize( + bias, + bias_scale, # input zero scale + relay.const(0, "int32"), # input zero point + qnn_conv2d_params.ofm.sc, # output zero scale + qnn_conv2d_params.ofm.zp, # output zero point + out_dtype=qnn_conv2d_params.ofm.dtype, + ) + if qnn_conv2d_params.activation != "NONE": + assert qnn_conv2d_params.activation == "CLIP" + clip_expr = relay.clip(req_expr, qnn_conv2d_params.clip_min, qnn_conv2d_params.clip_max) + return clip_expr, v_params + + return req_expr, v_params + + +def create_pool2d(pooling_params, ifm_expr): + """Create a relay pooling operation""" + assert pooling_params.ifm.layout == "NHWC" + params = { + "pool_size": (pooling_params.size[0], pooling_params.size[1]), + "strides": (pooling_params.strides[0], pooling_params.strides[1]), + "padding": [0, 0], + "layout": "NHWC", + } + if pooling_params.pad == "SAME": + pad_top, pad_bottom = get_pad_value( + pooling_params.ifm.shape[1], pooling_params.size[0], pooling_params.strides[0] + ) + pad_left, pad_right = get_pad_value( + pooling_params.ifm.shape[2], pooling_params.size[1], pooling_params.strides[1] + ) + params["padding"] = [pad_top, pad_left, pad_bottom, pad_right] + if pooling_params.type == "MAX": + out = relay.op.nn.max_pool2d(ifm_expr, **params) + else: + assert pooling_params.type == "AVG" + out = relay.op.cast(ifm_expr, dtype="int32") + out = relay.op.nn.avg_pool2d(out, **params) + out = relay.op.cast(out, dtype=pooling_params.ofm.dtype) + return out + + +def create_qnn_add(ifm0_expr, ifm1_expr, add_params): + add = relay.qnn.op.add( + lhs=ifm0_expr, + rhs=ifm1_expr, + lhs_scale=add_params.ifm0.sc, + lhs_zero_point=add_params.ifm0.zp, + rhs_scale=add_params.ifm1.sc, + rhs_zero_point=add_params.ifm1.zp, + output_scale=add_params.ofm.sc, + output_zero_point=add_params.ofm.zp, + ) + return add diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py new file mode 100644 index 000000000000..f4c863ea4a3c --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -0,0 +1,343 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument + +import pytest +import numpy as np + +import tvm +from tvm import relay +from tvm.relay.backend.contrib import ethosu +from tvm.relay.backend.contrib.ethosu import legalize, preprocess +from tvm.relay.dataflow_pattern import * +from tvm.relay.op.contrib.ethosu import * +import relay_ir_builder + + +def test_split_indices_legalize(): + def create_graph(axis): + x = relay.var("x", shape=(1, 50, 50, 3)) + x_relu = relay.nn.relu(x) + split_o = relay.split(x_relu, [5, 20, 45], axis).tuple_value + return relay.Function([x], split_o) + + def expected_mod_axis1(): + expected_ir_string = """ + #[version = "0.0.5"] + def @tvmgen_default_ethosu_main_0(%x: Tensor[(1, 50, 50, 3), float32]) -> (Tensor[(1, 5, 50, 3), float32],\ + Tensor[(1, 15, 50, 3), float32],\ + Tensor[(1, 25, 50, 3), float32],\ + Tensor[(1, 5, 50, 3), float32]) { + %0 = nn.relu(%x) /* ty=Tensor[(1, 50, 50, 3), float32] */; + %1 = strided_slice(%0, begin=[0, 0, 0, 0], end=[1, 5, 50, 3], strides=[1], axes=None)\ + /* ty=Tensor[(1, 5, 50, 3), float32] */; + %2 = strided_slice(%0, begin=[0, 5, 0, 0], end=[1, 20, 50, 3], strides=[1], axes=None)\ + /* ty=Tensor[(1, 15, 50, 3), float32] */; + %3 = strided_slice(%0, begin=[0, 20, 0, 0], end=[1, 45, 50, 3], strides=[1], axes=None)\ + /* ty=Tensor[(1, 25, 50, 3), float32] */; + %4 = strided_slice(%0, begin=[0, 45, 0, 0], end=[1, 50, 50, 3], strides=[1], axes=None)\ + /* ty=Tensor[(1, 5, 50, 3), float32] */; + (%1, %2, %3, %4) + } + """ + return tvm.parser.fromtext(expected_ir_string) + + def expected_mod_axis2(): + expected_ir_string = """ + #[version = "0.0.5"] + def @tvmgen_default_ethosu_main_0(%x: Tensor[(1, 50, 50, 3), float32]) -> (Tensor[(1, 50, 5, 3), float32],\ + Tensor[(1, 50, 15, 3), float32],\ + Tensor[(1, 50, 25, 3), float32],\ + Tensor[(1, 50, 5, 3), float32]) { + %0 = nn.relu(%x) /* ty=Tensor[(1, 50, 50, 3), float32] */; + %1 = strided_slice(%0, begin=[0, 0, 0, 0], end=[1, 50, 5, 3], strides=[1], axes=None)\ + /* ty=Tensor[(1, 50, 5, 3), float32] */; + %2 = strided_slice(%0, begin=[0, 0, 5, 0], end=[1, 50, 20, 3], strides=[1], axes=None)\ + /* ty=Tensor[(1, 50, 15, 3), float32] */; + %3 = strided_slice(%0, begin=[0, 0, 20, 0], end=[1, 50, 45, 3], strides=[1], axes=None)\ + /* ty=Tensor[(1, 50, 25, 3), float32] */; + %4 = strided_slice(%0, begin=[0, 0, 45, 0], end=[1, 50, 50, 3], strides=[1], axes=None)\ + /* ty=Tensor[(1, 50, 5, 3), float32] */; + (%1, %2, %3, %4) + } + """ + return tvm.parser.fromtext(expected_ir_string) + + mod_axis1 = tvm.IRModule() + mod_axis1["tvmgen_default_ethosu_main_0"] = create_graph(1) + mod_axis1["tvmgen_default_ethosu_main_0"] = rewrite( + legalize.SplitRewriter(), mod_axis1["tvmgen_default_ethosu_main_0"] + ) + expected_axis1 = expected_mod_axis1() + tvm.ir.assert_structural_equal(mod_axis1, expected_axis1) + + mod_axis2 = tvm.IRModule() + mod_axis2["tvmgen_default_ethosu_main_0"] = create_graph(2) + mod_axis2["tvmgen_default_ethosu_main_0"] = rewrite( + legalize.SplitRewriter(), mod_axis2["tvmgen_default_ethosu_main_0"] + ) + expected_axis2 = expected_mod_axis2() + tvm.ir.assert_structural_equal(mod_axis2, expected_axis2) + + +def test_split_sections_legalize(): + def create_graph(axis, sections): + x = relay.var("x", shape=(1, 50, 50, 3)) + x_abs = relay.abs(x) + split_o = relay.split(x_abs, sections, axis).tuple_value + outputs = list() + for section_idx in range(sections): + split_single_out = relay.TupleGetItem(split_o, section_idx) + tanh = relay.tanh(split_single_out) + outputs.append(tanh) + tuple_out = relay.Tuple(outputs) + return relay.Function([x], tuple_out) + + def expected_mod_axis1(): + expected_ir_string = """ + #[version = "0.0.5"] + def @tvmgen_default_ethosu_main_0(%x: Tensor[(1, 50, 50, 3), float32]) -> (Tensor[(1, 10, 50, 3), float32],\ + Tensor[(1, 10, 50, 3), float32],\ + Tensor[(1, 10, 50, 3), float32],\ + Tensor[(1, 10, 50, 3), float32],\ + Tensor[(1, 10, 50, 3), float32]) { + %0 = abs(%x) /* ty=Tensor[(1, 50, 50, 3), float32] */; + %1 = strided_slice(%0, begin=[0, 0, 0, 0], end=[1, 10, 50, 3], strides=[1], axes=None)\ + /* ty=Tensor[(1, 10, 50, 3), float32] */; + %2 = strided_slice(%0, begin=[0, 10, 0, 0], end=[1, 20, 50, 3], strides=[1], axes=None)\ + /* ty=Tensor[(1, 10, 50, 3), float32] */; + %3 = strided_slice(%0, begin=[0, 20, 0, 0], end=[1, 30, 50, 3], strides=[1], axes=None)\ + /* ty=Tensor[(1, 10, 50, 3), float32] */; + %4 = strided_slice(%0, begin=[0, 30, 0, 0], end=[1, 40, 50, 3], strides=[1], axes=None)\ + /* ty=Tensor[(1, 10, 50, 3), float32] */; + %5 = strided_slice(%0, begin=[0, 40, 0, 0], end=[1, 50, 50, 3], strides=[1], axes=None)\ + /* ty=Tensor[(1, 10, 50, 3), float32] */; + %6 = (%1, %2, %3, %4, %5); + %7 = %6.0; + %8 = tanh(%7) /* ty=Tensor[(1, 10, 50, 3), float32] */; + %9 = %6.1; + %10 = tanh(%9) /* ty=Tensor[(1, 10, 50, 3), float32] */; + %11 = %6.2; + %12 = tanh(%11) /* ty=Tensor[(1, 10, 50, 3), float32] */; + %13 = %6.3; + %14 = tanh(%13) /* ty=Tensor[(1, 10, 50, 3), float32] */; + %15 = %6.4; + %16 = tanh(%15) /* ty=Tensor[(1, 10, 50, 3), float32] */; + (%8, %10, %12, %14, %16) + } + """ + return tvm.parser.fromtext(expected_ir_string) + + def expected_mod_axis2(): + expected_ir_string = """ + #[version = "0.0.5"] + def @tvmgen_default_ethosu_main_0(%x: Tensor[(1, 50, 50, 3), float32]) -> (Tensor[(1, 50, 10, 3), float32],\ + Tensor[(1, 50, 10, 3), float32],\ + Tensor[(1, 50, 10, 3), float32],\ + Tensor[(1, 50, 10, 3), float32],\ + Tensor[(1, 50, 10, 3), float32]) { + %0 = abs(%x) /* ty=Tensor[(1, 50, 50, 3), float32] */; + %1 = strided_slice(%0, begin=[0, 0, 0, 0], end=[1, 50, 10, 3], strides=[1], axes=None)\ + /* ty=Tensor[(1, 50, 10, 3), float32] */; + %2 = strided_slice(%0, begin=[0, 0, 10, 0], end=[1, 50, 20, 3], strides=[1], axes=None)\ + /* ty=Tensor[(1, 50, 10, 3), float32] */; + %3 = strided_slice(%0, begin=[0, 0, 20, 0], end=[1, 50, 30, 3], strides=[1], axes=None)\ + /* ty=Tensor[(1, 50, 10, 3), float32] */; + %4 = strided_slice(%0, begin=[0, 0, 30, 0], end=[1, 50, 40, 3], strides=[1], axes=None)\ + /* ty=Tensor[(1, 50, 10, 3), float32] */; + %5 = strided_slice(%0, begin=[0, 0, 40, 0], end=[1, 50, 50, 3], strides=[1], axes=None)\ + /* ty=Tensor[(1, 50, 10, 3), float32] */; + %6 = (%1, %2, %3, %4, %5); + %7 = %6.0; + %8 = tanh(%7) /* ty=Tensor[(1, 50, 10, 3), float32] */; + %9 = %6.1; + %10 = tanh(%9) /* ty=Tensor[(1, 50, 10, 3), float32] */; + %11 = %6.2; + %12 = tanh(%11) /* ty=Tensor[(1, 50, 10, 3), float32] */; + %13 = %6.3; + %14 = tanh(%13) /* ty=Tensor[(1, 50, 10, 3), float32] */; + %15 = %6.4; + %16 = tanh(%15) /* ty=Tensor[(1, 50, 10, 3), float32] */; + (%8, %10, %12, %14, %16) + } + """ + return tvm.parser.fromtext(expected_ir_string) + + mod_axis1 = tvm.IRModule() + mod_axis1["tvmgen_default_ethosu_main_0"] = create_graph(1, 5) + mod_axis1["tvmgen_default_ethosu_main_0"] = rewrite( + legalize.SplitRewriter(), mod_axis1["tvmgen_default_ethosu_main_0"] + ) + expected_axis1 = expected_mod_axis1() + tvm.ir.assert_structural_equal(mod_axis1, expected_axis1) + + mod_axis2 = tvm.IRModule() + mod_axis2["tvmgen_default_ethosu_main_0"] = create_graph(2, 5) + mod_axis2["tvmgen_default_ethosu_main_0"] = rewrite( + legalize.SplitRewriter(), mod_axis2["tvmgen_default_ethosu_main_0"] + ) + expected_axis2 = expected_mod_axis2() + tvm.ir.assert_structural_equal(mod_axis2, expected_axis2) + + +def infer_type_function_pass(func): + mod = tvm.IRModule() + mod["test"] = func + mod = relay.transform.InferType()(mod) + return mod["test"] + + +def get_shape_expr(in_expr, out_expr): + main_f = relay.Function([in_expr], out_expr) + main_f = infer_type_function_pass(main_f) + shape = [int(i) for i in main_f.body.checked_type.shape] + return shape + + +INVERSE_LAYOUT_TRANSFORM_OHWI_MAP = { + "HWIO": [1, 2, 3, 0], + "HWOI": [1, 2, 0, 3], + "OWHI": [0, 1, 2, 3], +} + + +def test_ethosu_conv2d_legalize(): + def create_graph_single(input_tensor_name, input_tensor_shape, input_tensor_dtype): + c1_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype) + c1_params.ifm.shape = input_tensor_shape + c1_params.kernel.shape = (3, 3, c1_params.ifm.shape[3], 32) + c1_params.strides = (1, 1) + c1_params.pad = "VALID" + c1_params.activation = "CLIP" + c1_params.clip_min = 23 + c1_params.clip_max = 180 + input0 = relay.var(input_tensor_name, shape=c1_params.ifm.shape, dtype=c1_params.ifm.dtype) + c1, new_params = relay_ir_builder.create_qnn_conv2d(c1_params, input0) + c1_params.ofm.shape = get_shape_expr(input0, c1) + + f = relay.Function([input0], c1) + mod = tvm.IRModule() + mod["main"] = f + return mod, [c1_params] + + def create_graph_double(input_tensor_name, input_tensor_shape, input_tensor_dtype): + c1_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype) + c1_params.ifm.shape = input_tensor_shape + c1_params.kernel.shape = (7, 7, c1_params.ifm.shape[3], 8) + c1_params.strides = (2, 2) + c1_params.pad = "VALID" + c1_params.activation = "CLIP" + c1_params.clip_min = 10 + c1_params.clip_max = 240 + input0 = relay.var(input_tensor_name, shape=c1_params.ifm.shape, dtype=c1_params.ifm.dtype) + c1, new_params = relay_ir_builder.create_qnn_conv2d(c1_params, input0) + c1_params.ofm.shape = get_shape_expr(input0, c1) + + c2_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype) + c2_params.ifm.shape = c1_params.ofm.shape + c2_params.kernel.shape = (5, 5, c2_params.ifm.shape[3], 16) + c2_params.strides = (1, 1) + c2_params.pad = "SAME" + c2, new_params = relay_ir_builder.create_qnn_conv2d(c2_params, c1) + c2_params.ofm.shape = get_shape_expr(input0, c2) + + f = relay.Function([input0], c2) + mod = tvm.IRModule() + mod["main"] = f + return mod, [c2_params, c1_params] + + def verify_tensor(tensor_type, expr): + assert list(tensor_type.shape) == list(expr.checked_type.shape) + assert str(tensor_type.dtype) == str(expr.checked_type.dtype) + + def verify_linear(ext_func, conv2d_params): + op = ext_func.body + for param in conv2d_params: + verify_tensor(param.ifm, op.args[0]) + verify_tensor(param.ofm, op) + + # This will be in OHWI layout + weights_ohwi = op.args[1].data.asnumpy() + weights_layout = str(param.kernel.layout) + weights = np.transpose(weights_ohwi, INVERSE_LAYOUT_TRANSFORM_OHWI_MAP[weights_layout]) + assert weights.shape == param.kernel.shape + assert weights.dtype == param.kernel.dtype + + assert list(op.args[2].checked_type.shape)[0] == weights_ohwi.shape[0] + + assert float(op.attrs.ifm_scale) == float(param.ifm.sc.data.asnumpy()) + assert int(op.attrs.ifm_zero_point) == int(param.ifm.zp.data.asnumpy()) + assert int(op.attrs.weight_zero_point) == int(param.kernel.zp.data.asnumpy()) + assert float(op.attrs.ofm_scale) == float(param.ofm.sc.data.asnumpy()) + assert int(op.attrs.ofm_zero_point) == int(param.ofm.zp.data.asnumpy()) + assert int(op.attrs.ofm_channels) == int(weights_ohwi.shape[0]) + assert list(op.attrs.padding) == list(param.pad) + assert list(op.attrs.strides) == list(param.strides) + assert list(op.attrs.dilation) == list(param.dilation) + assert str(op.attrs.activation) == str(param.activation) + assert int(op.attrs.clip_min) == int(param.clip_min) + assert int(op.attrs.clip_max) == int(param.clip_max) + op = op.args[0] + + test_cases = [ + (create_graph_single, ["input", (1, 299, 299, 3), "uint8"]), + (create_graph_double, ["input", (1, 128, 256, 4), "uint8"]), + ] + for test_case in test_cases: + mod, conv_params = test_case[0](*test_case[1]) + mod = ethosu.partition_for_ethosu(mod) + mod["tvmgen_default_ethosu_main_0"] = rewrite( + legalize.EthosUConv2DRewriter(), mod["tvmgen_default_ethosu_main_0"] + ) + verify_linear(mod["tvmgen_default_ethosu_main_0"], conv_params) + + +def test_ethosu_conv2d_legalize_errors(): + def create_graph_single_unsupported_ifm_layout( + input_tensor_name, input_tensor_shape, input_tensor_dtype + ): + c1_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype) + c1_params.ifm.shape = input_tensor_shape + c1_params.ifm.layout = "NCHW" + c1_params.kernel.shape = (3, 3, c1_params.ifm.shape[1], 32) + c1_params.strides = (1, 1) + c1_params.pad = "VALID" + c1_params.activation = "CLIP" + c1_params.clip_min = 23 + c1_params.clip_max = 180 + input0 = relay.var(input_tensor_name, shape=c1_params.ifm.shape, dtype=c1_params.ifm.dtype) + c1, new_params = relay_ir_builder.create_qnn_conv2d(c1_params, input0) + c1_params.ofm.shape = get_shape_expr(input0, c1) + + f = relay.Function([input0], c1) + mod = tvm.IRModule() + mod["main"] = f + return mod, [c1_params] + + test_cases = [ + (create_graph_single_unsupported_ifm_layout, ["input", (1, 3, 299, 299), "uint8"]), + ] + + for test_case in test_cases: + mod, conv_params = test_case[0](*test_case[1]) + mod = ethosu.partition_for_ethosu(mod) + try: + mod["tvmgen_default_ethosu_main_0"] = rewrite( + legalize.EthosUConv2DRewriter(), mod["tvmgen_default_ethosu_main_0"] + ) + except Exception as e: + assert "EthosUCodegenError: Unsupported Layout NCHW" in e.args[0] diff --git a/tests/python/contrib/test_ethosu/test_preprocess.py b/tests/python/contrib/test_ethosu/test_preprocess.py new file mode 100644 index 000000000000..d22d1b7f98b4 --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_preprocess.py @@ -0,0 +1,343 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument + +import numpy as np + +import tvm +from tvm import relay +from tvm.relay.backend.contrib.ethosu import preprocess + + +def set_func_attr(func, compile_name, symbol_name): + """ + Helper function to attach attributes to the external function. + """ + func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + func = func.with_attr("Compiler", compile_name) + func = func.with_attr("global_symbol", symbol_name) + return func + + +def test_single_io(): + """ + This test will test the pass wont touch external functions that + have a single input and a single output. + """ + + def create_graph(): + def create_external_func1(mod_, compiler_name, symbol_name): + x_int = relay.var("x_int", shape=(10, 10)) + z0 = relay.nn.relu(x_int) + f1 = relay.Function([x_int], z0) + f1 = set_func_attr(f1, compiler_name, symbol_name) + glb_f1 = relay.GlobalVar(symbol_name) + mod_[glb_f1] = f1 + mod_ = relay.transform.InferType()(mod_) + return glb_f1, mod_ + + mod = tvm.IRModule() + x = relay.var("x", shape=(10, 10)) + + glb_symbol_f1, mod = create_external_func1(mod, "ethosu", "ethosu_0") + r = relay.Call(glb_symbol_f1, [x]) + main = relay.Function([x], r) + mod["main"] = main + mod = relay.transform.InferType()(mod) + return mod + + mod = create_graph() + exp = create_graph() + mod = preprocess.preprocess_ext_io()(mod) + assert tvm.ir.structural_equal(mod, exp, map_free_vars=True) + + +def test_2ins_single_out(): + """ + The test is check two inputs and a single output of external function + """ + + def create_graph(): + def create_external_func1(mod_, compiler_name, symbol_name): + x_int = relay.var("x_int", shape=(10, 10)) + w0_int = relay.var("w0_int", shape=(10, 10)) + z0 = relay.add(x_int, w0_int) + + f1 = relay.Function([x_int, w0_int], z0) + f1 = set_func_attr(f1, compiler_name, symbol_name) + glb_f1 = relay.GlobalVar(symbol_name) + mod_[glb_f1] = f1 + mod_ = relay.transform.InferType()(mod_) + return glb_f1, mod_ + + mod = tvm.IRModule() + + x = relay.var("x", shape=(10, 10)) + w0 = relay.var("w0", shape=(10, 10)) + + glb_symbol_f1, mod = create_external_func1(mod, "ethosu", "ethosu_0") + r = relay.Call(glb_symbol_f1, [x, w0]) + main = relay.Function([x, w0], r) + mod["main"] = main + mod = relay.transform.InferType()(mod) + return mod + + def expected(): + def create_external_func1(mod_, compiler_name, symbol_name): + ifms_int = relay.var("ifms_int", shape=[200]) + + # splits + (x_int_flat, w0_int_flat) = relay.split(ifms_int, [100]) + # reshapes + x_int = relay.reshape(x_int_flat, newshape=(10, 10)) + w0_int = relay.reshape(w0_int_flat, newshape=(10, 10)) + + z0 = relay.add(x_int, w0_int) + f1 = relay.Function([ifms_int], z0) + f1 = set_func_attr(f1, compiler_name, symbol_name) + glb_f1 = relay.GlobalVar(symbol_name) + mod_[glb_f1] = f1 + mod_ = relay.transform.InferType()(mod_) + return glb_f1, mod_ + + mod = tvm.IRModule() + + x = relay.var("x", shape=(10, 10)) + w0 = relay.var("w0", shape=(10, 10)) + + # reshapes + x_reshaped = relay.reshape(x, newshape=100) + w0_reshaped = relay.reshape(w0, newshape=100) + + # concat + ifms = relay.concatenate((x_reshaped, w0_reshaped), 0) + + glb_symbol_f1, mod = create_external_func1(mod, "ethosu", "ethosu_0") + r = relay.Call(glb_symbol_f1, [ifms]) + main = relay.Function([x, w0], r) + mod["main"] = main + mod = relay.transform.InferType()(mod) + return mod + + mod = create_graph() + exp = expected() + mod = preprocess.preprocess_ext_io()(mod) + assert tvm.ir.structural_equal(mod, exp, map_free_vars=True) + + +def test_single_in_2outs(): + """ + The test is to check a single input and two outputs of external function + """ + + def create_graph(): + def create_external_func1(mod_, compiler_name, symbol_name): + x_int = relay.var("x_int", shape=(10, 10)) + + p0 = relay.nn.relu(x_int) + q0 = relay.tanh(x_int) + f1_o_tuple = relay.Tuple([p0, q0]) + + f1 = relay.Function([x_int], f1_o_tuple) + f1 = set_func_attr(f1, compiler_name, symbol_name) + glb_f1 = relay.GlobalVar(symbol_name) + mod_[glb_f1] = f1 + mod_ = relay.transform.InferType()(mod_) + return glb_f1, mod_ + + mod = tvm.IRModule() + x = relay.var("x", shape=(10, 10)) + glb_symbol_f1, mod = create_external_func1(mod, "ethosu", "ethosu_0") + pq_tuple = relay.Call(glb_symbol_f1, [x]) + p0 = relay.TupleGetItem(pq_tuple, 0) + q0 = relay.TupleGetItem(pq_tuple, 1) + r = relay.concatenate((p0, q0), axis=0) + main = relay.Function([x], r) + mod["main"] = main + mod = relay.transform.InferType()(mod) + return mod + + def expected(): + def create_external_func1(mod_, compiler_name, symbol_name): + x_int = relay.var("x_int", shape=(10, 10)) + + p0 = relay.nn.relu(x_int) + q0 = relay.tanh(x_int) + + # reshapes + p0_reshaped = relay.reshape(p0, newshape=100) + q0_reshaped = relay.reshape(q0, newshape=100) + ofms = relay.concatenate((p0_reshaped, q0_reshaped), 0) + + f1 = relay.Function([x_int], ofms) + f1 = set_func_attr(f1, compiler_name, symbol_name) + glb_f1 = relay.GlobalVar(symbol_name) + mod_[glb_f1] = f1 + mod_ = relay.transform.InferType()(mod_) + return glb_f1, mod_ + + mod = tvm.IRModule() + x = relay.var("x", shape=(10, 10)) + glb_symbol_f1, mod = create_external_func1(mod, "ethosu", "ethosu_0") + ofms = relay.Call(glb_symbol_f1, [x]) + + # splits + (p0_flat, q0_flat) = relay.split(ofms, [100]) + # reshapes + p0_flat_reshaped = relay.reshape(p0_flat, newshape=(10, 10)) + q0_flat_reshaped = relay.reshape(q0_flat, newshape=(10, 10)) + # original output + tuple_out = relay.Tuple([p0_flat_reshaped, q0_flat_reshaped]) + + p0 = relay.TupleGetItem(tuple_out, 0) + q0 = relay.TupleGetItem(tuple_out, 1) + r = relay.concatenate((p0, q0), axis=0) + main = relay.Function([x], r) + mod["main"] = main + mod = relay.transform.InferType()(mod) + return mod + + mod = create_graph() + exp = expected() + mod = relay.transform.InferType()(mod) + mod = preprocess.preprocess_ext_io()(mod) + assert tvm.ir.structural_equal(mod, exp, map_free_vars=True) + + +def test_4ins_2outs(): + """ + The test is to check a 4 inputs and two outputs of external function. + This just stand as a general test for multiple ins/outs. + """ + + def create_graph(): + def create_external_func1(mod_, compiler_name, symbol_name): + x_int = relay.var("x_int", shape=(10, 10)) + w0_int = relay.var("w0_int", shape=(10, 10)) + w1_int = relay.var("w1_int", shape=(10, 10)) + w2_int = relay.var("w2_int", shape=(10, 10)) + + z0 = relay.add(x_int, w0_int) + p0 = relay.subtract(z0, w1_int) + q0 = relay.multiply(z0, w2_int) + f1_o_tuple = relay.Tuple([p0, q0]) + + f1 = relay.Function([x_int, w0_int, w1_int, w2_int], f1_o_tuple) + f1 = set_func_attr(f1, compiler_name, symbol_name) + glb_f1 = relay.GlobalVar(symbol_name) + mod_[glb_f1] = f1 + mod_ = relay.transform.InferType()(mod_) + return glb_f1, mod_ + + mod = tvm.IRModule() + + x = relay.var("x", shape=(10, 10)) + w0 = relay.var("w0", shape=(10, 10)) + w1 = relay.var("w1", shape=(10, 10)) + w2 = relay.var("w2", shape=(10, 10)) + + glb_symbol_f1, mod = create_external_func1(mod, "ethosu", "ethosu_0") + pq_tuple = relay.Call(glb_symbol_f1, [x, w0, w1, w2]) + + p0 = relay.TupleGetItem(pq_tuple, 0) + q0 = relay.TupleGetItem(pq_tuple, 1) + r = relay.concatenate((p0, q0), axis=0) + main = relay.Function([x, w0, w1, w2], r) + mod["main"] = main + mod = relay.transform.InferType()(mod) + return mod + + def expected(): + def create_external_func1(mod_, compiler_name, symbol_name): + ifms_int = relay.var("ifms_int", shape=[400]) + + # splits + (x_int_flat, w0_int_flat, w1_int_flat, w2_int_flat) = relay.split( + ifms_int, [100, 200, 300] + ) + # reshapes + x_int = relay.reshape(x_int_flat, newshape=(10, 10)) + w0_int = relay.reshape(w0_int_flat, newshape=(10, 10)) + w1_int = relay.reshape(w1_int_flat, newshape=(10, 10)) + w2_int = relay.reshape(w2_int_flat, newshape=(10, 10)) + + z0 = relay.add(x_int, w0_int) + p0 = relay.subtract(z0, w1_int) + q0 = relay.multiply(z0, w2_int) + # f1_o_tuple = relay.Tuple([p0, q0]) + + # reshapes + p0_reshaped = relay.reshape(p0, newshape=100) + q0_reshaped = relay.reshape(q0, newshape=100) + ofms = relay.concatenate((p0_reshaped, q0_reshaped), 0) + + f1 = relay.Function([ifms_int], ofms) + f1 = set_func_attr(f1, compiler_name, symbol_name) + glb_f1 = relay.GlobalVar(symbol_name) + mod_[glb_f1] = f1 + mod_ = relay.transform.InferType()(mod_) + return glb_f1, mod_ + + mod = tvm.IRModule() + + x = relay.var("x", shape=(10, 10)) + w0 = relay.var("w0", shape=(10, 10)) + w1 = relay.var("w1", shape=(10, 10)) + w2 = relay.var("w2", shape=(10, 10)) + + # reshapes + x_reshaped = relay.reshape(x, newshape=100) + w0_reshaped = relay.reshape(w0, newshape=100) + w1_reshaped = relay.reshape(w1, newshape=100) + w2_reshaped = relay.reshape(w2, newshape=100) + + # concat + ifms = relay.concatenate((x_reshaped, w0_reshaped, w1_reshaped, w2_reshaped), 0) + + # call + glb_func, mod = create_external_func1(mod, "ethosu", "ethosu_0") + ofms = relay.Call(glb_func, [ifms]) + + # splits + (p0_flat, q0_flat) = relay.split(ofms, [100]) + # reshapes + p0_flat_reshaped = relay.reshape(p0_flat, newshape=(10, 10)) + q0_flat_reshaped = relay.reshape(q0_flat, newshape=(10, 10)) + # original output + tuple_out = relay.Tuple([p0_flat_reshaped, q0_flat_reshaped]) + + p0 = relay.TupleGetItem(tuple_out, 0) + q0 = relay.TupleGetItem(tuple_out, 1) + + r = relay.concatenate((p0, q0), axis=0) + main = relay.Function([x, w0, w1, w2], r) + mod["main"] = main + mod = relay.transform.InferType()(mod) + return mod + + mod = create_graph() + exp = expected() + mod = preprocess.preprocess_ext_io()(mod) + assert tvm.ir.structural_equal(mod, exp, map_free_vars=True) + + +if __name__ == "__main__": + test_2ins_single_out() + test_single_io() + test_4ins_2outs() + test_single_in_2outs() diff --git a/tests/python/contrib/test_ethosu/test_vela_api.py b/tests/python/contrib/test_ethosu/test_vela_api.py new file mode 100644 index 000000000000..ad3363c7691f --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_vela_api.py @@ -0,0 +1,453 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import numpy as np +from ethosu.vela import api as vapi +from unittest.mock import patch + +import tvm +from tvm import tir +from tvm.script import ty +from tvm.relay.backend.contrib.ethosu import vela_api + +ACCEL_TYPES = [ + vapi.NpuAccelerator.Ethos_U55_256, + vapi.NpuAccelerator.Ethos_U55_128, + vapi.NpuAccelerator.Ethos_U55_64, + vapi.NpuAccelerator.Ethos_U55_32, +] + + +"""Test case 1""" + + +@tvm.script.tir +class Module1: + def main( + placeholder: ty.handle, + placeholder_1: ty.handle, + placeholder_2: ty.handle, + ethosu_conv2d: ty.handle, + ) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_3 = tir.match_buffer( + placeholder, [1, 8, 8, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1 + ) + placeholder_4 = tir.match_buffer( + placeholder_1, [48], dtype="uint8", elem_offset=0, align=128, offset_factor=1 + ) + placeholder_5 = tir.match_buffer( + placeholder_2, [16], dtype="int32", elem_offset=0, align=128, offset_factor=1 + ) + ethosu_conv2d_1 = tir.match_buffer( + ethosu_conv2d, [1, 8, 8, 16], dtype="uint8", elem_offset=0, align=128, offset_factor=1 + ) + # body + tir.evaluate( + tir.call_extern( + "ethosu_conv2d", + "uint8", + 8, + 8, + 3, + 8, + 0, + 8, + tir.load("uint8", placeholder_3.data, 0), + 0, + 0, + 0, + tir.float32(0.5), + 10, + "NHWC", + 24, + 3, + 1, + "uint8", + 8, + 8, + 16, + 8, + 0, + 8, + tir.load("uint8", ethosu_conv2d_1.data, 0), + 0, + 0, + 0, + tir.float32(0.25), + 14, + "NHWC", + 128, + 16, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + tir.load("uint8", placeholder_4.data, 0), + 0, + 12, + tir.load("uint8", placeholder_5.data, 0), + 0, + 0, + 0, + 0, + 0, + "CLIP", + 0, + 0, + "NONE", + dtype="uint8", + ) + ) + + __tvm_meta__ = None + + +"""Test case 2 with per-channel quantization""" + + +@tvm.script.tir +class Module2: + def main( + placeholder: ty.handle, + placeholder_1: ty.handle, + placeholder_2: ty.handle, + placeholder_6: ty.handle, + ethosu_conv2d: ty.handle, + ) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_3 = tir.match_buffer( + placeholder, [1, 8, 8, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1 + ) + placeholder_4 = tir.match_buffer( + placeholder_1, [16, 1, 1, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1 + ) + placeholder_5 = tir.match_buffer( + placeholder_2, [16], dtype="int32", elem_offset=0, align=128, offset_factor=1 + ) + # Per-channel weight scales + placeholder_7 = tir.match_buffer( + placeholder_6, [16], dtype="float32", elem_offset=0, align=128, offset_factor=1 + ) + ethosu_conv2d_1 = tir.match_buffer( + ethosu_conv2d, [1, 8, 8, 16], dtype="uint8", elem_offset=0, align=128, offset_factor=1 + ) + # body + tir.evaluate( + tir.call_extern( + "ethosu_conv2d", + "uint8", + 8, + 8, + 3, + 8, + 0, + 8, + tir.load("uint8", placeholder_3.data, 0), + 0, + 0, + 0, + tir.float32(0.5), + 10, + "NHWC", + 24, + 3, + 1, + "uint8", + 8, + 8, + 16, + 8, + 0, + 8, + tir.load("uint8", ethosu_conv2d_1.data, 0), + 0, + 0, + 0, + tir.float32(0.25), + 14, + "NHWC", + 128, + 16, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + tir.load("uint8", placeholder_4.data, 0), + 0, + 12, + tir.load("uint8", placeholder_5.data, 0), + 0, + 0, + 0, + 0, + 0, + "CLIP", + 0, + 0, + "NONE", + dtype="uint8", + ) + ) + + __tvm_meta__ = None + + +def test_get_optimal_block_config(): + block_configs_cases = [ + { + "test": [ + vapi.NpuShape3D(10, 20, 8), + vapi.NpuShape3D(10, 30, 16), + vapi.NpuShape3D(10, 40, 32), + ], + "ref": vapi.NpuShape3D(10, 40, 32), + }, + { + "test": [ + vapi.NpuShape3D(10, 20, 8), + vapi.NpuShape3D(10, 50, 32), + vapi.NpuShape3D(10, 40, 32), + ], + "ref": vapi.NpuShape3D(10, 50, 32), + }, + { + "test": [ + vapi.NpuShape3D(50, 50, 8), + vapi.NpuShape3D(10, 30, 32), + vapi.NpuShape3D(8, 8, 64), + ], + "ref": vapi.NpuShape3D(8, 8, 64), + }, + ] + + for test_case in block_configs_cases: + assert vela_api._get_optimal_block_config(test_case["test"]) == test_case["ref"] + + +def test_compress_weights(): + test_vecs = [ + { + # Stimulus + "accel": vapi.NpuAccelerator.Ethos_U55_256, + "block_depth": 8, + "ifm_dtype": np.uint8, + "shape": (3, 3, 16, 64), + "layout": "HWIO", + "zero_point": np.int64(134), + "dilation": (1, 1), + "is_depthwise": False, + # Reference outputs + "block_traversal": vapi.NpuBlockTraversal.PART_KERNEL_FIRST, + }, + { + # Stimulus + "accel": vapi.NpuAccelerator.Ethos_U55_256, + "block_depth": 8, + "ifm_dtype": np.uint8, + "shape": (3, 3, 32, 64), + "layout": "HWIO", + "zero_point": np.int64(134), + "dilation": (1, 1), + "is_depthwise": False, + # Reference outputs + "block_traversal": vapi.NpuBlockTraversal.DEPTH_FIRST, + }, + { + # Stimulus + "accel": vapi.NpuAccelerator.Ethos_U55_256, + "block_depth": 8, + "ifm_dtype": np.int16, + "shape": (3, 3, 16, 64), + "layout": "HWIO", + "zero_point": np.int64(134), + "dilation": (1, 1), + "is_depthwise": False, + # Reference outputs + "block_traversal": vapi.NpuBlockTraversal.DEPTH_FIRST, + }, + # Pass-through value check + { + # Stimulus + "accel": vapi.NpuAccelerator.Ethos_U55_128, + "block_depth": 16, + "ifm_dtype": np.uint8, + "shape": (243, 152, 7, 1), + "layout": "HWOI", + "zero_point": np.int64(110), + "dilation": (2, 2), + "is_depthwise": True, + # Reference outputs + "block_traversal": vapi.NpuBlockTraversal.DEPTH_FIRST, + }, + { + # Stimulus + "accel": vapi.NpuAccelerator.Ethos_U55_128, + "block_depth": 32, + "ifm_dtype": np.uint8, + "shape": (64, 67, 35, 8), + "layout": "OHWI", + "zero_point": np.int64(100), + "dilation": (1, 2), + "is_depthwise": False, + # Reference outputs + "block_traversal": vapi.NpuBlockTraversal.PART_KERNEL_FIRST, + }, + ] + + def verify(test_vec, mock_obj): + layout_transform_indices = { + "HWIO": (3, 0, 1, 2), + "HWOI": (2, 0, 1, 3), + "OHWI": (0, 1, 2, 3), + } + + assert mock_obj + mock_obj.assert_called_once() + assert mock_obj.call_args[1]["accelerator"] == test_vec["accel"] + assert mock_obj.call_args[1]["accelerator"] == test_vec["accel"] + ishape = test_vec["shape"] + shape_owhi = ( + ishape[layout_transform_indices[test_vec["layout"]][0]], + ishape[layout_transform_indices[test_vec["layout"]][1]], + ishape[layout_transform_indices[test_vec["layout"]][2]], + ishape[layout_transform_indices[test_vec["layout"]][3]], + ) + assert mock_obj.call_args[1]["weights_volume"].shape == shape_owhi + assert mock_obj.call_args[1]["dilation_xy"] == test_vec["dilation"] + assert mock_obj.call_args[1]["ifm_bitdepth"] == np.iinfo(test_vec["ifm_dtype"]).bits + assert mock_obj.call_args[1]["ofm_block_depth"] == test_vec["block_depth"] + assert mock_obj.call_args[1]["is_depthwise"] == test_vec["is_depthwise"] + assert mock_obj.call_args[1]["block_traversal"] == test_vec["block_traversal"] + + def create_mock(test_vec): + with patch( + "tvm.relay.backend.contrib.ethosu.vela_api.vapi.npu_encode_weights" + ) as mock_npu_encode_weights: + ifm_bitdepth = np.iinfo(test_vec["ifm_dtype"]).bits + ifm_dtype = test_vec["ifm_dtype"] + max = np.iinfo(ifm_dtype).max + min = np.iinfo(ifm_dtype).min + values = np.random.randint(min, max, test_vec["shape"], ifm_dtype) + compressed_weights = vela_api.compress_weights( + weights=values, + weights_zp=test_vec["zero_point"], + weights_layout=test_vec["layout"], + ifm_bitdepth=ifm_bitdepth, + block_depth=test_vec["block_depth"], + dilation=test_vec["dilation"], + accel_type=test_vec["accel"], + is_depthwise=test_vec["is_depthwise"], + ) + return mock_npu_encode_weights + return None + + for tv in test_vecs: + mock_obj = create_mock(tv) + verify(tv, mock_obj) + + +def test_pack_biases(): + test_vecs = [ + { + # Stimulus + "bias_length": 3, + "ifm_scale": np.single(1.11111111), + "ifm_dtype": np.uint8, + "weight_scales": np.array( + [np.single(0.91111111), np.single(1.01111111), np.single(1.11111111)] + ), + "ofm_scale": np.single(1.2), + "is_activation_tanh_or_sigmoid": False, + # Reference outputs + "hw_scales": (1811663288, 2010504240, 1104672703), + "hw_shifts": (31, 31, 30), + }, + { + # Stimulus + "bias_length": 3, + "ifm_scale": np.single(1.11111111), + "ifm_dtype": np.int8, + "weight_scales": np.array( + [np.single(0.91111111), np.single(1.01111111), np.single(1.11111111)] + ), + "ofm_scale": np.single(1.2), + "is_activation_tanh_or_sigmoid": False, + # Reference outputs + "hw_scales": (1811663185, 2010504312, 1104672720), + "hw_shifts": (31, 31, 30), + }, + { + # Stimulus + "bias_length": 3, + "ifm_scale": np.single(1.11111111), + "ifm_dtype": np.int16, + "weight_scales": np.array( + [np.single(0.91111111), np.single(1.01111111), np.single(1.11111111)] + ), + "ofm_scale": np.single(1.2), + "is_activation_tanh_or_sigmoid": False, + # Reference outputs + "hw_scales": (27644, 30678, 16856), + "hw_shifts": (15, 15, 14), + }, + ] + + def verify(test_vec, mock_obj, packed_biases): + assert mock_obj + for idx, val in enumerate(test_vec["bias_values"]): + assert val == mock_obj.call_args_list[idx][0][0] + assert test_vec["hw_scales"][idx] == mock_obj.call_args_list[idx][0][1] + assert test_vec["hw_shifts"][idx] == mock_obj.call_args_list[idx][0][2] + + def create_mock(test_vec): + with patch( + "tvm.relay.backend.contrib.ethosu.vela_api.vapi.npu_encode_bias" + ) as mock_npu_encode_bias: + mock_npu_encode_bias.return_value = bytearray(10) + ifm_dtype = test_vec["ifm_dtype"] + max = np.iinfo(ifm_dtype).max + min = np.iinfo(ifm_dtype).min + # tvm will always create biases in int32 + biases = np.random.randint(min, max, test_vec["bias_length"], np.int32) + packed_biases = vela_api.pack_biases( + biases=biases, + ifm_scale=test_vec["ifm_scale"], + ifm_dtype=test_vec["ifm_dtype"], + weight_scales=test_vec["weight_scales"], + ofm_scale=test_vec["ofm_scale"], + is_activation_tanh_or_sigmoid=test_vec["is_activation_tanh_or_sigmoid"], + ) + test_vec["bias_values"] = biases + return mock_npu_encode_bias, packed_biases + return None + + for _test_vec in test_vecs: + mock_obj, packed_biases = create_mock(_test_vec) + verify(_test_vec, mock_obj, packed_biases) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/scripts/task_config_build_cpu.sh b/tests/scripts/task_config_build_cpu.sh index 4c9d50ec90bb..ace88b46ae0c 100755 --- a/tests/scripts/task_config_build_cpu.sh +++ b/tests/scripts/task_config_build_cpu.sh @@ -48,3 +48,4 @@ echo set\(USE_VITIS_AI ON\) >> config.cmake echo set\(USE_VERILATOR ON\) >> config.cmake echo set\(USE_LIBBACKTRACE ON\) >> config.cmake echo set\(USE_CCACHE OFF\) >> config.cmake +echo set\(USE_ETHOSU ON\) >> config.cmake From 9a4188b207b44d7c72704efed3d08d6ed1aa43aa Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Fri, 20 Aug 2021 11:19:06 +0100 Subject: [PATCH 02/14] Arm(R) Ethos(TM)-U NPU Relay passes and Conv2D op * skipping the test if vela is not in the container. Change-Id: I68cc4259dc33e1473e460956978f364fbf6596d8 --- tests/python/contrib/test_ethosu/test_legalize.py | 2 ++ tests/python/contrib/test_ethosu/test_preprocess.py | 3 +++ tests/python/contrib/test_ethosu/test_vela_api.py | 2 ++ 3 files changed, 7 insertions(+) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index f4c863ea4a3c..358fa0cc7055 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -17,6 +17,8 @@ # pylint: disable=invalid-name, unused-argument import pytest + +pytest.importorskip("ethosu.vela") import numpy as np import tvm diff --git a/tests/python/contrib/test_ethosu/test_preprocess.py b/tests/python/contrib/test_ethosu/test_preprocess.py index d22d1b7f98b4..2aeeed078942 100644 --- a/tests/python/contrib/test_ethosu/test_preprocess.py +++ b/tests/python/contrib/test_ethosu/test_preprocess.py @@ -16,6 +16,9 @@ # under the License. # pylint: disable=invalid-name, unused-argument +import pytest + +pytest.importorskip("ethosu.vela") import numpy as np import tvm diff --git a/tests/python/contrib/test_ethosu/test_vela_api.py b/tests/python/contrib/test_ethosu/test_vela_api.py index ad3363c7691f..d9b22d10e9e2 100644 --- a/tests/python/contrib/test_ethosu/test_vela_api.py +++ b/tests/python/contrib/test_ethosu/test_vela_api.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. import pytest + +pytest.importorskip("ethosu.vela") import numpy as np from ethosu.vela import api as vapi from unittest.mock import patch From e059c89048a4a4d91a1a0be6cdeb0e6cc1b54392 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Tue, 24 Aug 2021 18:49:03 +0100 Subject: [PATCH 03/14] Arm(R) Ethos(TM)-U NPU Relay passes and Conv2D op * addressing Jared's comments Change-Id: Ief669f788c6bd1a1be1004cbce5129ed06b63c3c --- python/tvm/relay/backend/contrib/__init__.py | 2 +- .../relay/backend/contrib/ethosu/__init__.py | 2 +- .../relay/backend/contrib/ethosu/errors.py | 6 +- .../relay/backend/contrib/ethosu/legalize.py | 66 ++++++++------- .../backend/contrib/ethosu/preprocess.py | 14 +++- .../tvm/relay/backend/contrib/ethosu/util.py | 83 ++++++++----------- .../relay/backend/contrib/ethosu/vela_api.py | 23 +++-- python/tvm/relay/op/contrib/ethosu.py | 42 +++++----- .../backend/contrib/ethosu/preprocess.cc | 7 +- .../contrib/test_ethosu/test_legalize.py | 24 ++---- 10 files changed, 128 insertions(+), 141 deletions(-) diff --git a/python/tvm/relay/backend/contrib/__init__.py b/python/tvm/relay/backend/contrib/__init__.py index 9074e40af08b..5d2a3979b9e1 100644 --- a/python/tvm/relay/backend/contrib/__init__.py +++ b/python/tvm/relay/backend/contrib/__init__.py @@ -14,6 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""external backend codegen modules for relay.""" +"""External backend codegen modules for Relay.""" from . import cmsisnn from . import ethosu diff --git a/python/tvm/relay/backend/contrib/ethosu/__init__.py b/python/tvm/relay/backend/contrib/ethosu/__init__.py index 3f315a74cbaa..f5c595462e73 100644 --- a/python/tvm/relay/backend/contrib/ethosu/__init__.py +++ b/python/tvm/relay/backend/contrib/ethosu/__init__.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Arm(R) Ethos(TM)-U NPU codegen modules for relay.""" +"""Arm(R) Ethos(TM)-U NPU codegen modules for Relay.""" from . import util from . import legalize from . import preprocess diff --git a/python/tvm/relay/backend/contrib/ethosu/errors.py b/python/tvm/relay/backend/contrib/ethosu/errors.py index 8625ddc880b7..435c9c8337ef 100644 --- a/python/tvm/relay/backend/contrib/ethosu/errors.py +++ b/python/tvm/relay/backend/contrib/ethosu/errors.py @@ -15,11 +15,11 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=super-init-not-called -"""This module is to hold all type of errors associated Arm(R) Ethos(TM)-U NPU Codegen""" +"""This module defines all error types associated with the Arm(R) Ethos(TM)-U NPU code generator.""" class EthosUCodegenError(Exception): - """Base class for all exceptions related to Codegen""" + """Base class for all exceptions related to code generation""" def __init__(self, data): self.message = "EthosUCodegenError:" + data @@ -29,7 +29,7 @@ def __str__(self): class UnsupportedLayout(EthosUCodegenError): - """Raised when unsupported layout is encountered in the codegen""" + """Raised when unsupported layout is encountered during code generation.""" def __init__(self, layout): super().__init__(f"Unsupported Layout {layout}") diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 2e54ffb25fc6..6aa36dde52a2 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -14,12 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, unused-argument, import-outside-toplevel +# pylint: disable=invalid-name, unused-argument, import-outside-toplevel, no-value-for-parameter """ A set of passes to legalize some of operations for the NPU""" import numpy as np import tvm from tvm import relay +from tvm import ir from tvm.relay.dataflow_pattern import DFPatternCallback from tvm.relay.dataflow_pattern import wildcard from tvm.relay.dataflow_pattern import is_op @@ -31,11 +32,10 @@ class SplitRewriter(DFPatternCallback): - """Convert split operations to bunch of strided_slice operations, - because codegen is going to be based on strided_slices that are - close to in/out boxes of Vela High-Level Command Stream (HLCS). - Moreover, Vela HLCS is a high-level description of the supported - hardware operator. + """This rewriting converts split operations into a sequence of + strided_slice operations, because codegen is going to be based + on strided_slices that will define the slice of the tensor that + will be fed to the consumer. """ def __init__(self): @@ -45,14 +45,16 @@ def __init__(self): @staticmethod def get_section_begin_coords(split): - """Currently, the split can take an array of indices or an integer - indicating the number of splits. This helper functions unifies - this by making it a array of section begins. + """Currently, the split operator takes an array of indices or an integer + indicating the number of splits. However, its an array of indices could + represent both cases, therefore this function just make it an array of + indices where each index represent the co-ordinate of beginning of each + section -- defines as section begins. Parameters ---------- split : relay.Expr - The relay expression for split operator + The Relay Call expression for a split operator Returns ------- @@ -106,8 +108,16 @@ def callback(self, pre, post, node_map): return relay.Tuple(strided_slices) +@ir.transform.module_pass(opt_level=1) +def SplitRewriterPass(mod, ctx): + for gv, func in mod.functions.items(): + func = rewrite(SplitRewriter(), func) + mod.update_func(gv, func) + return mod + + class EthosUConv2DRewriter(DFPatternCallback): - """Convert conv2d related composite functions to ethosu_conv2d operators""" + """Convert conv2d related composite functions into ethosu_conv2d operators""" def __init__(self): super().__init__(require_type=True) @@ -175,26 +185,22 @@ def callback(self, pre, post, node_map): return ethosu_conv2d +@ir.transform.module_pass(opt_level=1) +def EthosUConv2DRewriterPass(mod, ctx): + for gv, func in mod.functions.items(): + func = rewrite(EthosUConv2DRewriter(), func) + mod.update_func(gv, func) + return mod + + +@relay.transform.function_pass(opt_level=1) class LegalizeEthosU: - """This is the wrapper class to call graph-rewrites to perform graph transformation - in a way such that the operations are replaced with hardware/codegen backend friendly + """This is the pass to call graph-rewrites to perform graph transformation + in a way such that the operations are replaced with hardware/codegen supported operations. """ - def __call__(self, func): - """The list of relay re-write passes need to be run to legalize - the external function for to be codegen'd. - - Parameters - ---------- - func : relay.function.Function - The external function - - Returns - ------- - func : relay.function.Function - The legalized external function - """ - func = rewrite(SplitRewriter(), func) - func = rewrite(EthosUConv2DRewriter(), func) - return func + def transform_function(self, func, mod): + mod = SplitRewriterPass(mod) + mod = EthosUConv2DRewriterPass(mod) + return mod diff --git a/python/tvm/relay/backend/contrib/ethosu/preprocess.py b/python/tvm/relay/backend/contrib/ethosu/preprocess.py index 77035b5b0826..f2bd079c99c3 100644 --- a/python/tvm/relay/backend/contrib/ethosu/preprocess.py +++ b/python/tvm/relay/backend/contrib/ethosu/preprocess.py @@ -15,13 +15,21 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-argument, import-outside-toplevel -"""Set of passes to pre-process the IRModule prior to codegen""" +"""Set of passes to pre-process the IRModule to support Arm(R)-Ethos(TM)-U +NPU code generation. These set of passes will mutate both the main and the +external functions. +""" from . import _ffi_api def preprocess_ext_io(): - """This function make the number of inputs going to / outputs coming out to/from - external function set to one. This is achieved via concatenation + """This pass mutates the number of inputs going to / outputs coming out to/from + external functions to one. This is achieved via concatenation of inputs and splitting of outputs in around the call to the external function. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass to mutate the IO of the external functions and their calls. """ return _ffi_api.PreprocessExternalFuncIO() diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index 45b3c6731809..85ca86cdfc24 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -15,11 +15,7 @@ # specific language governing permissions and limitations # under the License. """ -Helper utility Enums and Functions used through out codegen - -The enums are there to indicate which argument of each relay operator -corresponds with which input. -e.g., input zero point of qnn.conv2d is 4th argument(3rd index) +Helper utility Enums and Functions used through out code generation. The rest of the utility functions are misc. Refer to the description inside such functions @@ -35,78 +31,67 @@ class QConv2DArgs(Enum): """ - This is a helper enums to access the correct index - qnn conv2d arguments + This is a helper enum to obtain the correct index + of qnn.conv2d arguments. """ - ifm = 0 - weights = 1 - ifm_zero_point = 2 - weights_zero_point = 3 - ifm_scale = 4 - weights_scale = 5 + IFM = 0 + WEIGHTS = 1 + IFM_ZERO_POINT = 2 + WEIGHTS_ZERO_POINT = 3 + IFM_SCALE = 4 + WEIGHTS_SCALE = 5 class RequantArgs(Enum): """ - This is a helper enums to access the correct index - qnn requantize arguments + This is a helper enum to obtain the correct index + of qnn.requantize arguments. """ - ifm_scale = 1 - ifm_zero_point = 2 - ofm_scale = 3 - ofm_zero_point = 4 + IFM_SCALE = 1 + IFM_ZERO_POINT = 2 + OFM_SCALE = 3 + OFM_ZERO_POINT = 4 class BiasAddArgs(Enum): """ - This is a helper enums to access the correct index - qnn bias_add arguments + This is a helper enums to obtain the correct index + of qnn.bias_add arguments. """ - biases = 1 + BIASES = 1 class ClipArgs(Enum): """ - This is a helper enums to access the correct index - qnn bias_add arguments + This is a helper enums to obtain the correct index + of clip arguments. """ - a_min = 1 - a_max = 2 + A_MIN = 1 + A_MAX = 2 -class MaxPoolArgs(Enum): - """ - This is a helper enums to access the correct index - max pool arguments +def is_composite_func(func, name): """ + This method checks whether the call is to + a composite function of a given name. - ifm = 0 + Parameters + ---------- + func : relay.Function + The header to be displayed along with the dump. + name : str + The candidate name to be checked -class AddArgs(Enum): - """This is a helper enums to access the correct index - max pool arguments + Returns + -------- + a boolean """ - ifm0 = 0 - ifm1 = 1 - ifm0_scale = 2 - ifm0_zero_point = 3 - ifm1_scale = 4 - ifm1_zero_point = 5 - ofm_scale = 6 - ofm_zero_point = 7 - - -def is_composite_func(func, name): - """ - This a method to check whether the call is to - a composite function of the "name". - """ if not hasattr(func, "attrs"): return False if "Composite" not in func.attrs.keys(): diff --git a/python/tvm/relay/backend/contrib/ethosu/vela_api.py b/python/tvm/relay/backend/contrib/ethosu/vela_api.py index 3e772a953c16..129e6d81ae56 100644 --- a/python/tvm/relay/backend/contrib/ethosu/vela_api.py +++ b/python/tvm/relay/backend/contrib/ethosu/vela_api.py @@ -15,9 +15,7 @@ # specific language governing permissions and limitations # under the License. """ -conversions between TVM and Vela. Therefore, all interactions with the -Vela API are supposed to go through this adapter, with the hope that -any changes to Vela API, TVM only needs to change this file. +This is an adapter module for conversions between TVM and Vela. The following conversion APIs are added : *Obtaining the best block config *Compressing weights @@ -119,7 +117,9 @@ def compress_weights( accel_type, is_depthwise=False, ): - """Obtain compressed weights from vela + """The NPU requires the weights to be compressed + to be executed. Therefore, this function calls into + the Vela APIs to compress the weights. Parameters ---------- @@ -201,8 +201,12 @@ def pack_biases( is_activation_tanh_or_sigmoid=False, ): """ - Obtain packed bias bytearray as the hardware requires from - Vela. + The NPU requires the each bias value to be packed with + output scale parameters in a 80-bit format (that is returned + via npu_encode_bias API). This function will pack such values + to a binary artifact that the NPU will use in the execution. + + Parameters ---------- biases : numpy.ndarray @@ -239,10 +243,6 @@ def pack_biases( packed_biases = bytearray() for idx, scale in enumerate(hw_bias_scales): packed_biases.extend(vapi.npu_encode_bias(biases[idx], *scale)) - # Align to 16 - # remainder = (len(packed_biases)) % 16 - # if remainder > 0: - # packed_biases.extend(bytearray(16 - remainder)) scale_bias = np.frombuffer(packed_biases, dtype=np.uint8) scale_bias = np.reshape(scale_bias, (-1, 10)) return scale_bias @@ -301,8 +301,7 @@ def _calculate_hw_bias_scales( def get_target_accel_type(): - """This is a helper function to convert cli accelerator type str argument - to NpuAccelerator""" + """This is a helper function to convert TVMC command line argument to NpuAccelerator type""" npu_accel_str_map = { "ethos-u55-256": vapi.NpuAccelerator.Ethos_U55_256, "ethos-u55-128": vapi.NpuAccelerator.Ethos_U55_128, diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index f7fee928e90a..8d8d68e4b1e7 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -28,7 +28,7 @@ def check_strides(strides): - """Checks whether strides are within the limits supported by the hardware""" + """This function checks whether strides are within the limits supported by the NPU""" stride_range = (1, 3) smin, smax = stride_range if not smax >= strides[0] >= smin: @@ -39,7 +39,7 @@ def check_strides(strides): def check_valid_dtypes(tensor_params): - """Check whether dtypes are supported by the hardware""" + """This function checks whether dtypes are supported by the NPU""" supported_dtypes = (np.uint8, np.int8) for tep in tensor_params: # Check for dtypes @@ -52,7 +52,7 @@ def check_valid_dtypes(tensor_params): def check_weights(weights, dilation): - """Checks whether weight tensor is compatible with HW""" + """This function checks whether weight tensor is compatible with the NPU""" dilated_height_range = (1, 64) dilated_hxw_range = (1, 64 * 64) weights_limit = 127 * 65536 @@ -79,7 +79,7 @@ def check_weights(weights, dilation): def check_bias(bias): - """Check whether the bias values fit in 40 bits""" + """This function checks whether the bias values fit in 40 bits""" if bias and bias.dtype == np.dtype("int64"): valid = all(len(bin(bias_value)[2:]) <= 40 for bias_value in bias.values) return valid @@ -87,14 +87,14 @@ def check_bias(bias): def check_batch_size(ifm): - """Checks for the number of batches vela currently supports""" + """This function checks for the number of batches vela currently supports""" if ifm.shape[0] != 1: return False return True def check_dilation(dilation): - """Checks whether dilation is within the limits supported by the hardware""" + """This function checks whether dilation is within the limits supported by the NPU""" dilation_range = (1, 2) dmin, dmax = dilation_range if not dmin <= dilation[0] <= dmax: @@ -105,7 +105,7 @@ def check_dilation(dilation): def check_padding(padding, bounds): - """Checks whether padding is within the limits supported by the hardware""" + """This function checks whether padding is within the limits supported by the NPU""" if len(padding) != 4 or len(bounds) != 4: return False top, left, bottom, right = padding @@ -148,7 +148,7 @@ class QnnConv2DParams: """ composite_name = "ethosu.qnn_conv2d" - # The hardware only supports padding upto the numbers as follows + # The NPU only supports padding upto the numbers as follows padding_bounds = [31, 31, 32, 32] activation_map = {"clip": "CLIP"} @@ -165,29 +165,29 @@ def __init__(self, func_body): kernel_layout = qnn_conv2d.attrs.kernel_layout # We consider the weights & biases as params as it should be a Constant self.weights = TensorParams( - qnn_conv2d.args[QConv2DArgs.weights.value], + qnn_conv2d.args[QConv2DArgs.WEIGHTS.value], kernel_layout, - qnn_conv2d.args[QConv2DArgs.weights_scale.value], - qnn_conv2d.args[QConv2DArgs.weights_zero_point.value], + qnn_conv2d.args[QConv2DArgs.WEIGHTS_SCALE.value], + qnn_conv2d.args[QConv2DArgs.WEIGHTS_ZERO_POINT.value], ) self.biases = TensorParams( - bias_add.args[BiasAddArgs.biases.value], + bias_add.args[BiasAddArgs.BIASES.value], data_layout, - requantize_op.args[RequantArgs.ifm_scale.value], - requantize_op.args[RequantArgs.ifm_zero_point.value], + requantize_op.args[RequantArgs.IFM_SCALE.value], + requantize_op.args[RequantArgs.IFM_ZERO_POINT.value], ) self.ifm = TensorParams( - qnn_conv2d.args[QConv2DArgs.ifm.value], + qnn_conv2d.args[QConv2DArgs.IFM.value], data_layout, - qnn_conv2d.args[QConv2DArgs.ifm_scale.value], - qnn_conv2d.args[QConv2DArgs.ifm_zero_point.value], + qnn_conv2d.args[QConv2DArgs.IFM_SCALE.value], + qnn_conv2d.args[QConv2DArgs.IFM_ZERO_POINT.value], ) self.ofm = TensorParams( func_body, data_layout, - requantize_op.args[RequantArgs.ofm_scale.value], - requantize_op.args[RequantArgs.ofm_zero_point.value], + requantize_op.args[RequantArgs.OFM_SCALE.value], + requantize_op.args[RequantArgs.OFM_ZERO_POINT.value], ) self.padding = qnn_conv2d.attrs.padding self.strides = qnn_conv2d.attrs.strides @@ -203,7 +203,7 @@ def __init__(self, func_body): def is_valid(self): """ - Checks whether QnnConv2D with Clip has compatible attributes with HW + This function checks whether QnnConv2D has compatible attributes with the NPU """ tensor_params = [self.weights, self.ifm, self.ofm] if not check_valid_dtypes(tensor_params): @@ -231,7 +231,7 @@ def is_valid(self): def qnn_conv2d_pattern(): """ - Create pattern for qnn.conv2D with optional fused relu + This function creates the pattern for qnn.conv2D with optional fused RELU activation. """ qnn_conv2d = is_op("qnn.conv2d")( wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant() diff --git a/src/relay/backend/contrib/ethosu/preprocess.cc b/src/relay/backend/contrib/ethosu/preprocess.cc index 9a18868a7f86..a1b9cb9d0b38 100644 --- a/src/relay/backend/contrib/ethosu/preprocess.cc +++ b/src/relay/backend/contrib/ethosu/preprocess.cc @@ -47,7 +47,7 @@ namespace ethosu { */ class ExternalFuncIOHandler : public ExprRewriter { public: - explicit ExternalFuncIOHandler(IRModule& module) : module_(module) {} + explicit ExternalFuncIOHandler(const IRModule& module) : module_(module) {} int count = 0; Function InferType(const Function& expr, const IRModule& m) { @@ -125,7 +125,8 @@ class ExternalFuncIOHandler : public ExprRewriter { * to make a single output. Finaly, the external function should only have a single input * and a single output. */ - Function ModifyExternalFunction(const Function& func, GlobalVar gv, const DataType& dtype) { + Function ModifyExternalFunction(const Function& func, const GlobalVar& gv, + const DataType& dtype) { Array inputs; Var ifms; if (func->params.size() > 1) { @@ -230,7 +231,7 @@ class ExternalFuncIOHandler : public ExprRewriter { IRModule module_; }; -IRModule PreprocessExternalFuncIO_(IRModule module) { +IRModule PreprocessExternalFuncIO_(const IRModule& module) { ExternalFuncIOHandler ex_func_io_handle(module); auto func = GetRef(module->Lookup("main").as()); auto preprocessed = PostOrderRewrite(func, &ex_func_io_handle); diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 358fa0cc7055..52fa64153583 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -81,17 +81,13 @@ def @tvmgen_default_ethosu_main_0(%x: Tensor[(1, 50, 50, 3), float32]) -> (Tenso mod_axis1 = tvm.IRModule() mod_axis1["tvmgen_default_ethosu_main_0"] = create_graph(1) - mod_axis1["tvmgen_default_ethosu_main_0"] = rewrite( - legalize.SplitRewriter(), mod_axis1["tvmgen_default_ethosu_main_0"] - ) + mod_axis1 = legalize.SplitRewriterPass(mod_axis1) expected_axis1 = expected_mod_axis1() tvm.ir.assert_structural_equal(mod_axis1, expected_axis1) mod_axis2 = tvm.IRModule() mod_axis2["tvmgen_default_ethosu_main_0"] = create_graph(2) - mod_axis2["tvmgen_default_ethosu_main_0"] = rewrite( - legalize.SplitRewriter(), mod_axis2["tvmgen_default_ethosu_main_0"] - ) + mod_axis2 = legalize.SplitRewriterPass(mod_axis2) expected_axis2 = expected_mod_axis2() tvm.ir.assert_structural_equal(mod_axis2, expected_axis2) @@ -181,17 +177,13 @@ def @tvmgen_default_ethosu_main_0(%x: Tensor[(1, 50, 50, 3), float32]) -> (Tenso mod_axis1 = tvm.IRModule() mod_axis1["tvmgen_default_ethosu_main_0"] = create_graph(1, 5) - mod_axis1["tvmgen_default_ethosu_main_0"] = rewrite( - legalize.SplitRewriter(), mod_axis1["tvmgen_default_ethosu_main_0"] - ) + mod_axis1 = legalize.SplitRewriterPass(mod_axis1) expected_axis1 = expected_mod_axis1() tvm.ir.assert_structural_equal(mod_axis1, expected_axis1) mod_axis2 = tvm.IRModule() mod_axis2["tvmgen_default_ethosu_main_0"] = create_graph(2, 5) - mod_axis2["tvmgen_default_ethosu_main_0"] = rewrite( - legalize.SplitRewriter(), mod_axis2["tvmgen_default_ethosu_main_0"] - ) + mod_axis2 = legalize.SplitRewriterPass(mod_axis2) expected_axis2 = expected_mod_axis2() tvm.ir.assert_structural_equal(mod_axis2, expected_axis2) @@ -302,9 +294,7 @@ def verify_linear(ext_func, conv2d_params): for test_case in test_cases: mod, conv_params = test_case[0](*test_case[1]) mod = ethosu.partition_for_ethosu(mod) - mod["tvmgen_default_ethosu_main_0"] = rewrite( - legalize.EthosUConv2DRewriter(), mod["tvmgen_default_ethosu_main_0"] - ) + mod = legalize.EthosUConv2DRewriterPass(mod) verify_linear(mod["tvmgen_default_ethosu_main_0"], conv_params) @@ -338,8 +328,6 @@ def create_graph_single_unsupported_ifm_layout( mod, conv_params = test_case[0](*test_case[1]) mod = ethosu.partition_for_ethosu(mod) try: - mod["tvmgen_default_ethosu_main_0"] = rewrite( - legalize.EthosUConv2DRewriter(), mod["tvmgen_default_ethosu_main_0"] - ) + mod = legalize.EthosUConv2DRewriterPass(mod) except Exception as e: assert "EthosUCodegenError: Unsupported Layout NCHW" in e.args[0] From 9517161426629da50caa938273f2f1e27ba939a0 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Wed, 25 Aug 2021 10:23:29 +0100 Subject: [PATCH 04/14] Arm(R) Ethos(TM)-U NPU Relay passes and Conv2D op * addressing Elen's comments Change-Id: Iad6315bb63f12ba318deb9c5c9eff7459ff58c48 --- python/tvm/relay/op/contrib/ethosu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 8d8d68e4b1e7..d055aec3d72e 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -143,7 +143,7 @@ def __init__(self, tensor, layout=None, scale=None, zero_point=None): class QnnConv2DParams: """ - This class will parse a Call to a ethosu.qnn_conv2d_clip composite function + This class will parse a Call to a ethosu.qnn_conv2d composite function and extract quantization information of all the associated tensors. """ From 6d525019c83384b2d18ecdd2ad3ffce4f832f0ec Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Wed, 25 Aug 2021 11:57:25 +0100 Subject: [PATCH 05/14] Arm(R) Ethos(TM)-U NPU Relay passes and Conv2D op * cleanup passes Change-Id: I8e1cbedd2c4d3d0cdff481d775d9eb0577e44456 --- .../relay/backend/contrib/ethosu/legalize.py | 40 ++++++++++++------- .../contrib/test_ethosu/test_legalize.py | 12 +++--- 2 files changed, 32 insertions(+), 20 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 6aa36dde52a2..58a7ba252c04 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -109,11 +109,17 @@ def callback(self, pre, post, node_map): @ir.transform.module_pass(opt_level=1) -def SplitRewriterPass(mod, ctx): - for gv, func in mod.functions.items(): - func = rewrite(SplitRewriter(), func) - mod.update_func(gv, func) - return mod +class LegalizeSplit: + """This is the pass that wraps SplitRewriter""" + + def transform_module(self, mod, ctx): + for gv, func in mod.functions.items(): + func = rewrite(SplitRewriter(), func) + mod.update_func(gv, func) + return mod + + def __call__(self, *args, **kwargs): + pass class EthosUConv2DRewriter(DFPatternCallback): @@ -186,21 +192,27 @@ def callback(self, pre, post, node_map): @ir.transform.module_pass(opt_level=1) -def EthosUConv2DRewriterPass(mod, ctx): - for gv, func in mod.functions.items(): - func = rewrite(EthosUConv2DRewriter(), func) - mod.update_func(gv, func) - return mod +class LegalizeEthosUConv2D: + """This is the pass that wraps the EthosUConv2DRewriter""" + def transform_module(self, mod, ctx): + for gv, func in mod.functions.items(): + func = rewrite(EthosUConv2DRewriter(), func) + mod.update_func(gv, func) + return mod -@relay.transform.function_pass(opt_level=1) + def __call__(self, *args, **kwargs): + pass + + +@ir.transform.module_pass(opt_level=1) class LegalizeEthosU: """This is the pass to call graph-rewrites to perform graph transformation in a way such that the operations are replaced with hardware/codegen supported operations. """ - def transform_function(self, func, mod): - mod = SplitRewriterPass(mod) - mod = EthosUConv2DRewriterPass(mod) + def transform_module(self, mod, ctx): + mod = LegalizeSplit()(mod) + mod = LegalizeEthosUConv2D()(mod) return mod diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 52fa64153583..08143f84d4bd 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -81,13 +81,13 @@ def @tvmgen_default_ethosu_main_0(%x: Tensor[(1, 50, 50, 3), float32]) -> (Tenso mod_axis1 = tvm.IRModule() mod_axis1["tvmgen_default_ethosu_main_0"] = create_graph(1) - mod_axis1 = legalize.SplitRewriterPass(mod_axis1) + mod_axis1 = legalize.LegalizeSplit()(mod_axis1) expected_axis1 = expected_mod_axis1() tvm.ir.assert_structural_equal(mod_axis1, expected_axis1) mod_axis2 = tvm.IRModule() mod_axis2["tvmgen_default_ethosu_main_0"] = create_graph(2) - mod_axis2 = legalize.SplitRewriterPass(mod_axis2) + mod_axis2 = legalize.LegalizeSplit()(mod_axis2) expected_axis2 = expected_mod_axis2() tvm.ir.assert_structural_equal(mod_axis2, expected_axis2) @@ -177,13 +177,13 @@ def @tvmgen_default_ethosu_main_0(%x: Tensor[(1, 50, 50, 3), float32]) -> (Tenso mod_axis1 = tvm.IRModule() mod_axis1["tvmgen_default_ethosu_main_0"] = create_graph(1, 5) - mod_axis1 = legalize.SplitRewriterPass(mod_axis1) + mod_axis1 = legalize.LegalizeSplit()(mod_axis1) expected_axis1 = expected_mod_axis1() tvm.ir.assert_structural_equal(mod_axis1, expected_axis1) mod_axis2 = tvm.IRModule() mod_axis2["tvmgen_default_ethosu_main_0"] = create_graph(2, 5) - mod_axis2 = legalize.SplitRewriterPass(mod_axis2) + mod_axis2 = legalize.LegalizeSplit()(mod_axis2) expected_axis2 = expected_mod_axis2() tvm.ir.assert_structural_equal(mod_axis2, expected_axis2) @@ -294,7 +294,7 @@ def verify_linear(ext_func, conv2d_params): for test_case in test_cases: mod, conv_params = test_case[0](*test_case[1]) mod = ethosu.partition_for_ethosu(mod) - mod = legalize.EthosUConv2DRewriterPass(mod) + mod = legalize.LegalizeEthosUConv2D()(mod) verify_linear(mod["tvmgen_default_ethosu_main_0"], conv_params) @@ -328,6 +328,6 @@ def create_graph_single_unsupported_ifm_layout( mod, conv_params = test_case[0](*test_case[1]) mod = ethosu.partition_for_ethosu(mod) try: - mod = legalize.EthosUConv2DRewriterPass(mod) + mod = legalize.LegalizeEthosUConv2D()(mod) except Exception as e: assert "EthosUCodegenError: Unsupported Layout NCHW" in e.args[0] From 148c43aeada474c8b89963a7231ed173d990a7d6 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Wed, 1 Sep 2021 13:17:29 +0100 Subject: [PATCH 06/14] Update TE comments Change-Id: I7e65c2714d017c8a4b64986b111a6b51d128c963 --- python/tvm/relay/backend/contrib/ethosu/te/dma.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/dma.py b/python/tvm/relay/backend/contrib/ethosu/te/dma.py index 3f8c8d1e7eef..9b3027275460 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/dma.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/dma.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name,unnecessary-lambda -"""Tensor Expressions for operations supported by the DMA engine""" +"""Tensor Expressions for operations supported by the NPU DMA engine""" import tvm from tvm import te from tvm.topi.utils import equal_const_int @@ -62,7 +62,7 @@ def _pad(*indices): def read_compute(tensor, layout, zero_point, scale): - """A TE compute operator to represent a read. + """A tensor expression which represents a read. Parameters ---------- @@ -92,7 +92,7 @@ def read_compute(tensor, layout, zero_point, scale): def write_compute(tensor, layout, zero_point, scale): - """A TE compute operator to represent a write. + """A tensor expression which represents a write. Parameters ---------- From 82ee99d8e9f294ee24b46cb3e8ecc65bd39eb6a8 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Wed, 1 Sep 2021 14:12:21 +0100 Subject: [PATCH 07/14] Address ekalda's comments in TE Change-Id: I55cfbb3787c0aacdadf46c4859dff39287e65ddc --- .../backend/contrib/ethosu/te/convolution.py | 50 ------------------- .../relay/backend/contrib/ethosu/te/dma.py | 11 ++-- 2 files changed, 4 insertions(+), 57 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py index a11974025f2e..030e75d23193 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py @@ -20,56 +20,6 @@ from .dma import dma_ofm_compute, dma_ifm_compute -def process_stride(stride): - """Process the striding into a common format. - - Parameters - ---------- - stride : Union[int, tuple, list] - The 2D striding. - int -> striding is the same in the height and width axis. - 2D -> striding specified as (stride height, stride width). - - Returns - ------- - int - The stride in the height axis. - int - The stride in the width axis. - - """ - assert isinstance(stride, int) or len(stride) == 2 - if isinstance(stride, int): - return stride, stride - - return stride - - -def process_dilation(dilation): - """Process the dilation into a common format. - - Parameters - ---------- - dilation : Union[int, tuple, list] - The 2D dilation. - int -> dilation is the same in the height and width axis. - 2D -> dilation specified as (dilation height, dilation width). - - Returns - ------- - int - The dilation in the height axis. - int - The dilation in the width axis. - - """ - assert isinstance(dilation, int) or len(dilation) == 2 - if isinstance(dilation, int): - return dilation, dilation - - return dilation - - def conv2d_compute( ifm, weight, diff --git a/python/tvm/relay/backend/contrib/ethosu/te/dma.py b/python/tvm/relay/backend/contrib/ethosu/te/dma.py index 9b3027275460..25b9d4b43a7a 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/dma.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/dma.py @@ -44,8 +44,8 @@ def _pad_tensor(tensor, pad_before, pad_after=None): assert len(pad_after) == dims def _pad(*indices): - not_zero = [] - index_tuple = [] + not_zero = [] # A list of padding conditions that aren't trivial (zero padding) + index_tuple = [] # The indices with which to access the padded tensor for i in range(dims): if equal_const_int(pad_before[i], 0) and equal_const_int(pad_after[i], 0): index_tuple.append(indices[i]) @@ -256,11 +256,8 @@ def dma_ifm_compute(ifm, layout, zero_point, scale, channels, padding): The scale of the data. channels : int The number of valid channels for the data. - padding : Union[int, tuple, list] - The desired padding. - int -> padding applied to both height and width axes. - 2D -> padding applied equally on both sides of the (height, width) axes. - 4D -> padding applied as (top, left, bottom, right) + padding : tuple + The 4 dimensional padding as (pad_top, pad_left, pad_bottom, pad_right). Returns ------- From 982152ceee48581c5159b110847ad4a9925acd4a Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Wed, 1 Sep 2021 20:04:52 +0100 Subject: [PATCH 08/14] Arm(R) Ethos(TM)-U NPU Relay passes and Conv2D op *addressing chris's comments *addressing Nicola's comments Change-Id: Id02788ddcdbc3679e0da37b2fa614cded0a4c1f5 --- python/tvm/relay/backend/contrib/ethosu/errors.py | 3 --- python/tvm/relay/backend/contrib/ethosu/legalize.py | 12 +++--------- python/tvm/relay/backend/contrib/ethosu/util.py | 4 ++-- python/tvm/relay/backend/contrib/ethosu/vela_api.py | 1 + python/tvm/relay/op/contrib/ethosu.py | 6 +++++- src/relay/op/contrib/ethosu/convolution.cc | 2 +- tests/python/contrib/test_ethosu/test_legalize.py | 10 +++++++--- tests/python/contrib/test_ethosu/test_preprocess.py | 5 +---- 8 files changed, 20 insertions(+), 23 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/errors.py b/python/tvm/relay/backend/contrib/ethosu/errors.py index 435c9c8337ef..65f3711838be 100644 --- a/python/tvm/relay/backend/contrib/ethosu/errors.py +++ b/python/tvm/relay/backend/contrib/ethosu/errors.py @@ -33,6 +33,3 @@ class UnsupportedLayout(EthosUCodegenError): def __init__(self, layout): super().__init__(f"Unsupported Layout {layout}") - - def __str__(self): - return self.message diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 58a7ba252c04..cabb23031c7f 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-argument, import-outside-toplevel, no-value-for-parameter -""" A set of passes to legalize some of operations for the NPU""" +"""A set of passes to legalize some of operations for the NPU""" import numpy as np import tvm @@ -58,7 +58,7 @@ def get_section_begin_coords(split): Returns ------- - section_begins : list + section_begins : list[int] A list containing integers corresponding to section begins """ @@ -71,16 +71,10 @@ def get_section_begin_coords(split): return [0] + list(indices_or_sections) split_axis_len = input_shape[split_axis].value section_length = split_axis_len // indices_or_sections.value - section_begins = list(range(0, split_axis_len, section_length)) - return section_begins + return list(range(0, split_axis_len, section_length)) def callback(self, pre, post, node_map): - splits_types = dict() split_input = post.args[0] - for idx, field_type in enumerate(post.checked_type.fields): - split = relay.TupleGetItem(post, idx) - splits_types[split] = field_type - split_begins = list() split_ends = list() section_begins_in_split_axis = self.get_section_begin_coords(post) diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index 85ca86cdfc24..50f3d133ff34 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -128,13 +128,13 @@ def get_range_for_dtype_str(dtype): def round_away_zero(f): - """round the number away from zero towards +inf / -inf""" + """Round the number away from zero towards +inf / -inf""" offset = -0.5 if (f < 0) else 0.5 return np.trunc(f + offset) def round_up(a, b): - """round up to a multiple of b""" + """Round up to a multiple of b""" return ((a + b - 1) // b) * b diff --git a/python/tvm/relay/backend/contrib/ethosu/vela_api.py b/python/tvm/relay/backend/contrib/ethosu/vela_api.py index 129e6d81ae56..3b5b3808aad0 100644 --- a/python/tvm/relay/backend/contrib/ethosu/vela_api.py +++ b/python/tvm/relay/backend/contrib/ethosu/vela_api.py @@ -56,6 +56,7 @@ def get_optimal_block_config(npu_op, accel_type): The NPU operation and its params accel_type : ethosu.vela.api.NpuAccelerator The NPU accelerator variant + Returns ------- ethosu.vela.api.NpuShape3d : diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index d055aec3d72e..1559cddd3683 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -247,5 +247,9 @@ def qnn_conv2d_pattern(): @register_pattern_table("ethosu") def pattern_table(): return [ - ("ethosu.qnn_conv2d", qnn_conv2d_pattern(), lambda pat: QnnConv2DParams(pat).is_valid()) + ( + QnnConv2DParams.composite_name, + qnn_conv2d_pattern(), + lambda pat: QnnConv2DParams(pat).is_valid(), + ) ] diff --git a/src/relay/op/contrib/ethosu/convolution.cc b/src/relay/op/contrib/ethosu/convolution.cc index ec5da6cd1c47..660b3d7ddad1 100644 --- a/src/relay/op/contrib/ethosu/convolution.cc +++ b/src/relay/op/contrib/ethosu/convolution.cc @@ -180,7 +180,7 @@ RELAY_REGISTER_OP("contrib.ethosu.conv2d") This Relay operator corresponds to the hardware-implemented quantized convolution operation found on Ethos(TM)-U NPUs. It accepts either NHWC -or NHCWB16 format for the input data (input feature map, or IFM) and +or NHCWB16 format for the input data (Input Feature Map, or IFM) and OHWI format for the kernel weights. Reference: https://developer.arm.com/documentation/102420/0200/ diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 08143f84d4bd..39d84ecc91b2 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -327,7 +327,11 @@ def create_graph_single_unsupported_ifm_layout( for test_case in test_cases: mod, conv_params = test_case[0](*test_case[1]) mod = ethosu.partition_for_ethosu(mod) - try: + with pytest.raises( + tvm._ffi.base.TVMError, match="EthosUCodegenError: Unsupported Layout NCHW" + ): mod = legalize.LegalizeEthosUConv2D()(mod) - except Exception as e: - assert "EthosUCodegenError: Unsupported Layout NCHW" in e.args[0] + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_preprocess.py b/tests/python/contrib/test_ethosu/test_preprocess.py index 2aeeed078942..f2c7b0afafd8 100644 --- a/tests/python/contrib/test_ethosu/test_preprocess.py +++ b/tests/python/contrib/test_ethosu/test_preprocess.py @@ -340,7 +340,4 @@ def create_external_func1(mod_, compiler_name, symbol_name): if __name__ == "__main__": - test_2ins_single_out() - test_single_io() - test_4ins_2outs() - test_single_in_2outs() + pytest.main([__file__]) From 7af999ccbba2351dd574d629f52383c717bf4c48 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Thu, 2 Sep 2021 18:15:41 +0100 Subject: [PATCH 09/14] Arm(R) Ethos(TM)-U NPU Relay passes and Conv2D op *addressing missed 'hidden' comments of Chris *addressing one missed comment of Nicola *adding type annotations Change-Id: Iadf4907b311e195731dbbed571e95a266341db8f --- .../relay/backend/contrib/ethosu/legalize.py | 35 +++++++---- .../backend/contrib/ethosu/op/convolution.py | 44 ++++++------- .../backend/contrib/ethosu/preprocess.py | 3 +- .../backend/contrib/ethosu/te/convolution.py | 40 ++++++------ .../relay/backend/contrib/ethosu/te/dma.py | 27 +++++--- .../tvm/relay/backend/contrib/ethosu/util.py | 24 ++++---- .../relay/backend/contrib/ethosu/vela_api.py | 61 +++++++++++-------- python/tvm/relay/op/contrib/ethosu.py | 16 ++--- .../backend/contrib/ethosu/preprocess.cc | 8 +-- src/relay/op/contrib/ethosu/common.cc | 14 ++--- src/relay/op/contrib/ethosu/convolution.cc | 6 +- .../contrib/test_ethosu/test_legalize.py | 4 +- 12 files changed, 156 insertions(+), 126 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index cabb23031c7f..2eb67417bc02 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=invalid-name, unused-argument, import-outside-toplevel, no-value-for-parameter """A set of passes to legalize some of operations for the NPU""" +from typing import List import numpy as np import tvm @@ -44,7 +45,7 @@ def __init__(self): self.pattern = is_op("split")(self.split_in) @staticmethod - def get_section_begin_coords(split): + def get_section_begin_coords(split: tvm.relay.Expr) -> List[int]: """Currently, the split operator takes an array of indices or an integer indicating the number of splits. However, its an array of indices could represent both cases, therefore this function just make it an array of @@ -53,12 +54,12 @@ def get_section_begin_coords(split): Parameters ---------- - split : relay.Expr + split : tvm.relay.Expr The Relay Call expression for a split operator Returns ------- - section_begins : list[int] + section_begins : List[int] A list containing integers corresponding to section begins """ @@ -73,7 +74,9 @@ def get_section_begin_coords(split): section_length = split_axis_len // indices_or_sections.value return list(range(0, split_axis_len, section_length)) - def callback(self, pre, post, node_map): + def callback( + self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map + ) -> tvm.relay.Expr: split_input = post.args[0] split_begins = list() split_ends = list() @@ -106,10 +109,12 @@ def callback(self, pre, post, node_map): class LegalizeSplit: """This is the pass that wraps SplitRewriter""" - def transform_module(self, mod, ctx): - for gv, func in mod.functions.items(): + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): func = rewrite(SplitRewriter(), func) - mod.update_func(gv, func) + mod.update_func(global_var, func) return mod def __call__(self, *args, **kwargs): @@ -123,7 +128,9 @@ def __init__(self): super().__init__(require_type=True) self.pattern = (wildcard().has_attr({"Composite": "ethosu.qnn_conv2d"}))(wildcard()) - def callback(self, pre, post, node_map): + def callback( + self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map + ) -> tvm.relay.Expr: params = ethosu_patterns.QnnConv2DParams(post.op.body) params.ifm.tensor = post.args[0] channels_map = { @@ -189,10 +196,12 @@ def callback(self, pre, post, node_map): class LegalizeEthosUConv2D: """This is the pass that wraps the EthosUConv2DRewriter""" - def transform_module(self, mod, ctx): - for gv, func in mod.functions.items(): + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): func = rewrite(EthosUConv2DRewriter(), func) - mod.update_func(gv, func) + mod.update_func(global_var, func) return mod def __call__(self, *args, **kwargs): @@ -206,7 +215,9 @@ class LegalizeEthosU: operations. """ - def transform_module(self, mod, ctx): + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: mod = LegalizeSplit()(mod) mod = LegalizeEthosUConv2D()(mod) return mod diff --git a/python/tvm/relay/backend/contrib/ethosu/op/convolution.py b/python/tvm/relay/backend/contrib/ethosu/op/convolution.py index 790f0645af3f..5365d6bd28db 100644 --- a/python/tvm/relay/backend/contrib/ethosu/op/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/op/convolution.py @@ -16,6 +16,8 @@ # under the License. # pylint: disable=unused-argument """Relay operators for convolutions for Arm(R) Ethos(TM)-U NPU""" +from typing import Tuple + import tvm from tvm.relay.op import _make from tvm.topi.generic import schedule_injective @@ -89,27 +91,27 @@ def conv2d_strategy_ethosu(attrs, inputs, out_type, target): def ethosu_conv2d( - ifm, - weight, - scale_bias, - lut, - ifm_scale, - ifm_zero_point, - weight_zero_point, - ofm_scale, - ofm_zero_point, - kernel_shape, - ofm_channels, - strides=(1, 1), - padding=(0, 0, 0, 0), - dilation=(1, 1), - activation="NONE", - clip_min=0, - clip_max=0, - upscale="NONE", - ifm_layout="NHWC", - ofm_layout="NHWC", -): + ifm: tvm.relay.Expr, + weight: tvm.relay.Expr, + scale_bias: tvm.relay.Expr, + lut: tvm.relay.Expr, + ifm_scale: float, + ifm_zero_point: int, + weight_zero_point: int, + ofm_scale: float, + ofm_zero_point: int, + kernel_shape: Tuple[int, int], + ofm_channels: int, + strides: Tuple[int, int] = (1, 1), + padding: Tuple[int, int, int, int] = (0, 0, 0, 0), + dilation: Tuple[int, int] = (1, 1), + activation: str = "NONE", + clip_min: int = 0, + clip_max: int = 0, + upscale: str = "NONE", + ifm_layout: str = "NHWC", + ofm_layout: str = "NHWC", +) -> tvm.relay.Call: """This is a quantized 2D convolution operation as supported by the the NPU. It accepts either NHWC or NHCWB16 format for the input data and OHWI format for the kernel weights. diff --git a/python/tvm/relay/backend/contrib/ethosu/preprocess.py b/python/tvm/relay/backend/contrib/ethosu/preprocess.py index f2bd079c99c3..b7540f5c2fe7 100644 --- a/python/tvm/relay/backend/contrib/ethosu/preprocess.py +++ b/python/tvm/relay/backend/contrib/ethosu/preprocess.py @@ -19,10 +19,11 @@ NPU code generation. These set of passes will mutate both the main and the external functions. """ +import tvm from . import _ffi_api -def preprocess_ext_io(): +def preprocess_ext_io() -> tvm.transform.Pass: """This pass mutates the number of inputs going to / outputs coming out to/from external functions to one. This is achieved via concatenation of inputs and splitting of outputs in around the call to the external function. diff --git a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py index 030e75d23193..d12e3908f5e1 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py @@ -16,30 +16,32 @@ # under the License. # pylint: disable=invalid-name,unused-argument """Tensor Expressions for convolutions for the NPU""" +from typing import Tuple, Union, List + from tvm import te from .dma import dma_ofm_compute, dma_ifm_compute def conv2d_compute( - ifm, - weight, - scale_bias, - lut, - ifm_scale, - ifm_zero_point, - weight_zero_point, - ofm_scale, - ofm_zero_point, - strides, - padding, - dilation, - activation, - clip_min, - clip_max, - upscale, - ifm_layout, - ofm_layout, -): + ifm: te.Tensor, + weight: te.Tensor, + scale_bias: te.Tensor, + lut: te.Tensor, + ifm_scale: float, + ifm_zero_point: int, + weight_zero_point: int, + ofm_scale: float, + ofm_zero_point: int, + strides: Tuple[int, int], + padding: Tuple[int, int, int, int], + dilation: Union[int, Tuple[int, int], List[int]], + activation: str, + clip_min: int, + clip_max: int, + upscale: str, + ifm_layout: str, + ofm_layout: str, +) -> te.Tensor: """A compute operator representing the capabilities of a 2D convolution for the NPU. Parameters diff --git a/python/tvm/relay/backend/contrib/ethosu/te/dma.py b/python/tvm/relay/backend/contrib/ethosu/te/dma.py index 25b9d4b43a7a..b774fce576d1 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/dma.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/dma.py @@ -16,12 +16,14 @@ # under the License. # pylint: disable=invalid-name,unnecessary-lambda """Tensor Expressions for operations supported by the NPU DMA engine""" +from typing import Callable, Tuple + import tvm from tvm import te from tvm.topi.utils import equal_const_int -def _pad_tensor(tensor, pad_before, pad_after=None): +def _pad_tensor(tensor: te.Tensor, pad_before: tuple, pad_after: tuple = None) -> Callable: """Generate a padded tensor. Parameters @@ -61,7 +63,7 @@ def _pad(*indices): return _pad -def read_compute(tensor, layout, zero_point, scale): +def read_compute(tensor: te.Tensor, layout: str, zero_point: int, scale: float) -> te.Tensor: """A tensor expression which represents a read. Parameters @@ -91,7 +93,7 @@ def read_compute(tensor, layout, zero_point, scale): return te.compute(tensor.shape, lambda *i: tensor(*i), name="ethosu_read", attrs=read_attrs) -def write_compute(tensor, layout, zero_point, scale): +def write_compute(tensor: te.Tensor, layout: str, zero_point: int, scale: float) -> te.Tensor: """A tensor expression which represents a write. Parameters @@ -126,7 +128,7 @@ def write_compute(tensor, layout, zero_point, scale): ) -def convert_to_nhwc_compute(tensor, layout, channels): +def convert_to_nhwc_compute(tensor: te.Tensor, layout: str, channels: int) -> te.Tensor: """Converts a tensor into NHWC layout if it's in NHWCB16 layout. Parameters @@ -165,7 +167,7 @@ def convert_to_nhwc_compute(tensor, layout, channels): ) -def convert_to_nhcwb16_compute(tensor, layout, channels): +def convert_to_nhcwb16_compute(tensor: te.Tensor, layout: str, channels: int) -> te.Tensor: """Converts a tensor into NHCWB16 layout if it's in NHWC layout. Parameters @@ -210,7 +212,7 @@ def convert_to_nhcwb16_compute(tensor, layout, channels): ) -def pad_compute(tensor, padding): +def pad_compute(tensor: te.Tensor, padding: tuple): """Pad an NHWC tensor in the height and width axes. Parameters @@ -241,7 +243,14 @@ def pad_compute(tensor, padding): ) -def dma_ifm_compute(ifm, layout, zero_point, scale, channels, padding): +def dma_ifm_compute( + ifm: te.Tensor, + layout: str, + zero_point: int, + scale: float, + channels: int, + padding: Tuple[int, int, int, int], +) -> te.Tensor: """A sequence of compute operators representing the DMA capabilities for an IFM. Parameters @@ -270,7 +279,9 @@ def dma_ifm_compute(ifm, layout, zero_point, scale, channels, padding): return pad_compute(convert_to_nhwc_ifm, padding) -def dma_ofm_compute(ofm, layout, zero_point, scale, channels): +def dma_ofm_compute( + ofm: te.Tensor, layout: str, zero_point: int, scale: float, channels: int +) -> te.Tensor: """A sequence of compute operators representing the DMA capabilities for an OFM. Parameters diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index 50f3d133ff34..66a3ed2ab504 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -22,8 +22,10 @@ """ from enum import Enum +from typing import Union, Tuple, Dict, Optional import numpy as np +import tvm from tvm import relay from tvm.relay.build_module import bind_params_by_name from tvm.relay.backend.contrib.ethosu import preprocess @@ -74,7 +76,7 @@ class ClipArgs(Enum): A_MAX = 2 -def is_composite_func(func, name): +def is_composite_func(func: relay.Function, name: str) -> bool: """ This method checks whether the call is to a composite function of a given name. @@ -98,12 +100,10 @@ def is_composite_func(func, name): return False composite_name = func.attrs["Composite"] - if composite_name != name: - return False - return True + return composite_name == name -def get_range_for_dtype_str(dtype): +def get_range_for_dtype_str(dtype: str) -> Tuple[int, int]: """ Produce the min,max for a give data type. @@ -127,28 +127,30 @@ def get_range_for_dtype_str(dtype): return type_info.min, type_info.max -def round_away_zero(f): +def round_away_zero(f: Union[float, np.double, np.single, np.float32, np.float64]) -> np.float64: """Round the number away from zero towards +inf / -inf""" offset = -0.5 if (f < 0) else 0.5 return np.trunc(f + offset) -def round_up(a, b): +def round_up(a: int, b: int) -> int: """Round up to a multiple of b""" return ((a + b - 1) // b) * b # pylint: disable=unused-argument -def partition_for_ethosu(mod, params=None, **opts): +def partition_for_ethosu( + mod: tvm.ir.IRModule, params: Optional[Dict[str, tvm.runtime.NDArray]] = None, **opts +): """This helper function partition the relay graph as produced by the relay frontend for a given model into external functions to be presented to the codegen. Parameters ---------- - mod : IRModule + mod : tvm.ir.IRModule The IRModule that gets generated from a relay frontend - params : Optional[Dict[str, NDArray]] + params : Optional[Dict[str, tvm.runtime.NDArray]] Constant input parameters. Returns @@ -171,7 +173,7 @@ def partition_for_ethosu(mod, params=None, **opts): return mod -def get_dim_value(layout, dim): +def get_dim_value(layout: str, dim: int): """This is a helper function to retrieve the value of the dimension given the shape and the layout """ diff --git a/python/tvm/relay/backend/contrib/ethosu/vela_api.py b/python/tvm/relay/backend/contrib/ethosu/vela_api.py index 3b5b3808aad0..f201abf27994 100644 --- a/python/tvm/relay/backend/contrib/ethosu/vela_api.py +++ b/python/tvm/relay/backend/contrib/ethosu/vela_api.py @@ -23,6 +23,7 @@ """ import logging import math +from typing import Tuple, Optional, List import numpy as np from ethosu.vela import api as vapi @@ -42,7 +43,9 @@ SCALE_BIAS_LENGTH = 10 -def get_optimal_block_config(npu_op, accel_type): +def get_optimal_block_config( + npu_op: vapi.NpuOperation, accel_type: vapi.NpuAccelerator +) -> vapi.NpuShape3D: """ "The NPU's unit of work is known as a block. It will fetch block(s) from Input Feature Map (IFM) and a compute block for Output Feature Map (OFM). @@ -59,21 +62,21 @@ def get_optimal_block_config(npu_op, accel_type): Returns ------- - ethosu.vela.api.NpuShape3d : + ethosu.vela.api.NpuShape3D : The optimal block config for the operator """ all_valid_block_configs = vapi.npu_find_block_configs(npu_op, accel_type) return _get_optimal_block_config(all_valid_block_configs) -def _get_optimal_block_config(all_valid_block_configs): +def _get_optimal_block_config(all_valid_block_configs: List[vapi.NpuShape3D]) -> vapi.NpuShape3D: """An internal function to get block config with largest depth and then highest volume/area""" assert isinstance(all_valid_block_configs, list) for block_cfg in all_valid_block_configs: assert isinstance(block_cfg, vapi.NpuShape3D) - # Getting the largest volume block for benchmarksing + # Getting the largest volume block for benchmarking all_valid_block_configs.sort( key=lambda _cfg: _cfg.depth * _cfg.height * _cfg.width, reverse=True ) @@ -109,15 +112,15 @@ def _get_optimal_block_config(all_valid_block_configs): def compress_weights( - weights, - weights_zp, - weights_layout, - ifm_bitdepth, - block_depth, - dilation, - accel_type, - is_depthwise=False, -): + weights: np.ndarray, + weights_zp: int, + weights_layout: str, + ifm_bitdepth: int, + block_depth: int, + dilation: Tuple[int, int], + accel_type: vapi.NpuAccelerator, + is_depthwise: Optional[bool] = False, +) -> bytearray: """The NPU requires the weights to be compressed to be executed. Therefore, this function calls into the Vela APIs to compress the weights. @@ -172,7 +175,9 @@ def compress_weights( return compressed_weights -def calculate_block_traversal_mode(is_depthwise, weights_shape_ohwi, ifm_bitdepth): +def calculate_block_traversal_mode( + is_depthwise: bool, weights_shape_ohwi: List[int], ifm_bitdepth: int +) -> vapi.NpuBlockTraversal: """Calculate a block traversal mode given whether the op is depthwise convolution, shape of weights and bit-depth of the ifm. """ @@ -194,13 +199,13 @@ def calculate_block_traversal_mode(is_depthwise, weights_shape_ohwi, ifm_bitdept def pack_biases( - biases, - ifm_scale, - ifm_dtype, - weight_scales, - ofm_scale, - is_activation_tanh_or_sigmoid=False, -): + biases: np.ndarray, + ifm_scale: float, + ifm_dtype: np.dtype, + weight_scales: np.ndarray, + ofm_scale: float, + is_activation_tanh_or_sigmoid: bool = False, +) -> np.ndarray: """ The NPU requires the each bias value to be packed with output scale parameters in a 80-bit format (that is returned @@ -249,7 +254,7 @@ def pack_biases( return scale_bias -def _quantize_scale(scale): +def _quantize_scale(scale: float) -> Tuple[int, int]: """Quantize floating point scale into 32-bit int scale with a 6-bit shift. This is to be used with 8-bit data. """ @@ -261,7 +266,7 @@ def _quantize_scale(scale): return mantissa_scaled, required_shift -def _reduced_quantize_scale(scale): +def _reduced_quantize_scale(scale: float) -> Tuple[int, int]: """A reduction of precision is required for 16 bit data.""" mantissa_scaled, required_shift = _quantize_scale(scale) # This is max a signed 16-bit number could represent @@ -277,8 +282,12 @@ def _reduced_quantize_scale(scale): def _calculate_hw_bias_scales( - ifm_scale, weight_scales, ofm_scale, ifm_dtype, is_faf_tanh_sigmoid=False -): + ifm_scale: float, + weight_scales: List[float], + ofm_scale: float, + ifm_dtype: np.dtype, + is_faf_tanh_sigmoid: bool = False, +) -> List[Tuple[int, int]]: """This function will produce a scale that is calculated using scales of ifm, weights and ofm. It is also important to note that if per-channel / per-value quantization required they should go into hw bias scales""" @@ -301,7 +310,7 @@ def _calculate_hw_bias_scales( return hw_bias_scales -def get_target_accel_type(): +def get_target_accel_type() -> vapi.NpuAccelerator: """This is a helper function to convert TVMC command line argument to NpuAccelerator type""" npu_accel_str_map = { "ethos-u55-256": vapi.NpuAccelerator.Ethos_U55_256, diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 1559cddd3683..9aa64fe52d51 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -73,9 +73,7 @@ def check_weights(weights, dilation): get_dim_value(weights.layout, "I"), ) sum_weights = np.amax(np.sum(np.absolute(weights.values), axis=axis)) - if not sum_weights <= weights_limit: - return False - return True + return sum_weights <= weights_limit def check_bias(bias): @@ -88,9 +86,7 @@ def check_bias(bias): def check_batch_size(ifm): """This function checks for the number of batches vela currently supports""" - if ifm.shape[0] != 1: - return False - return True + return ifm.shape[0] == 1 def check_dilation(dilation): @@ -110,9 +106,7 @@ def check_padding(padding, bounds): return False top, left, bottom, right = padding topb, leftb, bottomb, rightb = bounds - if top > topb or left > leftb or bottom > bottomb or right > rightb: - return False - return True + return not (top > topb or left > leftb or bottom > bottomb or right > rightb) class TensorParams: @@ -224,9 +218,7 @@ def is_valid(self): if self.groups not in legal_groups: return False # This should be a valid QnnDepthwise2DParams, not QnnConv2DParams - if self.is_depthwise: - return False - return True + return not self.is_depthwise def qnn_conv2d_pattern(): diff --git a/src/relay/backend/contrib/ethosu/preprocess.cc b/src/relay/backend/contrib/ethosu/preprocess.cc index a1b9cb9d0b38..ac52844091b4 100644 --- a/src/relay/backend/contrib/ethosu/preprocess.cc +++ b/src/relay/backend/contrib/ethosu/preprocess.cc @@ -64,8 +64,8 @@ class ExternalFuncIOHandler : public ExprRewriter { */ int64_t CalcSize(const Array& shape) { int size = 1; - for (auto dim_sz : shape) { - size = size * Downcast(dim_sz)->value; + for (auto dim_size : shape) { + size = size * Downcast(dim_size)->value; } return size; } @@ -77,8 +77,8 @@ class ExternalFuncIOHandler : public ExprRewriter { Expr CreateFlattenTensor(const Expr& input) { auto ishape = Downcast>(Downcast(input->checked_type())->shape); int flatten_size = CalcSize(ishape); - Array oshape = {Integer(flatten_size)}; - return MakeReshape(input, oshape); + Array output_shape = {Integer(flatten_size)}; + return MakeReshape(input, output_shape); } /*! diff --git a/src/relay/op/contrib/ethosu/common.cc b/src/relay/op/contrib/ethosu/common.cc index a7109ade71a0..bdda81bc7708 100644 --- a/src/relay/op/contrib/ethosu/common.cc +++ b/src/relay/op/contrib/ethosu/common.cc @@ -18,7 +18,7 @@ */ /*! - * \file src/relay/op/contrib/ethosu/op_common.cc + * \file src/relay/op/contrib/ethosu/common.cc * \brief A set of utilities and common functionality for Arm(R) Ethos(TM)-U NPU QNN ops. */ @@ -40,22 +40,22 @@ Array EthosuInferKernelOutput(Array ifm_shape, String ifm_ if (ifm_layout == "NHCWB16") { ifm_shape = {ifm_shape[0], ifm_shape[1], ifm_shape[3]}; } - Array oshape({ifm_shape[0], 0, 0, ofm_channels}); + Array output_shape({ifm_shape[0], 0, 0, ofm_channels}); IndexExpr dilated_ksize_y = 1 + (kernel_shape[0] - 1) * dilation[0]; IndexExpr dilated_ksize_x = 1 + (kernel_shape[1] - 1) * dilation[1]; IndexExpr pad_h, pad_w; GetPaddingHeightWidth(padding, &pad_h, &pad_w); - oshape.Set(1, indexdiv(ifm_shape[1] + pad_h - dilated_ksize_y, strides[0]) + 1); - oshape.Set(2, indexdiv(ifm_shape[2] + pad_w - dilated_ksize_x, strides[1]) + 1); + output_shape.Set(1, indexdiv(ifm_shape[1] + pad_h - dilated_ksize_y, strides[0]) + 1); + output_shape.Set(2, indexdiv(ifm_shape[2] + pad_w - dilated_ksize_x, strides[1]) + 1); // If the ofm is NHCWB16, convert the layout if (ofm_layout == "NHCWB16") { - int channel_bricks = 1 + (oshape[3].as()->value - 1) / 16; - oshape = {oshape[0], oshape[1], channel_bricks, oshape[2], 16}; + int channel_bricks = 1 + (output_shape[3].as()->value - 1) / 16; + output_shape = {output_shape[0], output_shape[1], channel_bricks, output_shape[2], 16}; } - return oshape; + return output_shape; } } // namespace ethosu diff --git a/src/relay/op/contrib/ethosu/convolution.cc b/src/relay/op/contrib/ethosu/convolution.cc index 660b3d7ddad1..4abc451d84d8 100644 --- a/src/relay/op/contrib/ethosu/convolution.cc +++ b/src/relay/op/contrib/ethosu/convolution.cc @@ -54,8 +54,8 @@ struct EthosuConv2DAttrs : public tvm::AttrsNode { int clip_min; int clip_max; String upscale; - tvm::String ifm_layout; - tvm::String ofm_layout; + String ifm_layout; + String ofm_layout; TVM_DECLARE_ATTRS(EthosuConv2DAttrs, "relay.attrs.EthosuConv2DAttrs") { TVM_ATTR_FIELD(ifm_scale).describe("The quantization scale for the Input Feature Map tensor."); @@ -68,7 +68,7 @@ struct EthosuConv2DAttrs : public tvm::AttrsNode { .describe("The quantization zero point for the Output Feature Map tensor."); TVM_ATTR_FIELD(kernel_shape) .describe("The 2 dimensional kernel shape as (kernel_height, kernel_width).") - .set_default(NullValue >()); + .set_default(NullValue>()); TVM_ATTR_FIELD(ofm_channels) .describe("The number of OFM channels.") .set_default(NullValue()); diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 39d84ecc91b2..758554e509bb 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -34,8 +34,8 @@ def test_split_indices_legalize(): def create_graph(axis): x = relay.var("x", shape=(1, 50, 50, 3)) x_relu = relay.nn.relu(x) - split_o = relay.split(x_relu, [5, 20, 45], axis).tuple_value - return relay.Function([x], split_o) + split_output = relay.split(x_relu, [5, 20, 45], axis).tuple_value + return relay.Function([x], split_output) def expected_mod_axis1(): expected_ir_string = """ From 76905af932ac242fae44322e4bf87ac7719eeb89 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Fri, 3 Sep 2021 07:31:32 +0100 Subject: [PATCH 10/14] Arm(R) Ethos(TM)-U NPU Relay passes and Conv2D op *further type fixes and one missed comment Change-Id: I6da69fd95d17dfeaf5940da4f8d8c8ea142b39d2 --- .../relay/backend/contrib/ethosu/_ffi_api.py | 2 +- .../relay/backend/contrib/ethosu/legalize.py | 12 +-- .../backend/contrib/ethosu/op/convolution.py | 8 +- .../backend/contrib/ethosu/preprocess.py | 6 +- .../backend/contrib/ethosu/te/convolution.py | 6 +- .../relay/backend/contrib/ethosu/te/dma.py | 16 ++-- .../tvm/relay/backend/contrib/ethosu/util.py | 8 +- .../relay/backend/contrib/ethosu/vela_api.py | 6 +- python/tvm/relay/op/contrib/ethosu.py | 90 ++++++++++--------- src/relay/op/contrib/ethosu/convolution.cc | 2 +- 10 files changed, 80 insertions(+), 76 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/_ffi_api.py b/python/tvm/relay/backend/contrib/ethosu/_ffi_api.py index a0175ba17a56..ccf1039a6994 100644 --- a/python/tvm/relay/backend/contrib/ethosu/_ffi_api.py +++ b/python/tvm/relay/backend/contrib/ethosu/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for relay transformation passes.""" -import tvm._ffi +import tvm._ffi # type: ignore tvm._ffi._init_api("relay.ext.ethosu", __name__) diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 2eb67417bc02..82b7f1e68cee 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -17,19 +17,19 @@ # pylint: disable=invalid-name, unused-argument, import-outside-toplevel, no-value-for-parameter """A set of passes to legalize some of operations for the NPU""" from typing import List -import numpy as np +import numpy as np # type: ignore -import tvm +import tvm # type: ignore from tvm import relay from tvm import ir -from tvm.relay.dataflow_pattern import DFPatternCallback +from tvm.relay.dataflow_pattern import DFPatternCallback # type: ignore from tvm.relay.dataflow_pattern import wildcard from tvm.relay.dataflow_pattern import is_op from tvm.relay.dataflow_pattern import rewrite -from tvm.relay.backend.contrib.ethosu import op as ethosu_ops -from tvm.relay.backend.contrib.ethosu.errors import UnsupportedLayout +from tvm.relay.backend.contrib.ethosu import op as ethosu_ops # type: ignore +from tvm.relay.backend.contrib.ethosu.errors import UnsupportedLayout # type: ignore from tvm.relay.backend.contrib.ethosu import vela_api -from tvm.relay.op.contrib import ethosu as ethosu_patterns +from tvm.relay.op.contrib import ethosu as ethosu_patterns # type: ignore class SplitRewriter(DFPatternCallback): diff --git a/python/tvm/relay/backend/contrib/ethosu/op/convolution.py b/python/tvm/relay/backend/contrib/ethosu/op/convolution.py index 5365d6bd28db..b159830ceaa9 100644 --- a/python/tvm/relay/backend/contrib/ethosu/op/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/op/convolution.py @@ -18,10 +18,10 @@ """Relay operators for convolutions for Arm(R) Ethos(TM)-U NPU""" from typing import Tuple -import tvm -from tvm.relay.op import _make -from tvm.topi.generic import schedule_injective -from tvm.relay.op.op import OpStrategy +import tvm # type: ignore +from tvm.relay.op import _make # type: ignore +from tvm.topi.generic import schedule_injective # type: ignore +from tvm.relay.op.op import OpStrategy # type: ignore from tvm.relay.op import strategy as _strategy from ..te import conv2d_compute diff --git a/python/tvm/relay/backend/contrib/ethosu/preprocess.py b/python/tvm/relay/backend/contrib/ethosu/preprocess.py index b7540f5c2fe7..795adfc2fb1f 100644 --- a/python/tvm/relay/backend/contrib/ethosu/preprocess.py +++ b/python/tvm/relay/backend/contrib/ethosu/preprocess.py @@ -19,8 +19,8 @@ NPU code generation. These set of passes will mutate both the main and the external functions. """ -import tvm -from . import _ffi_api +import tvm # type: ignore +from . import _ffi_api # type: ignore def preprocess_ext_io() -> tvm.transform.Pass: @@ -33,4 +33,4 @@ def preprocess_ext_io() -> tvm.transform.Pass: ret : tvm.transform.Pass The registered pass to mutate the IO of the external functions and their calls. """ - return _ffi_api.PreprocessExternalFuncIO() + return _ffi_api.PreprocessExternalFuncIO() # type: ignore # pylint: disable=no-member diff --git a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py index d12e3908f5e1..40015ac296a6 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py @@ -18,7 +18,7 @@ """Tensor Expressions for convolutions for the NPU""" from typing import Tuple, Union, List -from tvm import te +from tvm import te # type: ignore from .dma import dma_ofm_compute, dma_ifm_compute @@ -34,7 +34,7 @@ def conv2d_compute( ofm_zero_point: int, strides: Tuple[int, int], padding: Tuple[int, int, int, int], - dilation: Union[int, Tuple[int, int], List[int]], + dilation: Union[Tuple[int, int], List[int]], activation: str, clip_min: int, clip_max: int, @@ -68,7 +68,7 @@ def conv2d_compute( The 2 dimensional strides as (stride_height, stride_width). padding : tuple The 4 dimensional padding as (pad_top, pad_left, pad_bottom, pad_right). - dilation : Union[int, tuple, list] + dilation : Union[Tuple[int, int], List[int]] The 2 dimensional dilation as (dilation_height, dilation_width). activation : str The activation function to use. diff --git a/python/tvm/relay/backend/contrib/ethosu/te/dma.py b/python/tvm/relay/backend/contrib/ethosu/te/dma.py index b774fce576d1..d19c8c56f7c2 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/dma.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/dma.py @@ -16,14 +16,16 @@ # under the License. # pylint: disable=invalid-name,unnecessary-lambda """Tensor Expressions for operations supported by the NPU DMA engine""" -from typing import Callable, Tuple +from typing import Callable, Tuple, Optional, List -import tvm +import tvm # type: ignore from tvm import te -from tvm.topi.utils import equal_const_int +from tvm.topi.utils import equal_const_int # type: ignore -def _pad_tensor(tensor: te.Tensor, pad_before: tuple, pad_after: tuple = None) -> Callable: +def _pad_tensor( + tensor: te.Tensor, pad_before: List[int], pad_after: Optional[List[int]] = None +) -> Callable: """Generate a padded tensor. Parameters @@ -212,7 +214,7 @@ def convert_to_nhcwb16_compute(tensor: te.Tensor, layout: str, channels: int) -> ) -def pad_compute(tensor: te.Tensor, padding: tuple): +def pad_compute(tensor: te.Tensor, padding: tuple) -> te.Tensor: """Pad an NHWC tensor in the height and width axes. Parameters @@ -229,8 +231,8 @@ def pad_compute(tensor: te.Tensor, padding: tuple): """ pad_top, pad_left, pad_down, pad_right = padding - pad_before = [0, pad_top, pad_left, 0] - pad_after = [0, pad_down, pad_right, 0] + pad_before = [0, int(pad_top), int(pad_left), 0] + pad_after = [0, int(pad_down), int(pad_right), 0] pad_attrs = { "op": "ethosu_pad", } diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index 66a3ed2ab504..e9d89d33e6f0 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -23,12 +23,12 @@ from enum import Enum from typing import Union, Tuple, Dict, Optional -import numpy as np +import numpy as np # type: ignore -import tvm +import tvm # type: ignore from tvm import relay -from tvm.relay.build_module import bind_params_by_name -from tvm.relay.backend.contrib.ethosu import preprocess +from tvm.relay.build_module import bind_params_by_name # type: ignore +from tvm.relay.backend.contrib.ethosu import preprocess # type: ignore class QConv2DArgs(Enum): diff --git a/python/tvm/relay/backend/contrib/ethosu/vela_api.py b/python/tvm/relay/backend/contrib/ethosu/vela_api.py index f201abf27994..027b26837d44 100644 --- a/python/tvm/relay/backend/contrib/ethosu/vela_api.py +++ b/python/tvm/relay/backend/contrib/ethosu/vela_api.py @@ -24,10 +24,10 @@ import logging import math from typing import Tuple, Optional, List -import numpy as np -from ethosu.vela import api as vapi +import numpy as np # type: ignore +from ethosu.vela import api as vapi # type: ignore -from tvm.relay.backend.contrib.ethosu import util +from tvm.relay.backend.contrib.ethosu import util # type: ignore # pylint: disable=invalid-name logger = logging.getLogger("Ethos-U") diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 9aa64fe52d51..0da81101c77b 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -15,19 +15,47 @@ # specific language governing permissions and limitations # under the License. """Arm(R) Ethos(TM)-U NPU supported operators.""" -import numpy as np - -from tvm.relay.expr import Constant -from tvm.relay.op.contrib.register import register_pattern_table -from tvm.relay.dataflow_pattern import wildcard, is_op, is_constant -from tvm.relay.backend.contrib.ethosu.util import QConv2DArgs +from typing import List, Tuple, Callable +import numpy as np # type: ignore + +import tvm # type: ignore +from tvm.relay.expr import Constant # type: ignore +from tvm.relay.op.contrib.register import register_pattern_table # type: ignore +from tvm.relay.dataflow_pattern import wildcard, is_op, is_constant # type: ignore +from tvm.relay.backend.contrib.ethosu.util import QConv2DArgs # type: ignore from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs from tvm.relay.backend.contrib.ethosu.util import RequantArgs from tvm.relay.backend.contrib.ethosu.util import get_dim_value -from ethosu.vela import api as vapi +from ethosu.vela import api as vapi # type: ignore + + +class TensorParams: + """ + This class will parse a tvm Expr along with quantization scale + and zero point to populate parameters that are required + for the creation of tensors in Vela. + """ + + def __init__(self, tensor, layout=None, scale=None, zero_point=None): + self.tensor = tensor + if isinstance(tensor, Constant): + self.values = tensor.data.asnumpy() + else: + self.values = None + self.dtype = tensor.checked_type.dtype + self.shape = [int(i) for i in tensor.checked_type.shape] + self.layout = layout + + if scale is not None and zero_point is not None: + self.q_params = vapi.NpuQuantization( + scale.data.asnumpy().astype("float32"), zero_point.data.asnumpy().astype(self.dtype) + ) + else: + # put default values + self.q_params = vapi.NpuQuantization(1.0, 0) -def check_strides(strides): +def check_strides(strides: List[int]) -> bool: """This function checks whether strides are within the limits supported by the NPU""" stride_range = (1, 3) smin, smax = stride_range @@ -38,7 +66,7 @@ def check_strides(strides): return True -def check_valid_dtypes(tensor_params): +def check_valid_dtypes(tensor_params: List[TensorParams]) -> bool: """This function checks whether dtypes are supported by the NPU""" supported_dtypes = (np.uint8, np.int8) for tep in tensor_params: @@ -51,7 +79,7 @@ def check_valid_dtypes(tensor_params): return True -def check_weights(weights, dilation): +def check_weights(weights: TensorParams, dilation: List[int]): """This function checks whether weight tensor is compatible with the NPU""" dilated_height_range = (1, 64) dilated_hxw_range = (1, 64 * 64) @@ -76,7 +104,7 @@ def check_weights(weights, dilation): return sum_weights <= weights_limit -def check_bias(bias): +def check_bias(bias: TensorParams): """This function checks whether the bias values fit in 40 bits""" if bias and bias.dtype == np.dtype("int64"): valid = all(len(bin(bias_value)[2:]) <= 40 for bias_value in bias.values) @@ -84,12 +112,12 @@ def check_bias(bias): return True -def check_batch_size(ifm): +def check_batch_size(ifm: TensorParams): """This function checks for the number of batches vela currently supports""" return ifm.shape[0] == 1 -def check_dilation(dilation): +def check_dilation(dilation: List[int]): """This function checks whether dilation is within the limits supported by the NPU""" dilation_range = (1, 2) dmin, dmax = dilation_range @@ -100,7 +128,7 @@ def check_dilation(dilation): return True -def check_padding(padding, bounds): +def check_padding(padding: List[int], bounds: List[int]): """This function checks whether padding is within the limits supported by the NPU""" if len(padding) != 4 or len(bounds) != 4: return False @@ -109,32 +137,6 @@ def check_padding(padding, bounds): return not (top > topb or left > leftb or bottom > bottomb or right > rightb) -class TensorParams: - """ - This class will parse a tvm Expr along with quantization scale - and zero point to populate parameters that are required - for the creation of tensors in Vela. - """ - - def __init__(self, tensor, layout=None, scale=None, zero_point=None): - self.tensor = tensor - if isinstance(tensor, Constant): - self.values = tensor.data.asnumpy() - else: - self.values = None - self.dtype = tensor.checked_type.dtype - self.shape = [int(i) for i in tensor.checked_type.shape] - self.layout = layout - - if scale is not None and zero_point is not None: - self.q_params = vapi.NpuQuantization( - scale.data.asnumpy().astype("float32"), zero_point.data.asnumpy().astype(self.dtype) - ) - else: - # put default values - self.q_params = vapi.NpuQuantization(1.0, 0) - - class QnnConv2DParams: """ This class will parse a Call to a ethosu.qnn_conv2d composite function @@ -146,7 +148,7 @@ class QnnConv2DParams: padding_bounds = [31, 31, 32, 32] activation_map = {"clip": "CLIP"} - def __init__(self, func_body): + def __init__(self, func_body: tvm.relay.Function): activation = None if str(func_body.op) in self.activation_map.keys(): activation = func_body @@ -195,7 +197,7 @@ def __init__(self, func_body): if qnn_conv2d.attrs.groups == self.weights.shape[channels_axis[kernel_layout]]: self.is_depthwise = True - def is_valid(self): + def is_valid(self) -> bool: """ This function checks whether QnnConv2D has compatible attributes with the NPU """ @@ -221,7 +223,7 @@ def is_valid(self): return not self.is_depthwise -def qnn_conv2d_pattern(): +def qnn_conv2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern: """ This function creates the pattern for qnn.conv2D with optional fused RELU activation. """ @@ -237,7 +239,7 @@ def qnn_conv2d_pattern(): @register_pattern_table("ethosu") -def pattern_table(): +def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]: return [ ( QnnConv2DParams.composite_name, diff --git a/src/relay/op/contrib/ethosu/convolution.cc b/src/relay/op/contrib/ethosu/convolution.cc index 4abc451d84d8..bad10bf66f3a 100644 --- a/src/relay/op/contrib/ethosu/convolution.cc +++ b/src/relay/op/contrib/ethosu/convolution.cc @@ -19,7 +19,7 @@ /*! * \file src/relay/op/contrib/ethosu/convolution.cc - * \brief Property def of the Arm(R) Ethos(TM)-U NPU convolution ops. + * \brief Operator definitions for the Arm(R) Ethos(TM)-U NPU convolution ops. */ #include "../../nn/convolution.h" From 23c9b102e579a84777d020648fda7d9534624673 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Fri, 3 Sep 2021 07:54:37 +0100 Subject: [PATCH 11/14] Arm(R) Ethos(TM)-U NPU Relay passes and Conv2D op *missed comment split_o Change-Id: I4a4b19ff2cd18e8f568a63ae827f44358ed85b8e --- tests/python/contrib/test_ethosu/test_legalize.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 758554e509bb..52f6995c3aaa 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -96,10 +96,10 @@ def test_split_sections_legalize(): def create_graph(axis, sections): x = relay.var("x", shape=(1, 50, 50, 3)) x_abs = relay.abs(x) - split_o = relay.split(x_abs, sections, axis).tuple_value + split_output = relay.split(x_abs, sections, axis).tuple_value outputs = list() for section_idx in range(sections): - split_single_out = relay.TupleGetItem(split_o, section_idx) + split_single_out = relay.TupleGetItem(split_output, section_idx) tanh = relay.tanh(split_single_out) outputs.append(tanh) tuple_out = relay.Tuple(outputs) From 0de4415732cbbc9d1a5c1d8509d25afda754a84a Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Tue, 7 Sep 2021 17:39:01 +0100 Subject: [PATCH 12/14] Arm(R) Ethos(TM)-U NPU Relay passes and Conv2D op * adding mypy check Change-Id: Iaf58dbba2a9d8e1098a10c589d91b63c7efe646d --- tests/scripts/task_mypy.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/scripts/task_mypy.sh b/tests/scripts/task_mypy.sh index b05acb090c2f..8507f311e9da 100755 --- a/tests/scripts/task_mypy.sh +++ b/tests/scripts/task_mypy.sh @@ -25,3 +25,6 @@ mypy --check-untyped-defs python/tvm/tir/analysis/ echo "Checking MyPy Type defs in the transofrm package." mypy --check-untyped-defs python/tvm/tir/transform/ + +echo "Checking MyPy Type defs in the tvm.relay.backend.contrib.ethosu package." +mypy --check-untyped-defs python/tvm/relay/backend/contrib/ethosu/ From 538a132e980aec166970fd0037f11ae894a4bf5b Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Tue, 7 Sep 2021 18:01:30 +0100 Subject: [PATCH 13/14] Arm(R) Ethos(TM)-U NPU Relay passes and Conv2D op *removing premature insertion of get_accel_type utility Change-Id: I210512e00a5eb46adf23d1d72eb16432db526d25 --- python/tvm/relay/backend/contrib/ethosu/vela_api.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/vela_api.py b/python/tvm/relay/backend/contrib/ethosu/vela_api.py index 027b26837d44..72ae18123b3d 100644 --- a/python/tvm/relay/backend/contrib/ethosu/vela_api.py +++ b/python/tvm/relay/backend/contrib/ethosu/vela_api.py @@ -308,16 +308,3 @@ def _calculate_hw_bias_scales( hw_bias_scales = [_quantize_scale(bs) for bs in bias_scales] return hw_bias_scales - - -def get_target_accel_type() -> vapi.NpuAccelerator: - """This is a helper function to convert TVMC command line argument to NpuAccelerator type""" - npu_accel_str_map = { - "ethos-u55-256": vapi.NpuAccelerator.Ethos_U55_256, - "ethos-u55-128": vapi.NpuAccelerator.Ethos_U55_128, - "ethos-u55-64": vapi.NpuAccelerator.Ethos_U55_64, - "ethos-u55-32": vapi.NpuAccelerator.Ethos_U55_32, - } - accel_type_str = util.get_accelerator_config() - assert accel_type_str in npu_accel_str_map.keys(), f"{accel_type_str} is not supported" - return npu_accel_str_map[accel_type_str] From ed68938be5a64139d9b59c8b3d5203de95ed9025 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Wed, 8 Sep 2021 07:18:20 +0100 Subject: [PATCH 14/14] Arm(R) Ethos(TM)-U NPU Relay passes and Conv2D op * rebase fixes Change-Id: I06c9b536a7598646efce2b664fcc405aa6008203 --- python/tvm/relay/backend/contrib/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/backend/contrib/__init__.py b/python/tvm/relay/backend/contrib/__init__.py index 5d2a3979b9e1..bfc5b79bb2ee 100644 --- a/python/tvm/relay/backend/contrib/__init__.py +++ b/python/tvm/relay/backend/contrib/__init__.py @@ -16,4 +16,3 @@ # under the License. """External backend codegen modules for Relay.""" from . import cmsisnn -from . import ethosu