From cfed2bdb450717926b048002582d58fa06ded7fe Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Thu, 5 Aug 2021 17:21:26 +0100 Subject: [PATCH 1/7] Arm(R) Ethos(TM)-U NPU TIR compiler with conv2d support This commit adds the lowering passes necessary to lower an NPU Relay module down to a TIR module that can be compiled for the NPU. Conv2d is supported as the first NPU operator. An intermediate TE stage between Relay and TIR allows support for scheduling the operators. Co-authored-by: Manupa Karunaratne --- .../relay/backend/contrib/ethosu/__init__.py | 1 + .../backend/contrib/ethosu/tir/__init__.py | 17 + .../backend/contrib/ethosu/tir/compiler.py | 199 +++++++ .../backend/contrib/ethosu/tir/convolution.py | 106 ++++ .../relay/backend/contrib/ethosu/tir/dma.py | 291 ++++++++++ .../backend/contrib/ethosu/tir/passes.py | 475 +++++++++++++++ .../backend/contrib/ethosu/tir/scheduler.py | 277 +++++++++ .../relay/backend/contrib/ethosu/tir/spec.py | 263 +++++++++ .../backend/contrib/ethosu/tir/transform.py | 61 ++ .../relay/backend/contrib/ethosu/tir/utils.py | 174 ++++++ .../contrib/ethosu/tir_to_cs_translator.py | 332 +++++++++++ .../tvm/relay/backend/contrib/ethosu/util.py | 14 + .../relay/backend/contrib/ethosu/vela_api.py | 48 ++ .../backend/contrib/ethosu/compiler_attrs.cc | 73 +++ .../backend/contrib/ethosu/to_te_graph.cc | 234 ++++++++ tests/python/contrib/test_ethosu/infra.py | 117 ++++ .../contrib/test_ethosu/test_attr_passing.py | 44 ++ .../contrib/test_ethosu/test_compiler.py | 45 ++ .../test_ethosu/test_encode_constants.py | 273 +++++++++ .../test_ethosu/test_extract_constants.py | 97 ++++ .../contrib/test_ethosu/test_lower_to_te.py | 63 ++ .../test_ethosu/test_replace_conv2d.py | 547 ++++++++++++++++++ .../contrib/test_ethosu/test_replace_copy.py | 75 +++ .../contrib/test_ethosu/test_scheduler.py | 148 +++++ .../contrib/test_ethosu/test_vela_api.py | 101 ++++ 25 files changed, 4075 insertions(+) create mode 100644 python/tvm/relay/backend/contrib/ethosu/tir/__init__.py create mode 100644 python/tvm/relay/backend/contrib/ethosu/tir/compiler.py create mode 100644 python/tvm/relay/backend/contrib/ethosu/tir/convolution.py create mode 100644 python/tvm/relay/backend/contrib/ethosu/tir/dma.py create mode 100644 python/tvm/relay/backend/contrib/ethosu/tir/passes.py create mode 100644 python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py create mode 100644 python/tvm/relay/backend/contrib/ethosu/tir/spec.py create mode 100644 python/tvm/relay/backend/contrib/ethosu/tir/transform.py create mode 100644 python/tvm/relay/backend/contrib/ethosu/tir/utils.py create mode 100644 python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py create mode 100644 src/relay/backend/contrib/ethosu/compiler_attrs.cc create mode 100644 src/relay/backend/contrib/ethosu/to_te_graph.cc create mode 100644 tests/python/contrib/test_ethosu/infra.py create mode 100644 tests/python/contrib/test_ethosu/test_attr_passing.py create mode 100644 tests/python/contrib/test_ethosu/test_compiler.py create mode 100644 tests/python/contrib/test_ethosu/test_encode_constants.py create mode 100644 tests/python/contrib/test_ethosu/test_extract_constants.py create mode 100644 tests/python/contrib/test_ethosu/test_lower_to_te.py create mode 100644 tests/python/contrib/test_ethosu/test_replace_conv2d.py create mode 100644 tests/python/contrib/test_ethosu/test_replace_copy.py create mode 100644 tests/python/contrib/test_ethosu/test_scheduler.py diff --git a/python/tvm/relay/backend/contrib/ethosu/__init__.py b/python/tvm/relay/backend/contrib/ethosu/__init__.py index f5c595462e73..2b424ebb5dec 100644 --- a/python/tvm/relay/backend/contrib/ethosu/__init__.py +++ b/python/tvm/relay/backend/contrib/ethosu/__init__.py @@ -20,4 +20,5 @@ from . import preprocess from . import errors from . import vela_api +from . import tir_to_cs_translator from .util import partition_for_ethosu diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/__init__.py b/python/tvm/relay/backend/contrib/ethosu/tir/__init__.py new file mode 100644 index 000000000000..cc285e5241cd --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Arm(R) Ethos(TM)-U NPU TIR codegen modules.""" diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py new file mode 100644 index 000000000000..c59a386fefbb --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -0,0 +1,199 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""The integration of Arm(R) Ethos(TM)-U NPU TIR compiler""" +import tvm +from tvm import relay +from tvm.relay.expr_functor import ExprMutator +from tvm.driver.build_module import get_binds + +from .passes import ReplaceOperators, RemoveZeroStores, EncodeConstants +from .scheduler import schedule + + +def lower_ethosu(sch, args, const_dict, name="main"): + """Lower a schedule to TIR for the Arm(R) Ethos(TM)-U NPU target. + + The resulting TIR module will contain a single function + that comprises of a sequence of tir.extern_calls to NPU + operations. + + Parameters + ---------- + sch : tvm.te.Schedule + The schedule to be lowered. + args : Union[list of tvm.te.Tensor, TEGraph] + The input/output tensors. + const_dict : dict of int to numpy.ndarray + The constant dictionary. + name : str, optional + The name of the lowered primitive function. + + Returns + ------- + mod : tvm.IRModule + The lowered TIR module. + const_dict : dict of int to numpy.ndarray + The modified constant dictionary. + + """ + if not isinstance(args, list): + args = list(args.inputs) + list(args.outputs) + # config setup + curr_pass_ctx = tvm.ir.transform.PassContext.current() + curr_cfg = dict() + for key, value in curr_pass_ctx.config.items(): + curr_cfg[key] = value + tir_compiler_cfg = { + "tir.LoopPartition": { + "partition_const_loop": True, + "no_unroll_loop_with_extent_one": True, + }, + "tir.UnrollLoop": {"auto_max_depth": -1}, + } + # Merge two configs + curr_cfg = {**curr_cfg, **tir_compiler_cfg} + + sch = sch.normalize() + bounds = tvm.te.schedule.InferBound(sch) + stmt = tvm.te.schedule.ScheduleOps(sch, bounds, True) + + compact = tvm.te.schedule.VerifyCompactBuffer(stmt) + binds, arg_list = get_binds(args, compact, None) + func = tvm.te.schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) + + func = func.with_attr("global_symbol", name) + func = func.with_attr("tir.noalias", True) + mod = tvm.IRModule({name: func}) + with tvm.transform.PassContext(config=curr_cfg): + mod = tvm.tir.transform.Simplify()(mod) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + mod = tvm.tir.transform.UnrollLoop()(mod) + mod = tvm.tir.transform.LoopPartition()(mod) + mod = RemoveZeroStores()(mod) + mod = tvm.tir.transform.Simplify()(mod) + mod = tvm.tir.transform.RemoveNoOp()(mod) + mod = ReplaceOperators()(mod) + mod = tvm.tir.transform.RemoveNoOp()(mod) + mod, const_dict = EncodeConstants(const_dict)(mod) + mod = tvm.tir.transform.StorageRewrite()(mod) + mod = tvm.tir.transform.RemoveNoOp()(mod) + return mod, const_dict + + +def lower_to_te(prim_func): + """Lower a Relay primitive function to a Tensor Expression graph. + + Parameters + ---------- + prim_func : tvm.relay.Function + The Relay function to lowerethosu_runtime([]). + + Returns + ------- + out : TEGraph + The lowered Tensor Expression graph. + + """ + f = tvm._ffi.get_global_func("relay.backend.contrib.ethosu.LowerToTE") + return f(prim_func) + + +class ExtractConstants(ExprMutator): + """The actual mutator pass to extract the constants from a function and replace them with + Vars so the function can be lowered to a TE graph. Additionally returns all the values of + the constants extracted.""" + + def __init__(self): + super().__init__() + self.constants = [] + + def visit_constant(self, const): + if isinstance(const.checked_type, relay.ty.TensorType): + if const.checked_type.concrete_shape != (): + self.constants.append(const.data.asnumpy()) + name = "p" + str(len(self.constants)) + return relay.var(type_annotation=const.checked_type, name_hint=name) + + return const + + def visit_function(self, fn): + new_body = self.visit(fn.body) + new_params = list(relay.analysis.free_vars(new_body)) + return relay.Function(new_params, new_body) + + def extract_constants(self, func): + new_func = self.visit(func) + return new_func, self.constants + + +def extract_constants(func): + """Extract the constants from a function and replace them with + Vars so the function can be lowered to a TE graph. Additionally + returns all the values of the constants extracted. + + Parameters + ---------- + func : tvm.relay.Function + The Relay function from which to extract constants. + + Returns + ------- + new_func : tvm.relay.Function + The Relay function with constants replaced by vars. + const_dict : dict of int to numpy.ndarray + A dict of the extracted constants keyed by their param index. + + """ + const_dict = {} + params = len(func.params) + new_func, consts = ExtractConstants().extract_constants(func) + for i, const in enumerate(consts): + const_dict[params + i] = const + + new_func = tvm.relay.transform.InferType()(tvm.IRModule.from_expr(new_func))["main"] + return new_func, const_dict + + +def lower_to_tir(func, cascader=None): + """Lower a Relay function to TIR for the Arm(R) Ethos(TM)-U NPU target. + + The Relay function should only contain operations supported + by the NPU. + + Parameters + ---------- + func : tvm.relay.Function + The Relay function to lower. + cascader : Callable + An optional cascading function, + + Returns + ------- + mod : tvm.IRModule + The lowered TIR module. + consts : dict of int to numpy.ndarray + A dict of the extracted constants keyed by their param index. + + """ + func, consts = extract_constants(func) + mod = tvm.IRModule.from_expr(func) + func = relay.transform.InferType()(mod)["main"] + te_graph = lower_to_te(func) + s = schedule(te_graph, consts, cascader) + mod, consts = lower_ethosu(s, te_graph, consts) + return mod, consts diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py new file mode 100644 index 000000000000..69d0e457e33b --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py @@ -0,0 +1,106 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""Extract information from the convolution operators in TIR.""" +import tvm +from ..vela_api import SCALE_BIAS_LENGTH +from .utils import get_outer_loops, get_op_attrs, get_base_address +from .dma import get_ifm_params, get_ofm_params +from .spec import SerialKernel, SerialAddressRange, SerialActivation, Serial2DConvolution + + +def get_conv2d_params(stmt, producers, consumers): + """Get the parameters necessary to construct a call_extern for a 2D convolution. + + Parameters + ---------- + stmt : tvm.tir.AttrStmt + The outermost attribute statement of a convolution loop nest. + producers : dict of tvm.tir.Var to tvm.tir.AttrStmt + A dictionary to associate pointers with the loop nest + that produces their values. + consumers : dict of tvm.tir.Var to tvm.tir.AttrStmt + A dictionary to associate pointers with the loop nest + that consumes their values. + + Returns + ------- + Serial2DConvolution + The parameters needed to construct a 2D convolution. + output_pointer : tvm.tir.Var + The output pointer of the convolution operation. + replace_pointer : tvm.tir.Var + The output pointer of the DMA write operation, which is to replace + the convolution output pointer. + + """ + attrs, body = get_op_attrs(stmt) + _, _, _, _, _, inner = get_outer_loops(body, "NHWC") + rh = inner + rw = rh.body + rc = rw.body + compute = rc.body.value.b + input_pointer = compute.a.a.buffer_var + output_pointer = rc.body.buffer_var + # Get feature map info + serial_ifm, serial_padding = get_ifm_params(input_pointer, producers) + serial_ofm, replace_pointer = get_ofm_params(output_pointer, consumers) + # Get kernel info + serial_kernel = SerialKernel( + width=int(rw.extent), + height=int(rh.extent), + stride_w=int(attrs["stride_w"]), + stride_h=int(attrs["stride_h"]), + dilation_w=int(attrs["dilation_w"]), + dilation_h=int(attrs["dilation_h"]), + ) + # Get scale_bias info + scale_bias_mul = compute.b + if isinstance(scale_bias_mul, tvm.tir.Cast): + scale_bias_mul = scale_bias_mul.value + scale_bias_load = scale_bias_mul.a + scale_bias_base = get_base_address(scale_bias_load.index) + serial_scale_bias = SerialAddressRange( + address=tvm.tir.Load("uint8", scale_bias_load.buffer_var, scale_bias_base), + length=SCALE_BIAS_LENGTH * serial_ofm[3], + ) + # Get weight info + weight_load = compute.a.b + weight_base = get_base_address(weight_load.index) + serial_weight = SerialAddressRange( + address=tvm.tir.Load("uint8", weight_load.buffer_var, weight_base), + length=serial_ofm[3] * serial_kernel[0] * serial_kernel[1] * rc.extent, + ) + # Get activation info + serial_activation = SerialActivation( + op=attrs["activation"], clip_min=attrs["clip_min"], clip_max=attrs["clip_max"] + ) + return ( + Serial2DConvolution( + ifm=serial_ifm, + ofm=serial_ofm, + kernel=serial_kernel, + weight=serial_weight, + weight_zero_point=attrs["weight_zero_point"], + scale_bias=serial_scale_bias, + padding=serial_padding, + activation=serial_activation, + upscale="NONE", + ), + output_pointer, + replace_pointer, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py new file mode 100644 index 000000000000..ecd402d63309 --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py @@ -0,0 +1,291 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""Extract information from the DMA operators in TIR.""" +import tvm +from .utils import get_outer_loops, get_base_address, get_strides, get_op_attrs +from .spec import SerialFeatureMap, SerialPadding + + +def get_pad_params(stmt): + """Get the padding parameters from a pad loop nest. + + Parameters + ---------- + stmt : tvm.tir.AttrStmt + The outermost attribute statement of a pad loop nest. + + Returns + ------- + pad : SerialPadding + The serializable padding. + input_pointer : tvm.tir.Var + The pointer consumed by the operation. + output_pointer : tvm.tir.Var + The pointer produced by the operation. + + """ + _, body = get_op_attrs(stmt) + n, h, w, c, _, inner = get_outer_loops(body, "NHWC") + output_pointer = inner.buffer_var + pad = SerialPadding(top=0, left=0, bottom=0, right=0) + if isinstance(inner.value, tvm.tir.Call): + input_pointer = inner.value.args[1].buffer_var + else: + input_pointer = inner.value.buffer_var + return pad, input_pointer, output_pointer + + padded_shape = [n.extent, h.extent, w.extent, c.extent] + + def _visit(expr): + if isinstance(expr, tvm.tir.expr.LT): + var = expr.a + val = expr.b + if var == h.loop_var: + pad.bottom = padded_shape[1] - val + else: + pad.right = padded_shape[2] - val + elif isinstance(expr, tvm.tir.expr.LE): + var = expr.b + val = expr.a + if var == h.loop_var: + pad.top = val + else: + pad.left = val + + cond = inner.value.args[0] + tvm.tir.stmt_functor.post_order_visit(cond, _visit) + return ( + pad, + input_pointer, + output_pointer, + ) + + +def get_convert_to_nhwc_params(stmt): + """Get the true number of channels from a convert_to_nhwc loop nest. + + Parameters + ---------- + stmt : tvm.tir.AttrStmt + The outermost attribute statement of a convert_to_nhwc loop nest. + + Returns + ------- + int + The true number of channels. + input_pointer : tvm.tir.Var + The pointer consumed by the operation. + output_pointer : tvm.tir.Var + The pointer produced by the operation. + + """ + _, body = get_op_attrs(stmt) + _, _, _, c, _, inner = get_outer_loops(body, "NHWC") + output_pointer = inner.buffer_var + input_pointer = inner.value.buffer_var + return c.extent, input_pointer, output_pointer + + +def get_convert_to_nhcwb16_params(stmt): + """Get the true number of channels from a convert_to_nhcwb16 loop nest. + + Parameters + ---------- + stmt : tvm.tir.AttrStmt + The outermost attribute statement of a convert_to_nhcwb16 loop nest. + + Returns + ------- + out_channels : int + The true number of channels. + input_pointer : tvm.tir.Var + The pointer consumed by the operation. + output_pointer : tvm.tir.Var + The pointer produced by the operation. + + """ + attrs, body = get_op_attrs(stmt) + _, _, _, c, b, inner = get_outer_loops(body, attrs["layout"]) + output_pointer = inner.buffer_var + if isinstance(inner.value, tvm.tir.Call): + cond = inner.value.args[0] + out_channels = cond.b.value + input_pointer = inner.value.args[1].buffer_var + else: + input_pointer = inner.value.buffer_var + out_channels = c.extent * b.extent if attrs["layout"] == "NHCWB16" else c.extent + + return out_channels, input_pointer, output_pointer + + +def get_read_params(stmt): + """Get the feature map parameters from a read loop nest. + + Parameters + ---------- + stmt : tvm.tir.AttrStmt + The outermost attribute statement of a read loop nest. + + Returns + ------- + SerialFeatureMap + The serializable feature map. + input_pointer : tvm.tir.Var + The pointer consumed by the operation. + output_pointer : tvm.tir.Var + The pointer produced by the operation. + + """ + attrs, body = get_op_attrs(stmt) + _, h, w, c, _, inner = get_outer_loops(body, attrs["layout"]) + input_pointer = inner.value.buffer_var + output_pointer = inner.buffer_var + stride_vars = [h.loop_var, w.loop_var, c.loop_var] + strides = get_strides(inner.value.index, stride_vars) + base_address = get_base_address(inner.value.index) + data_type = inner.buffer_var.type_annotation.element_type.dtype + return ( + SerialFeatureMap( + data_type=data_type, + height=h.extent, + width=w.extent, + channels=c.extent, + tile_height_0=h.extent, + tile_height_1=0, + tile_width_0=w.extent, + tile_address_0=tvm.tir.Load(data_type, inner.value.buffer_var, base_address), + tile_address_1=0, + tile_address_2=0, + tile_address_3=0, + scale=attrs["scale"], + zero_point=attrs["zero_point"], + layout=attrs["layout"], + stride_h=strides[0], + stride_w=strides[1], + stride_c=strides[2], + ), + input_pointer, + output_pointer, + ) + + +def get_write_params(stmt): + """Get the feature map parameters from a write loop nest. + + Parameters + ---------- + stmt : tvm.tir.AttrStmt + The outermost attribute statement of a write loop nest. + + Returns + ------- + SerialFeatureMap + The serializable feature map. + input_pointer : tvm.tir.Var + The pointer consumed by the operation. + output_pointer : tvm.tir.Var + The pointer produced by the operation. + + """ + attrs, body = get_op_attrs(stmt) + _, h, w, c, _, inner = get_outer_loops(body, attrs["layout"]) + input_pointer = inner.value.buffer_var + output_pointer = inner.buffer_var + stride_vars = [h.loop_var, w.loop_var, c.loop_var] + strides = get_strides(inner.index, stride_vars) + base_address = get_base_address(inner.index) + data_type = inner.buffer_var.type_annotation.element_type.dtype + return ( + SerialFeatureMap( + data_type=data_type, + height=h.extent, + width=w.extent, + channels=c.extent, + tile_height_0=h.extent, + tile_height_1=0, + tile_width_0=w.extent, + tile_address_0=tvm.tir.Load(data_type, inner.buffer_var, base_address), + tile_address_1=0, + tile_address_2=0, + tile_address_3=0, + scale=attrs["scale"], + zero_point=attrs["zero_point"], + layout=attrs["layout"], + stride_h=strides[0], + stride_w=strides[1], + stride_c=strides[2], + ), + input_pointer, + output_pointer, + ) + + +def get_ifm_params(pointer, producers): + """Get the parameters associated with the DMA capabilities for an IFM. + + Parameters + ---------- + pointer : tvm.tir.Var + The pointer that the IFM DMA pipeline produces. + producers : dict of tvm.tir.Var to tvm.tir.AttrStmt + A dictionary to associate pointers with the loop nest + that produces their values. + + Returns + ------- + serial_ifm : SerialFeatureMap + The serializable IFM. + serial_padding : SerialPadding + The serializable padding. + + """ + pad = producers[pointer] + serial_padding, input_pointer, _ = get_pad_params(pad) + convert_to_nhwc = producers[input_pointer] + in_channels, input_pointer, _ = get_convert_to_nhwc_params(convert_to_nhwc) + read = producers[input_pointer] + serial_ifm, _, _ = get_read_params(read) + serial_ifm.channels = in_channels + return serial_ifm, serial_padding + + +def get_ofm_params(pointer, consumers): + """Get the parameters associated with the DMA capabilities for an OFM. + + Parameters + ---------- + pointer : tvm.tir.Var + The pointer that the OFM DMA pipeline consumes. + consumers : dict of tvm.tir.Var to tvm.tir.AttrStmt + A dictionary to associate pointers with the loop nest + that consumes their values. + + Returns + ------- + serial_ifm : SerialFeatureMap + The serializable OFM. + output_pointer : tvm.tir.Var + The pointer that the OFM DMA pipeline produces. + + """ + convert_to_nhcwb16 = consumers[pointer] + out_channels, _, output_pointer = get_convert_to_nhcwb16_params(convert_to_nhcwb16) + write = consumers[output_pointer] + serial_ofm, _, output_pointer = get_write_params(write) + serial_ofm.channels = out_channels + return serial_ofm, output_pointer diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py new file mode 100644 index 000000000000..75a2b5b3362b --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -0,0 +1,475 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""The TIR passes to be run on Arm(R) Ethos(TM)-U NPU TIR Compiler""" +import numpy as np + +import tvm +from tvm.relay.backend.contrib.ethosu import vela_api +from .convolution import get_conv2d_params +from .transform import get_copy_params +from .utils import get_weights_pointer, get_scale_bias_pointer + + +def RemoveZeroStores(): + """This pass removes stores which just store zero to initialise buffers. + + We don't codegen these stores and they otherwise considerably reduce + the simplicity of the static traversal of convolution.""" + + def _remove_zero_store(stmt): + if isinstance(stmt.value, tvm.tir.IntImm) and int(stmt.value) == 0: + return tvm.tir.Evaluate(tvm.tir.IntImm("uint8", 0)) + return stmt + + def _ftransform(f, mod, ctx): + return f.with_body( + tvm.tir.stmt_functor.ir_transform(f.body, _remove_zero_store, None, ["tir.Store"]) + ) + + return tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.ethosu.remove_zero_stores" + ) + + +def ReplaceOperators(): + """Replace operators represented as explicit loop nests with call_externs + to NPU operators.""" + op_map = { + "ethosu_conv2d": get_conv2d_params, + "ethosu_copy": get_copy_params, + } + pointer_to_producer = {} + pointer_to_consumer = {} + replace_output_pointer = {} + pointer_to_extents = {} + + def _resolve_pointers(stmt): + """This pass determines information about the pointers present in the IR. + In particular, it associates pointers with both the operations that + produce them and the operations that consume them through the + pointer_to_producer and pointer_to_consumer dicts. + + Additionally, it determines the extent (size/shape) of each pointer which + is required for the _replace_pointers pass which runs later.""" + loads = [] + + def _get_loads(stmt): + if isinstance(stmt, tvm.tir.Load): + loads.append(stmt.buffer_var) + + if isinstance(stmt, tvm.tir.Allocate): + pointer_to_extents[stmt.buffer_var] = stmt.extents + if isinstance(stmt.body[0], tvm.tir.AttrStmt): + if stmt.body[0].attr_key == "pragma_op": + pointer_to_producer[stmt.buffer_var] = stmt.body[0] + + elif isinstance(stmt, tvm.tir.AttrStmt): + if stmt.attr_key == "pragma_op": + tvm.tir.stmt_functor.post_order_visit(stmt, _get_loads) + for load_buffer in loads: + pointer_to_consumer[load_buffer] = stmt + + def _replace_operator(stmt): + """Replace operators with call_externs, having derived the parameters + from the relevant TIR expressions/statements. + + Note the complexity of this pass is mostly from the concept of 'replace + pointers'. A call_extern may in principle require information from several + loop nests in TIR (each corresponding to a different TE compute op). For + example, a convolution operator will have other TE compute ops before and + after corresponding to the input/output DMA functionality. Therefore, when + the 'central' convolution op is replaced with a call_extern, the memory + from the final DMA output op must be hoisted to the location/scope of + the call_extern. + + The is done by replacing the pointer corresponding to the current operation + with the 'true' output operator through the replace_output_pointer dict. + Because of this, the param_func must provide a replace_pointer if the op + isn't the true output but instead a no_compile op is.""" + if isinstance(stmt, tvm.tir.AttrStmt): + op_name = stmt.value.value + if stmt.attr_key == "pragma_op" and op_name in op_map: + # Get the parameters for the extern call + param_func = op_map[op_name] + info, output_pointer, replace_pointer = param_func( + stmt, pointer_to_producer, pointer_to_consumer + ) + if replace_pointer is not None: + replace_output_pointer[output_pointer] = replace_pointer + # Make the extern call + irb = tvm.tir.ir_builder.create() + irb.emit(tvm.tir.call_extern("handle", op_name, *info)) + return irb.get() + return None + + def _remove_no_compile(stmt): + """Certain operators are marked as 'no compile' operators. This means they + should be removed from the IR as they are compiled as part of other operators. + The IFM DMA operations are an example of this, as they don't get compiled + independently but instead get compiled into the operator they're associated with, + e.g. a conv2d. + + There are potentially 3 parts to remove for an operator: the memory scope, the + allocate for its output and the compute nest itself. For the memory scope and + allocate, we can check if the pointer they reference is produced by a 'no compile' + operator. For the compute nest, we can just check the op pragma.""" + if isinstance(stmt, tvm.tir.AttrStmt): + # Remove memory scopes + if stmt.node in pointer_to_producer: + producer_attr = pointer_to_producer[stmt.node] + if ( + producer_attr.attr_key == "pragma_op" + and producer_attr.value.value not in op_map + ): + return stmt.body + + # Remove compute nests + if stmt.attr_key == "pragma_op" and stmt.value.value not in op_map: + return tvm.tir.Evaluate(0) + + if isinstance(stmt, tvm.tir.Allocate): + # Remove allocates + if stmt.buffer_var in pointer_to_producer: + op_attr = pointer_to_producer[stmt.buffer_var] + if op_attr.attr_key == "pragma_op" and op_attr.value.value not in op_map: + return stmt.body + return None + + def _replace_pointers(stmt): + if isinstance(stmt, tvm.tir.AttrStmt): + # If the attribute references a pointer that needs replacing + if stmt.node in replace_output_pointer: + replace_pointer = replace_output_pointer[stmt.node] + # If the pointer doesn't have an extent registered to it, + # this means the pointer is to a Buffer. In this case, we + # just want to delete the memory scope attribute + if replace_pointer not in pointer_to_extents: + return stmt.body + # Otherwise, rewrite the memory scope attribute with the new pointer + return tvm.tir.AttrStmt( + replace_output_pointer[stmt.node], stmt.attr_key, stmt.value, stmt.body + ) + + if isinstance(stmt, tvm.tir.Allocate): + # If the allocate allocates a pointer that needs replacing + if stmt.buffer_var in replace_output_pointer: + replace_pointer = replace_output_pointer[stmt.buffer_var] + # If the pointer doesn't have an extent registered to it, + # this means the pointer is to a Buffer. In this case, we + # just want to delete the allocation statement + if replace_pointer not in pointer_to_extents: + return stmt.body + # Otherwise, rewrite the allocation statement with the new pointer + # and the new extent + replace_type = replace_pointer.type_annotation.element_type.dtype + replace_extents = pointer_to_extents[replace_pointer] + return tvm.tir.Allocate( + replace_pointer, replace_type, replace_extents, stmt.condition, stmt.body + ) + return None + + def _post_transform(stmt): + # Replace operators with call_externs + result = _replace_operator(stmt) + # Remove operators that don't need compiling + result = result or _remove_no_compile(stmt) + # Replace necessary pointers that were removed in the previous step + return result or _replace_pointers(stmt) + + def _ftransform(f, mod, ctx): + tvm.tir.stmt_functor.post_order_visit(f.body, _resolve_pointers) + return f.with_body( + tvm.tir.stmt_functor.ir_transform( + f.body, None, _post_transform, ["tir.AttrStmt", "tir.Allocate"] + ) + ) + + return tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.ethosu.replace_operators" + ) + + +def DivideConstants(const_dict): + """This pass rewrites the IR and constant dict such that all constant + accesses are at 0 offset and full length (i.e. they read the whole buffer). + + Where necessary, new constants are created in order to ensure the rewrite + can take place. As an example, if a convolution is tiled along the channels + axis, the accesses to the weights will need to be offset. This pass will + create new constants consisting of 'slices' of the weights so each tile + of the compute can access one of these 'slices'. + + The purpose of this pass is to transform the IR into a form we can apply + constant encoding to (which will compress weights and encode biases).""" + buffer_to_const = {} + new_buffers = [] + new_consts = [] + keep_buffers = set() + new_const_dict = {} + + def _visit(stmt): + new_args = [] + for i, arg in enumerate(stmt.args): + if isinstance(arg, tvm.tir.expr.Load): + # If we're trying to load a buffer that maps to a constant + if arg.buffer_var in buffer_to_const: + const = buffer_to_const[arg.buffer_var] + offset = int(arg.index) + # Note by convention the arg after a constant read is the length of the read + length = int(stmt.args[i + 1]) + # If it's anything other than a full read, create a new buffer + if offset != 0 or len(const) != length: + new_consts.append(const[offset : offset + length]) + new_buffer = tvm.tir.decl_buffer((length,), arg.dtype) + new_buffers.append(new_buffer) + new_args.append(tvm.tir.expr.Load(new_buffer.dtype, new_buffer.data, 0)) + continue + keep_buffers.add(arg.buffer_var) + + new_args.append(arg) + + return tvm.tir.Call(stmt.dtype, stmt.op, new_args, stmt.span) + + def _ftransform(f, mod, ctx): + for i, param in enumerate(f.params): + if i in const_dict: + buffer_to_const[param] = const_dict[i].flatten() + buffer_to_const[f.buffer_map[param].data] = const_dict[i].flatten() + + new_body = tvm.tir.stmt_functor.ir_transform(f.body, _visit, None, ["tir.Call"]) + # Both the params and buffer map need updating for the newly introduced buffers + new_params = [] + new_buffer_map = {} + for i, param in enumerate(f.params): + buffer = f.buffer_map[param] + pointer = buffer.data + if pointer in buffer_to_const: + if pointer not in keep_buffers: + continue + new_const_dict[len(new_params)] = const_dict[i] + new_params.append(param) + new_buffer_map[param] = buffer + + for i, new_buffer in enumerate(new_buffers): + handle = tvm.tir.Var("placeholder", "handle") + new_params.append(handle) + new_buffer_map[handle] = new_buffer + new_const_dict[len(new_params) - 1] = new_consts[i] + + new_f = tvm.tir.PrimFunc(new_params, new_body, f.ret_type, new_buffer_map, f.attrs, f.span) + return new_f + + def _divide_constants(mod): + transform_func = tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.ethosu.divide_constants" + ) + new_func = transform_func(mod) + return new_func, new_const_dict + + return _divide_constants + + +def EncodeConstants(const_dict): + """the NPU requires that weights are compressed and bias/scales are 'encoded', both + of which are performed by this pass. + + This pass modifies both the constant dict to contain the post-encoding values of the + constants and the IR to adjust buffer types/sizes/accesses so they align with the + encoded constants. Calls to the Vela API are made to perform the actual compression/ + encoding. + + """ + new_const_dict = {} + buffer_to_const = {} + pointer_to_buffer = {} + rewrite_buffer = {} + rewrite_pointer = {} + accel_type = vela_api.get_target_accel_type() + + def _align_scale_bias(tir_extern_call, bias): + """Align the scale_bias to 16 bytes.""" + value_bytes = bytearray() + value_bytes.extend(bias.tobytes()) + # Align to 16 + remainder = (len(value_bytes)) % 16 + if remainder > 0: + value_bytes.extend(bytearray(16 - remainder)) + value = np.frombuffer(value_bytes, dtype="uint8") + return value + + def _encode_weights(tir_extern_call, weights): + """Encode the weights for a TIR extern call.""" + value_bytes = vela_api.encode_weights(tir_extern_call, weights, accel_type) + value = np.frombuffer(value_bytes, dtype="uint8") + return value + + def _new_buffer(old_buffer, new_value): + """Create a new buffer and add the old buffer and its pointer to the + rewriting maps.""" + new_buffer = tvm.tir.decl_buffer((len(new_value),), str(new_value.dtype)) + pointer_to_buffer[new_buffer.data] = new_buffer + rewrite_buffer[old_buffer] = new_buffer + rewrite_pointer[old_buffer.data] = new_buffer.data + buffer_to_const[new_buffer] = new_value + + def _visit_encode_pre(stmt): + if isinstance(stmt, tvm.tir.Call): + # Handle copies as a special-case by propagating the buffer information + # from the read to the write pointer. + if stmt.args[0] == "ethosu_copy": + read_pointer = stmt.args[1].buffer_var + if read_pointer in pointer_to_buffer: + write_pointer = stmt.args[3].buffer_var + # Assert writing to the base of the write_var (pre-StorageRewrite) + assert stmt.args[3].index == 0 + assert stmt.args[1].index == 0 + pointer_to_buffer[write_pointer] = pointer_to_buffer[read_pointer] + else: + # Encode the weights + weights_pointer = get_weights_pointer(stmt) + if weights_pointer is not None: + assert weights_pointer in pointer_to_buffer + weights_buffer = pointer_to_buffer[weights_pointer] + weights_value = buffer_to_const[weights_buffer] + new_weights_value = _encode_weights(stmt, weights_value) + _new_buffer(weights_buffer, new_weights_value) + # Align the scale_bias to 16 bytes + scale_bias_pointer = get_scale_bias_pointer(stmt) + if scale_bias_pointer is not None: + assert scale_bias_pointer in pointer_to_buffer + scale_bias_buffer = pointer_to_buffer[scale_bias_pointer] + scale_bias_value = buffer_to_const[scale_bias_buffer] + new_scale_bias_value = _align_scale_bias(stmt, scale_bias_value) + _new_buffer(scale_bias_buffer, new_scale_bias_value) + + def _visit_encode_post(stmt): + # Because encoding may change the data type (e.g. bias to uint8) and type information + # is stored in pointer vars, it's necessary to rewrite all the pointers which point + # to encoded data. + if isinstance(stmt, tvm.tir.Allocate): + allocate_pointer = stmt.buffer_var + if allocate_pointer in pointer_to_buffer: + buffer = pointer_to_buffer[allocate_pointer] + if buffer in rewrite_buffer: # If the pointer needs rewriting + # Create a new pointer var with the type of the new buffer + new_buffer = rewrite_buffer[buffer] + storage_type = tvm.ir.PrimType(new_buffer.dtype) + new_pointer = tvm.tir.Var( + allocate_pointer.name, + tvm.ir.PointerType(storage_type, buffer.scope()), + allocate_pointer.span, + ) + # Set the new pointer to resolve to the new buffer + pointer_to_buffer[new_pointer] = new_buffer + # Add the old pointer to the pointer rewriting dict + rewrite_pointer[allocate_pointer] = new_pointer + + def _visit_rewrite(stmt): + if isinstance(stmt, tvm.tir.Call): + # For extern calls, we need to rewrite pairs of arguments corresponding to + # base address load and the length of the load. + new_args = [stmt.args[0]] + for i in range(1, len(stmt.args)): + # If the previous argument was a load, the current should be a length + if isinstance(stmt.args[i - 1], tvm.tir.Load): + load = stmt.args[i - 1] + pointer = load.buffer_var + if pointer in pointer_to_buffer: + new_args.append(np.prod(list(pointer_to_buffer[pointer].shape))) + continue + new_args.append(stmt.args[i]) + + return tvm.tir.Call(stmt.dtype, stmt.op, new_args, stmt.span) + if isinstance(stmt, tvm.tir.Allocate): + # Where a pointer needs rewriting, the allocate for it must be rewritten + allocate_pointer = stmt.buffer_var + if allocate_pointer in pointer_to_buffer: + if pointer_to_buffer[allocate_pointer] in rewrite_buffer: + new_buffer = rewrite_buffer[pointer_to_buffer[allocate_pointer]] + new_pointer = rewrite_pointer[allocate_pointer] + return tvm.tir.Allocate( + new_pointer, + new_buffer.dtype, + new_buffer.shape, + stmt.condition, + stmt.body, + stmt.span, + ) + # The following rewrites would be better expressed by just rewriting the Vars, however + # ir_transform doesn't seem to visit Vars. So instead we do the next best thing and rewrite + # the nodes which contain the Vars. + if isinstance(stmt, tvm.tir.Load): + load_pointer = stmt.buffer_var + if load_pointer in rewrite_pointer: + new_pointer = rewrite_pointer[load_pointer] + element_type = new_pointer.type_annotation.element_type.dtype + return tvm.tir.Load( + element_type, new_pointer, stmt.index, stmt.predicate, stmt.span + ) + if isinstance(stmt, tvm.tir.AttrStmt): + node_pointer = stmt.node + if node_pointer in rewrite_pointer: + return tvm.tir.AttrStmt( + rewrite_pointer[node_pointer], stmt.attr_key, stmt.value, stmt.body, stmt.span + ) + return None + + def _ftransform(f, mod, ctx): + for i, param in enumerate(f.params): + if i in const_dict: + buffer_to_const[f.buffer_map[param]] = const_dict[i].flatten() + pointer_to_buffer[f.buffer_map[param].data] = f.buffer_map[param] + + # First analyse what needs to be rewritten + new_body = tvm.tir.stmt_functor.ir_transform( + f.body, _visit_encode_pre, _visit_encode_post, ["tir.Call", "tir.Allocate"] + ) + # Then perform the rewrites + new_body = tvm.tir.stmt_functor.ir_transform( + f.body, None, _visit_rewrite, ["tir.Call", "tir.Allocate", "tir.Load", "tir.AttrStmt"] + ) + new_buffer_map = {} + # Rewrite the buffer map and const dict to instead use the encoded versions + for i, param in enumerate(f.params): + buffer = f.buffer_map[param] + if buffer in rewrite_buffer: + new_buffer = rewrite_buffer[buffer] + new_buffer_map[param] = new_buffer + new_value = buffer_to_const[new_buffer] + new_const_dict[i] = new_value + elif buffer in buffer_to_const: + new_const_dict[i] = buffer_to_const[buffer] + new_buffer_map[param] = buffer + else: + new_buffer_map[param] = buffer + + new_f = tvm.tir.PrimFunc(f.params, new_body, f.ret_type, new_buffer_map, f.attrs, f.span) + return new_f + + def _encode_constants(mod): + mod, divided_const_dict = DivideConstants(const_dict)(mod) + const_dict.clear() + for key, value in divided_const_dict.items(): + const_dict[key] = value + transform_func = tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.ethosu.encode_constants" + ) + new_func = transform_func(mod) + return new_func, new_const_dict + + return _encode_constants diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py new file mode 100644 index 000000000000..fd52e1821cb6 --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py @@ -0,0 +1,277 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""Different schedulers for Arm(R) Ethos(TM)-U NPU""" +import tvm + + +def schedule(te_graph, const_dict, cascader=None): + """Schedule a TE graph for NPU compilation. + + Parameters + ---------- + te_graph + The TE graph to schedule. + const_dict : dict of int to numpy.ndarray + The constant dictionary. + cascader : callable, optional + A cascading function to apply optimizing scheduling + to the graph. + + Returns + ------- + s : tvm.te.Schedule + The completed schedule for the graph. + + """ + s = tvm.te.create_schedule([t.op for t in te_graph.outputs]) + if cascader: + cascader(te_graph, const_dict, s) + inline_no_ops(te_graph, s) + schedule_pragmas(s) + schedule_cache_reads(s) + return s + + +def tile_nd(s, tensor, tile): + """Scheduling utility to perform N-dimensional tiling. + + Parameters + ---------- + s : tvm.te.Schedule + The schedule to apply the tiling to. + tensor : tvm.te.Tensor + The tensor to apply the tiling to. + tile : tuple + The N-dimensional tile size. + + Returns + ------- + outer_indices : list of tvm.tir.IterVar + The outer iteration variables. + inner_indices : list of tvm.tir.IterVar + The inner iteration variables. + + """ + outer_indices = [] + inner_indices = [] + for i, size in enumerate(tile): + outer, inner = s[tensor].split(tensor.op.axis[i], size) + outer_indices.append(outer) + inner_indices.append(inner) + + s[tensor].reorder(*outer_indices, *inner_indices) + return outer_indices, inner_indices + + +def total_cascader(stripe_size): + """A demo/test cascader which tries to cascade every op in the graph together. + + The desired output stride size should be specified. Note this only works + for single output graphs. + + Parameters + ---------- + stripe_size : tuple + The output stripe size. + + Returns + ------- + func : callable + The cascading function. + + """ + + def _cascader(te_graph, const_dict, sch): + scheduled = set() + + def _visit(tensor, stage, ax): + if tensor not in scheduled and isinstance(tensor.op, tvm.te.ComputeOp): + sch[tensor].compute_at(stage, ax) + scheduled.add(tensor) + for input_tensor in tensor.op.input_tensors: + _visit(input_tensor, stage, ax) + + assert len(te_graph.outputs) == 1 + out = te_graph.outputs[0] + oi, _ = tile_nd(sch, out, stripe_size) + for ax in oi: + sch[out].unroll(ax) + for input_tensor in out.op.input_tensors: + _visit(input_tensor, sch[out], oi[-1]) + + return _cascader + + +def copy_constants(): + """A simple planner which copies all constant data from FLASH -> SRAM. + + Returns + ------- + planner : callable + The planning function. + """ + + def _planner(te_graph, const_dict, sch): + planned = set() + + def _visit(tensor, reader): + if tensor is not planned: + planned.add(tensor) + if isinstance(tensor.op, tvm.te.PlaceholderOp): + index = list(te_graph.inputs).index(tensor) + if index in const_dict: + sch.cache_read(tensor, "global", [reader]) + + elif isinstance(tensor.op, tvm.te.ComputeOp): + for input_tensor in tensor.op.input_tensors: + _visit(input_tensor, tensor) + + for output_tensor in te_graph.outputs: + _visit(output_tensor, None) + + return _planner + + +def schedule_pragmas(sch): + """Add pragmas to the operators that require them. + + This adds the pragmas used for codegen to the NPU ops. + They are taken directly from the TE compute op's attributes. + Modifies the schedule in-place. + + Parameters + ---------- + sch : tvm.te.Schedule + The schedule. + + """ + + def _add_pragmas(stage, ax): + if "op" in [attr for attr, val in stage.op.attrs.items()]: + stage.pragma(ax, "op", stage.op.attrs["op"]) + for attr, val in stage.op.attrs.items(): + if attr != "op": + stage.pragma(ax, str(attr), val) + + for stage in sch.stages: + if ( + isinstance(stage.op, tvm.te.ComputeOp) + and len(stage.op.axis) + len(stage.op.reduce_axis) > 0 + ): + # The logic ensures the pragmas are assigned to the inner tiling loops + # rather than the outer ones (which end up getting unrolled). + num_inner_loops = len(stage.op.axis) + len(stage.op.reduce_axis) + ax = stage.leaf_iter_vars[-num_inner_loops] + _add_pragmas(stage, ax) + + +def schedule_cache_reads(sch): + """Schedule cache reads that have been introduced. + + There are two things we need to happen to cache_read stages. They should be tagged + with the 'ethosu_copy' pragma and have all their axes fused to make them 1D. + + Parameters + ---------- + sch : tvm.te.Schedule + The schedule. + + """ + + def _detect_cache_read(stage): + # Try and detect cache_reads by checking if the compute op is identity + if isinstance(stage.op, tvm.te.ComputeOp): + op = stage.op + if "ethosu" in op.name: + return False + axes = op.axis + if len(op.input_tensors) == 1: + tensor = op.input_tensors[0] + try: + identity_op = tensor(*axes) + except ValueError: + return False + if tvm.tir.analysis.expr_deep_equal(identity_op, op.body[0]): + return True + return False + + for stage in sch.stages: + if _detect_cache_read(stage): + fax = stage.fuse(*stage.op.axis) + stage.pragma(fax, "op", "ethosu_copy") + + +def inline_no_ops(te_graph, sch): + """Inline 'no-ops' - operations that in principle do nothing. + + Modifies the schedule in-place. For now we inline reshape and + strided slice - more could be added. + + Parameters + ---------- + te_graph + The TE graph. + sch : tvm.te.Schedule + The schedule. + + """ + no_ops = {"T_reshape", "T_strided_slice"} + scheduled = set() + + def _visit(tensor): + if tensor not in scheduled and isinstance(tensor.op, tvm.te.ComputeOp): + if tensor.op.name in no_ops: + sch[tensor].compute_inline() + scheduled.add(tensor) + for input_tensor in tensor.op.input_tensors: + _visit(input_tensor) + + for out in te_graph.outputs: + _visit(out) + + +class Convolution2DCompute: + """A helper class to manipulate the series of compute ops that make up a 2D convolution.""" + + def __init__(self, read, convert_to_nhwc, pad, conv2d, convert_to_nhcwb16, write): + self.read = read + self.convert_to_nhwc = convert_to_nhwc + self.pad = pad + self.conv2d = conv2d + self.convert_to_nhcwb16 = convert_to_nhcwb16 + self.write = write + + @classmethod + def from_output(cls, out): + write = out + convert_to_nhcwb16 = write.op.input_tensors[0] + conv2d = convert_to_nhcwb16.op.input_tensors[0] + pad = conv2d.op.input_tensors[0] + convert_to_nhwc = pad.op.input_tensors[0] + read = convert_to_nhwc.op.input_tensors[0] + return cls(read, convert_to_nhwc, pad, conv2d, convert_to_nhcwb16, write) + + def split(self, sch, axis, val): + outer, inner = sch[self.write].split(self.write.op.axis[axis], val) + sch[self.write].reorder( + outer, *[ax for ax in self.write.op.axis if ax != self.write.op.axis[axis]], inner + ) + sch[self.write].unroll(outer) + g = sch.create_group(outputs=self.convert_to_nhcwb16, inputs=self.read, include_inputs=True) + g.compute_at(sch[self.write], outer) + return outer diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py new file mode 100644 index 000000000000..3ecbcd5f3cdc --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py @@ -0,0 +1,263 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The TIR serialization specification for Arm(R) Ethos(TM)-U NPU.""" +from typing import Union +from typing import get_type_hints +from inspect import isclass + +import tvm +from tvm.relay.backend.contrib.ethosu import util + + +def create_serial_object(serialized_type, deserialized_elements): + """ + This function will create serialized type that is one of the subclasses + of tvm.relay.backend.contrib.ethosu.tir.spec.SerializableFormat + + Parameters + ---------- + serialized_type : a subclass type of SerializableFormat + + deserialized_elements : list + The list of arguments that needs to packed to create SerializableFormat objects + + Returns + ------- + The constructed object of type serialized_type + """ + + def _create_serial_object(internal_serialized_type, read_element_idx=0): + """The internal function that increments the read_element_idx + when creating nested serial objects""" + arg_len = util.get_arg_count(internal_serialized_type.__init__) - 1 + serial_init_types = get_type_hints(internal_serialized_type.__init__) + serial_init_arg_names = list(serial_init_types.keys()) + serial_init_args = [] + assert arg_len == len(serial_init_arg_names) + for si_arg_name in serial_init_arg_names: + si_arg_type = serial_init_types[si_arg_name] + if isclass(si_arg_type) and issubclass(si_arg_type, SerializableFormat): + sia, read_element_idx = _create_serial_object(si_arg_type, read_element_idx) + serial_init_args.append(sia) + else: + serial_init_args.append(deserialized_elements[read_element_idx]) + read_element_idx += 1 + return internal_serialized_type(*serial_init_args), read_element_idx + + # Just return the primary serial object + return _create_serial_object(serialized_type)[0] + + +class SerializableFormat: + """Base class to retrieve arguments on a predefined ordering""" + + def __iter__(self): + # Note class attribute definition order is preserved - see PEP 520 + for name in self.__dict__: + value = self.__getattribute__(name) + if isinstance(value, SerializableFormat): + yield from list(value) + else: + yield value + + def __getitem__(self, index): + # Note class attribute definition order is preserved - see PEP 520 + name = list(self.__dict__.keys())[index] + return self.__getattribute__(name) + + +class SerialFeatureMap(SerializableFormat): + """Specialization class to retrieve arguments of a Feature Map + (similiar to NpuFeatureMap of Vela) on a predefined ordering""" + + def __init__( + self, + data_type: str, + height: int, + width: int, + channels: int, + tile_height_0: int, + tile_height_1: int, + tile_width_0: int, + tile_address_0: tvm.tir.expr.Load, + tile_address_1: Union[tvm.tir.expr.Load, int], + tile_address_2: Union[tvm.tir.expr.Load, int], + tile_address_3: Union[tvm.tir.expr.Load, int], + scale: float, + zero_point: int, + layout: str, + stride_h: int, + stride_w: int, + stride_c: int, + ): + self.data_type = data_type + self.height = height + self.width = width + self.channels = channels + self.tile_height_0 = tile_height_0 + self.tile_height_1 = tile_height_1 + self.tile_width_0 = tile_width_0 + self.tile_address_0 = tile_address_0 + self.tile_address_1 = tile_address_1 + self.tile_address_2 = tile_address_2 + self.tile_address_3 = tile_address_3 + self.scale = scale + self.zero_point = zero_point + self.layout = layout + self.stride_h = stride_h + self.stride_w = stride_w + self.stride_c = stride_c + + +class SerialKernel(SerializableFormat): + """Specialization class to retrieve arguments of a Kernel + (similiar to NpuKernel of Vela) on a predefined ordering""" + + def __init__( + self, + width: int, + height: int, + stride_w: int, + stride_h: int, + dilation_w: int, + dilation_h: int, + ): + self.width = width + self.height = height + self.stride_w = stride_w + self.stride_h = stride_h + self.dilation_w = dilation_w + self.dilation_h = dilation_h + + +class SerialAddressRange(SerializableFormat): + """Specialization class to retrieve arguments of a AddressRange + (similiar to NpuAddressRange of Vela) on a predefined ordering""" + + def __init__(self, address: tvm.tir.expr.Load, length: int): + self.address = address + self.length = length + + +class SerialPadding(SerializableFormat): + """Specialization class to retrieve arguments of a Padding + (similiar to NpuPadding of Vela) on a predefined ordering""" + + def __init__(self, top: int, left: int, bottom: int, right: int): + self.top = top + self.left = left + self.bottom = bottom + self.right = right + + +class SerialActivation(SerializableFormat): + """Specialization class to retrieve arguments of a Activation + (similiar to NpuActivation of Vela) on a predefined ordering""" + + def __init__(self, op: str, clip_min: int, clip_max: int): + self.op = op + self.clip_min = clip_min + self.clip_max = clip_max + + +class Serial2DConvolution(SerializableFormat): + """Specialization class to retrieve arguments of + a ethosu.conv2d tir extern call on a predefined ordering""" + + def __init__( + self, + ifm: SerialFeatureMap, + ofm: SerialFeatureMap, + kernel: SerialKernel, + weight: SerialAddressRange, + weight_zero_point: int, + scale_bias: SerialAddressRange, + padding: SerialPadding, + activation: SerialActivation, + upscale: str, + ): + self.ifm = ifm + self.ofm = ofm + self.kernel = kernel + self.weight = weight + self.weight_zero_point = weight_zero_point + self.scale_bias = scale_bias + self.padding = padding + self.activation = activation + self.upscale = upscale + + +class Serial2DDepthwise(SerializableFormat): + """Specialization class to retrieve arguments of + a ethosu.depthwise2d tir extern call on a predefined ordering""" + + def __init__( + self, + ifm: SerialFeatureMap, + ofm: SerialFeatureMap, + kernel: SerialKernel, + weight: SerialAddressRange, + weight_zero_point: int, + scale_bias: SerialAddressRange, + padding: SerialPadding, + activation: SerialActivation, + upscale: str, + ): + self.ifm = ifm + self.ofm = ofm + self.kernel = kernel + self.weight = weight + self.weight_zero_point = weight_zero_point + self.scale_bias = scale_bias + self.padding = padding + self.activation = activation + self.upscale = upscale + + +class SerialCopy(SerializableFormat): + """Specialization class to retrieve arguments of + a ethosu.copy tir extern call on a predefined ordering""" + + def __init__( + self, read_address: tvm.tir.expr.Load, length: int, write_address: tvm.tir.expr.Load + ): + self.read_address = read_address + self.length = length + self.write_address = write_address + + +class SerialPooling(SerializableFormat): + """Specialization class to retrieve arguments of + a ethosu.pooling tir extern call on a predefined ordering""" + + def __init__( + self, + ifm: SerialFeatureMap, + ofm: SerialFeatureMap, + pooling_type: str, + pool_shape: SerialKernel, + padding: SerialPadding, + activation: SerialActivation, + upscale: str, + ): + self.ifm = ifm + self.ofm = ofm + self.pooling_type = pooling_type + self.pool_shape = pool_shape + self.padding = padding + self.activation = activation + self.upscale = upscale diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/transform.py b/python/tvm/relay/backend/contrib/ethosu/tir/transform.py new file mode 100644 index 000000000000..0403ce2c7e8f --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir/transform.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""Extract information from the transform operators in TIR.""" +import tvm +from .spec import SerialCopy +from .utils import get_base_address, get_op_attrs + + +def get_copy_params(stmt, producers, consumers): + """Get the parameters necessary to construct a call_extern for a copy. + + Parameters + ---------- + stmt : tvm.tir.AttrStmt + The outermost attribute statement of a copy loop nest. + producers : dict of tvm.tir.Var to tvm.tir.AttrStmt + A dictionary to associate pointers with the loop nest + that produces their values. + consumers : dict of tvm.tir.Var to tvm.tir.AttrStmt + A dictionary to associate pointers with the loop nest + that consumes their values. + + Returns + ------- + SerialCopy + The parameters needed to construct a copy. + tvm.tir.Var + The output pointer of the copy operation. + + """ + _, body = get_op_attrs(stmt) + length = body.extent + write_store = body.body + write_base = get_base_address(write_store.index) + read_load = body.body.value + read_base = get_base_address(read_load.index) + dtype = body.body.value.dtype + return ( + SerialCopy( + read_address=tvm.tir.expr.Load(dtype, read_load.buffer_var, read_base), + length=length, + write_address=tvm.tir.expr.Load(dtype, write_store.buffer_var, write_base), + ), + write_store.buffer_var, + None, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py new file mode 100644 index 000000000000..55db62edfa5a --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py @@ -0,0 +1,174 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Helper utility functions used by the TIR compiler""" +import tvm +from tvm import arith + + +# TODO(@mbaret): Formalise this with a specification +def get_weights_pointer(tir_extern_call): + """Get the weights pointer from a NPU extern call if it exists""" + if tir_extern_call.args[0] == "ethosu_conv2d": + return tir_extern_call.args[41].buffer_var + return None + + +# TODO(@mbaret): Formalise this with a specification +def get_scale_bias_pointer(tir_extern_call): + """Get the scale_bias pointer from a NPU extern call if it exists""" + if tir_extern_call.args[0] == "ethosu_conv2d": + return tir_extern_call.args[44].buffer_var + return None + + +def get_op_attrs(stmt): + """Iterate through nested attribute statements accumulating their values + in an attribute dictionary. + + The "pragma_" prefix is removed as a convenience. + + Parameters + ---------- + stmt : tvm.tir.AttrStmt + The outermost attribute statement to begin from. + + Returns + ------- + attrs : dict of str to object + The attribute dictionary. + stmt : tvm.tir.Stmt + The body after having collected the final attribute statement. + + """ + attrs = {} + while isinstance(stmt, tvm.tir.AttrStmt): + # The pragma scheduler inserts "pragma_" before all the + # attr names, this is annoying so we get rid of it + attr = stmt.attr_key.replace("pragma_", "") + attrs[attr] = stmt.value + stmt = stmt.body + + return attrs, stmt + + +def get_strides(index, stride_vars): + """Get the striding of given vars in an indexing expression. + + Parameters + ---------- + index : tvm.tir.PrimExpr + The index expression where the stride vars are present. + stride_vars : list of tvm.tir.Var + The vars to determine the striding of. + + Returns + ------- + strides : list of int + The striding of each stride var in the index expression + in the same order as the stride vars were given. + + """ + strides = [1] * len(stride_vars) + dmap = {} + + def _visit(stmt): + if isinstance(stmt, tvm.tir.Var): + dmap[stmt] = arith.IntervalSet(0, 0) + + tvm.tir.stmt_functor.post_order_visit(index, _visit) + min_value = int(arith.Analyzer().int_set(index, dmap).min_value) + for var in dmap: + if var in stride_vars: + # NOTE: Doing this using a [0, 1] interval doesn't work reliably + # Seems to be a bug + dmap[var] = arith.IntervalSet(1, 1) + max_value = int(arith.Analyzer().int_set(index, dmap).max_value) + stride = int(max_value - min_value) + i = stride_vars.index(var) + strides[i] = stride + dmap[var] = arith.IntervalSet(0, 0) + + return strides + + +def get_base_address(index): + """Determine the first (base) address accessed by an index expression. + + Parameters + ---------- + index : tvm.tir.PrimExpr + The index expression to determine the base address of. + + Returns + ------- + base_address: + The first address accessed by the index expression. + + """ + dmap = {} + + def _visit(stmt): + if isinstance(stmt, tvm.tir.Var): + dmap[stmt] = arith.IntervalSet(0, 0) + + tvm.tir.stmt_functor.post_order_visit(index, _visit) + base_address = int(arith.Analyzer().int_set(index, dmap).min_value) + return base_address + + +def get_outer_loops(stmt, layout): + """Get the outer loops of an operator. + + Parameters + ---------- + stmt : tvm.tir.For + The outermost loop. + layout : str + The output tensor layout (NHWC or NHCWB16). + + Returns + ------- + n : tvm.tir.For + The batch loop. + h : tvm.tir.For + The height loop. + w : tvm.tir.For + The width loop. + c : tvm.tir.For + The channels loop. + b : tvm.tir.For + The brick loop. None for NHWC + body : tvm.tir.Stmt + The inner body of the loops. + + """ + if layout == "NHWC": + n = stmt + h = n.body + w = h.body + c = w.body + b = tvm.tir.For(tvm.tir.Var("b", "int32"), 0, 0, 0, tvm.tir.Evaluate(0)) + return n, h, w, c, b, c.body + if layout == "NHCWB16": + n = stmt + h = n.body + cb = h.body + w = cb.body + b = w.body + return n, h, w, cb, b, b.body + return None diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py new file mode 100644 index 000000000000..1f021ed6046a --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -0,0 +1,332 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This source will contain code to convert TIR, as produced by +the Relay to TIR compilation process, to Vela API calls to +generate command stream. +""" +from typing import NamedTuple +from enum import auto +from enum import Enum +import numpy as np +import ethosu.vela.api as vapi + +import tvm +from tvm.relay.backend.contrib.ethosu import vela_api +from tvm.relay.backend.contrib.ethosu.tir import spec + + +class BufferType(Enum): + """The buffer types the codegen supports""" + + constant = auto() + input_or_output = auto() + scratch = auto() + input = auto() + output = auto() + + +class BufferInfo(NamedTuple): + """A data structure to hold metadata of the buffer""" + + # If the buffer holds constants, the values will contain that otherwise None + values: np.ndarray + shape: tvm.ir.container.Array + dtype: np.dtype + btype: BufferType + + +def extract_buffer_info(mod, param_dict): + """ + This function is to read the tvm.IRModule that + contains Relay to TIR compiled IRModule. Thereafter, + this will extract the buffer information as the shape + and constant data (if any). + + Parameters + ---------- + mod : tvm.IRModule + The NPU TIR IRModule. + param_dict : dict + A dictionary containing param idx --> const numpy.NDArray + Returns + ------- + dict + a dictionary of buffer names --> BufferInfo + """ + buffer_info = dict() + # There should only be a single function + assert len(mod.functions.items()) == 1 + primfunc = mod.functions.items()[0][1] + for idx, const_data in param_dict.items(): + param = primfunc.params[idx] + buffer_info[primfunc.buffer_map[param].data] = BufferInfo( + const_data, const_data.shape, const_data.dtype, BufferType.constant + ) + + for param in primfunc.params: + if primfunc.buffer_map[param].data not in buffer_info.keys(): + buffer_info[primfunc.buffer_map[param].data] = BufferInfo( + None, + primfunc.buffer_map[param].shape, + primfunc.buffer_map[param].dtype, + BufferType.input_or_output, + ) + + def populate_allocate_buffer_info(stmt): + if isinstance(stmt, tvm.tir.stmt.Allocate): + allocate = stmt + buffer_info[allocate.buffer_var] = BufferInfo( + None, + allocate.extents, + allocate.dtype, + BufferType.scratch, + ) + + tvm.tir.stmt_functor.post_order_visit(primfunc.body, populate_allocate_buffer_info) + + return buffer_info + + +def _convert_clip_bounds(npu_op): + """ + This function will convert the min and max value + of clip activations to non quantized floats as + expected by the API. + Parameters + ---------- + npu_op : ethosu.vela.api.NpuBlockOperation + """ + clip_min_quant = npu_op.activation.min + clip_max_quant = npu_op.activation.max + clip_min_actual = ( + clip_min_quant - npu_op.ofm.quantization.zero_point + ) * npu_op.ofm.quantization.scale_f32 + clip_max_actual = ( + clip_max_quant - npu_op.ofm.quantization.zero_point + ) * npu_op.ofm.quantization.scale_f32 + npu_op.activation.min = clip_min_actual + npu_op.activation.max = clip_max_actual + + +def translate_ethosu_conv2d(tir_extern_call): + """This function will translate a tir extern_call + as produced by Relay to TIR compilation. + Parameters + ---------- + tir_extern_call : tvm.tir.Call + This should be an tir external call that has a agreed upon ordering + for TIR Compiler. See Serial2DConvolution in + tvm/relay/backend/contrib/ethosu/tir/spec.py for the ordering. + + Returns + ------- + ethosu.vela.api.NpuConv2DOperation + The vela object containing the params of ethosu_conv2d + weights_zero_point : int + The zero point of the weights + """ + # We skip the first element as it is the extern_call function name + serial_object = spec.create_serial_object(spec.Serial2DConvolution, tir_extern_call.args[1:]) + return _create_npu_op_conv2d(serial_object) + + +def _create_npu_op_conv2d(serial_2d_convolution): + """This is a helper function to capture a list + of arguments to create Vela NpuConv2DOperation object + """ + npu_conv2d_op = vapi.NpuConv2DOperation() + npu_conv2d_op.ifm = _create_npu_feature_map(serial_2d_convolution.ifm) + npu_conv2d_op.ofm = _create_npu_feature_map(serial_2d_convolution.ofm) + npu_conv2d_op.kernel = _create_npu_kernel(serial_2d_convolution.kernel) + npu_conv2d_op.weights = [_create_npu_address_range(serial_2d_convolution.weight)] + weights_zero_point = np.int64(serial_2d_convolution.weight_zero_point.value) + npu_conv2d_op.biases = [_create_npu_address_range(serial_2d_convolution.scale_bias)] + npu_conv2d_op.padding = _create_npu_padding(serial_2d_convolution.padding) + + npu_conv2d_op.activation = _create_npu_activation(serial_2d_convolution.activation) + if ( + npu_conv2d_op.activation + and npu_conv2d_op.activation.op_type == vapi.NpuActivationOp.NONE_OR_RELU + ): + _convert_clip_bounds(npu_conv2d_op) + + npu_conv2d_op.upscale = _create_npu_resampling_mode(serial_2d_convolution.upscale) + target_accel_type = vela_api.get_target_accel_type() + block_config = vela_api.get_optimal_block_config(npu_conv2d_op, target_accel_type) + npu_conv2d_op.block_config = block_config + weights_shape_ohwi = [ + npu_conv2d_op.ofm.shape.depth, + npu_conv2d_op.kernel.height, + npu_conv2d_op.kernel.width, + npu_conv2d_op.ifm.shape.depth, + ] + npu_conv2d_op.block_traversal = vela_api.calculate_block_traversal_mode( + is_depthwise=False, + weights_shape_ohwi=weights_shape_ohwi, + ifm_bitdepth=npu_conv2d_op.ifm.data_type.size_in_bits(), + ) + return npu_conv2d_op, weights_zero_point + + +def _create_npu_feature_map(serial_feature_map): + """This is a helper function to capture a list + of arguments to create Vela NpuFeatureMap object + """ + layout_map = {"NHWC": vapi.NpuLayout.NHWC, "NHCWB16": vapi.NpuLayout.NHCWB16} + datatype_map = { + "uint8": vapi.NpuDataType.UINT8, + "int8": vapi.NpuDataType.INT8, + "uint16": vapi.NpuDataType.UINT16, + "int16": vapi.NpuDataType.INT16, + "int32": vapi.NpuDataType.INT32, + } + layout = str(serial_feature_map.layout.value) + data_type = str(serial_feature_map.data_type.value) + assert layout in layout_map.keys() + assert data_type in datatype_map.keys() + nfm = vapi.NpuFeatureMap() + nfm.data_type = datatype_map[data_type] + nfm.shape = vapi.NpuShape3D( + int(serial_feature_map.height.value), + int(serial_feature_map.width.value), + int(serial_feature_map.channels.value), + ) + nfm.tiles = vapi.NpuTileBox( + int(serial_feature_map.tile_height_0.value), + int(serial_feature_map.tile_height_1.value), + int(serial_feature_map.tile_width_0.value), + [ + serial_feature_map.tile_address_0, + serial_feature_map.tile_address_1, + serial_feature_map.tile_address_2, + serial_feature_map.tile_address_3, + ], + ) + nfm.quantization = _create_npu_quantization( + serial_feature_map.scale, serial_feature_map.zero_point + ) + nfm.layout = layout_map[layout] + nfm.strides = vapi.NpuShape3D( + int(serial_feature_map.stride_h.value), + int(serial_feature_map.stride_w.value), + int(serial_feature_map.stride_c.value), + ) + return nfm + + +def _create_npu_kernel(serial_kernel): + """This is a helper function to capture a list + of arguments to create Vela NpuKernel object + """ + nknl = vapi.NpuKernel( + w=int(serial_kernel.width.value), + h=int(serial_kernel.height.value), + stride_x=int(serial_kernel.stride_w.value), + stride_y=int(serial_kernel.stride_h.value), + dilation_x=int(serial_kernel.dilation_w.value), + dilation_y=int(serial_kernel.dilation_h.value), + ) + return nknl + + +def _create_npu_address_range(serial_address_range): + """This is a helper function to capture a list + of arguments to create Vela NpuAddressRange object + """ + addr_range = vapi.NpuAddressRange( + # region will be updated later + region=0, + address=serial_address_range.address, + length=int(serial_address_range.length.value), + ) + return addr_range + + +def _create_npu_quantization( + scale, + zero_point, +): + """This is a helper function to capture a list + of arguments to create Vela NpuQuantization object + """ + # Scale could be an ndarray if per-channel quantization is available + if not isinstance(scale, tvm.tir.expr.Load): + if isinstance(scale.value, float): + scale = np.single(scale.value) + else: + assert isinstance(scale.value.value, float) + scale = np.single(scale.value.value) + q_params = vapi.NpuQuantization(scale_f32=scale, zero_point=zero_point.value) + return q_params + + +def _create_npu_weights_zero_point( + zero_point, +): + """This is a helper function to capture the weights zero point""" + return zero_point.value + + +def _create_npu_padding(serial_padding): + """This is a helper function to capture a list + of arguments to create Vela NpuPadding object""" + padding = vapi.NpuPadding( + top=int(serial_padding.top.value), + left=int(serial_padding.left.value), + bottom=int(serial_padding.bottom.value), + right=int(serial_padding.right.value), + ) + return padding + + +def _create_npu_activation(serial_activation): + """This is a helper function to capture a list + of arguments to create Vela NpuActivation object""" + if serial_activation.op == "NONE": + return None + if ( + serial_activation.op == "CLIP" + and serial_activation.clip_min == 0 + and serial_activation.clip_max == 0 + ): + return None + op_map = { + "CLIP": vapi.NpuActivationOp.NONE_OR_RELU, + "TANH": vapi.NpuActivationOp.TANH, + "SIGMOID": vapi.NpuActivationOp.SIGMOID, + } + op = str(serial_activation.op.value) + assert op in op_map.keys() + act_op = vapi.NpuActivation(op_map[op]) + act_op.min = int(serial_activation.clip_min.value) + act_op.max = int(serial_activation.clip_max.value) + return act_op + + +def _create_npu_resampling_mode( + mode, +): + """This is a helper function to capture a list + of arguments to create Vela NpuResamplingMode object""" + mode_map = { + "NONE": vapi.NpuResamplingMode.NONE, + "NEAREST": vapi.NpuResamplingMode.NEAREST, + "TRANSPOSE": vapi.NpuResamplingMode.TRANSPOSE, + } + mode = str(mode.value) + assert mode in mode_map.keys() + return mode_map[mode] diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index e9d89d33e6f0..0919d3fe7a5f 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -21,6 +21,7 @@ Refer to the description inside such functions """ +from inspect import signature from enum import Enum from typing import Union, Tuple, Dict, Optional import numpy as np # type: ignore @@ -138,6 +139,12 @@ def round_up(a: int, b: int) -> int: return ((a + b - 1) // b) * b +def get_accelerator_config(): + """Get the variant of the accelerator to compile for""" + compiler_attrs = tvm.get_global_func("relay.ext.ethosu.get_compiler_attrs")() + return compiler_attrs.accelerator_config + + # pylint: disable=unused-argument def partition_for_ethosu( mod: tvm.ir.IRModule, params: Optional[Dict[str, tvm.runtime.NDArray]] = None, **opts @@ -173,6 +180,13 @@ def partition_for_ethosu( return mod +def get_arg_count(func): + """Helper function to get the number of + arguments in a python function""" + sig = signature(func) + return len(sig.parameters) + + def get_dim_value(layout: str, dim: int): """This is a helper function to retrieve the value of the dimension given the shape and the layout diff --git a/python/tvm/relay/backend/contrib/ethosu/vela_api.py b/python/tvm/relay/backend/contrib/ethosu/vela_api.py index 72ae18123b3d..be011bd73359 100644 --- a/python/tvm/relay/backend/contrib/ethosu/vela_api.py +++ b/python/tvm/relay/backend/contrib/ethosu/vela_api.py @@ -28,6 +28,7 @@ from ethosu.vela import api as vapi # type: ignore from tvm.relay.backend.contrib.ethosu import util # type: ignore +from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator as tirtocs # pylint: disable=invalid-name logger = logging.getLogger("Ethos-U") @@ -111,6 +112,53 @@ def _get_optimal_block_config(all_valid_block_configs: List[vapi.NpuShape3D]) -> return max_area_depth_block_configs[0] +def encode_weights(tir_extern_call, values, accel_type): + """This is an API function to compress weights by passing + a tir_extern_call to NPU Convolution operation and values. + + Parameters + ---------- + tir_extern_call : tvm.tir.Call + tir_extern_call to NPU Convolution operation + values : numpy.ndarray + The constant flattened weight data in OHWI layout + accel_type : ethosu.vela.api.NpuAccelerator + The NPU accelerator variant + + Returns + ------- + bytearray + Compressed weights + """ + supported_ops = ["ethosu_conv2d"] + op = str(tir_extern_call.args[0].value) + assert op in supported_ops + npu_op, weights_zero_point = tirtocs.translate_ethosu_conv2d(tir_extern_call) + block_config = get_optimal_block_config(npu_op, accel_type) + # The weight layout is assumed to be flat OHWI, always. + assert len(values.shape) == 1 + shape_ohwi = ( + npu_op.ofm.shape.depth, + npu_op.kernel.height, + npu_op.kernel.width, + npu_op.ifm.shape.depth, + ) + assert values.size == np.prod(shape_ohwi) + values = np.reshape(values, shape_ohwi) + return compress_weights( + weights=values, + weights_zp=weights_zero_point, + # The weight layout is assumed to be OHWI, always. + weights_layout="OHWI", + ifm_bitdepth=npu_op.ifm.data_type.size_in_bits(), + block_depth=block_config.depth, + dilation=(npu_op.kernel.dilation_x, npu_op.kernel.dilation_y), + accel_type=accel_type, + # TODO(@manupa-arm): change this when we support depthwise + is_depthwise=False, + ) + + def compress_weights( weights: np.ndarray, weights_zp: int, diff --git a/src/relay/backend/contrib/ethosu/compiler_attrs.cc b/src/relay/backend/contrib/ethosu/compiler_attrs.cc new file mode 100644 index 000000000000..6a87d11d5d6a --- /dev/null +++ b/src/relay/backend/contrib/ethosu/compiler_attrs.cc @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../../../op/make_op.h" + +namespace tvm { +namespace relay { +namespace contrib { +namespace ethosu { + +/*! \brief Attributes to store the compiler options for Arm(R) Ethos(TM)-U NPU. */ +struct EthosUCompilerConfigNode : public tvm::AttrsNode { + String accelerator_config; + + TVM_DECLARE_ATTRS(EthosUCompilerConfigNode, "ext.attrs.EthosUCompilerConfigNode") { + TVM_ATTR_FIELD(accelerator_config) + .describe( + "The class of Arm(R) Ethos(TM)-U NPU; possible values = {ethos-u55-32, ethos-u55-64, " + "ethos-u55-128, ethos-u55-256}") + .set_default("ethos-u55-256"); + } +}; + +class EthosUCompilerConfig : public Attrs { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(EthosUCompilerConfig, Attrs, EthosUCompilerConfigNode); +}; + +TVM_REGISTER_NODE_TYPE(EthosUCompilerConfigNode); +TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.ethosu.options", EthosUCompilerConfig); + +auto GetCompilerAttrs() { + auto ctx = transform::PassContext::Current(); + auto cfg = ctx->GetConfig("relay.ext.ethosu.options"); + if (!cfg.defined()) { + cfg = AttrsWithDefaultValues(); + } + return cfg; +} +TVM_REGISTER_GLOBAL("relay.ext.ethosu.get_compiler_attrs").set_body_typed(GetCompilerAttrs); + +} // namespace ethosu +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/contrib/ethosu/to_te_graph.cc b/src/relay/backend/contrib/ethosu/to_te_graph.cc new file mode 100644 index 000000000000..9646c39da089 --- /dev/null +++ b/src/relay/backend/contrib/ethosu/to_te_graph.cc @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/backend/contrib/ethosu/to_te_graph.cc + * \brief Lower a Relay function to a TE graph. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "../../compile_engine.h" +#include "../../utils.h" + +namespace tvm { +namespace relay { +namespace contrib { +namespace ethosu { + +/*! \brief Node container to represent a Tensor Expression graph. */ +class TEGraphNode : public Object { + public: + /* \brief The inputs to the graph */ + tvm::Array inputs; + /* \brief The outputs to the graph */ + tvm::Array outputs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("inputs", &inputs); + v->Visit("outputs", &outputs); + } + + static constexpr const char* _type_key = "relay.TEGraph"; + TVM_DECLARE_FINAL_OBJECT_INFO(TEGraphNode, Object); +}; + +class TEGraph : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TEGraph, ObjectRef, TEGraphNode); +}; + +TVM_REGISTER_NODE_TYPE(TEGraphNode); + +Array GetShape(const Array& shape) { + // for now, we always use int32 shape when possible + // even if the result of shape inference becomes int64. + Array res; + for (IndexExpr val : shape) { + const int64_t* pval = tir::as_const_int(val); + if (pval != nullptr) { +#ifndef TVM_INDEX_DEFAULT_I64 + ICHECK_LE(pval[0], std::numeric_limits::max()); + ICHECK_GE(pval[0], std::numeric_limits::min()); + res.push_back(IntImm(DataType::Int(32), *pval)); +#else + res.push_back(val); +#endif // TVM_INDEX_DEFAULT_I64 + } else if (val->IsInstance()) { + res.push_back(val.as()->ToVar()); + } else { + res.push_back(val); + } + } + return res; +} + +class RelayToTE : public backend::MemoizedExprTranslator> { + public: + RelayToTE() = default; + + TEGraph Lower(const Function& prim_func) { + auto graph_node = make_object(); + for (Var param : prim_func->params) { + Array inputs; + if (const auto* ttype = param->checked_type().as()) { + tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); + graph_node->inputs.push_back(tensor); + inputs.push_back(tensor); + } else { + // flatten tuple of tensor type. + const auto* tuple_type = param->type_as(); + for (Type field : tuple_type->fields) { + const auto* ttype = field.as(); + ICHECK(ttype != nullptr); + tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); + graph_node->inputs.push_back(tensor); + inputs.push_back(tensor); + } + } + memo_[param] = inputs; + } + graph_node->outputs = this->VisitExpr(prim_func->body); + return TEGraph(graph_node); + } + + Array VisitExpr_(const VarNode* op) final { + LOG(FATAL) << "Free variable " << op->name_hint(); + return {}; + } + + Array VisitExpr_(const ConstantNode* op) final { + using tir::make_const; + ICHECK(op->is_scalar()); + void* data = op->data->data; + DataType dtype = DataType(op->data->dtype); + auto value = te::compute( + {}, + [&](const Array&) { + if (dtype == DataType::Int(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Int(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Bool()) { + return make_const(dtype, static_cast(data)[0]); + } else { + LOG(FATAL) << "not handled"; + return tvm::PrimExpr(); + } + }, + "compile_engine_const", topi::kBroadcast); + return {value}; + } + + Array VisitExpr_(const CallNode* call_node) final { + static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); + ICHECK(flower_call) << "relay.backend.lower_call is not registered."; + + Array inputs; + int count_tuple = 0; + for (Expr arg : call_node->args) { + if (arg->checked_type().as()) { + ++count_tuple; + } + for (te::Tensor tensor : VisitExpr(arg)) { + inputs.push_back(tensor); + } + } + if (count_tuple) { + ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; + } + + ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; + Op op = Downcast(call_node->op); + + Array outputs; + LoweredOutput lowered_out = + (*flower_call)(GetRef(call_node), inputs, tvm::Target("llvm")); + outputs = lowered_out->outputs; + + if (outputs.size() != 1) { + const auto* tuple_type = call_node->checked_type().as(); + ICHECK(tuple_type) << "Expect output to be a tuple type"; + ICHECK_EQ(tuple_type->fields.size(), outputs.size()); + } + return outputs; + } + + Array VisitExpr_(const FunctionNode* op) final { + LOG(FATAL) << "Do not support sub function"; + return Array(); + } + + Array VisitExpr_(const LetNode* op) final { + Array val = VisitExpr(op->value); + ICHECK(!memo_.count(op->var)); + memo_[op->var] = val; + return VisitExpr(op->body); + } + + Array VisitExpr_(const TupleNode* op) final { + Array fields; + for (Expr field : op->fields) { + ICHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; + Array res = VisitExpr(field); + ICHECK_EQ(res.size(), 1); + fields.push_back(res[0]); + } + return fields; + } + + Array VisitExpr_(const TupleGetItemNode* op) final { + const auto* tuple_type = op->tuple->type_as(); + Array tuple = VisitExpr(op->tuple); + ICHECK_EQ(tuple_type->fields.size(), tuple.size()); + ICHECK_GE(op->index, 0); + ICHECK_LT(static_cast(op->index), tuple.size()); + return {tuple[op->index]}; + } +}; + +TVM_REGISTER_GLOBAL("relay.backend.contrib.ethosu.LowerToTE") + .set_body_typed([](Function prim_func) { return RelayToTE().Lower(prim_func); }); + +} // namespace ethosu +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py new file mode 100644 index 000000000000..fc795c066cb6 --- /dev/null +++ b/tests/python/contrib/test_ethosu/infra.py @@ -0,0 +1,117 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This module provides infrastructure to verify the correctness of +the command stream produced. + +Currently it will invoke vela to generate a vela-optimized tflite +in which the command stream is contained as a custom operator. +This class include methods to parse the custom operator to extract +the command stream and perform an equivalency check for single operator +test cases. +""" + +import numpy +from enum import IntEnum + +import tvm +from tvm import relay +import tvm.relay.backend.contrib.ethosu.op as ethosu_ops +from tvm.topi.nn.utils import get_pad_tuple + + +class AttachType(IntEnum): + kGroupRoot = 1 + kInline = 2 + kInlinedAlready = 3 + kScope = 4 + kScanUpdate = 5 + + +def generate_weights_data(shape, dtype): + size = 1 + for dim in shape: + size *= dim + return (numpy.arange(size) % 255).reshape(shape).astype(dtype) + + +def get_convolutional_args(call, include_buffers=False, remove_constants=False): + """A method to extract the arguments from conv2d or depthwise2d extern call.""" + args = call.args + conv_args = [] + remove_indices = [0] + + if remove_constants: + remove_indices += [41, 42, 44, 45] + + for i, arg in enumerate(args): + if i in remove_indices: + continue + elif isinstance(arg, tvm.tir.expr.IntImm) or isinstance(arg, tvm.tir.expr.FloatImm): + conv_args.append(arg.value) + elif isinstance(arg, tvm.tir.expr.Load) and not include_buffers: + conv_args.append(arg.index) + else: + conv_args.append(arg) + + return conv_args + + +def make_ethosu_conv2d( + ifm, + ifm_channels, + ofm_channels, + kernel_shape, + padding, + strides, + dilation, + activation="NONE", + ifm_layout="NHWC", + ofm_layout="NHWC", + weight_dtype="int8", +): + # conv params + weight_shape = (ofm_channels, kernel_shape[0], kernel_shape[1], ifm_channels) + padding = get_pad_tuple(padding, kernel_shape) + + scale_bias_data = generate_weights_data((weight_shape[0], 10), "uint8") + scale_bias = relay.const(scale_bias_data, dtype="uint8") + weight_data = generate_weights_data(weight_shape, "int8") + weight = relay.const(weight_data, dtype=weight_dtype) + conv = ethosu_ops.ethosu_conv2d( + ifm, + weight, + scale_bias, + lut=relay.const([], dtype="int8"), + ifm_scale=0.5, + ifm_zero_point=10, + weight_zero_point=12, + ofm_scale=0.25, + ofm_zero_point=14, + kernel_shape=kernel_shape, + ofm_channels=ofm_channels, + strides=strides, + padding=padding, + dilation=dilation, + activation=activation, + clip_min=10 if activation == "CLIP" else 0, + clip_max=100 if activation == "CLIP" else 0, + upscale="NONE", + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ) + return conv diff --git a/tests/python/contrib/test_ethosu/test_attr_passing.py b/tests/python/contrib/test_ethosu/test_attr_passing.py new file mode 100644 index 000000000000..812f68513c31 --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_attr_passing.py @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument + +import tvm +from tvm import relay +from tvm.relay.backend.contrib.ethosu import util + + +def test_compiler_attr(): + config = { + "accelerator_config": "ethos-u55-32", + } + with tvm.transform.PassContext(opt_level=3, config={"relay.ext.ethosu.options": config}): + with tvm.target.Target("c -device=micro_dev"): + assert util.get_accelerator_config() == config["accelerator_config"] + + +def test_compiler_attr_default(): + default_config = { + "accelerator_config": "ethos-u55-256", + } + with tvm.transform.PassContext(opt_level=3): + with tvm.target.Target("c -device=micro_dev"): + assert util.get_accelerator_config() == default_config["accelerator_config"] + + +if __name__ == "__main__": + test_compiler_attr() + test_compiler_attr_default() diff --git a/tests/python/contrib/test_ethosu/test_compiler.py b/tests/python/contrib/test_ethosu/test_compiler.py new file mode 100644 index 000000000000..ae649c6beeac --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_compiler.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relay +from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir + + +def test_lower_to_tir(): + data = relay.var("data", shape=(1, 1, 1, 1024), dtype="uint8") + weight = relay.var("weight", shape=(1, 1, 1024, 1001), dtype="int8") + p2 = relay.var("p2", shape=(1, 1, 1, 1), dtype="int32") + conv = relay.nn.conv2d( + data, + weight, + kernel_size=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + out_dtype="int32", + ) + multiply = relay.multiply(relay.const(-22, dtype="int32"), p2) + tile = relay.tile(multiply, reps=(1, 1, 1, 1001)) + subtract = relay.subtract(conv, tile) + func = subtract + expr = relay.Function(relay.analysis.free_vars(func), func) + mod = tvm.IRModule.from_expr(expr) + mod = relay.transform.InferType()(mod) + lower_to_tir(mod["main"]) + + +if __name__ == "__main__": + test_lower_to_tir() diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py new file mode 100644 index 000000000000..05d8d1c71618 --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -0,0 +1,273 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np + +import tvm +from tvm import tir +from tvm import script +from tvm import relay +from tvm.script import ty +from tvm.relay.testing import run_opt_pass +from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from tvm.relay.backend.contrib.ethosu.tir.scheduler import Convolution2DCompute +import pytest + +from infra import make_ethosu_conv2d + + +# fmt: off +@tvm.script.tir +class WeightStreamOnly: + def main(placeholder: ty.handle, ethosu_write: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, placeholder_5: ty.handle, placeholder_6: ty.handle, placeholder_7: ty.handle, placeholder_8: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + buffer = tir.match_buffer(placeholder_7, [112], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_4, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_2 = tir.match_buffer(placeholder_2, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_3 = tir.match_buffer(placeholder_8, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_4 = tir.match_buffer(placeholder_5, [112], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_9 = tir.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_5 = tir.match_buffer(placeholder_3, [112], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_6 = tir.match_buffer(placeholder_1, [128], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_7 = tir.match_buffer(placeholder_6, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + placeholder_global = tir.allocate([128], "uint8", "global") + placeholder_d_global = tir.allocate([32], "uint8", "global") + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_6.data, 0), 128, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_2.data, 0), 32, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, tir.load("int8", placeholder_9.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 128, 12, tir.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_5.data, 0), 112, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_1.data, 0), 32, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, tir.load("int8", placeholder_9.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 2), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 112, 12, tir.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_4.data, 0), 112, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_7.data, 0), 32, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, tir.load("int8", placeholder_9.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 4), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 112, 12, tir.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer.data, 0), 112, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_3.data, 0), 32, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, tir.load("int8", placeholder_9.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 6), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 112, 12, tir.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None +# fmt: on + + +def test_weight_stream_only(): + def _planner(te_graph, const_dict, sch): + weights = te_graph.inputs[1] + bias = te_graph.inputs[2] + out = te_graph.outputs[0] + conv_compute = Convolution2DCompute.from_output(out) + co = conv_compute.split(sch, 3, 2) + cache_weights = sch.cache_read(weights, "global", [conv_compute.conv2d]) + cache_bias = sch.cache_read(bias, "global", [conv_compute.conv2d]) + sch[cache_weights].compute_at(sch[out], co) + sch[cache_bias].compute_at(sch[out], co) + + def _get_func(): + ifm = relay.var("ifm", shape=(1, 16, 16, 32), dtype="int8") + conv = make_ethosu_conv2d( + ifm, + 32, + 8, + (1, 1), + (0, 0), + (1, 1), + (1, 1), + ) + func = relay.Function(relay.analysis.free_vars(conv), conv) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + func = _get_func() + mod, consts = lower_to_tir(func, cascader=_planner) + script = tvm.script.asscript(mod, True) + test_mod = tvm.script.from_source(script) + reference_mod = WeightStreamOnly() + tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) + + reference_const_sizes = {2: 128, 3: 32, 4: 112, 5: 32, 6: 112, 7: 32, 8: 112, 9: 32} + test_const_sizes = {} + for key, value in consts.items(): + test_const_sizes[key] = len(value) + + assert reference_const_sizes == test_const_sizes + + +# fmt: off +@tvm.script.tir +class DirectReadOnly: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, ethosu_write: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + buffer = tir.match_buffer(placeholder_3, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = tir.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_1, [592], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_2 = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_3 = tir.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + ethosu_write_2 = tir.allocate([4096], "int8", "global") + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, tir.load("int8", placeholder_5.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 592, 12, tir.load("uint8", buffer_2.data, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", buffer.data, 0), 160, 12, tir.load("uint8", buffer_3.data, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None +# fmt: on + + +def test_direct_read_only(): + def _get_func(): + ifm = relay.var("ifm", shape=(1, 16, 16, 32), dtype="int8") + conv1 = make_ethosu_conv2d( + ifm, + 32, + 16, + (1, 1), + (0, 0), + (1, 1), + (1, 1), + ) + conv2 = make_ethosu_conv2d( + conv1, + 16, + 8, + (1, 1), + (0, 0), + (1, 1), + (1, 1), + ) + func = relay.Function(relay.analysis.free_vars(conv2), conv2) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + func = _get_func() + mod, consts = lower_to_tir(func) + + script = tvm.script.asscript(mod, True) + test_mod = tvm.script.from_source(script) + reference_mod = DirectReadOnly() + tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) + + reference_const_sizes = {1: 592, 2: 160, 3: 160, 4: 80} + test_const_sizes = {} + for key, value in consts.items(): + test_const_sizes[key] = len(value) + + assert reference_const_sizes == test_const_sizes + + +# fmt: off +@tvm.script.tir +class MixedRead: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, placeholder_5: ty.handle, placeholder_6: ty.handle, placeholder_7: ty.handle, placeholder_8: ty.handle, placeholder_9: ty.handle, placeholder_10: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + buffer = tir.match_buffer(placeholder_7, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_5, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_2 = tir.match_buffer(placeholder_3, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_3 = tir.match_buffer(placeholder_4, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_4 = tir.match_buffer(placeholder_9, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_5 = tir.match_buffer(placeholder_6, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_11 = tir.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_6 = tir.match_buffer(placeholder_1, [592], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_7 = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_8 = tir.match_buffer(placeholder_8, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_9 = tir.match_buffer(placeholder_10, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + ethosu_write_2 = tir.allocate([4096], "int8", "global") + placeholder_global = tir.allocate([80], "uint8", "global") + placeholder_d_global = tir.allocate([32], "uint8", "global") + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, tir.load("int8", placeholder_11.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", buffer_6.data, 0), 592, 12, tir.load("uint8", buffer_7.data, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_2.data, 0), 80, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_3.data, 0), 32, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 80, 12, tir.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_1.data, 0), 80, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_5.data, 0), 32, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 2), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 80, 12, tir.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer.data, 0), 80, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_8.data, 0), 32, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 4), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 80, 12, tir.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_4.data, 0), 80, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_9.data, 0), 32, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 6), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 80, 12, tir.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None +# fmt: on + + +def test_mixed_read(): + def _planner(te_graph, const_dict, sch): + weight = te_graph.inputs[4] + scale_bias = te_graph.inputs[5] + out = te_graph.outputs[0] + conv_compute = Convolution2DCompute.from_output(out) + co = conv_compute.split(sch, 3, 2) + cache_weight = sch.cache_read(weight, "global", [conv_compute.conv2d]) + cache_scale_bias = sch.cache_read(scale_bias, "global", [conv_compute.conv2d]) + sch[cache_weight].compute_at(sch[out], co) + sch[cache_scale_bias].compute_at(sch[out], co) + + def _get_func(): + ifm = relay.var("ifm", shape=(1, 16, 16, 32), dtype="int8") + conv1 = make_ethosu_conv2d( + ifm, + 32, + 16, + (1, 1), + (0, 0), + (1, 1), + (1, 1), + ) + conv2 = make_ethosu_conv2d( + conv1, + 16, + 8, + (1, 1), + (0, 0), + (1, 1), + (1, 1), + ) + func = relay.Function(relay.analysis.free_vars(conv2), conv2) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + func = _get_func() + mod, consts = lower_to_tir(func, cascader=_planner) + + script = tvm.script.asscript(mod, True) + test_mod = tvm.script.from_source(script) + reference_mod = MixedRead() + tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) + + reference_const_sizes = { + 1: 592, + 2: 160, + 4: 80, + 5: 32, + 6: 80, + 7: 32, + 8: 80, + 9: 32, + 10: 80, + 11: 32, + } + test_const_sizes = {} + for key, value in consts.items(): + test_const_sizes[key] = len(value) + + assert reference_const_sizes == test_const_sizes + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_extract_constants.py b/tests/python/contrib/test_ethosu/test_extract_constants.py new file mode 100644 index 000000000000..48266b54a605 --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_extract_constants.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relay +from tvm.relay.testing import run_opt_pass +from tvm.relay.backend.contrib.ethosu.tir.compiler import extract_constants + +import numpy as np + + +def test_extract_constants_single(): + def _get_func(): + var_input = relay.var("data", shape=(10, 10), dtype="uint8") + const_data = np.random.uniform(0, 255, (10, 10)).astype("uint8") + const_input = relay.const(const_data, dtype="uint8") + out = relay.add(var_input, const_input) + func = relay.Function(relay.analysis.free_vars(out), out) + func = run_opt_pass(func, relay.transform.InferType()) + return func, const_input + + def _expected(): + var_input1 = relay.var("data", shape=(10, 10), dtype="uint8") + var_input2 = relay.var("p1", shape=(10, 10), dtype="uint8") + out = relay.add(var_input1, var_input2) + func = relay.Function(relay.analysis.free_vars(out), out) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + func, const = _get_func() + new_func, const_dict = extract_constants(func) + assert tvm.ir.structural_equal(new_func, _expected()) + assert 1 in const_dict + assert (const_dict[1] == const.data.asnumpy()).all() + + +def test_extract_constants_multi(): + def _get_func(): + var_input1 = relay.var("data1", shape=(10, 10), dtype="uint8") + var_input2 = relay.var("data2", shape=(10, 10), dtype="uint8") + const_data_1 = np.random.uniform(0, 255, (10, 10)).astype("uint8") + const_data_2 = np.random.uniform(0, 255, (10, 10)).astype("uint8") + const_data_3 = np.random.uniform(0, 255, (10, 10)).astype("uint8") + const_data_4 = np.random.uniform(0, 255, (10, 10)).astype("uint8") + const_input_1 = relay.const(const_data_1, dtype="uint8") + const_input_2 = relay.const(const_data_2, dtype="uint8") + const_input_3 = relay.const(const_data_3, dtype="uint8") + const_input_4 = relay.const(const_data_4, dtype="uint8") + out = relay.add(var_input1, var_input2) + out = relay.add(out, const_input_1) + out = relay.add(out, const_input_2) + out = relay.add(out, const_input_3) + out = relay.add(out, const_input_4) + func = relay.Function(relay.analysis.free_vars(out), out) + func = run_opt_pass(func, relay.transform.InferType()) + return func, [const_input_1, const_input_2, const_input_3, const_input_4] + + def _expected(): + var_input1 = relay.var("data1", shape=(10, 10), dtype="uint8") + var_input2 = relay.var("data2", shape=(10, 10), dtype="uint8") + var_input3 = relay.var("p1", shape=(10, 10), dtype="uint8") + var_input4 = relay.var("p2", shape=(10, 10), dtype="uint8") + var_input5 = relay.var("p3", shape=(10, 10), dtype="uint8") + var_input6 = relay.var("p4", shape=(10, 10), dtype="uint8") + out = relay.add(var_input1, var_input2) + out = relay.add(out, var_input3) + out = relay.add(out, var_input4) + out = relay.add(out, var_input5) + out = relay.add(out, var_input6) + func = relay.Function(relay.analysis.free_vars(out), out) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + func, consts = _get_func() + new_func, const_dict = extract_constants(func) + assert tvm.ir.structural_equal(new_func, _expected()) + for i, const in enumerate(consts): + assert i + 2 in const_dict + assert (const_dict[i + 2] == consts[i].data.asnumpy()).all() + + +if __name__ == "__main__": + test_extract_constants_single() + test_extract_constants_multi() diff --git a/tests/python/contrib/test_ethosu/test_lower_to_te.py b/tests/python/contrib/test_ethosu/test_lower_to_te.py new file mode 100644 index 000000000000..18bde7ebd7c0 --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_lower_to_te.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument + +import tvm +from tvm import relay +from tvm.relay.testing import run_opt_pass +from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_te +from tvm.relay.backend.contrib.ethosu.tir.scheduler import Convolution2DCompute +import tvm.relay.backend.contrib.ethosu.op as ethosu_ops + + +def test_ethosu_conv2d(): + ifm = relay.var("ifm", shape=(1, 10, 20, 30), dtype="uint8") + weight = relay.var("weight", shape=(40, 3, 3, 30), dtype="uint8") + scale_bias = relay.var("scale_bias", shape=(40, 10), dtype="uint8") + lut = relay.var("lut", shape=(), dtype="uint8") + conv = ethosu_ops.ethosu_conv2d( + ifm, + weight, + scale_bias, + lut, + ifm_scale=0.5, + ifm_zero_point=10, + weight_zero_point=12, + ofm_scale=0.25, + ofm_zero_point=14, + ofm_channels=40, + padding=(1, 1, 1, 1), + kernel_shape=(3, 3), + strides=(1, 1), + dilation=(1, 1), + ) + expr = relay.Function(relay.analysis.free_vars(conv), conv) + mod = tvm.IRModule.from_expr(expr) + mod = relay.transform.InferType()(mod) + lowered = lower_to_te(mod["main"]) + assert len(lowered.outputs) == 1 + assert len(lowered.inputs) == 4 + conv2d_compute = Convolution2DCompute.from_output(lowered.outputs[0]) + assert conv2d_compute.conv2d.name == "ethosu_conv2d" + input_shapes = set() + for inp in lowered.inputs: + input_shapes.add(tuple([x.value for x in inp.shape])) + assert input_shapes == {(40, 10), (1, 10, 20, 30), (40, 3, 3, 30), ()} + + +if __name__ == "__main__": + test_ethosu_conv2d() diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py new file mode 100644 index 000000000000..b8889e25fe9c --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -0,0 +1,547 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.script +from tvm.script import tir, ty +from tvm import relay +from tvm.relay.testing import run_opt_pass +from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from tvm.relay.backend.contrib.ethosu.tir.scheduler import total_cascader +from infra import make_ethosu_conv2d, get_convolutional_args + +import pytest + + +@pytest.mark.parametrize( + "trial", + [ + [(1, 8, 8, 3), 3, 16, (1, 1), (2, 1), (1, 1), (1, 1), "TANH", "NHWC", "NHWC"], + [(1, 8, 8, 3), 3, 16, (1, 1), (0, 0), (1, 1), (1, 1), "NONE", "NHWC", "NHWC"], + [(1, 1, 1, 1), 1, 16, (1, 1), (0, 0), (1, 1), (1, 1), "CLIP", "NHWC", "NHWC"], + [(1, 7, 9, 4), 4, 13, (3, 2), (1, 2), (2, 1), (1, 2), "SIGMOID", "NHWC", "NHWC"], + [(1, 8, 2, 8, 16), 18, 12, (1, 1), (2, 1), (1, 1), (1, 1), "CLIP", "NHCWB16", "NHWC"], + [(1, 7, 9, 4), 4, 71, (3, 2), (1, 2), (2, 1), (1, 2), "CLIP", "NHWC", "NHCWB16"], + [(1, 4, 12, 9, 16), 182, 67, (2, 3), (6, 3), (2, 2), (1, 1), "CLIP", "NHCWB16", "NHCWB16"], + [(1, 7, 9, 4), 4, 13, (3, 2), (1, 2), (2, 1), (2, 2), "CLIP", "NHWC", "NHWC"], + [(1, 7, 9, 4), 4, 71, (3, 2), (1, 2), (2, 1), (2, 2), "CLIP", "NHWC", "NHCWB16"], + [ + (1, 13, 12, 19, 16), + 182, + 67, + (1, 3), + (5, 3), + (2, 1), + (2, 1), + "CLIP", + "NHCWB16", + "NHCWB16", + ], + ], +) +def test_conv2d_single(trial): + def _get_func( + ifm_shape, + ifm_channels, + ofm_channels, + kernel_shape, + padding, + strides, + dilation, + activation, + ifm_layout, + ofm_layout, + ): + ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + conv = make_ethosu_conv2d( + ifm, + ifm_channels, + ofm_channels, + kernel_shape, + padding, + strides, + dilation, + activation, + ifm_layout, + ofm_layout, + ) + func = relay.Function(relay.analysis.free_vars(conv), conv) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + # TODO(@mbaret) Fix the tests for these known failures + # These are anticipated to actually be correct, just a testing issue to do with + # equivalent convolutions. + known_failures = [ + [(1, 3, 12, 9, 16), 182, 67, (2, 3), (1, 3), (2, 2), (1, 1), "CLIP", "NHCWB16", "NHCWB16"], + [(1, 2, 12, 9, 16), 182, 67, (1, 3), (6, 3), (2, 2), (1, 1), "CLIP", "NHCWB16", "NHCWB16"], + ] + func = _get_func(*trial) + mod, _ = lower_to_tir(func) + data = [] + + def _visit(stmt): + if isinstance(stmt, tvm.tir.Call): + data.append(get_convolutional_args(stmt, remove_constants=True)) + + tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _visit) + ( + ifm_shape, + ifm_channels, + ofm_channels, + kernel_shape, + padding, + strides, + dilation, + activation, + ifm_layout, + ofm_layout, + ) = trial + dilated_kernel_h = (kernel_shape[0] - 1) * dilation[0] + 1 + dilated_kernel_w = (kernel_shape[1] - 1) * dilation[1] + 1 + if ifm_layout == "NHWC": + ifm_stride_c = 1 + ifm_stride_w = ifm_shape[3] + ifm_stride_h = ifm_shape[2] * ifm_shape[3] + ofm_height = (ifm_shape[1] - dilated_kernel_h + padding[0] + padding[0]) // strides[0] + 1 + ofm_width = (ifm_shape[2] - dilated_kernel_w + padding[1] + padding[1]) // strides[1] + 1 + else: + ifm_stride_w = 16 + ifm_stride_c = 16 * ifm_shape[3] + ifm_stride_h = 16 * ifm_shape[2] * ifm_shape[3] + ofm_height = (ifm_shape[1] - dilated_kernel_h + padding[0] + padding[0]) // strides[0] + 1 + ofm_width = (ifm_shape[3] - dilated_kernel_w + padding[1] + padding[1]) // strides[1] + 1 + + if ofm_layout == "NHWC": + ofm_stride_c = 1 + ofm_stride_w = ofm_channels if ofm_width > 1 else 1 + ofm_stride_h = ofm_channels * ofm_width if ofm_height > 1 else 1 + else: + ofm_stride_w = 16 + ofm_stride_c = 16 * ofm_width + ofm_stride_h = 16 * ofm_width * ((ofm_channels - 1) // 16 + 1) + + answer = [ + "int8", + ifm_shape[1], + ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3], + ifm_channels, + ifm_shape[1], + 0, + ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3], + 0, + 0, + 0, + 0, + 0.5, + 10, + ifm_layout, + ifm_stride_h, + ifm_stride_w, + ifm_stride_c, + "int8", + ofm_height, + ofm_width, + ofm_channels, + ofm_height, + 0, + ofm_width, + 0, + 0, + 0, + 0, + 0.25, + 14, + ofm_layout, + ofm_stride_h, + ofm_stride_w, + ofm_stride_c, + kernel_shape[1], + kernel_shape[0], + strides[1], + strides[0], + dilation[1], + dilation[0], + 12, + padding[0], + padding[1], + padding[0], + padding[1], + activation, + 10 if activation == "CLIP" else 0, + 100 if activation == "CLIP" else 0, + "NONE", + ] + assert data[0] == answer, data[0] + + +# fmt: off +@tvm.script.tir +class Conv2dDoubleCascade1: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, ethosu_write: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + buffer = tir.match_buffer(placeholder_3, [304], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = tir.match_buffer(placeholder, [1, 8, 8, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_2 = tir.match_buffer(placeholder_2, [320], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 8, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_3 = tir.match_buffer(placeholder_1, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + ethosu_write_2 = tir.allocate([1024], "int8", "global") + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, tir.load("int8", placeholder_5.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", buffer_3.data, 0), 160, 12, tir.load("uint8", buffer_2.data, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", buffer.data, 0), 304, 12, tir.load("uint8", buffer_1.data, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, tir.load("int8", placeholder_5.data, 12), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", buffer_3.data, 0), 160, 12, tir.load("uint8", buffer_2.data, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, tir.load("int8", ethosu_write_1.data, 32), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", buffer.data, 0), 304, 12, tir.load("uint8", buffer_1.data, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None + + +@tvm.script.tir +class Conv2dDoubleCascade2: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, ethosu_write: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + buffer = tir.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_2, [320], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_2 = tir.match_buffer(placeholder_1, [1312], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_3 = tir.match_buffer(placeholder_3, [2608], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = tir.match_buffer(placeholder, [1, 8, 8, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 8, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + ethosu_write_2 = tir.allocate([1536], "int8", "global") + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, tir.load("int8", placeholder_5.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, tir.load("int8", ethosu_write_2, 256), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_2.data, 0), 1312, 12, tir.load("uint8", buffer_1.data, 0), 320, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, tir.load("int8", ethosu_write_2, 256), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_3.data, 0), 2608, 12, tir.load("uint8", buffer.data, 0), 80, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, tir.load("int8", placeholder_5.data, 48), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_2.data, 0), 1312, 12, tir.load("uint8", buffer_1.data, 0), 320, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, tir.load("int8", ethosu_write_1.data, 256), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_3.data, 0), 2608, 12, tir.load("uint8", buffer.data, 0), 80, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None + + +@tvm.script.tir +class Conv2dDoubleCascade3: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, ethosu_write: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 20, 4, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer = tir.match_buffer(placeholder_3, [1744], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_2 = tir.match_buffer(placeholder_2, [320], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_3 = tir.match_buffer(placeholder_1, [880], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = tir.match_buffer(placeholder, [1, 16, 16, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + ethosu_write_2 = tir.allocate([2560], "int8", "global") + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 16, 3, 8, 0, 16, tir.load("int8", placeholder_5.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 8, 8, 32, 8, 0, 8, tir.load("int8", ethosu_write_2, 512), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, tir.load("uint8", buffer_3.data, 0), 880, 12, tir.load("uint8", buffer_2.data, 0), 320, 2, 1, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 8, 32, 8, 0, 8, tir.load("int8", ethosu_write_2, 512), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, tir.load("uint8", buffer.data, 0), 1744, 12, tir.load("uint8", buffer_1.data, 0), 80, 2, 1, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 12, 16, 3, 12, 0, 16, tir.load("int8", placeholder_5.data, 192), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 10, 8, 32, 10, 0, 8, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, tir.load("uint8", buffer_3.data, 0), 880, 12, tir.load("uint8", buffer_2.data, 0), 320, 0, 1, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 10, 8, 32, 10, 0, 8, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, tir.load("int8", ethosu_write_1.data, 256), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, tir.load("uint8", buffer.data, 0), 1744, 12, tir.load("uint8", buffer_1.data, 0), 80, 0, 1, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 4, 16, 3, 4, 0, 16, tir.load("int8", placeholder_5.data, 576), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 4, 8, 32, 4, 0, 8, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, tir.load("uint8", buffer_3.data, 0), 880, 12, tir.load("uint8", buffer_2.data, 0), 320, 0, 1, 2, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 4, 8, 32, 4, 0, 8, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 4, 8, 4, 0, 4, tir.load("int8", ethosu_write_1.data, 512), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, tir.load("uint8", buffer.data, 0), 1744, 12, tir.load("uint8", buffer_1.data, 0), 80, 0, 1, 2, 0, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None + + +@tvm.script.tir +class Conv2dDoubleCascade4: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, ethosu_write: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + buffer = tir.match_buffer(placeholder_1, [1456], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_2, [352], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = tir.match_buffer(placeholder, [1, 8, 1, 8, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 2, 8, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_2 = tir.match_buffer(placeholder_4, [272], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_3 = tir.match_buffer(placeholder_3, [11040], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + ethosu_write_2 = tir.allocate([2304], "int8", "global") + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, tir.load("int8", placeholder_5.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, tir.load("int8", ethosu_write_2, 384), 0, 0, 0, tir.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer.data, 0), 1456, 12, tir.load("uint8", buffer_1.data, 0), 352, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, tir.load("int8", ethosu_write_2, 384), 0, 0, 0, tir.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_3.data, 0), 11040, 12, tir.load("uint8", buffer_2.data, 0), 272, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, tir.load("int8", placeholder_5.data, 256), 0, 0, 0, tir.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer.data, 0), 1456, 12, tir.load("uint8", buffer_1.data, 0), 352, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, tir.load("int8", ethosu_write_1.data, 1024), 0, 0, 0, tir.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_3.data, 0), 11040, 12, tir.load("uint8", buffer_2.data, 0), 272, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None +# fmt: on + + +@pytest.mark.parametrize( + "trial", + [ + [ + Conv2dDoubleCascade1(), + (1, 8, 8, 3), + 3, + 32, + 8, + (1, 1), + (0, 0), + (1, 1), + (1, 1), + "NHWC", + (1, 8, 4, 8), + ], + [ + Conv2dDoubleCascade2(), + (1, 8, 8, 3), + 3, + 32, + 8, + (3, 3), + (1, 1), + (1, 1), + (1, 1), + "NHWC", + (1, 4, 8, 8), + ], + [ + Conv2dDoubleCascade3(), + (1, 16, 16, 3), + 3, + 32, + 8, + (3, 2), + (2, 1), + (1, 2), + (1, 2), + "NHWC", + (1, 8, 4, 8), + ], + [ + Conv2dDoubleCascade4(), + (1, 8, 1, 8, 16), + 3, + 35, + 26, + (3, 3), + (1, 1), + (1, 1), + (1, 1), + "NHCWB16", + (1, 4, 2, 8, 16), + ], + ], +) +def test_conv2d_double_cascade(trial): + def _get_func( + ifm_shape, + ifm_channels, + mid_channels, + ofm_channels, + kernel_shape, + padding, + strides, + dilation, + layout, + ): + ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + conv1 = make_ethosu_conv2d( + ifm, + ifm_channels, + mid_channels, + kernel_shape, + padding, + strides, + dilation, + "NONE", + layout, + layout, + ) + conv2 = make_ethosu_conv2d( + conv1, + mid_channels, + ofm_channels, + kernel_shape, + padding, + strides, + dilation, + "NONE", + layout, + layout, + ) + func = relay.Function(relay.analysis.free_vars(conv2), conv2) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + reference_mod = trial[0] + params = trial[1:] + func = _get_func(*params[:-1]) + mod, _ = lower_to_tir(func, cascader=total_cascader(params[-1])) + script = tvm.script.asscript(mod, True) + mod = tvm.script.from_source(script) + tvm.ir.assert_structural_equal(mod["main"], reference_mod["main"], True) + + +# fmt: off +@tvm.script.tir +class Conv2dInlineCopy1: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + buffer = tir.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_3 = tir.match_buffer(placeholder, [1, 10, 12, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 8, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 0, 8, tir.load("int8", placeholder_3.data, 120), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 96, 8, 1, "int8", 8, 8, 16, 8, 0, 8, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer.data, 0), 848, 12, tir.load("uint8", buffer_1.data, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None + + +@tvm.script.tir +class Conv2dInlineCopy2: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 3, 5, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + placeholder_3 = tir.match_buffer(placeholder, [1, 7, 9, 5], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_1, [656], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 3, 5, 3, 3, 0, 5, tir.load("int8", placeholder_3.data, 146), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 45, 5, 1, "int8", 3, 5, 16, 3, 0, 5, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 80, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 656, 12, tir.load("uint8", buffer.data, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None +# fmt: on + + +@pytest.mark.parametrize( + "trial", + [ + [Conv2dInlineCopy1(), (1, 10, 12, 8), (0, 1, 3, 0), (1, 9, 11, 4)], + [Conv2dInlineCopy2(), (1, 7, 9, 5), (0, 3, 2, 1), (1, 6, 7, 4)], + ], +) +def test_conv2d_inline_copy(trial): + def _get_func(ifm_shape, lower, upper, ofm_channels=16): + ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + sliced = relay.strided_slice(ifm, lower, upper) + conv = make_ethosu_conv2d( + sliced, upper[3] - lower[3], ofm_channels, (3, 3), (1, 1), (1, 1), (1, 1) + ) + func = relay.Function(relay.analysis.free_vars(conv), conv) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + reference_mod = trial[0] + params = trial[1:] + func = _get_func(*params) + mod, _ = lower_to_tir(func) + script = tvm.script.asscript(mod, True) + mod = tvm.script.from_source(script) + tvm.ir.assert_structural_equal(mod["main"], reference_mod["main"], True) + + +# fmt: off +@tvm.script.tir +class Conv2dInlineReshape1: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 6, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_3 = tir.match_buffer(placeholder, [4, 6, 8, 1], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, tir.load("int8", placeholder_3.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 848, 12, tir.load("uint8", buffer.data, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, tir.load("int8", placeholder_3.data, 72), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, tir.load("int8", ethosu_write_1.data, 384), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 848, 12, tir.load("uint8", buffer.data, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None + + +@tvm.script.tir +class Conv2dInlineReshape2: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 6, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_3 = tir.match_buffer(placeholder, [1, 24, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, tir.load("int8", placeholder_3.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 848, 12, tir.load("uint8", buffer.data, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, tir.load("int8", placeholder_3.data, 72), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, tir.load("int8", ethosu_write_1.data, 384), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 848, 12, tir.load("uint8", buffer.data, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None + + +@tvm.script.tir +class Conv2dInlineReshape3: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + buffer = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_3 = tir.match_buffer(placeholder, [192, 1], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 6, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, tir.load("int8", placeholder_3.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 848, 12, tir.load("uint8", buffer.data, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, tir.load("int8", placeholder_3.data, 72), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, tir.load("int8", ethosu_write_1.data, 384), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 848, 12, tir.load("uint8", buffer.data, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None + + +@tvm.script.tir +class Conv2dInlineReshape4: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 6, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_3 = tir.match_buffer(placeholder, [192], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, tir.load("int8", placeholder_3.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 848, 12, tir.load("uint8", buffer.data, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, tir.load("int8", placeholder_3.data, 72), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, tir.load("int8", ethosu_write_1.data, 384), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 848, 12, tir.load("uint8", buffer.data, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None +# fmt: on + + +@pytest.mark.parametrize( + "trial", + [ + [Conv2dInlineReshape1(), (4, 6, 8, 1), (1, 8, 6, 4), "NHWC"], + [Conv2dInlineReshape2(), (1, 4 * 6, 8), (1, 8, 6, 4), "NHWC"], + [Conv2dInlineReshape3(), (4 * 6 * 8, 1), (1, 8, 6, 4), "NHWC"], + [Conv2dInlineReshape4(), (4 * 6 * 8,), (1, 8, 6, 4), "NHWC"], + ], +) +def test_conv2d_inline_reshape(trial): + def _get_func(ifm_shape, reshaped, ifm_layout): + ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + ifm_reshaped = relay.reshape(ifm, reshaped) + conv = make_ethosu_conv2d( + ifm_reshaped, reshaped[3], 16, (3, 3), (1, 1), (1, 1), (1, 1), "NONE", ifm_layout + ) + func = relay.Function(relay.analysis.free_vars(conv), conv) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + reference_mod = trial[0] + params = trial[1:] + func = _get_func(*params) + mod, _ = lower_to_tir(func, cascader=total_cascader((1, 4, 6, 16))) + script = tvm.script.asscript(mod, True) + mod = tvm.script.from_source(script) + tvm.ir.assert_structural_equal(mod["main"], reference_mod["main"], True) + + +# TODO(@mbaret) Fix this case +@pytest.mark.xfail(raises=TypeError, strict=True) +def test_conv2d_big_pad(): + def _get_func(): + ifm_shape = (1, 2, 2, 8) + ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + conv = make_ethosu_conv2d(ifm, ifm_shape[3], 16, (1, 1), (7, 7), (1, 1), (1, 1), "NHWC") + func = relay.Function(relay.analysis.free_vars(conv), conv) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + func = _get_func() + mod, _ = lower_to_tir(func, cascader=total_cascader((1, 4, 4, 16))) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py new file mode 100644 index 000000000000..911e150593cc --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.script +from tvm.script import tir, ty +from tvm import relay +from tvm.relay.testing import run_opt_pass +from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants + +from infra import make_ethosu_conv2d + + +# fmt: off +@tvm.script.tir +class ReferenceModule: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + buffer = tir.match_buffer(placeholder_2, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_3 = tir.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_1, [304], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + placeholder_global = tir.allocate([304], "uint8", "global") + placeholder_d_global = tir.allocate([80], "uint8", "global") + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_1.data, 0), 304, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer.data, 0), 80, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, tir.load("int8", placeholder_3.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 304, 12, tir.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None +# fmt: on + + +def test_copy(): + def _get_func(): + data = relay.var("data", shape=(1, 16, 16, 32), dtype="int8") + conv = make_ethosu_conv2d( + data, + 32, + 8, + (1, 1), + (0, 0), + (1, 1), + (1, 1), + ) + func = relay.Function(relay.analysis.free_vars(conv), conv) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + func = _get_func() + mod, _ = lower_to_tir(func, cascader=copy_constants()) + + script = tvm.script.asscript(mod, True) + test_mod = tvm.script.from_source(script) + reference_mod = ReferenceModule() + tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py new file mode 100644 index 000000000000..bef665fe2444 --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -0,0 +1,148 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relay +from tvm.relay.testing import run_opt_pass +from tvm import te, topi +from tvm.relay.backend.contrib.ethosu.tir.scheduler import ( + tile_nd, + schedule_pragmas, + inline_no_ops, + total_cascader, + copy_constants, + schedule_cache_reads, +) +from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_te, extract_constants +from infra import AttachType, make_ethosu_conv2d + + +class TestTEGraph: + def __init__(self, inputs, outputs): + self.inputs = inputs + self.outputs = outputs + + +def test_tile_nd(): + input = te.placeholder((12, 12), dtype="uint8", name="input") + out = topi.nn.relu(input) + sch = te.create_schedule([out.op]) + outer_iters, inner_iters = tile_nd(sch, out, (3, 4)) + assert tuple(sch[out].leaf_iter_vars) == (*outer_iters, *inner_iters) + + +def test_schedule_pragmas(): + input = te.placeholder((12, 12), dtype="uint8", name="input") + out = te.compute( + (12, 12), + lambda i, j: input[i, j], + attrs={ + "op": "unity", + "info": 1, + }, + ) + sch = te.create_schedule([out.op]) + sch[out].split(out.op.axis[0], 3) + schedule_pragmas(sch) + iter_var = sch[out].leaf_iter_vars[1] + assert list(sch[out].iter_var_attrs[iter_var].pragma_keys) == ["op", "info"] + assert list(sch[out].iter_var_attrs[iter_var].pragma_values) == ["unity", 1] + + +def test_schedule_pragmas_for_const(): + input = te.placeholder((12, 12), dtype="uint8", name="input") + const = te.compute((), lambda: 2) + add = topi.add(input, const) + sch = te.create_schedule([add.op]) + schedule_pragmas(sch) + + +def test_inline_no_ops(): + input = relay.var("input", shape=(12, 12), dtype="uint8") + slice = relay.strided_slice(input, [0, 0], [6, 6]) + relu1 = relay.nn.relu(slice) + reshape = relay.reshape(relu1, (36,)) + relu2 = relay.nn.relu(reshape) + func = relay.Function(relay.analysis.free_vars(relu2), relu2) + func = run_opt_pass(func, relay.transform.InferType()) + + te_graph = lower_to_te(func) + sch = te.create_schedule([te_graph.outputs[0].op]) + inline_no_ops(te_graph, sch) + reshape_tensor = te_graph.outputs[0].op.input_tensors[0] + slice_tensor = reshape_tensor.op.input_tensors[0].op.input_tensors[0] + assert sch[reshape_tensor].attach_type == AttachType.kInline + assert sch[slice_tensor].attach_type == AttachType.kInline + + +def test_total_cascader(): + input = te.placeholder((12, 12), dtype="uint8", name="input") + relu1 = topi.nn.relu(input) + relu2 = topi.nn.relu(relu1) + relu3 = topi.nn.relu(relu2) + sch = te.create_schedule([relu3.op]) + cascader = total_cascader((4, 4)) + cascader(TestTEGraph([input], [relu3]), {}, sch) + assert sch[relu1].attach_type == AttachType.kScope + assert sch[relu2].attach_type == AttachType.kScope + assert sch[relu3].attach_type == AttachType.kGroupRoot + # Check that the attaches are at the correct iter var + assert sch[relu1].attach_ivar == sch[relu3].leaf_iter_vars[1] + assert sch[relu2].attach_ivar == sch[relu3].leaf_iter_vars[1] + + +def test_copy_constants(): + ifm_a = relay.var("IFM_A", shape=(1, 26, 26, 32), dtype="int8") + conv_a = make_ethosu_conv2d(ifm_a, 32, 8, (3, 3), (0, 0), (1, 1), (1, 1)) + conv_b = make_ethosu_conv2d(conv_a, 8, 4, (1, 1), (0, 0), (1, 1), (1, 1)) + func = relay.Function(relay.analysis.free_vars(conv_b), conv_b) + func = run_opt_pass(func, relay.transform.InferType()) + + func, const_dict = extract_constants(func) + te_graph = lower_to_te(func) + + sch = te.create_schedule([te_graph.outputs[0].op]) + planner = copy_constants() + planner(te_graph, const_dict, sch) + assert len(sch.stages) == 21 + assert ".global" in sch.stages[5].op.name + assert ".global" in sch.stages[7].op.name + assert ".global" in sch.stages[15].op.name + assert ".global" in sch.stages[17].op.name + + +def test_schedule_cache_reads(): + a = te.placeholder((12, 12), dtype="uint8", name="a") + b = te.placeholder((12, 12), dtype="uint8", name="b") + add = topi.add(a, b) + sch = te.create_schedule([add.op]) + cr = sch.cache_read(b, "global", [add]) + schedule_cache_reads(sch) + assert len(sch.stages) == 4 + assert len(sch[cr].leaf_iter_vars) == 1 + iv = sch[cr].leaf_iter_vars[0] + assert list(sch[cr].iter_var_attrs[iv].pragma_keys) == ["op"] + assert list(sch[cr].iter_var_attrs[iv].pragma_values) == ["ethosu_copy"] + + +if __name__ == "__main__": + test_tile_nd() + test_schedule_pragmas() + test_schedule_pragmas_for_const() + test_inline_no_ops() + test_total_cascader() + test_copy_constants() + test_schedule_cache_reads() diff --git a/tests/python/contrib/test_ethosu/test_vela_api.py b/tests/python/contrib/test_ethosu/test_vela_api.py index d9b22d10e9e2..a86dd919d5ca 100644 --- a/tests/python/contrib/test_ethosu/test_vela_api.py +++ b/tests/python/contrib/test_ethosu/test_vela_api.py @@ -24,7 +24,9 @@ import tvm from tvm import tir from tvm.script import ty +from tvm.tir import stmt_functor from tvm.relay.backend.contrib.ethosu import vela_api +import tvm.relay.backend.contrib.ethosu.tir_to_cs_translator as tirtocs ACCEL_TYPES = [ vapi.NpuAccelerator.Ethos_U55_256, @@ -451,5 +453,104 @@ def create_mock(test_vec): verify(_test_vec, mock_obj, packed_biases) +def extract_ethosu_conv2d_extern_calls(mod): + """This function will obtain all ethosu_conv2d + calls from a NPU TIR module + + Parameters + ---------- + mod : tvm.IRModule + This is a NPU TIR Module + + Returns + ------- + list + List of tvm.tir.Call objects + that are tir extern calls + for ethosu_conv2d + """ + # There should only be a single function + assert len(mod.functions.items()) == 1 + primfunc = mod.functions.items()[0][1] + + ethosu_conv2d_calls = list() + + def populate_ethosu_conv2d_calls(stmt): + if ( + isinstance(stmt, tvm.tir.Call) + and stmt.op.name == "tir.call_extern" + and stmt.args[0] == "ethosu_conv2d" + ): + ethosu_conv2d_calls.append(stmt) + + stmt_functor.post_order_visit(primfunc.body, populate_ethosu_conv2d_calls) + return ethosu_conv2d_calls + + +@pytest.mark.parametrize( + "accel", + ACCEL_TYPES, +) +def test_encode_weights(accel): + test_vecs = [ + { + # Stimulus + "tir_module": Module1(), + "param_dict": { + 1: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [48], "uint8"), + 2: np.random.randint(np.iinfo("int32").min, np.iinfo("int32").max, [16], "int32"), + }, + "accel_type": accel, + # Reference outputs + "block_traversal": vapi.NpuBlockTraversal.PART_KERNEL_FIRST, + }, + ] + + def create_mock(test_vec): + with patch( + "tvm.relay.backend.contrib.ethosu.vela_api.vapi.npu_encode_weights" + ) as mock_enc_w: + with patch( + "tvm.relay.backend.contrib.ethosu.vela_api.vapi.npu_find_block_configs" + ) as mock_blk_cfg: + mock_blk_cfg.return_value = [vapi.NpuShape3D(8, 8, 8)] + ethosu_conv2d_calls = extract_ethosu_conv2d_extern_calls(test_vec["tir_module"]) + buffer_info = tirtocs.extract_buffer_info( + test_vec["tir_module"], test_vec["param_dict"] + ) + for ethosu_conv2d_call in ethosu_conv2d_calls: + npu_op, _ = tirtocs.translate_ethosu_conv2d(ethosu_conv2d_call) + weights = buffer_info[npu_op.weights[0].address.buffer_var][0] + vela_api.encode_weights(ethosu_conv2d_call, weights, accel) + return mock_enc_w + + def verify(test_vec, mock_enc_w): + ethosu_conv2d_calls = extract_ethosu_conv2d_extern_calls(test_vec["tir_module"]) + buffer_info = tirtocs.extract_buffer_info(test_vec["tir_module"], test_vec["param_dict"]) + for ethosu_conv2d_call in ethosu_conv2d_calls: + npu_op, w_zero_point = tirtocs.translate_ethosu_conv2d(ethosu_conv2d_call) + weights = buffer_info[npu_op.weights[0].address.buffer_var][0] + + assert mock_enc_w.call_args[1]["accelerator"] == accel + assert ( + mock_enc_w.call_args[1]["weights_volume"].flatten() + == weights.astype(np.int64) - w_zero_point + ).all() + assert mock_enc_w.call_args[1]["dilation_xy"] == ( + npu_op.kernel.dilation_x, + npu_op.kernel.dilation_y, + ) + assert mock_enc_w.call_args[1]["dilation_xy"] == ( + npu_op.kernel.dilation_x, + npu_op.kernel.dilation_y, + ) + assert mock_enc_w.call_args[1]["ifm_bitdepth"] == npu_op.ifm.data_type.size_in_bits() + assert mock_enc_w.call_args[1]["block_traversal"] == test_vec["block_traversal"] + + for _test_vec in test_vecs: + _mock_enc_w = create_mock(_test_vec) + verify(_test_vec, _mock_enc_w) + + if __name__ == "__main__": pytest.main([__file__]) From 09046dfb08d1d730fbf784c70b214d3456fc9d5c Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Mon, 23 Aug 2021 13:58:12 +0100 Subject: [PATCH 2/7] Fix Conv2D TIR type sensitivity Change-Id: I3741f9dd8bb5952590ff8c586f6b96e5c3a03795 --- .../backend/contrib/ethosu/te/convolution.py | 2 +- .../relay/backend/contrib/ethosu/te/dma.py | 4 +- .../backend/contrib/ethosu/tir/convolution.py | 18 +++---- .../relay/backend/contrib/ethosu/tir/utils.py | 48 +++++++++++++++++++ 4 files changed, 61 insertions(+), 11 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py index 40015ac296a6..26f7ea979219 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py @@ -140,7 +140,7 @@ def conv2d_compute( ).astype(ifm.dtype) * weight[cc, rh, rw, rc].astype(ifm.dtype) # This is a trick to load 10 elements of the scale_bias at once, not accurate maths - + (scale_bias[cc, 0] * scale_bias[cc, 9]), + + (scale_bias[cc, 0] * scale_bias[cc, 9]).astype(ifm.dtype), axis=[rh, rw, rc], ), name="ethosu_conv2d", diff --git a/python/tvm/relay/backend/contrib/ethosu/te/dma.py b/python/tvm/relay/backend/contrib/ethosu/te/dma.py index d19c8c56f7c2..bf9a018ea855 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/dma.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/dma.py @@ -59,7 +59,9 @@ def _pad(*indices): not_zero.append(indices[i] < tensor.shape[i] + pad_before[i]) if not_zero: not_zero = tvm.tir.all(*not_zero) - return tvm.tir.if_then_else(not_zero, tensor(*index_tuple), tvm.tir.const(0, "uint8")) + return tvm.tir.if_then_else( + not_zero, tensor(*index_tuple), tvm.tir.const(0, tensor.dtype) + ) return tensor(*index_tuple) return _pad diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py index 69d0e457e33b..33fbdcd2b24f 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py @@ -18,7 +18,7 @@ """Extract information from the convolution operators in TIR.""" import tvm from ..vela_api import SCALE_BIAS_LENGTH -from .utils import get_outer_loops, get_op_attrs, get_base_address +from .utils import get_outer_loops, get_op_attrs, get_base_address, get_loads, get_stores from .dma import get_ifm_params, get_ofm_params from .spec import SerialKernel, SerialAddressRange, SerialActivation, Serial2DConvolution @@ -53,9 +53,12 @@ def get_conv2d_params(stmt, producers, consumers): rh = inner rw = rh.body rc = rw.body - compute = rc.body.value.b - input_pointer = compute.a.a.buffer_var - output_pointer = rc.body.buffer_var + # loads = [output, input, weights, scale_bias, scale_bias] + loads = get_loads(rc.body) + # stores = [output] + stores = get_stores(rc.body) + input_pointer = loads[1].buffer_var + output_pointer = stores[0].buffer_var # Get feature map info serial_ifm, serial_padding = get_ifm_params(input_pointer, producers) serial_ofm, replace_pointer = get_ofm_params(output_pointer, consumers) @@ -69,17 +72,14 @@ def get_conv2d_params(stmt, producers, consumers): dilation_h=int(attrs["dilation_h"]), ) # Get scale_bias info - scale_bias_mul = compute.b - if isinstance(scale_bias_mul, tvm.tir.Cast): - scale_bias_mul = scale_bias_mul.value - scale_bias_load = scale_bias_mul.a + scale_bias_load = loads[3] scale_bias_base = get_base_address(scale_bias_load.index) serial_scale_bias = SerialAddressRange( address=tvm.tir.Load("uint8", scale_bias_load.buffer_var, scale_bias_base), length=SCALE_BIAS_LENGTH * serial_ofm[3], ) # Get weight info - weight_load = compute.a.b + weight_load = loads[2] weight_base = get_base_address(weight_load.index) serial_weight = SerialAddressRange( address=tvm.tir.Load("uint8", weight_load.buffer_var, weight_base), diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py index 55db62edfa5a..7d6fd3bf82d8 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py @@ -172,3 +172,51 @@ def get_outer_loops(stmt, layout): b = w.body return n, h, w, cb, b, b.body return None + + +def get_loads(stmt): + """Get the Load statements. + + Parameters + ---------- + stmt : tvm.tir.Stmt + The statement to get the Loads from. + + Returns + ------- + loads : list of tvm.tir.Load + The Loads found. + + """ + loads = [] + + def _visit(s): + if isinstance(s, tvm.tir.Load): + loads.append(s) + + tvm.tir.stmt_functor.post_order_visit(stmt, _visit) + return loads + + +def get_stores(stmt): + """Get the Store statements. + + Parameters + ---------- + stmt : tvm.tir.Stmt + The statement to get the Stores from. + + Returns + ------- + stores : list of tvm.tir.Store + The Stores found. + + """ + stores = [] + + def _visit(s): + if isinstance(s, tvm.tir.Store): + stores.append(s) + + tvm.tir.stmt_functor.post_order_visit(stmt, _visit) + return stores From dea5853be7b2fcff0d93d962be85e61e5d09a959 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Wed, 25 Aug 2021 12:25:29 +0100 Subject: [PATCH 3/7] Arm(R) Ethos(TM)-U NPU TIR passes and TE for Conv2D *fixing tests Change-Id: Id4a4c80f72ce29b98fc8b3954a1413c1c7fda500 --- .../test_ethosu/test_encode_constants.py | 6 +++--- .../test_ethosu/test_replace_conv2d.py | 20 +++++++++---------- .../contrib/test_ethosu/test_replace_copy.py | 2 +- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 05d8d1c71618..3a3c4b5913cc 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -34,7 +34,7 @@ class WeightStreamOnly: def main(placeholder: ty.handle, ethosu_write: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, placeholder_5: ty.handle, placeholder_6: ty.handle, placeholder_7: ty.handle, placeholder_8: ty.handle) -> None: # function attr dict - tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = tir.match_buffer(placeholder_7, [112], dtype="uint8", elem_offset=0, align=128, offset_factor=1) buffer_1 = tir.match_buffer(placeholder_4, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) buffer_2 = tir.match_buffer(placeholder_2, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) @@ -111,7 +111,7 @@ def _get_func(): class DirectReadOnly: def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, ethosu_write: ty.handle) -> None: # function attr dict - tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = tir.match_buffer(placeholder_3, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) placeholder_5 = tir.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8", elem_offset=0, align=128, offset_factor=1) @@ -172,7 +172,7 @@ def _get_func(): class MixedRead: def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, placeholder_5: ty.handle, placeholder_6: ty.handle, placeholder_7: ty.handle, placeholder_8: ty.handle, placeholder_9: ty.handle, placeholder_10: ty.handle) -> None: # function attr dict - tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = tir.match_buffer(placeholder_7, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) buffer_1 = tir.match_buffer(placeholder_5, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) buffer_2 = tir.match_buffer(placeholder_3, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index b8889e25fe9c..382260fd53e0 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -194,7 +194,7 @@ def _visit(stmt): class Conv2dDoubleCascade1: def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, ethosu_write: ty.handle) -> None: # function attr dict - tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = tir.match_buffer(placeholder_3, [304], dtype="uint8", elem_offset=0, align=128, offset_factor=1) placeholder_5 = tir.match_buffer(placeholder, [1, 8, 8, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) buffer_1 = tir.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) @@ -214,7 +214,7 @@ def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.han class Conv2dDoubleCascade2: def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, ethosu_write: ty.handle) -> None: # function attr dict - tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = tir.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) buffer_1 = tir.match_buffer(placeholder_2, [320], dtype="uint8", elem_offset=0, align=128, offset_factor=1) buffer_2 = tir.match_buffer(placeholder_1, [1312], dtype="uint8", elem_offset=0, align=128, offset_factor=1) @@ -234,7 +234,7 @@ def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.han class Conv2dDoubleCascade3: def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, ethosu_write: ty.handle) -> None: # function attr dict - tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 20, 4, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) buffer = tir.match_buffer(placeholder_3, [1744], dtype="uint8", elem_offset=0, align=128, offset_factor=1) buffer_1 = tir.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) @@ -256,7 +256,7 @@ def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.han class Conv2dDoubleCascade4: def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, ethosu_write: ty.handle) -> None: # function attr dict - tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = tir.match_buffer(placeholder_1, [1456], dtype="uint8", elem_offset=0, align=128, offset_factor=1) buffer_1 = tir.match_buffer(placeholder_2, [352], dtype="uint8", elem_offset=0, align=128, offset_factor=1) placeholder_5 = tir.match_buffer(placeholder, [1, 8, 1, 8, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) @@ -385,7 +385,7 @@ def _get_func( class Conv2dInlineCopy1: def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: # function attr dict - tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = tir.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) placeholder_3 = tir.match_buffer(placeholder, [1, 10, 12, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 8, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) @@ -399,7 +399,7 @@ def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.han class Conv2dInlineCopy2: def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: # function attr dict - tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 3, 5, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) placeholder_3 = tir.match_buffer(placeholder, [1, 7, 9, 5], dtype="int8", elem_offset=0, align=128, offset_factor=1) buffer = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) @@ -442,7 +442,7 @@ def _get_func(ifm_shape, lower, upper, ofm_channels=16): class Conv2dInlineReshape1: def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: # function attr dict - tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 6, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) buffer = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) buffer_1 = tir.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) @@ -457,7 +457,7 @@ def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.han class Conv2dInlineReshape2: def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: # function attr dict - tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 6, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) buffer = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) buffer_1 = tir.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) @@ -472,7 +472,7 @@ def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.han class Conv2dInlineReshape3: def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: # function attr dict - tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) placeholder_3 = tir.match_buffer(placeholder, [192, 1], dtype="int8", elem_offset=0, align=128, offset_factor=1) buffer_1 = tir.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) @@ -487,7 +487,7 @@ def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.han class Conv2dInlineReshape4: def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: # function attr dict - tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 6, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) buffer = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) placeholder_3 = tir.match_buffer(placeholder, [192], dtype="int8", elem_offset=0, align=128, offset_factor=1) diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index 911e150593cc..afa9b8e74ca3 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -31,7 +31,7 @@ class ReferenceModule: def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: # function attr dict - tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = tir.match_buffer(placeholder_2, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) placeholder_3 = tir.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8", elem_offset=0, align=128, offset_factor=1) buffer_1 = tir.match_buffer(placeholder_1, [304], dtype="uint8", elem_offset=0, align=128, offset_factor=1) From 36bde1e526390dd11955c9896801db0478e37cf9 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Thu, 26 Aug 2021 10:30:36 +0100 Subject: [PATCH 4/7] Fix import guards for tests Change-Id: Iaee06017bd125d3040ce42182c4ccdb80d7fc946 --- .../python/contrib/test_ethosu/test_attr_passing.py | 6 +++--- tests/python/contrib/test_ethosu/test_compiler.py | 5 ++++- .../contrib/test_ethosu/test_encode_constants.py | 4 ++-- .../contrib/test_ethosu/test_extract_constants.py | 6 ++++-- tests/python/contrib/test_ethosu/test_lower_to_te.py | 6 +++--- .../contrib/test_ethosu/test_replace_conv2d.py | 4 ++-- .../python/contrib/test_ethosu/test_replace_copy.py | 2 ++ tests/python/contrib/test_ethosu/test_scheduler.py | 12 ++++-------- 8 files changed, 24 insertions(+), 21 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_attr_passing.py b/tests/python/contrib/test_ethosu/test_attr_passing.py index 812f68513c31..a2fbe1888d2a 100644 --- a/tests/python/contrib/test_ethosu/test_attr_passing.py +++ b/tests/python/contrib/test_ethosu/test_attr_passing.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, unused-argument +import pytest +pytest.importorskip("ethosu.vela") import tvm from tvm import relay from tvm.relay.backend.contrib.ethosu import util @@ -40,5 +41,4 @@ def test_compiler_attr_default(): if __name__ == "__main__": - test_compiler_attr() - test_compiler_attr_default() + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_compiler.py b/tests/python/contrib/test_ethosu/test_compiler.py index ae649c6beeac..4df6311a230c 100644 --- a/tests/python/contrib/test_ethosu/test_compiler.py +++ b/tests/python/contrib/test_ethosu/test_compiler.py @@ -14,6 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest + +pytest.importorskip("ethosu.vela") import tvm from tvm import relay from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir @@ -42,4 +45,4 @@ def test_lower_to_tir(): if __name__ == "__main__": - test_lower_to_tir() + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 3a3c4b5913cc..0e546ae2fd24 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import numpy as np +import pytest +pytest.importorskip("ethosu.vela") import tvm from tvm import tir from tvm import script @@ -24,7 +25,6 @@ from tvm.relay.testing import run_opt_pass from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir from tvm.relay.backend.contrib.ethosu.tir.scheduler import Convolution2DCompute -import pytest from infra import make_ethosu_conv2d diff --git a/tests/python/contrib/test_ethosu/test_extract_constants.py b/tests/python/contrib/test_ethosu/test_extract_constants.py index 48266b54a605..98094d8a4ed4 100644 --- a/tests/python/contrib/test_ethosu/test_extract_constants.py +++ b/tests/python/contrib/test_ethosu/test_extract_constants.py @@ -14,6 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest + +pytest.importorskip("ethosu.vela") import tvm from tvm import relay from tvm.relay.testing import run_opt_pass @@ -93,5 +96,4 @@ def _expected(): if __name__ == "__main__": - test_extract_constants_single() - test_extract_constants_multi() + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_lower_to_te.py b/tests/python/contrib/test_ethosu/test_lower_to_te.py index 18bde7ebd7c0..cabd68b4e8d2 100644 --- a/tests/python/contrib/test_ethosu/test_lower_to_te.py +++ b/tests/python/contrib/test_ethosu/test_lower_to_te.py @@ -14,11 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, unused-argument +import pytest +pytest.importorskip("ethosu.vela") import tvm from tvm import relay -from tvm.relay.testing import run_opt_pass from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_te from tvm.relay.backend.contrib.ethosu.tir.scheduler import Convolution2DCompute import tvm.relay.backend.contrib.ethosu.op as ethosu_ops @@ -60,4 +60,4 @@ def test_ethosu_conv2d(): if __name__ == "__main__": - test_ethosu_conv2d() + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index 382260fd53e0..96fe56d1778e 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest +pytest.importorskip("ethosu.vela") import tvm import tvm.script from tvm.script import tir, ty @@ -24,8 +26,6 @@ from tvm.relay.backend.contrib.ethosu.tir.scheduler import total_cascader from infra import make_ethosu_conv2d, get_convolutional_args -import pytest - @pytest.mark.parametrize( "trial", diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index afa9b8e74ca3..222dccacc906 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. import pytest + +pytest.importorskip("ethosu.vela") import tvm import tvm.script from tvm.script import tir, ty diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index bef665fe2444..b07f8ea7f48b 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm +import pytest + +pytest.importorskip("ethosu.vela") from tvm import relay from tvm.relay.testing import run_opt_pass from tvm import te, topi @@ -139,10 +141,4 @@ def test_schedule_cache_reads(): if __name__ == "__main__": - test_tile_nd() - test_schedule_pragmas() - test_schedule_pragmas_for_const() - test_inline_no_ops() - test_total_cascader() - test_copy_constants() - test_schedule_cache_reads() + pytest.main([__file__]) From df9ec83abad5701015cea50aa176b258b73dd473 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Thu, 9 Sep 2021 16:22:31 +0100 Subject: [PATCH 5/7] Fix typing failures with ignores Change-Id: I81513f112a42b93cfdd3bcaf8e8852dd60ffe9e9 --- python/tvm/relay/backend/contrib/ethosu/tir/passes.py | 9 +++++---- python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py | 2 +- .../relay/backend/contrib/ethosu/tir_to_cs_translator.py | 6 +++--- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 75a2b5b3362b..e2bd845aa7ea 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -16,7 +16,8 @@ # under the License. # pylint: disable=invalid-name, unused-argument """The TIR passes to be run on Arm(R) Ethos(TM)-U NPU TIR Compiler""" -import numpy as np +from typing import Dict +import numpy as np # type: ignore import tvm from tvm.relay.backend.contrib.ethosu import vela_api @@ -216,7 +217,7 @@ def DivideConstants(const_dict): The purpose of this pass is to transform the IR into a form we can apply constant encoding to (which will compress weights and encode biases).""" - buffer_to_const = {} + buffer_to_const = {} # type: ignore new_buffers = [] new_consts = [] keep_buffers = set() @@ -253,7 +254,7 @@ def _ftransform(f, mod, ctx): new_body = tvm.tir.stmt_functor.ir_transform(f.body, _visit, None, ["tir.Call"]) # Both the params and buffer map need updating for the newly introduced buffers - new_params = [] + new_params = [] # type: ignore new_buffer_map = {} for i, param in enumerate(f.params): buffer = f.buffer_map[param] @@ -299,7 +300,7 @@ def EncodeConstants(const_dict): pointer_to_buffer = {} rewrite_buffer = {} rewrite_pointer = {} - accel_type = vela_api.get_target_accel_type() + accel_type = vela_api.get_target_accel_type() # type: ignore def _align_scale_bias(tir_extern_call, bias): """Align the scale_bias to 16 bytes.""" diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py index fd52e1821cb6..5d9027bf2078 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py @@ -127,7 +127,7 @@ def copy_constants(): """ def _planner(te_graph, const_dict, sch): - planned = set() + planned = set() # type: ignore def _visit(tensor, reader): if tensor is not planned: diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index 1f021ed6046a..ce9abcbd683d 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -21,8 +21,8 @@ from typing import NamedTuple from enum import auto from enum import Enum -import numpy as np -import ethosu.vela.api as vapi +import numpy as np # type: ignore +import ethosu.vela.api as vapi # type: ignore import tvm from tvm.relay.backend.contrib.ethosu import vela_api @@ -165,7 +165,7 @@ def _create_npu_op_conv2d(serial_2d_convolution): _convert_clip_bounds(npu_conv2d_op) npu_conv2d_op.upscale = _create_npu_resampling_mode(serial_2d_convolution.upscale) - target_accel_type = vela_api.get_target_accel_type() + target_accel_type = vela_api.get_target_accel_type() # type: ignore block_config = vela_api.get_optimal_block_config(npu_conv2d_op, target_accel_type) npu_conv2d_op.block_config = block_config weights_shape_ohwi = [ From 2f165fc72126e1724d58b4b7fc12374962283215 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Thu, 9 Sep 2021 17:12:11 +0100 Subject: [PATCH 6/7] Remove unused import Change-Id: I6596b62ab56e4ca8b31ef08293686f53f38454d2 --- python/tvm/relay/backend/contrib/ethosu/tir/passes.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index e2bd845aa7ea..1af44962c141 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=invalid-name, unused-argument """The TIR passes to be run on Arm(R) Ethos(TM)-U NPU TIR Compiler""" -from typing import Dict import numpy as np # type: ignore import tvm From 7f0e9bc3d1838a5cb3bee9ddddf66b742e054747 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Fri, 10 Sep 2021 12:05:25 +0100 Subject: [PATCH 7/7] Reintroduce get_target_accel_type Change-Id: I0aaf83fe0204c0db435692e9b92dee6e9d6997fe --- .../tvm/relay/backend/contrib/ethosu/vela_api.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/tvm/relay/backend/contrib/ethosu/vela_api.py b/python/tvm/relay/backend/contrib/ethosu/vela_api.py index be011bd73359..5009c3157c77 100644 --- a/python/tvm/relay/backend/contrib/ethosu/vela_api.py +++ b/python/tvm/relay/backend/contrib/ethosu/vela_api.py @@ -356,3 +356,17 @@ def _calculate_hw_bias_scales( hw_bias_scales = [_quantize_scale(bs) for bs in bias_scales] return hw_bias_scales + + +def get_target_accel_type(): + """This is a helper function to convert cli accelerator type str argument + to NpuAccelerator""" + npu_accel_str_map = { + "ethos-u55-256": vapi.NpuAccelerator.Ethos_U55_256, + "ethos-u55-128": vapi.NpuAccelerator.Ethos_U55_128, + "ethos-u55-64": vapi.NpuAccelerator.Ethos_U55_64, + "ethos-u55-32": vapi.NpuAccelerator.Ethos_U55_32, + } + accel_type_str = util.get_accelerator_config() + assert accel_type_str in npu_accel_str_map.keys(), f"{accel_type_str} is not supported" + return npu_accel_str_map[accel_type_str]