From 76787453d6cb54b21700b7eaf152624cdfae8033 Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Mon, 11 Oct 2021 16:03:48 +0100 Subject: [PATCH 1/3] [microNPU] Add the infrastructure for lookup table and TANH Some activation functions like TANH and SIGMOID are implemented by calculating the values based on the QNN parameters and recording the values into a lookup table (LUT). This patch adds the LUT functionality alongside with the TANH activation function and the tests. Change-Id: Ibe49759fd02724af869826663ff0babd352e5894 --- .../relay/backend/contrib/ethosu/codegen.py | 104 ++++++++++ .../relay/backend/contrib/ethosu/legalize.py | 76 ++++++++ .../backend/contrib/ethosu/op/op_attrs.py | 39 ++++ .../backend/contrib/ethosu/te/convolution.py | 9 +- .../backend/contrib/ethosu/te/depthwise.py | 9 +- .../backend/contrib/ethosu/te/identity.py | 13 +- .../backend/contrib/ethosu/te/pooling.py | 11 +- .../backend/contrib/ethosu/tir/convolution.py | 2 +- .../backend/contrib/ethosu/tir/identity.py | 7 +- .../backend/contrib/ethosu/tir/pooling.py | 11 +- .../backend/contrib/ethosu/tir/scheduler.py | 46 ++++- .../contrib/ethosu/tir_to_cs_translator.py | 49 ++++- python/tvm/relay/op/contrib/ethosu.py | 30 +++ tests/python/contrib/test_ethosu/infra.py | 3 +- .../contrib/test_ethosu/test_codegen.py | 67 +++++++ .../contrib/test_ethosu/test_legalize.py | 54 +++++- .../contrib/test_ethosu/test_lookup_table.py | 181 ++++++++++++++++++ .../contrib/test_ethosu/test_lut_optimizer.py | 95 +++++++++ .../test_ethosu/test_replace_conv2d.py | 38 ++-- .../test_replace_depthwise_conv2d.py | 2 +- .../contrib/test_ethosu/test_scheduler.py | 28 ++- 21 files changed, 832 insertions(+), 42 deletions(-) create mode 100644 python/tvm/relay/backend/contrib/ethosu/op/op_attrs.py create mode 100644 tests/python/contrib/test_ethosu/test_lookup_table.py create mode 100644 tests/python/contrib/test_ethosu/test_lut_optimizer.py diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index 5fe51b4cbda0..8f193d48ad6b 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -22,6 +22,109 @@ from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator from tvm.relay.backend.contrib.ethosu import util +from tvm.relay.expr_functor import ExprMutator +from tvm.ir.transform import Pass + +# pylint: disable=unused-import +from tvm.relay.backend.contrib.ethosu.op import op_attrs +from tvm.relay.backend.contrib.ethosu import op + + +class OptimizeLUTs(ExprMutator): + """A pass to merge an identity operator with a LUT based activation function with + a preceding operator provided that operator can do a table lookup for the activation + in the hardware""" + + def __init__(self): + super().__init__() + self.lut_ops = { + "contrib.ethosu.conv2d": op.ethosu_conv2d, + "contrib.ethosu.depthwise_conv2d": op.ethosu_depthwise_conv2d, + "contrib.ethosu.pooling": op.ethosu_pooling, + } + + def create_op_with_lut(self, call): + """Extract the parameters and attributes from the NPU operator and create + a new operator with LUT. + ---------- + call : tvm.relay.expr.Call + The current call node being visited. + Returns + ------- + tvm.relay.expr.Call + The new operator with LUT. + """ + identity = call + ethosu_op = call.args[0] + lut = identity.args[1] + activation = identity.attrs.activation + + new_attrs = dict(ethosu_op.attrs) + new_attrs["activation"] = activation + + # Assume that LUT is always the last argument + new_args = [ethosu_op.args[n] for n in range(len(ethosu_op.args) - 1)] + new_args.append(lut) + assert ethosu_op.op.name in self.lut_ops.keys() + + return self.lut_ops[ethosu_op.op.name](*new_args, **new_attrs) + + def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call: + """Recursively visit call nodes in the input graph and if an ethosu.identity + operator with LUT is found and the preceding operator has a LUT attribute, create + a new NPU operator. + Parameters + ---------- + call : tvm.relay.expr.Call + The current call node being visited. + Returns + ------- + tvm.relay.expr.Call + The input call node in the case the current call node does + not refer to an Op. Else, a new call node with a new operator. + """ + new_call = call + lut_activations = ["TANH", "LUT"] + + if ( + call.op.name == "contrib.ethosu.identity" + and call.attrs.activation in lut_activations + and isinstance(call.args[0], tvm.relay.expr.Call) + ): + producer_op = call.args[0] + # Check if the producer can do a LUT operation + if producer_op.op.name in self.lut_ops.keys(): + # Check the producer doesn't already have a LUT + has_lut = producer_op.attrs.activation in lut_activations + if not has_lut: + new_call = self.create_op_with_lut(call) + + new_call = super().visit_call(new_call) + + return new_call + + +@relay.transform.function_pass(opt_level=1, name="LutOptimizer") +class LUTsOptimizer(Pass): + """Register LutOptimizer as a relay pass.""" + + def transform_function( + self, func: tvm.relay.function.Function, mod: tvm.IRModule, _ + ) -> tvm.IRModule: + """Visit relay nodes in the given module. + Parameters + ---------- + func : tvm.relay.function.Function + The function to apply the layout optimization pass to. + mod : tvm.IRModule + The module to apply the layout optimization pass to. + Returns + ------- + mod : tvm.IRModule + New module with augmented layouts. + """ + assert len(mod.functions.items()) == 1, "Module can only contain one function." + return OptimizeLUTs().visit(func) @tvm._ffi.register_func("relay.ext.ethos-u") @@ -74,6 +177,7 @@ def _compile(ext_func): mod = tvm.IRModule() mod["main"] = ext_func mod = LegalizeEthosU()(mod) + mod = LUTsOptimizer()(mod) mod = relay.transform.InferType()(mod) # We are currently using copy_constants scheduler In the long run, # this should be a single intelligent and a composite scheduler diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 8f2dddbf88a6..3d4f8b71cfbb 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -17,6 +17,7 @@ # pylint: disable=invalid-name, unused-argument, import-outside-toplevel, no-value-for-parameter """A set of passes to legalize some of operations for the NPU""" from typing import List, Type +import math import numpy as np # type: ignore @@ -123,6 +124,80 @@ def __call__(self, *args, **kwargs): pass +def round_away_zero(f): + r = -0.5 if (f < 0) else 0.5 + return np.trunc(f + r) + + +def find_tanh_values(ifm_scale, ifm_zp, ofm_scale, ofm_zp): + """Method to calculate the values of the tanh lookup table""" + lut_values = list() + # Only int8 is currently supported + dtype = np.int8 + qmin, qmax = np.iinfo(dtype).min, np.iinfo(dtype).max + for x in range(qmin, qmax + 1): + x_real = ifm_scale * (x - ifm_zp) + out_real = math.tanh(x_real) + lut_result = int(round_away_zero(ofm_zp + out_real / ofm_scale)) + lut_result = min(qmax, max(qmin, lut_result)) + lut_values.append(lut_result) + + return lut_values + + +class TanhRewriter(DFPatternCallback): + """This pass adds tanh as a LUT to the identity operator""" + + def __init__(self): + super().__init__(require_type=True, rewrite_once=True) + self.pattern = ( + wildcard().has_attr({"Composite": ethosu_patterns.TanhParams.composite_name}) + )(wildcard()) + + def callback(self, pre, post, node_map): + id_input = post.args[0] + + quantize_args = post.op.body.args + output_scale = float(quantize_args[1].data.asnumpy()) + output_zp = int(quantize_args[2].data.asnumpy()) + + dequantize_args = quantize_args[0].args[0].args + input_scale = float(dequantize_args[1].data.asnumpy()) + input_zp = int(dequantize_args[2].data.asnumpy()) + + lut_values = find_tanh_values(input_scale, input_zp, output_scale, output_zp) + lut = relay.const(lut_values, dtype="uint8") + + # We baked the requantization into the LUT, so we don't requantize the identity operator + identity = ethosu_ops.ethosu_identity( + ifm=id_input, + lut=lut, + ifm_scale=input_scale, + ifm_zero_point=input_zp, + ofm_scale=input_scale, + ofm_zero_point=input_zp, + activation="TANH", + ) + + return identity + + +@ir.transform.module_pass(opt_level=1) +class LegalizeTanh: + """This is the pass that wraps TanhRewriter""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(TanhRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + class Conv2DRewriter(DFPatternCallback): """Convert conv2d related composite functions into ethosu_conv2d operators""" @@ -915,6 +990,7 @@ def transform_module( mod = LegalizeMax()(mod) mod = LegalizeShl()(mod) mod = LegalizeAbs()(mod) + mod = LegalizeTanh()(mod) mod = LegalizeReshape()(mod) mod = LegalizeStridedSlice()(mod) mod = LegalizeNoOps()(mod) diff --git a/python/tvm/relay/backend/contrib/ethosu/op/op_attrs.py b/python/tvm/relay/backend/contrib/ethosu/op/op_attrs.py new file mode 100644 index 000000000000..e38a3dfd97de --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/op/op_attrs.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The attributes node used for EthosU Relay operators.""" +from tvm.ir import Attrs +import tvm._ffi + + +@tvm._ffi.register_object("relay.attrs.EthosuConv2DAttrs") +class EthosuConv2DAttrs(Attrs): + """Attributes for contrib.ethosu.conv2d.""" + + +@tvm._ffi.register_object("relay.attrs.EthosuIdentityAttrs") +class EthosuIdentityAttrs(Attrs): + """Attributes for contrib.ethosu.identity.""" + + +@tvm._ffi.register_object("relay.attrs.EthosuDepthwiseConv2DAttrs") +class EthosuDepthwiseConv2DAttrs(Attrs): + """Attributes for contrib.ethosu.depthwise_conv2d.""" + + +@tvm._ffi.register_object("relay.attrs.EthosuPoolingAttrs") +class EthosuPooling2DAttrs(Attrs): + """Attributes for contrib.ethosu.pooling.""" diff --git a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py index 26785649457c..242c6feaa195 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py @@ -140,6 +140,13 @@ def conv2d_compute( "dilation_w": dilation_w, } + # This is a trick to insert the LUT tensor into the TE graph if LUT is present + lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0 + + # Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT + if activation in ("TANH", "LUT"): + conv2d_attrs["lut"] = lut + conv = te.compute( (1, ofm_height, ofm_width, ofm_channels), lambda nn, hh, ww, cc: te.sum( @@ -148,7 +155,7 @@ def conv2d_compute( ).astype(ifm.dtype) * weight[cc, rh, rw, rc].astype(ifm.dtype) # This is a trick to load 10 elements of the scale_bias at once, not accurate maths - + (scale_bias[cc, 0] * scale_bias[cc, 9]).astype(ifm.dtype), + + (scale_bias[cc, 0] * scale_bias[cc, 9] + lut_expr).astype(ifm.dtype), axis=[rh, rw, rc], ), name="ethosu_conv2d", diff --git a/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py index 664a3f489fb5..05b2993f5857 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py @@ -136,6 +136,13 @@ def depthwise_conv2d_compute( "dilation_w": dilation_w, } + # This is a trick to insert the LUT tensor into the TE graph if LUT is present + lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0 + + # Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT + if activation in ("TANH", "LUT"): + depthwise_conv2d_attrs["lut"] = lut + depthwise = te.compute( (1, ofm_height, ofm_width, channels), lambda nn, hh, ww, cc: te.sum( @@ -144,7 +151,7 @@ def depthwise_conv2d_compute( ).astype(ifm.dtype) * weight[cc, rh, rw, 0].astype(ifm.dtype) # This is a trick to load 10 elements of the scale_bias at once, not accurate maths - + (scale_bias[cc, 0] * scale_bias[cc, 9]).astype(ifm.dtype), + + (scale_bias[cc, 0] * scale_bias[cc, 9] + lut_expr).astype(ifm.dtype), axis=[rh, rw], ), name="ethosu_depthwise_conv2d", diff --git a/python/tvm/relay/backend/contrib/ethosu/te/identity.py b/python/tvm/relay/backend/contrib/ethosu/te/identity.py index f26179422b4b..574fc661599f 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/identity.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/identity.py @@ -58,14 +58,21 @@ def identity_compute( The Output Feature Map tensor. """ - dmaed_ifm = read_compute(ifm, ifm_zero_point, ifm_scale) + id_attrs = {"op": "ethosu_identity", "activation": activation} + + # This is a trick to insert the LUT tensor into the TE graph if LUT is present + lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0 + + # Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT + if activation in ("TANH", "LUT"): + id_attrs["lut"] = lut identity = te.compute( ifm.shape, - lambda *i: dmaed_ifm(*i).astype(ifm.dtype), + lambda *i: (dmaed_ifm(*i) + lut_expr).astype(ifm.dtype), name="ethosu_identity", - attrs={"op": "ethosu_identity", "activation": activation}, + attrs=id_attrs, ) dmaed_ofm = write_compute(identity, ofm_zero_point, ofm_scale) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/pooling.py b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py index bf35479d7556..2ab0844b1622 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/pooling.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py @@ -123,10 +123,19 @@ def pooling_compute( "upscale": upscale, } + # This is a trick to insert the LUT tensor into the TE graph if LUT is present + lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0 + + # Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT + if activation in ("TANH", "LUT"): + pooling_attrs["lut"] = lut + pooling = te.compute( (1, ofm_height, ofm_width, ofm_channels), lambda nn, hh, ww, cc: te.max( - dmaed_ifm(nn, hh * stride_h + rh, ww * stride_w + rw, cc).astype(ifm.dtype), + (dmaed_ifm(nn, hh * stride_h + rh, ww * stride_w + rw, cc) + lut_expr).astype( + ifm.dtype + ), axis=[rh, rw], ), name="ethosu_pooling", diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py index 5e8ea002783f..254f92a30c32 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py @@ -53,7 +53,7 @@ def get_conv2d_params(stmt, producers, consumers): rh = inner rw = rh.body rc = rw.body - # loads = [output, input, weights, scale_bias, scale_bias] + # loads = [output, input, weights, scale_bias, scale_bias, LUT, LUT] loads = get_loads(rc.body) # stores = [output] stores = get_stores(rc.body) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/identity.py b/python/tvm/relay/backend/contrib/ethosu/tir/identity.py index 7a81a702f019..23fc31efbfac 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/identity.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/identity.py @@ -19,7 +19,7 @@ from typing import Dict, Tuple import tvm from .spec import SerialKernel, SerialActivation, SerialPooling, SerialPadding, SerialFeatureMap -from .utils import get_op_attrs, get_base_address, get_strides +from .utils import get_op_attrs, get_base_address, get_strides, get_loads def _get_feature_map(stmt: tvm.tir.AttrStmt, fm_type: str) -> Tuple[SerialFeatureMap, tvm.tir.Var]: @@ -123,7 +123,10 @@ def get_identity_params( while hasattr(stmt, "body"): stmt = stmt.body - input_pointer = stmt.value.buffer_var + # loads = [input, LUT, LUT] + loads = get_loads(stmt) + + input_pointer = loads[0].buffer_var output_pointer = stmt.buffer_var read = producers[input_pointer] diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py b/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py index 33dcb36fbbb6..b19ec034e7d4 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py @@ -18,7 +18,7 @@ """Extract information from the pooling operators in TIR.""" from typing import Dict, Tuple import tvm -from .utils import get_outer_loops, get_op_attrs +from .utils import get_outer_loops, get_op_attrs, get_loads, get_stores from .dma import get_ifm_params, get_ofm_params from .spec import SerialKernel, SerialActivation, SerialPooling @@ -55,9 +55,12 @@ def get_pooling_params( _, _, _, _, _, inner = get_outer_loops(body, "NHWC") rh = inner rw = rh.body - compute = rw.body.value.b - input_pointer = compute.buffer_var - output_pointer = rw.body.buffer_var + # loads = [output, input, LUT, LUT] + loads = get_loads(rw.body) + # stores = [output] + stores = get_stores(rw.body) + input_pointer = loads[1].buffer_var + output_pointer = stores[0].buffer_var # Get feature map info serial_ifm, serial_padding = get_ifm_params(input_pointer, producers) serial_ofm, replace_pointer = get_ofm_params(output_pointer, consumers) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py index 7f892d0c602a..e4dcfcd670aa 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py @@ -42,6 +42,8 @@ def schedule(cached_func, const_dict, cascader=None): if cascader: cascader(cached_func, const_dict, s) inline_no_ops(cached_func, s) + copy_luts()(cached_func, const_dict, s) + inline_no_ops(cached_func, s) schedule_pragmas(s) schedule_cache_reads(s) return s @@ -129,20 +131,54 @@ def copy_constants(): def _planner(cached_func, const_dict, sch): planned = set() # type: ignore - def _visit(tensor, reader): + def _visit(tensor, reader, lut): if tensor is not planned: planned.add(tensor) - if isinstance(tensor.op, tvm.te.PlaceholderOp): + if isinstance(tensor.op, tvm.te.PlaceholderOp) and tensor != lut: index = list(cached_func.inputs).index(tensor) if index in const_dict: sch.cache_read(tensor, "global", [reader]) elif isinstance(tensor.op, tvm.te.ComputeOp): + if "lut" in tensor.op.attrs.keys(): + lut = tensor.op.attrs["lut"] for input_tensor in tensor.op.input_tensors: - _visit(input_tensor, tensor) + _visit(input_tensor, tensor, lut) for output_tensor in cached_func.outputs: - _visit(output_tensor, None) + _visit(output_tensor, None, None) + + return _planner + + +def copy_luts(): + """A scheduler that copies LUTs to SHRAM. + + Returns + ------- + planner : callable + The planning function. + """ + + def _planner(te_graph, const_dict, sch): + planned = set() # type: ignore + + def _visit(tensor, reader, lut): + if tensor is not planned: + planned.add(tensor) + if isinstance(tensor.op, tvm.te.PlaceholderOp) and tensor == lut: + index = list(te_graph.inputs).index(tensor) + if index in const_dict: + sch.cache_read(tensor, "local", [reader]) + + elif isinstance(tensor.op, tvm.te.ComputeOp): + if "lut" in tensor.op.attrs.keys(): + lut = tensor.op.attrs["lut"] + for input_tensor in tensor.op.input_tensors: + _visit(input_tensor, tensor, lut) + + for output_tensor in te_graph.outputs: + _visit(output_tensor, None, None) return _planner @@ -165,7 +201,7 @@ def _add_pragmas(stage, ax): if "op" in [attr for attr, val in stage.op.attrs.items()]: stage.pragma(ax, "op", stage.op.attrs["op"]) for attr, val in stage.op.attrs.items(): - if attr != "op": + if attr not in ("op", "lut"): stage.pragma(ax, str(attr), val) for stage in sch.stages: diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index 4e84febe5e48..e1af7f1534e2 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -39,6 +39,7 @@ class BufferType(Enum): scratch = auto() input = auto() output = auto() + shram = auto() _REGION_MAP = { @@ -46,6 +47,7 @@ class BufferType(Enum): BufferType.scratch: 1, BufferType.input: 3, BufferType.output: 4, + BufferType.shram: int((1 << 8) | (3 << 0)), } @@ -59,6 +61,25 @@ class BufferInfo(NamedTuple): btype: BufferType +class AcceleratorArchConfig: + def __init__(self, total_shram_banks): + self.shram_bank_size = 1024 + self.total_shram_banks = total_shram_banks + self.shram_size_bytes = self.shram_bank_size * self.total_shram_banks + self.lut_size_bytes = 2048 + self.lut_start_address = self.shram_size_bytes - self.lut_size_bytes + + +def get_accelerator_arch_config(accel_type): + accel_config_str_map = { + "ethos-u55-256": AcceleratorArchConfig(48), + "ethos-u55-128": AcceleratorArchConfig(24), + "ethos-u55-64": AcceleratorArchConfig(16), + "ethos-u55-32": AcceleratorArchConfig(16), + } + return accel_config_str_map[accel_type] + + def translate(tir_module, params): """This will take an tir module for the NPU and compile to command stream @@ -168,11 +189,20 @@ def extract_buffer_info( def populate_allocate_buffer_info(stmt): if isinstance(stmt, tvm.tir.stmt.Allocate): allocate = stmt + if "placeholder" in allocate.buffer_var.name: + storage_scope = allocate.buffer_var.name.split(".")[-1] + else: + storage_scope = "global" + + if storage_scope == "local": + buffer_type = BufferType.shram + else: + buffer_type = BufferType.scratch buffer_info[allocate.buffer_var] = BufferInfo( None, allocate.extents, allocate.dtype, - BufferType.scratch, + buffer_type, ) tvm.tir.stmt_functor.post_order_visit(primfunc.body, populate_allocate_buffer_info) @@ -279,6 +309,11 @@ def classify_io(buffer): assert buffer_type in (BufferType.input, BufferType.output) address = 0 buffer_addresses[_buffer] = (address, buffer_type) + elif info.btype == BufferType.shram: + accl_config = util.get_accelerator_config() + arch_config = get_accelerator_arch_config(accl_config) + address = arch_config.lut_start_address + buffer_addresses[_buffer] = (address, info.btype) else: assert info.btype == BufferType.scratch address = scratch_size @@ -597,14 +632,18 @@ def _create_npu_activation(serial_activation: spec.SerialActivation) -> vapi.Npu return None op_map = { "CLIP": vapi.NpuActivationOp.NONE_OR_RELU, - "TANH": vapi.NpuActivationOp.TANH, - "SIGMOID": vapi.NpuActivationOp.SIGMOID, + "TANH": vapi.NpuActivationOp.TABLE_LOOKUP, + "SIGMOID": vapi.NpuActivationOp.TABLE_LOOKUP, + "LUT": vapi.NpuActivationOp.TABLE_LOOKUP, } op = str(serial_activation.op.value) assert op in op_map.keys() act_op = vapi.NpuActivation(op_map[op]) - act_op.min = int(serial_activation.clip_min) - act_op.max = int(serial_activation.clip_max) + if serial_activation.op == "CLIP": + act_op.min = int(serial_activation.clip_min.value) + act_op.max = int(serial_activation.clip_max.value) + if op_map[op] == vapi.NpuActivationOp.TABLE_LOOKUP: + act_op.lookup_table_index = 0 return act_op diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 73de3329c45f..73d94e8ca3bd 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -915,6 +915,35 @@ def abs_pattern() -> tvm.relay.dataflow_pattern.DFPattern: return pattern +class TanhParams: + """ + This class will parse a call to a ethos-u.tanh composite function + and extract the parameter information. + """ + + composite_name = "ethos-u.tanh" + + def __init__(self, func_body: Call): + self.ofm = TensorParams(func_body) + self.ifm = TensorParams(func_body.args[0].args[0].args[0]) + + def is_valid(self): + """ + This function checks whether reshape has compatible attributes with the NPU + """ + if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8]): + return False + return True + + +def tanh_pattern(): + """Create pattern for tanh""" + dequant = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant()) + tanh = is_op("tanh")(dequant) + quant = is_op("qnn.quantize")(tanh, is_constant(), is_constant()) + return quant + + @register_pattern_table("ethos-u") def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]: return [ @@ -983,6 +1012,7 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal abs_pattern(), lambda pat: AbsParams(pat).is_valid(), ), + (TanhParams.composite_name, tanh_pattern(), lambda pat: TanhParams(pat).is_valid()), ] diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index 5f339267e0b8..7842de5d9ac9 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -411,6 +411,7 @@ def make_ethosu_conv2d( padding, strides, dilation, + lut=relay.const([], dtype="int8"), activation="NONE", ifm_layout="NHWC", ofm_layout="NHWC", @@ -430,7 +431,7 @@ def make_ethosu_conv2d( ifm, weight, scale_bias, - lut=relay.const([], dtype="int8"), + lut=lut, ifm_scale=0.5, ifm_zero_point=10, weight_zero_point=12, diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index b6cf873cb6f3..e20ab41cb576 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -1003,5 +1003,72 @@ def clz_comp(n): infra.verify_source(compiled_model, accel_type) +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +def test_tflite_tanh(accel_type): + dtype = "int8" + ifm_shape = [1, 115, 32, 7] + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tanh_function(self, x): + op = tf.nn.tanh(x) + return op + + model = Model() + concrete_func = model.tanh_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + tflite_graph = create_tflite_graph() + + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + relay_module, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + mod = partition_for_ethosu(relay_module, params) + + # Generate reference data + input_data, output_data = infra.generate_ref_data_tflite(tflite_graph) + + compiled_models = infra.build_source( + mod, + input_data, + output_data, + accel_type, + ) + + # Assumes only two runtime.Modules are created -- i.e. single offload module + imported_modules = compiled_models[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + + infra.print_payload(cmms) + infra.verify_source(compiled_models, accel_type) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index dbe11cd2d7ad..64bdae5c1b8b 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -541,7 +541,6 @@ def verify(ext_func): lambda pat: ethosu.AvgPool2DParams(pat).is_valid(), ), ] - tflite_graph = create_tflite_graph() tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) @@ -1007,5 +1006,58 @@ def verify(ext_func): verify(mod["tvmgen_default_ethos_u_main_0"]) +def test_tflite_tanh_legalize(): + dtype = "int8" + ifm_shape = (1, 241, 132, 7) + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tanh_func(self, x): + op = tf.math.tanh(x) + return op + + model = Model() + concrete_func = model.tanh_func.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + + mod = ethosu.partition_for_ethosu(mod, params) + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.TanhRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + mod = relay.transform.InferType()(mod) + + func_body = mod["tvmgen_default_ethos_u_main_0"].body + assert func_body.op.name == "contrib.ethosu.identity" + assert func_body.attrs.activation == "TANH" + assert tuple(func_body.args[0].checked_type.shape) == (ifm_shape) + assert tuple(func_body.args[1].checked_type.shape) == (256,) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_lookup_table.py b/tests/python/contrib/test_ethosu/test_lookup_table.py new file mode 100644 index 000000000000..67870bd6472e --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_lookup_table.py @@ -0,0 +1,181 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +import pytest + +pytest.importorskip("ethosu.vela") +import numpy as np +import tflite.Model + +import tvm +import tensorflow as tf +from tvm import relay +from tvm.relay.op.contrib.ethosu import partition_for_ethosu +from tvm.relay.build_module import bind_params_by_name # type: ignore + +from . import infra + + +ACCEL_TYPES = ["ethos-u55-256", "ethos-u55-128", "ethos-u55-64", "ethos-u55-32"] + + +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +def test_tflite_lut_activations(accel_type): + + dtype = "int8" + ifm_shape = (1, 55, 55, 3) + + def create_tflite_graph(): + tf.config.run_functions_eagerly(True) + + class Model(tf.Module): + @tf.function + def tf_func(self, x): + weight_shape = (3, 3, ifm_shape[3], 4) + weight = tf.constant( + np.random.uniform(low=0, high=0.3, size=weight_shape), dtype=tf.float32 + ) + # The input strides to the TensorFlow API needs to be of shape 1x4 + op = tf.nn.conv2d(x, weight, strides=(1, 2, 2, 1), padding="SAME", dilations=(1, 1)) + op = tf.nn.tanh(op) + op = tf.nn.tanh(op) + + weight_shape2 = (2, 3, 4, 1) + weight2 = tf.constant( + np.random.uniform(low=0, high=0.3, size=weight_shape2), dtype=tf.float32 + ) + op = tf.nn.depthwise_conv2d( + op, weight2, strides=(1, 1, 1, 1), padding="VALID", dilations=(2, 2) + ) + op = tf.nn.tanh(op) + op = tf.nn.max_pool(op, (1, 1), strides=(1, 1, 1, 1), padding="SAME") + op = tf.nn.tanh(op) + return op + + model = Model() + concrete_func = model.tf_func.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + # Convert the model + def representative_dataset(): + for _ in range(100): + data = 0.7 * np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + relay_module, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + mod = partition_for_ethosu(relay_module, params) + + # Generate reference data + input_data, output_data = infra.generate_ref_data_tflite(tflite_graph) + + compiled_models = infra.build_source( + mod, + input_data, + output_data, + accel_type, + ) + + # Assumes only two runtime.Modules are created -- i.e. single offload module + imported_modules = compiled_models[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + infra.print_payload(cmms) + + infra.verify_source(compiled_models, accel_type) + + +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +def test_random_lut(accel_type): + + dtype = "int8" + ifm_shape = (1, 55, 55, 3) + + lut_data = np.random.randint(-128, high=127, size=[256]) + lut_data_map = {idx: lut_data[idx + 128] for idx in range(-128, 128)} + + in_data = np.random.randint(-128, high=127, size=ifm_shape, dtype=dtype) + out_data = np.array([lut_data_map[i] for i in in_data.ravel()]).reshape(ifm_shape).astype(dtype) + + ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype) + ifm0 = relay.var("ifm0", shape=ifm_shape, dtype=dtype) + lut1 = relay.var("lut1", shape=(256,), dtype="uint8") + + identity = infra.make_ethosu_identity(ifm0, lut=lut1, activation="LUT") + glb_ethosu = relay.GlobalVar("tvmgen_default_ethos_u_main_0") + + func = ( + relay.Function([ifm0, lut1], identity) + .with_attr("Inline", 1) + .with_attr("Compiler", "ethos-u") + .with_attr("global_symbol", "tvmgen_default_ethos_u_main_0") + .with_attr("Primitive", 1) + ) + + params = {"lut1": tvm.nd.array(lut_data.astype("uint8"))} + func = bind_params_by_name(func, params) + + mod = tvm.IRModule() + mod[glb_ethosu] = func + mod = relay.transform.InferType()(mod) + + call = relay.Call(glb_ethosu, [ifm]) + mod["main"] = relay.Function([ifm], call) + mod = relay.transform.InferType()(mod) + + compiled_models = infra.build_source( + mod, + {"ifm": in_data}, + out_data, + accel_type, + ) + + # Assumes only two runtime.Modules are created -- i.e. single offload module + imported_modules = compiled_models[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + infra.print_payload(cmms) + + infra.verify_source(compiled_models, accel_type) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_lut_optimizer.py b/tests/python/contrib/test_ethosu/test_lut_optimizer.py new file mode 100644 index 000000000000..8b406d15cfc7 --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_lut_optimizer.py @@ -0,0 +1,95 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test the pass that removes unnecssary identity operation if the identity +uses LUT and the preceding operator is LUT capable and doesn't already have a LUT. +""" +import pytest + +pytest.importorskip("ethosu.vela") + +import tvm +from tvm import relay +from tvm.relay.backend.contrib.ethosu.codegen import LUTsOptimizer +from . import infra + + +def test_merge_lut_into_conv(): + """If an operator that has a LUT attribute is followed by an identity operator + with LUT, we can merge the two operataors.""" + + ifm = relay.var("x", shape=(1, 8, 8, 4), dtype="int8") + lut1 = relay.const([i for i in range(256)], dtype="int8") + lut2 = relay.const([i for i in reversed(range(256))], dtype="int8") + + def before(): + conv1 = infra.make_ethosu_conv2d(ifm, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1)) + id1 = infra.make_ethosu_identity(conv1, lut=lut1, activation="TANH") + conv2 = infra.make_ethosu_conv2d(id1, 4, 7, (2, 2), (1, 1), (1, 1), (1, 1)) + id2 = infra.make_ethosu_identity(conv2, lut=lut2, activation="TANH") + + func = relay.Function(relay.analysis.free_vars(id2), id2) + mod = tvm.IRModule.from_expr(func) + return mod + + def after(): + conv1 = infra.make_ethosu_conv2d( + ifm, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1), lut=lut1, activation="TANH" + ) + conv2 = infra.make_ethosu_conv2d( + conv1, 4, 7, (2, 2), (1, 1), (1, 1), (1, 1), lut=lut2, activation="TANH" + ) + + func = relay.Function(relay.analysis.free_vars(conv2), conv2) + mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) + return mod + + mod = LUTsOptimizer()(before()) + + assert tvm.ir.structural_equal(mod, after()) + + +def test_multiple_luts(): + """Test that when an operation already has a LUT, we don't overwrite that LUT""" + + ifm = relay.var("x", shape=(1, 8, 8, 4), dtype="int8") + lut1 = relay.const([i for i in range(256)], dtype="int8") + lut2 = relay.const([i for i in reversed(range(256))], dtype="int8") + + def before(): + conv1 = infra.make_ethosu_conv2d(ifm, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1)) + id1 = infra.make_ethosu_identity(conv1, lut=lut1, activation="TANH") + id2 = infra.make_ethosu_identity(id1, lut=lut2, activation="TANH") + + func = relay.Function(relay.analysis.free_vars(id2), id2) + mod = tvm.IRModule.from_expr(func) + return mod + + def after(): + conv1 = infra.make_ethosu_conv2d( + ifm, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1), lut=lut1, activation="TANH" + ) + id2 = infra.make_ethosu_identity(conv1, lut=lut2, activation="TANH") + + func = relay.Function(relay.analysis.free_vars(id2), id2) + mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) + return mod + + mod = LUTsOptimizer()(before()) + + assert tvm.ir.structural_equal(mod, after()) diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index 7992f421a5bd..1d3afec30cbc 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -29,7 +29,7 @@ @pytest.mark.parametrize( "trial", [ - [(1, 8, 8, 3), 3, 16, (1, 1), (2, 1), (1, 1), (1, 1), "TANH", "NHWC", "NHWC", "TFL"], + [(1, 8, 8, 3), 3, 16, (1, 1), (2, 1), (1, 1), (1, 1), "CLIP", "NHWC", "NHWC", "TFL"], [(1, 8, 8, 3), 3, 16, (1, 1), (0, 0), (1, 1), (1, 1), "NONE", "NHWC", "NHWC", "NATURAL"], [(1, 1, 1, 1), 1, 16, (1, 1), (0, 0), (1, 1), (1, 1), "CLIP", "NHWC", "NHWC", "TRUNCATE"], [(1, 7, 9, 4), 4, 13, (3, 2), (1, 2), (2, 1), (1, 2), "SIGMOID", "NHWC", "NHWC", "TFL"], @@ -124,12 +124,10 @@ def _get_func( padding, strides, dilation, - activation, - ifm_layout, - ofm_layout, - "int8", - "uint8", - rounding_mode, + activation=activation, + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + rounding_mode=rounding_mode, ) func = relay.Function(relay.analysis.free_vars(conv), conv) func = run_opt_pass(func, relay.transform.InferType()) @@ -409,9 +407,9 @@ def _get_func( padding, strides, dilation, - "NONE", - layout, - layout, + activation="NONE", + ifm_layout=layout, + ofm_layout=layout, ) conv2 = make_ethosu_conv2d( conv1, @@ -421,9 +419,9 @@ def _get_func( padding, strides, dilation, - "NONE", - layout, - layout, + activation="NONE", + ifm_layout=layout, + ofm_layout=layout, ) func = relay.Function(relay.analysis.free_vars(conv2), conv2) func = run_opt_pass(func, relay.transform.InferType()) @@ -577,7 +575,15 @@ def _get_func(ifm_shape, reshaped, ifm_layout): ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") ifm_reshaped = relay.reshape(ifm, reshaped) conv = make_ethosu_conv2d( - ifm_reshaped, reshaped[3], 16, (3, 3), (1, 1), (1, 1), (1, 1), "NONE", ifm_layout + ifm_reshaped, + reshaped[3], + 16, + (3, 3), + (1, 1), + (1, 1), + (1, 1), + activation="NONE", + ifm_layout=ifm_layout, ) func = relay.Function(relay.analysis.free_vars(conv), conv) func = run_opt_pass(func, relay.transform.InferType()) @@ -598,7 +604,9 @@ def test_conv2d_big_pad(): def _get_func(): ifm_shape = (1, 2, 2, 8) ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") - conv = make_ethosu_conv2d(ifm, ifm_shape[3], 16, (1, 1), (7, 7), (1, 1), (1, 1), "NHWC") + conv = make_ethosu_conv2d( + ifm, ifm_shape[3], 16, (1, 1), (7, 7), (1, 1), (1, 1), ifm_layout="NHWC" + ) func = relay.Function(relay.analysis.free_vars(conv), conv) func = run_opt_pass(func, relay.transform.InferType()) return func diff --git a/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py index cf2ac147759c..afd632cf355e 100644 --- a/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py @@ -30,7 +30,7 @@ "trial", [ [(1, 8, 8, 3), 3, (3, 2), (0, 0), (1, 1), (1, 1), "CLIP", "NHWC", "NHWC", "TFL"], - [(1, 8, 8, 3), 3, (1, 1), (2, 1), (1, 1), (1, 1), "TANH", "NHWC", "NHWC", "NATURAL"], + [(1, 8, 8, 3), 3, (1, 1), (2, 1), (1, 1), (1, 1), "NONE", "NHWC", "NHWC", "NATURAL"], [(1, 8, 8, 3), 3, (1, 1), (0, 0), (1, 1), (1, 1), "NONE", "NHWC", "NHWC", "TRUNCATE"], [(1, 1, 1, 1), 1, (1, 1), (0, 0), (1, 1), (1, 1), "CLIP", "NHWC", "NHWC", "TFL"], [(1, 7, 9, 4), 4, (3, 2), (1, 2), (2, 1), (1, 2), "SIGMOID", "NHWC", "NHWC", "NATURAL"], diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index b04059011e8e..cd84449c4a1b 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -27,9 +27,10 @@ total_cascader, copy_constants, schedule_cache_reads, + copy_luts, ) from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_te, extract_constants -from .infra import AttachType, make_ethosu_conv2d +from .infra import AttachType, make_ethosu_conv2d, make_ethosu_identity class TestTEGraph: @@ -126,6 +127,31 @@ def test_copy_constants(): assert ".global" in sch.stages[17].op.name +# This test makes sure that constants and LUTs have a correct storage scope +def test_copy_luts(): + ifm_shape = (1, 33, 33, 11) + ifm = relay.var("IFM", shape=ifm_shape, dtype="int8") + lut = relay.const([i for i in range(256)], dtype="int8") + conv = make_ethosu_conv2d( + ifm, ifm_shape[3], 8, (3, 3), (0, 0), (1, 1), (1, 1), lut=lut, activation="TANH" + ) + identity = make_ethosu_identity(conv, lut=lut, activation="TANH") + func = relay.Function(relay.analysis.free_vars(identity), identity) + func = run_opt_pass(func, relay.transform.InferType()) + + func, const_dict = extract_constants(func) + te_graph = lower_to_te(func) + + sch = te.create_schedule([te_graph.outputs[0].op]) + copy_constants()(te_graph, const_dict, sch) + copy_luts()(te_graph, const_dict, sch) + assert len(sch.stages) == 17 + assert ".global" in sch.stages[5].op.name + assert ".global" in sch.stages[7].op.name + assert ".local" in sch.stages[9].op.name + assert ".local" in sch.stages[10].op.name + + def test_schedule_cache_reads(): a = te.placeholder((12, 12), dtype="uint8", name="a") b = te.placeholder((12, 12), dtype="uint8", name="b") From 46b460225e01bb58e3ed75346f0e91b217d307bd Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Fri, 26 Nov 2021 11:35:08 +0000 Subject: [PATCH 2/3] Responding to the reviews --- .../relay/backend/contrib/ethosu/codegen.py | 20 ++++++++++++------- .../relay/backend/contrib/ethosu/legalize.py | 8 ++------ .../contrib/test_ethosu/test_lookup_table.py | 2 -- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index 8f193d48ad6b..1f331822e1ac 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -46,9 +46,12 @@ def __init__(self): def create_op_with_lut(self, call): """Extract the parameters and attributes from the NPU operator and create a new operator with LUT. + + Parameters ---------- call : tvm.relay.expr.Call The current call node being visited. + Returns ------- tvm.relay.expr.Call @@ -63,8 +66,7 @@ def create_op_with_lut(self, call): new_attrs["activation"] = activation # Assume that LUT is always the last argument - new_args = [ethosu_op.args[n] for n in range(len(ethosu_op.args) - 1)] - new_args.append(lut) + new_args = ethosu_op.args[:-1] + [lut] assert ethosu_op.op.name in self.lut_ops.keys() return self.lut_ops[ethosu_op.op.name](*new_args, **new_attrs) @@ -73,10 +75,12 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call: """Recursively visit call nodes in the input graph and if an ethosu.identity operator with LUT is found and the preceding operator has a LUT attribute, create a new NPU operator. + Parameters ---------- call : tvm.relay.expr.Call The current call node being visited. + Returns ------- tvm.relay.expr.Call @@ -104,24 +108,26 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call: return new_call -@relay.transform.function_pass(opt_level=1, name="LutOptimizer") +@relay.transform.function_pass(opt_level=1, name="LUTsOptimizer") class LUTsOptimizer(Pass): - """Register LutOptimizer as a relay pass.""" + """Register LUTsOptimizer as a relay pass.""" def transform_function( self, func: tvm.relay.function.Function, mod: tvm.IRModule, _ ) -> tvm.IRModule: """Visit relay nodes in the given module. + Parameters ---------- func : tvm.relay.function.Function - The function to apply the layout optimization pass to. + The function to apply the optimization pass for multiple LUTs to. mod : tvm.IRModule - The module to apply the layout optimization pass to. + The module to apply the optimization pass for multiple LUTs to. + Returns ------- mod : tvm.IRModule - New module with augmented layouts. + New module with optimized LUTs. """ assert len(mod.functions.items()) == 1, "Module can only contain one function." return OptimizeLUTs().visit(func) diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 3d4f8b71cfbb..5613d613f984 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -32,6 +32,7 @@ from tvm.relay.backend.contrib.ethosu import op as ethosu_ops # type: ignore from tvm.relay.backend.contrib.ethosu.errors import UnsupportedLayout # type: ignore from tvm.relay.backend.contrib.ethosu import vela_api +from tvm.relay.backend.contrib.ethosu import util from tvm.relay.op.contrib import ethosu as ethosu_patterns # type: ignore @@ -124,11 +125,6 @@ def __call__(self, *args, **kwargs): pass -def round_away_zero(f): - r = -0.5 if (f < 0) else 0.5 - return np.trunc(f + r) - - def find_tanh_values(ifm_scale, ifm_zp, ofm_scale, ofm_zp): """Method to calculate the values of the tanh lookup table""" lut_values = list() @@ -138,7 +134,7 @@ def find_tanh_values(ifm_scale, ifm_zp, ofm_scale, ofm_zp): for x in range(qmin, qmax + 1): x_real = ifm_scale * (x - ifm_zp) out_real = math.tanh(x_real) - lut_result = int(round_away_zero(ofm_zp + out_real / ofm_scale)) + lut_result = int(util.round_away_zero(ofm_zp + out_real / ofm_scale)) lut_result = min(qmax, max(qmin, lut_result)) lut_values.append(lut_result) diff --git a/tests/python/contrib/test_ethosu/test_lookup_table.py b/tests/python/contrib/test_ethosu/test_lookup_table.py index 67870bd6472e..d32b441fd2eb 100644 --- a/tests/python/contrib/test_ethosu/test_lookup_table.py +++ b/tests/python/contrib/test_ethosu/test_lookup_table.py @@ -40,8 +40,6 @@ def test_tflite_lut_activations(accel_type): ifm_shape = (1, 55, 55, 3) def create_tflite_graph(): - tf.config.run_functions_eagerly(True) - class Model(tf.Module): @tf.function def tf_func(self, x): From 2218102785fc3a5e3dc65ffac00f352594c2ba15 Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Fri, 26 Nov 2021 13:47:35 +0000 Subject: [PATCH 3/3] Checking that tvm.relay.expr.Call.op exists --- python/tvm/relay/backend/contrib/ethosu/codegen.py | 12 ++++++------ .../tvm/relay/backend/contrib/ethosu/op/op_attrs.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index 1f331822e1ac..e51f1702773b 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -90,14 +90,14 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call: new_call = call lut_activations = ["TANH", "LUT"] - if ( - call.op.name == "contrib.ethosu.identity" - and call.attrs.activation in lut_activations - and isinstance(call.args[0], tvm.relay.expr.Call) - ): + if isinstance(call.op, tvm.ir.Op) and isinstance(call.args[0], tvm.relay.expr.Call): producer_op = call.args[0] # Check if the producer can do a LUT operation - if producer_op.op.name in self.lut_ops.keys(): + if ( + producer_op.op.name in self.lut_ops.keys() + and call.op.name == "contrib.ethosu.identity" + and call.attrs.activation in lut_activations + ): # Check the producer doesn't already have a LUT has_lut = producer_op.attrs.activation in lut_activations if not has_lut: diff --git a/python/tvm/relay/backend/contrib/ethosu/op/op_attrs.py b/python/tvm/relay/backend/contrib/ethosu/op/op_attrs.py index e38a3dfd97de..a52736fe3964 100644 --- a/python/tvm/relay/backend/contrib/ethosu/op/op_attrs.py +++ b/python/tvm/relay/backend/contrib/ethosu/op/op_attrs.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""The attributes node used for EthosU Relay operators.""" +"""The attributes node used for Arm(R) Ethos(TM)-U NPU Relay operators.""" from tvm.ir import Attrs import tvm._ffi