diff --git a/python/tvm/autotvm/task/topi_integration.py b/python/tvm/autotvm/task/topi_integration.py index 307d44810c79..a4f3636edbbe 100644 --- a/python/tvm/autotvm/task/topi_integration.py +++ b/python/tvm/autotvm/task/topi_integration.py @@ -250,11 +250,15 @@ def wrapper(outs, *args, **kwargs): def get_workload(outs, task_name=None): """Retrieve the workload from outputs""" + visited = set() def traverse(tensors): """traverse all ops to find attached workload""" for t in tensors: op = t.op + if op in visited: + continue + visited.add(op) wkl = traverse(op.input_tensors) if wkl is not None: return wkl diff --git a/python/tvm/relay/backend/contrib/ethosu/__init__.py b/python/tvm/relay/backend/contrib/ethosu/__init__.py index c4948d54dc26..be77a81e4eb5 100644 --- a/python/tvm/relay/backend/contrib/ethosu/__init__.py +++ b/python/tvm/relay/backend/contrib/ethosu/__init__.py @@ -21,3 +21,4 @@ from . import codegen from . import vela_api from . import tir_to_cs_translator +from . import softmax_rewriter diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 5aaa1417ae4d..3e69b409a3a9 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -32,6 +32,7 @@ from tvm.relay.backend.contrib.ethosu import op as ethosu_ops # type: ignore from tvm.relay.backend.contrib.ethosu import vela_api from tvm.relay.backend.contrib.ethosu import util +from tvm.relay.backend.contrib.ethosu.softmax_rewriter import SoftmaxRewriter from tvm.relay.op.contrib import ethosu as ethosu_patterns # type: ignore @@ -1479,6 +1480,7 @@ def transform_npu_function(self, _, func: relay.Function) -> relay.Function: LeakyReLURewriter(), MeanRewriter(), SumRewriter(), + SoftmaxRewriter(), ConcatRewriter(), SigmoidRewriter(), RequantizeRewriter(), diff --git a/python/tvm/relay/backend/contrib/ethosu/softmax_rewriter.py b/python/tvm/relay/backend/contrib/ethosu/softmax_rewriter.py new file mode 100644 index 000000000000..16067fed951d --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/softmax_rewriter.py @@ -0,0 +1,516 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""SoftmaxRewriter for legalization Softmax operation.""" +import math + +import numpy as np +from ethosu.vela import fp_math, scaling + +import tvm +from tvm import relay +from tvm.relay.backend.contrib.ethosu import op as ethosu_ops +from tvm.relay.dataflow_pattern import DFPatternCallback, wildcard +from tvm.relay.op.contrib import ethosu as ethosu_patterns + + +class SoftmaxRewriter(DFPatternCallback): + """This rewriting converts Softmax operation into a sequence of operations as in Vela.""" + + def __init__(self): + super().__init__(require_type=True, rewrite_once=True) + self.params_class = ethosu_patterns.SoftMaxParams + self.pattern = ( + wildcard().has_attr({"Composite": ethosu_patterns.SoftMaxParams.composite_name}) + )(None) + + def generate_exp_table(self, input_scale): + """Generate a LUT table for exponential function. + + Parameters + ---------- + input_scale : float + The scale for input. + + Returns + ------- + lut : tvm.relay.expr.Constant + LUT table for exponential function. + """ + beta = 1.0 + integer_bits = 5 + total_signed_bits = 31 + # Calculate scaling + real_beta = min( + np.double(beta) * np.double(input_scale) * (1 << (31 - integer_bits)), + np.double((1 << 31) - 1.0), + ) + scale, shift = scaling.quantise_scale(real_beta) + shift = 31 - shift + diff_min = -1.0 * math.floor( + 1.0 + * ((1 << integer_bits) - 1) + * (1 << (total_signed_bits - integer_bits)) + / (1 << shift) + ) + # Generate the exp LUT + lut = [] + for x in range(256): + input_diff = x - 255 + if input_diff >= diff_min: + rescale = fp_math.saturating_rounding_mul32(input_diff * (1 << shift), scale) + lut.append(fp_math.exp_on_negative_values(rescale)) + else: + lut.append(0) + res = np.array(lut, dtype="int32") + return relay.const(res) + + def callback( + self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map + ) -> tvm.relay.Expr: + params = self.params_class(post.op.body) + + ifm = post.args[0] + ifm_dtype = ifm.checked_type.dtype + bhw = np.prod(params.ifm.shape[:-1]) + depth = params.ifm.shape[-1] + + # The calculation of Softmax is similar to that in Vela + # https://review.mlplatform.org/plugins/gitiles/ml/ethos-u/ethos-u-vela/+/refs/tags/3.7.0/ethosu/vela/softmax.py#230 + # PASS 0 - Depthwise Maxpool + # reshape for depthwise maxpool + ifm = relay.reshape(ifm, (1, bhw, depth, 1)) + lut = relay.const([], dtype="int32") + depthwise_maxpool = ethosu_ops.ethosu_pooling( + ifm=ifm, + lut=lut, + pooling_type="MAX", + ifm_scale=float(params.ifm.q_params.scale_f32), + ifm_zero_point=int(params.ifm.q_params.zero_point), + ofm_scale=0.0, + ofm_zero_point=int(params.ofm.q_params.zero_point), + pool_shape=(1, depth), + ofm_channels=1, + ofm_dtype=ifm_dtype, + ) + + # PASS 1 - Sub+LUT(exp) + # move all data along the height axis, except channels + ifm = relay.reshape(ifm, (1, bhw, 1, depth)) + exp_lut = self.generate_exp_table(float(params.ifm.q_params.scale_f32)) + ifm_exp = ethosu_ops.ethosu_binary_elementwise( + ifm=ifm, + ifm2=depthwise_maxpool, + lut=exp_lut, + operator_type="SUB", + ifm_scale=float(params.ifm.q_params.scale_f32), + ifm_zero_point=int(params.ifm.q_params.zero_point), + ifm2_scale=0.0, + ifm2_zero_point=int(params.ifm.q_params.zero_point), + ofm_scale=1.0, + ofm_zero_point=127, + ifm_channels=depth, + ifm2_channels=1, + reversed_operands=False, + ofm_dtype="int32", + activation="LUT", + ) + + # PASS 2 - SHR + shr_const = relay.const(np.full([1, 1, 1, 1], 12, dtype="int32")) + shr = ethosu_ops.ethosu_binary_elementwise( + ifm=ifm_exp, + ifm2=shr_const, + lut=lut, + operator_type="SHR", + ifm_scale=1.0, + ifm_zero_point=0, + ifm2_scale=0.0, + ifm2_zero_point=0, + ofm_scale=0.0, + ofm_zero_point=int(params.ofm.q_params.zero_point), + ifm_channels=params.ifm.shape[-1], + ifm2_channels=1, + reversed_operands=False, + ofm_dtype="int32", + activation="CLIP", + clip_min=-128, + clip_max=127, + rounding_mode="NATURAL", + ) + + # PASS 3 - Reduce sum + sum_of_exp = ethosu_ops.ethosu_pooling( + ifm=shr, + lut=lut, + pooling_type="SUM", + ifm_scale=0.0, + ifm_zero_point=0, + ofm_scale=0.0, + ofm_zero_point=int(params.ofm.q_params.zero_point), + pool_shape=(1, 1), + ofm_channels=1, + upscale="NONE", + ofm_dtype="int32", + ) + + # PASS 4 - CLZ + headroom_plus_one = ethosu_ops.ethosu_unary_elementwise( + ifm=sum_of_exp, + lut=lut, + operator_type="CLZ", + ifm_scale=0.0, + ifm_zero_point=0, + ofm_scale=0.0, + ofm_zero_point=int(params.ofm.q_params.zero_point), + ofm_channels=1, + ) + + # PASS 5 - Sub + headroom_offset_const = relay.const(np.full([1, bhw, 1, 1], 35, dtype="int32")) + right_shift = ethosu_ops.ethosu_binary_elementwise( + ifm=headroom_offset_const, + ifm2=headroom_plus_one, + lut=lut, + operator_type="SUB", + ifm_scale=1.0, + ifm_zero_point=0, + ifm2_scale=0.0, + ifm2_zero_point=0, + ofm_scale=1.0, + ofm_zero_point=int(params.ofm.q_params.zero_point), + ifm_channels=1, + ifm2_channels=1, + reversed_operands=False, + ofm_dtype="int32", + ) + + # PASS 6 - Sub + one_const = relay.const(np.full([1, 1, 1, 1], 1, dtype="int32")) + headroom = ethosu_ops.ethosu_binary_elementwise( + ifm=headroom_plus_one, + ifm2=one_const, + lut=lut, + operator_type="SUB", + ifm_scale=0.0, + ifm_zero_point=0, + ifm2_scale=0.0, + ifm2_zero_point=0, + ofm_scale=0.0, + ofm_zero_point=int(params.ofm.q_params.zero_point), + ifm_channels=1, + ifm2_channels=1, + reversed_operands=False, + ofm_dtype="int32", + ) + + # PASS 7 - SHL + shifted_sum = ethosu_ops.ethosu_binary_elementwise( + ifm=sum_of_exp, + ifm2=headroom, + lut=lut, + operator_type="SHL", + ifm_scale=0.0, + ifm_zero_point=0, + ifm2_scale=0.0, + ifm2_zero_point=0, + ofm_scale=0.0, + ofm_zero_point=int(params.ofm.q_params.zero_point), + ifm_channels=depth, + ifm2_channels=1, + reversed_operands=False, + ofm_dtype="int32", + activation="CLIP", + clip_min=-128, + clip_max=127, + ) + + # PASS 8 - Sub + shifted_one_const = relay.const(np.full([1, 1, 1, 1], 1 << 30, dtype="int32")) + shifted_sum_minus_one = ethosu_ops.ethosu_binary_elementwise( + ifm=shifted_sum, + ifm2=shifted_one_const, + lut=lut, + operator_type="SUB", + ifm_scale=0.0, + ifm_zero_point=0, + ifm2_scale=0.0, + ifm2_zero_point=0, + ofm_scale=0.0, + ofm_zero_point=int(params.ofm.q_params.zero_point), + ifm_channels=1, + ifm2_channels=1, + reversed_operands=False, + ofm_dtype="int32", + ) + + # PASS 9 - SHL + shifted_sum_minus_one = ethosu_ops.ethosu_binary_elementwise( + ifm=shifted_sum_minus_one, + ifm2=one_const, + lut=lut, + operator_type="SHL", + ifm_scale=0.0, + ifm_zero_point=0, + ifm2_scale=0.0, + ifm2_zero_point=0, + ofm_scale=0.0, + ofm_zero_point=int(params.ofm.q_params.zero_point), + ifm_channels=1, + ifm2_channels=1, + reversed_operands=False, + ofm_dtype="int32", + activation="CLIP", + clip_min=-128, + clip_max=127, + ) + + # PASS 10 - Add + f0_one_const = relay.const(np.full([1, 1, 1, 1], (1 << 31) - 1, dtype="int32")) + half_denominator = ethosu_ops.ethosu_binary_elementwise( + ifm=shifted_sum_minus_one, + ifm2=f0_one_const, + lut=lut, + operator_type="ADD", + ifm_scale=0.0, + ifm_zero_point=0, + ifm2_scale=0.0, + ifm2_zero_point=0, + ofm_scale=1.0, + ofm_zero_point=0, + ifm_channels=1, + ifm2_channels=1, + reversed_operands=False, + ofm_dtype="int32", + activation="CLIP", + clip_min=-128, + clip_max=127, + use_rescale=True, + rescale_scale=1, + rescale_shift=1, + ) + + # PASS 11 - Mul + neg_32_over_17_const = relay.const(np.full([1, 1, 1, 1], -1010580540, dtype="int32")) + rescaled = ethosu_ops.ethosu_binary_elementwise( + ifm=half_denominator, + ifm2=neg_32_over_17_const, + lut=lut, + operator_type="MUL", + ifm_scale=1.0, + ifm_zero_point=0, + ifm2_scale=1.0, + ifm2_zero_point=0, + ofm_scale=2.0, + ofm_zero_point=0, + ifm_channels=depth, + ifm2_channels=1, + reversed_operands=False, + ofm_dtype="int32", + activation="CLIP", + clip_min=-128 * 2, + clip_max=127 * 2, + ) + + # PASS 12 - Add + const_48_over_17_const = relay.const(np.full([1, 1, 1, 1], 1515870810, dtype="int32")) + rescale_w_offset = ethosu_ops.ethosu_binary_elementwise( + ifm=rescaled, + ifm2=const_48_over_17_const, + lut=lut, + operator_type="ADD", + ifm_scale=2.0, + ifm_zero_point=0, + ifm2_scale=0.0, + ifm2_zero_point=0, + ofm_scale=1.0, + ofm_zero_point=0, + ifm_channels=1, + ifm2_channels=1, + reversed_operands=False, + ofm_dtype="int32", + activation="CLIP", + clip_min=-128, + clip_max=127, + ) + + nr_x = rescale_w_offset + f2_one_const = relay.const(np.full([1, bhw, 1, 1], 1 << 29, dtype="int32")) + four_const = relay.const(np.full([1, 1, 1, 1], 4, dtype="int32")) + for _ in range(3): + # PASS 13, 18, 23 - Mul + half_denominator_times_x = ethosu_ops.ethosu_binary_elementwise( + ifm=nr_x, + ifm2=half_denominator, + lut=lut, + operator_type="MUL", + ifm_scale=1.0, + ifm_zero_point=0, + ifm2_scale=1.0, + ifm2_zero_point=0, + ofm_scale=2.0, + ofm_zero_point=0, + ifm_channels=1, + ifm2_channels=1, + reversed_operands=False, + ofm_dtype="int32", + activation="CLIP", + clip_min=-128 * 2, + clip_max=127 * 2, + ) + + # PASS 14, 19, 24 - Sub + one_minus_half_denomin_times_x = ethosu_ops.ethosu_binary_elementwise( + ifm=f2_one_const, + ifm2=half_denominator_times_x, + lut=lut, + operator_type="SUB", + ifm_scale=0.0, + ifm_zero_point=0, + ifm2_scale=2.0, + ifm2_zero_point=0, + ofm_scale=1.0, + ofm_zero_point=0, + ifm_channels=1, + ifm2_channels=1, + reversed_operands=False, + ofm_dtype="int32", + ) + + # PASS 15, 20, 25 - Mul + to_rescale = ethosu_ops.ethosu_binary_elementwise( + ifm=nr_x, + ifm2=one_minus_half_denomin_times_x, + lut=lut, + operator_type="MUL", + ifm_scale=1.0, + ifm_zero_point=0, + ifm2_scale=1.0, + ifm2_zero_point=0, + ofm_scale=2.0, + ofm_zero_point=0, + ifm_channels=1, + ifm2_channels=1, + reversed_operands=False, + ofm_dtype="int32", + activation="CLIP", + clip_min=-128 * 2, + clip_max=127 * 2, + ) + + # PASS 16, 21, 26 - Mul + to_add = ethosu_ops.ethosu_binary_elementwise( + ifm=to_rescale, + ifm2=four_const, + lut=lut, + operator_type="MUL", + ifm_scale=2.0, + ifm_zero_point=0, + ifm2_scale=0.0, + ifm2_zero_point=0, + ofm_scale=0.0, + ofm_zero_point=int(params.ofm.q_params.zero_point), + ifm_channels=1, + ifm2_channels=1, + reversed_operands=False, + ofm_dtype="int32", + activation="CLIP", + clip_min=-128, + clip_max=127, + ) + + # PASS 17, 22, 27 - Add + nr_x = ethosu_ops.ethosu_binary_elementwise( + ifm=nr_x, + ifm2=to_add, + lut=lut, + operator_type="ADD", + ifm_scale=1.0, + ifm_zero_point=0, + ifm2_scale=0.0, + ifm2_zero_point=0, + ofm_scale=1.0, + ofm_zero_point=0, + ifm_channels=1, + ifm2_channels=1, + reversed_operands=False, + ofm_dtype="int32", + ) + + # PASS 28 - Mul + two_const = relay.const(np.full([1, 1, 1, 1], 2, dtype="int32")) + scale_factor = ethosu_ops.ethosu_binary_elementwise( + ifm=nr_x, + ifm2=two_const, + lut=lut, + operator_type="MUL", + ifm_scale=1.0, + ifm_zero_point=0, + ifm2_scale=0.0, + ifm2_zero_point=0, + ofm_scale=1.0, + ofm_zero_point=0, + ifm_channels=1, + ifm2_channels=1, + reversed_operands=False, + ofm_dtype="int32", + activation="CLIP", + clip_min=-128, + clip_max=127, + ) + + # PASS 29 - Mul + scaled_exp = ethosu_ops.ethosu_binary_elementwise( + ifm=ifm_exp, + ifm2=scale_factor, + lut=lut, + operator_type="MUL", + ifm_scale=1.0, + ifm_zero_point=0, + ifm2_scale=1.0, + ifm2_zero_point=0, + ofm_scale=2.0, + ofm_zero_point=0, + ifm_channels=depth, + ifm2_channels=1, + reversed_operands=False, + ofm_dtype="int32", + activation="CLIP", + clip_min=-128 * 2, + clip_max=127 * 2, + ) + + # PASS 30 - SHR + shr30_op = ethosu_ops.ethosu_binary_elementwise( + ifm=scaled_exp, + ifm2=right_shift, + lut=lut, + operator_type="SHR", + ifm_scale=2.0, + ifm_zero_point=0, + ifm2_scale=0.0, + ifm2_zero_point=0, + ofm_scale=float(params.ofm.q_params.scale_f32), + ofm_zero_point=int(params.ofm.q_params.zero_point), + ifm_channels=depth, + ifm2_channels=1, + reversed_operands=False, + rounding_mode="NATURAL", + ofm_dtype=ifm_dtype, + ) + + reshape = relay.reshape(shr30_op, params.ofm.shape) + return reshape diff --git a/python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py index 50bbd36d9800..dde3133b56fa 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py @@ -94,7 +94,8 @@ def unary_elementwise_compute( assert ofm_layout in {"NHWC", "NHCWB16"} # Changing the ifm and ofm scale to conform with that expected by Vela API - ofm_scale = ifm_scale / ofm_scale + if ofm_scale != 0: + ofm_scale = ifm_scale / ofm_scale ifm_scale = 1.0 # Compute operation for the IFM DMA pipeline diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index c721efb4710a..0f6105277f46 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -250,7 +250,9 @@ def _visit(stmt): # Note by convention the arg after a constant read is the length of the read length = int(stmt.args[i + 1]) # If it's anything other than a full read, create a new buffer - if (offset != 0 or flattened_const_shape != length) and not is_u65_conv2d: + if ( + offset != 0 or flattened_const_shape != length and length > 0 + ) and not is_u65_conv2d: out_channels = const.shape[0] offset_channels = int((offset * out_channels) / flattened_const_shape) length_channels = int((length * out_channels) / flattened_const_shape) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index 50268f5f874f..e2ebfd0d1cd3 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -582,12 +582,16 @@ def _convert_clip_bounds(npu_op: vapi.NpuBlockOperation): """ clip_min_quant = npu_op.activation.min clip_max_quant = npu_op.activation.max - clip_min_actual = ( - clip_min_quant - npu_op.ofm.quantization.zero_point - ) * npu_op.ofm.quantization.scale_f32 - clip_max_actual = ( - clip_max_quant - npu_op.ofm.quantization.zero_point - ) * npu_op.ofm.quantization.scale_f32 + if npu_op.ofm.quantization.scale_f32: + clip_min_actual = ( + clip_min_quant - npu_op.ofm.quantization.zero_point + ) * npu_op.ofm.quantization.scale_f32 + clip_max_actual = ( + clip_max_quant - npu_op.ofm.quantization.zero_point + ) * npu_op.ofm.quantization.scale_f32 + else: + clip_min_actual = clip_min_quant + clip_max_actual = clip_max_quant npu_op.activation.min = clip_min_actual npu_op.activation.max = clip_max_actual diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 8ec06d3a923e..744c15987bbe 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -2006,6 +2006,63 @@ def pad_pattern(): return pattern +class SoftMaxParams: + """ + This class will parse a call to a ethos-u.softmax composite function + and extract the parameter information. + """ + + composite_name = "ethos-u.softmax" + + def __init__(self, func_body: Call): + from tvm.relay.backend.contrib.ethosu.util import QuantizeArgs + from tvm.relay.backend.contrib.ethosu.util import DequantizeArgs + + quantize = func_body + softmax_op = quantize.args[0] + dequantize = softmax_op.args[0] + + layout = "NHWC" + + self.ifm = TensorParams( + dequantize.args[DequantizeArgs.IFM.value], + layout, + dequantize.args[DequantizeArgs.IFM_SCALE.value], + dequantize.args[DequantizeArgs.IFM_ZERO_POINT.value], + ) + self.ofm = TensorParams( + quantize, + layout, + quantize.args[QuantizeArgs.OFM_SCALE.value], + quantize.args[QuantizeArgs.OFM_ZERO_POINT.value], + ) + + self.operator_type = "SOFTMAX" + + def is_valid(self): + """Checks whether Softmax has compatible attributes with HW""" + tensor_params = [self.ifm, self.ofm] + if not check_valid_dtypes(tensor_params, supported_dtypes=[np.int8]): + return False + if self.ifm.dtype != self.ofm.dtype: + return False + if not check_dimensions(self.ifm): + return False + if self.ifm.shape != self.ofm.shape: + return False + return True + + +def softmax_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """ + This function creates the pattern for Softmax. + """ + pattern = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant()) + pattern = is_op("nn.softmax")(pattern) + pattern = is_op("qnn.quantize")(pattern, is_constant(), is_constant()) + return pattern + + @register_pattern_table("ethos-u") def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]: return [ @@ -2110,6 +2167,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal sum_pattern(), lambda pat: SumParams(pat).is_valid(), ), + ( + SoftMaxParams.composite_name, + softmax_pattern(), + lambda pat: SoftMaxParams(pat).is_valid(), + ), ( LeakyReLUParams.composite_name, leaky_relu_pattern(), diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 14441d8e9313..86f64d648320 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -315,6 +315,24 @@ def pooling(x): infra.compare_tvm_with_tflite(pooling, [ifm_shape], accel_type) +@pytest.mark.parametrize( + "accel_type", + ["ethos-u55-256", "ethos-u65-256"], +) +@pytest.mark.parametrize("ifm_shape", [[1, 148, 29], [4, 148, 29], [1, 12], [8, 12]]) +def test_ethosu_softmax( + accel_type, + ifm_shape, +): + np.random.seed(0) + + @tf.function + def softmax(x): + return tf.nn.softmax(x) + + infra.compare_tvm_with_tflite(softmax, [ifm_shape], accel_type) + + @pytest.mark.parametrize("accel_type", ACCEL_TYPES) @pytest.mark.parametrize("operator_type", ["ADD", "SUB", "MUL", "MIN", "MAX"]) @pytest.mark.parametrize( diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 6330930fa5f8..d1d0befcee70 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -3076,5 +3076,125 @@ def representative_dataset(): assert tuple(func_body.args[1].checked_type.shape) == (256,) +@pytest.mark.parametrize("ifm_shape", [(1, 12), (1, 12, 32)]) +def test_tflite_softmax(ifm_shape): + dtype = "int8" + + def create_tflite_graph(): + @tf.function + def softmax(x): + return tf.nn.softmax(x) + + concrete_func = softmax.get_concrete_function(tf.TensorSpec(ifm_shape, dtype=tf.float32)) + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + def verify(ext_func): + out_op = ext_func.body + ops = [] + # List of expected operations and their type if it exists + expected_ops = [ + ("reshape", None), + ("reshape", None), + ("contrib.ethosu.pooling", "MAX"), + ("contrib.ethosu.binary_elementwise", "SUB"), + ("contrib.ethosu.binary_elementwise", "SHR"), + ("contrib.ethosu.pooling", "SUM"), + ("contrib.ethosu.unary_elementwise", "CLZ"), + ("contrib.ethosu.binary_elementwise", "SUB"), + ("contrib.ethosu.binary_elementwise", "SHL"), + ("contrib.ethosu.binary_elementwise", "SUB"), + ("contrib.ethosu.binary_elementwise", "SHL"), + ("contrib.ethosu.binary_elementwise", "ADD"), + ("contrib.ethosu.binary_elementwise", "MUL"), + ("contrib.ethosu.binary_elementwise", "ADD"), + ("contrib.ethosu.binary_elementwise", "MUL"), + ("contrib.ethosu.binary_elementwise", "SUB"), + ("contrib.ethosu.binary_elementwise", "MUL"), + ("contrib.ethosu.binary_elementwise", "MUL"), + ("contrib.ethosu.binary_elementwise", "ADD"), + ("contrib.ethosu.binary_elementwise", "MUL"), + ("contrib.ethosu.binary_elementwise", "SUB"), + ("contrib.ethosu.binary_elementwise", "MUL"), + ("contrib.ethosu.binary_elementwise", "MUL"), + ("contrib.ethosu.binary_elementwise", "ADD"), + ("contrib.ethosu.binary_elementwise", "MUL"), + ("contrib.ethosu.binary_elementwise", "SUB"), + ("contrib.ethosu.binary_elementwise", "MUL"), + ("contrib.ethosu.binary_elementwise", "MUL"), + ("contrib.ethosu.binary_elementwise", "ADD"), + ("contrib.ethosu.binary_elementwise", "MUL"), + ("contrib.ethosu.binary_elementwise", "MUL"), + ("contrib.ethosu.binary_elementwise", "SUB"), + ("contrib.ethosu.binary_elementwise", "SHR"), + ("reshape", None), + ] + + def get_op_type(op): + if hasattr(op.attrs, "pooling_type"): + return op.attrs.pooling_type + elif hasattr(op.attrs, "operator_type"): + return op.attrs.operator_type + return None + + def _visit(stmt): + if isinstance(stmt, relay.expr.Call): + ops.append(stmt) + + relay.analysis.post_order_visit(out_op, _visit) + + # check IFM + ifm = ops[0].args[0].checked_type + assert list(ifm.shape) == list(ifm_shape) + assert str(ifm.dtype) == dtype + + # check OFM + ofm = out_op.checked_type + assert list(ofm.shape) == list(ifm_shape) + assert ofm.dtype == dtype + + # check operations + + ops = [(op.op.name, get_op_type(op)) for op in ops] + assert expected_ops == ops + + softmax_pattern_table = [ + ( + ethosu.SoftMaxParams.composite_name, + ethosu.softmax_pattern(), + lambda pat: ethosu.SoftMaxParams(pat).is_valid(), + ) + ] + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + mod["main"] = bind_params_by_name(mod["main"], params) + mod = partition_ethosu_by_table(mod, softmax_pattern_table) + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.SoftmaxRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + mod = relay.transform.InferType()(mod) + + verify(mod["tvmgen_default_ethos_u_main_0"]) + + if __name__ == "__main__": tvm.testing.main()