diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index 002cb4b6be9b..22f248ba012c 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -89,7 +89,7 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call: not refer to an Op. Else, a new call node with a new operator. """ new_call = call - lut_activations = ["TANH", "LUT"] + lut_activations = ["TANH", "LUT", "SIGMOID"] if isinstance(call.op, tvm.ir.Op) and isinstance(call.args[0], tvm.relay.expr.Call): producer_op = call.args[0] diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index f8beb7f7464e..b2264f32611e 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=invalid-name, unused-argument, import-outside-toplevel, no-value-for-parameter """A set of passes to legalize some of operations for the NPU""" -from typing import List, Type +from typing import List, Type, Callable import math import numpy as np # type: ignore @@ -125,15 +125,17 @@ def __call__(self, *args, **kwargs): pass -def find_tanh_values(ifm_scale, ifm_zp, ofm_scale, ofm_zp): - """Method to calculate the values of the tanh lookup table""" +def get_lut_from_func( + ifm_scale: float, ifm_zp: int, ofm_scale: float, ofm_zp: int, func: Callable[[float], float] +) -> List[int]: + """Method to calculate the values of the lookup table based on the calculation function""" lut_values = list() # Only int8 is currently supported dtype = np.int8 qmin, qmax = np.iinfo(dtype).min, np.iinfo(dtype).max for x in range(qmin, qmax + 1): x_real = ifm_scale * (x - ifm_zp) - out_real = math.tanh(x_real) + out_real = func(x_real) lut_result = int(util.round_away_zero(ofm_zp + out_real / ofm_scale)) lut_result = min(qmax, max(qmin, lut_result)) lut_values.append(lut_result) @@ -141,16 +143,18 @@ def find_tanh_values(ifm_scale, ifm_zp, ofm_scale, ofm_zp): return lut_values -class TanhRewriter(DFPatternCallback): - """This pass adds tanh as a LUT to the identity operator""" +class LutActivationRewriter(DFPatternCallback): + """A class to create an identity operator with the LUT""" - def __init__(self): + def __init__( + self, params_class: Type, activation_type: str, calc_func: Callable[[float], float] + ): super().__init__(require_type=True, rewrite_once=True) - self.pattern = ( - wildcard().has_attr({"Composite": ethosu_patterns.TanhParams.composite_name}) - )(wildcard()) + self.pattern = (wildcard().has_attr({"Composite": params_class.composite_name}))(wildcard()) + self.activation_type = activation_type + self.calc_func = calc_func - def callback(self, pre, post, node_map): + def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map): id_input = post.args[0] quantize_args = post.op.body.args @@ -161,7 +165,9 @@ def callback(self, pre, post, node_map): input_scale = float(dequantize_args[1].data.asnumpy()) input_zp = int(dequantize_args[2].data.asnumpy()) - lut_values = find_tanh_values(input_scale, input_zp, output_scale, output_zp) + lut_values = get_lut_from_func( + input_scale, input_zp, output_scale, output_zp, self.calc_func + ) lut = relay.const(lut_values, dtype="uint8") # We baked the requantization into the LUT, so we don't requantize the identity operator @@ -172,12 +178,21 @@ def callback(self, pre, post, node_map): ifm_zero_point=input_zp, ofm_scale=input_scale, ofm_zero_point=input_zp, - activation="TANH", + activation=self.activation_type, ) return identity +class TanhRewriter(LutActivationRewriter): + """This pass adds tanh as a LUT to the identity operator""" + + def __init__(self): + super().__init__( + params_class=ethosu_patterns.TanhParams, activation_type="TANH", calc_func=math.tanh + ) + + @ir.transform.module_pass(opt_level=1) class LegalizeTanh: """This is the pass that wraps TanhRewriter""" @@ -194,6 +209,48 @@ def __call__(self, *args, **kwargs): pass +def sigmoid_calc_func(x: float) -> float: + """Function to calculate the values for sigmoid""" + # Thse limits are inherited from TFLite + upper_limit = 8.0 + lower_limit = -8.0 + + if x <= lower_limit: + y = 0.0 + elif x >= upper_limit: + y = 1.0 + else: + y = 1 / (1 + math.exp(-x)) + return y + + +class SigmoidRewriter(LutActivationRewriter): + """This pass adds sigmoid as a LUT for identity op""" + + def __init__(self): + super().__init__( + params_class=ethosu_patterns.SigmoidParams, + activation_type="SIGMOID", + calc_func=sigmoid_calc_func, + ) + + +@ir.transform.module_pass(opt_level=1) +class LegalizeSigmoid: + """This is the pass that wraps SigmoidRewriter""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(SigmoidRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + class Conv2DRewriter(DFPatternCallback): """Convert conv2d related composite functions into ethosu_conv2d operators""" @@ -1196,6 +1253,7 @@ def transform_module( mod = LegalizeTanh()(mod) mod = LegalizeMean()(mod) mod = LegalizeConcat()(mod) + mod = LegalizeSigmoid()(mod) mod = LegalizeReshape()(mod) mod = LegalizeStridedSlice()(mod) mod = LegalizeNoOps()(mod) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py index 242c6feaa195..6e50c6ff3b0b 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py @@ -140,11 +140,13 @@ def conv2d_compute( "dilation_w": dilation_w, } + has_lut = activation in ("TANH", "LUT", "SIGMOID") + # This is a trick to insert the LUT tensor into the TE graph if LUT is present - lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0 + lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if has_lut else 0 # Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT - if activation in ("TANH", "LUT"): + if has_lut: conv2d_attrs["lut"] = lut conv = te.compute( diff --git a/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py index c9a88e803c3d..f54f2f3654e2 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py @@ -139,11 +139,13 @@ def depthwise_conv2d_compute( "dilation_w": dilation_w, } + has_lut = activation in ("TANH", "LUT", "SIGMOID") + # This is a trick to insert the LUT tensor into the TE graph if LUT is present - lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0 + lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if has_lut else 0 # Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT - if activation in ("TANH", "LUT"): + if has_lut: depthwise_conv2d_attrs["lut"] = lut depthwise = te.compute( diff --git a/python/tvm/relay/backend/contrib/ethosu/te/identity.py b/python/tvm/relay/backend/contrib/ethosu/te/identity.py index 574fc661599f..271ca1542fc5 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/identity.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/identity.py @@ -61,11 +61,13 @@ def identity_compute( dmaed_ifm = read_compute(ifm, ifm_zero_point, ifm_scale) id_attrs = {"op": "ethosu_identity", "activation": activation} + has_lut = activation in ("TANH", "LUT", "SIGMOID") + # This is a trick to insert the LUT tensor into the TE graph if LUT is present - lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0 + lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if has_lut else 0 # Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT - if activation in ("TANH", "LUT"): + if has_lut: id_attrs["lut"] = lut identity = te.compute( diff --git a/python/tvm/relay/backend/contrib/ethosu/te/pooling.py b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py index 2ab0844b1622..e98a72db7f02 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/pooling.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py @@ -123,11 +123,13 @@ def pooling_compute( "upscale": upscale, } + has_lut = activation in ("TANH", "LUT", "SIGMOID") + # This is a trick to insert the LUT tensor into the TE graph if LUT is present - lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0 + lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if has_lut else 0 # Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT - if activation in ("TANH", "LUT"): + if has_lut: pooling_attrs["lut"] = lut pooling = te.compute( diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index bf9e3f8cc977..a7d3da3200b5 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -918,27 +918,30 @@ def abs_pattern() -> tvm.relay.dataflow_pattern.DFPattern: return pattern -class TanhParams: +class LutActivationParams: """ - This class will parse a call to a ethos-u.tanh composite function - and extract the parameter information. + A parent class for LUT based activation functions that extract the input and + output tensors and check whether they are valid. """ - composite_name = "ethos-u.tanh" - def __init__(self, func_body: Call): self.ofm = TensorParams(func_body) self.ifm = TensorParams(func_body.args[0].args[0].args[0]) def is_valid(self): """ - This function checks whether reshape has compatible attributes with the NPU + This function checks whether activation has compatible attributes with the NPU """ if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8]): return False return True +class TanhParams(LutActivationParams): + + composite_name = "ethos-u.tanh" + + def tanh_pattern(): """Create pattern for tanh""" dequant = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant()) @@ -947,6 +950,23 @@ def tanh_pattern(): return quant +class SigmoidParams(LutActivationParams): + """ + This class will parse a call to a ethos-u.sigmoid composite function + and extract the parameter information. + """ + + composite_name = "ethos-u.sigmoid" + + +def sigmoid_pattern(): + """Create pattern for sigmoid""" + dequant = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant()) + sigmoid = is_op("sigmoid")(dequant) + quant = is_op("qnn.quantize")(sigmoid, is_constant(), is_constant()) + return quant + + class MeanParams: """ This class will parse a call to ethosu.mean composite function @@ -1162,6 +1182,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal lambda pat: MeanParams(pat).is_valid(), ), (ConcatParams.composite_name, concat_pattern(), lambda pat: ConcatParams(pat).is_valid()), + ( + SigmoidParams.composite_name, + sigmoid_pattern(), + lambda pat: SigmoidParams(pat).is_valid(), + ), ] diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 0e55487ae85f..21e86c866512 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -815,66 +815,14 @@ def clz_comp(n): @pytest.mark.parametrize("accel_type", ACCEL_TYPES) def test_tflite_tanh(accel_type): - dtype = "int8" ifm_shape = [1, 115, 32, 7] - def create_tflite_graph(): - class Model(tf.Module): - @tf.function - def tanh_function(self, x): - op = tf.nn.tanh(x) - return op - - model = Model() - concrete_func = model.tanh_function.get_concrete_function( - tf.TensorSpec(ifm_shape, dtype=tf.float32) - ) - - # Convert the model - def representative_dataset(): - for _ in range(100): - data = np.random.rand(*tuple(ifm_shape)) - yield [data.astype(np.float32)] - - converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.representative_dataset = representative_dataset - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] - converter.inference_input_type = tf.int8 - converter.inference_output_type = tf.int8 - tflite_model = converter.convert() - return tflite_model - - tflite_graph = create_tflite_graph() - - tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) - - relay_module, params = relay.frontend.from_tflite( - tflite_model, - shape_dict={"input": ifm_shape}, - dtype_dict={"input": dtype}, - ) - mod = partition_for_ethosu(relay_module, params) - - # Generate reference data - input_data, output_data = infra.generate_ref_data_tflite(tflite_graph) + @tf.function + def tanh_func(x): + op = tf.nn.tanh(x) + return op - compiled_models = infra.build_source( - mod, - input_data, - output_data, - accel_type, - ) - - # Assumes only two runtime.Modules are created -- i.e. single offload module - ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] - - # Verify generated C source - get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") - compilation_artifacts = get_artifacts(ethosu_module) - cmms = bytes.fromhex(compilation_artifacts[0].command_stream) - infra.print_payload(cmms) - infra.verify_source(compiled_models, accel_type) + _compare_tvm_with_tflite(tanh_func, [ifm_shape], accel_type) @pytest.mark.parametrize("accel_type", ACCEL_TYPES) @@ -896,5 +844,17 @@ def concat_func(*inputs): _compare_tvm_with_tflite(concat_func, shapes, accel_type) +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +def test_tflite_sigmoid(accel_type): + ifm_shape = [1, 135, 41, 6] + + @tf.function + def sigmoid_function(x): + op = tf.nn.sigmoid(x) + return op + + _compare_tvm_with_tflite(sigmoid_function, [ifm_shape], accel_type) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 59bcf13849ea..946aa951679b 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -1297,5 +1297,58 @@ def verify(ext_func): ] +def test_tflite_sigmoid_legalize(): + dtype = "int8" + ifm_shape = (1, 237, 91, 7) + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def sigmoid_func(self, x): + op = tf.math.sigmoid(x) + return op + + model = Model() + concrete_func = model.sigmoid_func.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_output_type = tf.int8 + converter.inference_input_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + + mod = ethosu.partition_for_ethosu(mod, params) + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.SigmoidRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + mod = relay.transform.InferType()(mod) + + func_body = mod["tvmgen_default_ethos_u_main_0"].body + assert func_body.op.name == "contrib.ethosu.identity" + assert func_body.attrs.activation == "SIGMOID" + assert tuple(func_body.args[0].checked_type.shape) == (ifm_shape) + assert tuple(func_body.args[1].checked_type.shape) == (256,) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_lookup_table.py b/tests/python/contrib/test_ethosu/test_lookup_table.py index 9485b4f69520..90a51c53fdb3 100644 --- a/tests/python/contrib/test_ethosu/test_lookup_table.py +++ b/tests/python/contrib/test_ethosu/test_lookup_table.py @@ -59,7 +59,7 @@ def tf_func(self, x): op = tf.nn.depthwise_conv2d( op, weight2, strides=(1, 1, 1, 1), padding="VALID", dilations=(2, 2) ) - op = tf.nn.tanh(op) + op = tf.nn.sigmoid(op) op = tf.nn.max_pool(op, (1, 1), strides=(1, 1, 1, 1), padding="SAME") op = tf.nn.tanh(op) return op diff --git a/tests/python/contrib/test_ethosu/test_lut_optimizer.py b/tests/python/contrib/test_ethosu/test_lut_optimizer.py index 8b406d15cfc7..16835ce94ed7 100644 --- a/tests/python/contrib/test_ethosu/test_lut_optimizer.py +++ b/tests/python/contrib/test_ethosu/test_lut_optimizer.py @@ -39,7 +39,7 @@ def before(): conv1 = infra.make_ethosu_conv2d(ifm, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1)) id1 = infra.make_ethosu_identity(conv1, lut=lut1, activation="TANH") conv2 = infra.make_ethosu_conv2d(id1, 4, 7, (2, 2), (1, 1), (1, 1), (1, 1)) - id2 = infra.make_ethosu_identity(conv2, lut=lut2, activation="TANH") + id2 = infra.make_ethosu_identity(conv2, lut=lut2, activation="SIGMOID") func = relay.Function(relay.analysis.free_vars(id2), id2) mod = tvm.IRModule.from_expr(func) @@ -50,7 +50,7 @@ def after(): ifm, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1), lut=lut1, activation="TANH" ) conv2 = infra.make_ethosu_conv2d( - conv1, 4, 7, (2, 2), (1, 1), (1, 1), (1, 1), lut=lut2, activation="TANH" + conv1, 4, 7, (2, 2), (1, 1), (1, 1), (1, 1), lut=lut2, activation="SIGMOID" ) func = relay.Function(relay.analysis.free_vars(conv2), conv2) diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index 2f2cd7a483db..7b09fb255663 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -32,7 +32,7 @@ [(1, 8, 8, 3), 3, 16, (1, 1), (2, 1), (1, 1), (1, 1), "CLIP", "NHWC", "NHWC", "TFL"], [(1, 8, 8, 3), 3, 16, (1, 1), (0, 0), (1, 1), (1, 1), "NONE", "NHWC", "NHWC", "NATURAL"], [(1, 1, 1, 1), 1, 16, (1, 1), (0, 0), (1, 1), (1, 1), "CLIP", "NHWC", "NHWC", "TRUNCATE"], - [(1, 7, 9, 4), 4, 13, (3, 2), (1, 2), (2, 1), (1, 2), "SIGMOID", "NHWC", "NHWC", "TFL"], + [(1, 7, 9, 4), 4, 13, (3, 2), (1, 2), (2, 1), (1, 2), "NONE", "NHWC", "NHWC", "TFL"], [ (1, 8, 2, 8, 16), 18, diff --git a/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py index afd632cf355e..edbfb4939b11 100644 --- a/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py @@ -33,7 +33,7 @@ [(1, 8, 8, 3), 3, (1, 1), (2, 1), (1, 1), (1, 1), "NONE", "NHWC", "NHWC", "NATURAL"], [(1, 8, 8, 3), 3, (1, 1), (0, 0), (1, 1), (1, 1), "NONE", "NHWC", "NHWC", "TRUNCATE"], [(1, 1, 1, 1), 1, (1, 1), (0, 0), (1, 1), (1, 1), "CLIP", "NHWC", "NHWC", "TFL"], - [(1, 7, 9, 4), 4, (3, 2), (1, 2), (2, 1), (1, 2), "SIGMOID", "NHWC", "NHWC", "NATURAL"], + [(1, 7, 9, 4), 4, (3, 2), (1, 2), (2, 1), (1, 2), "NONE", "NHWC", "NHWC", "NATURAL"], [ (1, 8, 2, 8, 16), 18,