diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index bd9a7d5ba0d1..5d1e75b03043 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -700,15 +700,13 @@ def __init__(self, func_body: Call, operator_type: str, is_quantized_operation: clip = None requantize = None - if is_quantized_operation: - if str(current_call.op.name) == "clip": - clip = current_call - current_call = clip.args[0] - else: - if str(current_call.op.name) == "qnn.requantize": - requantize = current_call - clip = current_call.args[0] - current_call = clip.args[0] + if str(current_call.op.name) == "clip": + clip = current_call + current_call = clip.args[0] + elif str(current_call.op.name) == "qnn.requantize": + requantize = current_call + clip = current_call.args[0] + current_call = clip.args[0] binary_op = current_call layout = "NHWC" @@ -941,21 +939,40 @@ def is_valid(self): [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8] ): return False + # MIN with different scales is not supported on NPU + # (please look at NPU_SET_OFM_SCALE register description + # https://developer.arm.com/documentation/102420/0200/Programmers-model/Command-stream/cmd1-commands-). + if self.ifm.q_params.scale_f32 != self.ofm.q_params.scale_f32: + return False return True +# This pattern is for case when there are different scales for requantize and +# minimum + clip + qnn.requantize can't be offloaded to NPU by one operation +# due to hardware constraints. +# It's offloaded by two operations ethosu_binary_elementwise + ethosu_identity. def minimum_pattern() -> tvm.relay.dataflow_pattern.DFPattern: """ - This function creates the pattern for minimum with optional fused RELU activation. + This function creates the pattern for minimum with optional fused RELU activation without + requantize. """ minimum = is_op("minimum")(wildcard(), wildcard()) optional_min_clip = is_op("clip")(minimum) - optional_min_clip = is_op("qnn.requantize")( - optional_min_clip, is_constant(), is_constant(), is_constant(), is_constant() - ) return minimum | optional_min_clip +def minimum_clip_requantize_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """ + This function creates the pattern for minimum with fused RELU activation with requantize. + """ + pattern = is_op("minimum")(wildcard(), wildcard()) + pattern = is_op("clip")(pattern) + pattern = is_op("qnn.requantize")( + pattern, is_constant(), is_constant(), is_constant(), is_constant() + ) + return pattern + + class MaxParams(BinaryElementwiseParams): """ This class will parse a call to a ethosu.binary_elementwise Max composite function @@ -979,21 +996,40 @@ def is_valid(self): [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8] ): return False + # MAX with different scales is not supported on NPU + # (please look at NPU_SET_OFM_SCALE register description + # https://developer.arm.com/documentation/102420/0200/Programmers-model/Command-stream/cmd1-commands-). + if self.ifm.q_params.scale_f32 != self.ofm.q_params.scale_f32: + return False return True +# This pattern is for case when there are different scales for requantize and +# maximum + clip + qnn.requantize can't be offloaded to NPU by one operation due to +# hardware constraints. +# It's offloaded by two operations ethosu_binary_elementwise + ethosu_identity. def maximum_pattern() -> tvm.relay.dataflow_pattern.DFPattern: """ - This function creates the pattern for maximum with optional fused RELU activation. + This function creates the pattern for maximum with optional fused RELU activation without + requantize. """ maximum = is_op("maximum")(wildcard(), wildcard()) optional_max_clip = is_op("clip")(maximum) - optional_max_clip = is_op("qnn.requantize")( - optional_max_clip, is_constant(), is_constant(), is_constant(), is_constant() - ) return maximum | optional_max_clip +def maximum_clip_requantize_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """ + This function creates the pattern for maximum with fused RELU activation with requantize. + """ + pattern = is_op("maximum")(wildcard(), wildcard()) + pattern = is_op("clip")(pattern) + pattern = is_op("qnn.requantize")( + pattern, is_constant(), is_constant(), is_constant(), is_constant() + ) + return pattern + + class ShlParams(BinaryElementwiseParams): """ This class will parse a call to a ethosu.binary_elementwise Shl composite function @@ -1913,11 +1949,21 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal qnn_mul_pattern(), lambda pat: MulParams(pat).is_valid(), ), + ( + MinParams.composite_name, + minimum_clip_requantize_pattern(), + lambda pat: MinParams(pat).is_valid(), + ), ( MinParams.composite_name, minimum_pattern(), lambda pat: MinParams(pat).is_valid(), ), + ( + MaxParams.composite_name, + maximum_clip_requantize_pattern(), + lambda pat: MaxParams(pat).is_valid(), + ), ( MaxParams.composite_name, maximum_pattern(), diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index dc54ef071d19..05ba7467b309 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -1191,6 +1191,29 @@ def conv2d_relu6(x): ) +# Specific case when operation cannot be offloaded to NPU by single binary elementwise operation because +# min and max operations cannot be fused with requantize if there are different scales as it's not supported on NPU. +@pytest.mark.parametrize("operation", [tf.math.minimum, tf.math.maximum]) +def test_tflite_min_max_relu_n1_to_1(operation): + np.random.seed(0) + accel_type = "ethos-u55-128" + ifm_shape = (1, 12, 16, 8) + + @tf.function + def min_max_relu_n1_to_1(lhs, rhs): + op = operation(lhs, rhs) + # The specific pattern will be replaced into RELU_N1_TO_1 by tflite. + return tf.math.maximum(-1.0, tf.math.minimum(op, 1.0)) + + infra.compare_tvm_with_tflite( + min_max_relu_n1_to_1, + [ifm_shape, ifm_shape], + accel_type, + enable_cascader=True, + ranges=[(-1, 1), (0, 2)], + ) + + @pytest.mark.parametrize("accel_type", ACCEL_TYPES) @pytest.mark.parametrize("ifm_shape", [(1, 14), (1, 151)]) @pytest.mark.parametrize("ofm_channels", [32, 64]) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 5ddc7565f20c..5bc31dacb59d 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -53,6 +53,13 @@ def partition_ethosu_by_table(mod, pattern_table): return mod +def relu_n1_to_1(x): + """ + The specific pattern will be replaced into RELU_N1_TO_1 by tflite. + """ + return tf.math.maximum(-1.0, tf.math.minimum(x, 1.0)) + + def test_split_indices_legalize(): def create_graph(axis): x = relay.var("x", shape=(1, 50, 50, 3)) @@ -881,7 +888,7 @@ def verify(ext_func): ([1, 4, 4], [4, 1], False), ], ) -@pytest.mark.parametrize("activation_function", ["NONE", "RELU"]) +@pytest.mark.parametrize("activation_function", [None, tf.nn.relu]) def test_tflite_binary_elemwise_legalize( operator_type, ifm_shape, @@ -906,8 +913,8 @@ def tf_function(self, x, y): op = tf.math.minimum(x, y) elif operator_type == "MAX": op = tf.math.maximum(x, y) - if activation_function == "RELU": - op = tf.nn.relu(op) + if activation_function: + op = activation_function(op) return op model = Model() @@ -938,9 +945,13 @@ def verify(ext_func): op = ext_func.body has_reshaped_output = False + has_separate_requantize = False shapes_padded = [[1] * (4 - len(s)) + s for s in shapes] out_padded = [1] * (4 - len(out_shape)) + out_shape - if op.op.name != "contrib.ethosu.binary_elementwise": + if op.op.name == "contrib.ethosu.identity": + op = op.args[0] + has_separate_requantize = True + if op.op.name == "reshape": has_reshaped_output = True op = op.args[0] @@ -951,20 +962,30 @@ def verify(ext_func): assert op.checked_type.dtype == dtype assert op.attrs.operator_type == operator_type assert op.attrs.reversed_operands == reversed_operands - if activation_function == "RELU": + if activation_function != None: assert str(op.attrs.activation) == "CLIP" if operator_type in ["MIN", "MAX"]: - # MIN and MAX with an activation must have a requantize operation - # baked into the output. To check the extra requantize node was - # picked up by the pattern, we can make sure the quantization - # information is not default. - assert float(op.attrs.ifm_scale) != 1.0 - assert int(op.attrs.ifm_zero_point) != 0 - assert float(op.attrs.ifm2_scale) != 1.0 - assert int(op.attrs.ifm2_zero_point) != 0 - assert float(op.attrs.ofm_scale) != 1.0 - assert int(op.attrs.ofm_zero_point) != 0 + if has_separate_requantize: + # In case when requantize cannot be fused with MIN/MAX + CLIP due to hardware constraints + # there should be default quantization values since requantize is separate operation. + assert float(op.attrs.ifm_scale) == 1.0 + assert int(op.attrs.ifm_zero_point) == 0 + assert float(op.attrs.ifm2_scale) == 1.0 + assert int(op.attrs.ifm2_zero_point) == 0 + assert float(op.attrs.ofm_scale) == 1.0 + assert int(op.attrs.ofm_zero_point) == 0 + else: + # MIN and MAX with an activation must have a requantize operation + # baked into the output. To check the extra requantize node was + # picked up by the pattern, we can make sure the quantization + # information is not default. + assert float(op.attrs.ifm_scale) != 1.0 + assert int(op.attrs.ifm_zero_point) != 0 + assert float(op.attrs.ifm2_scale) != 1.0 + assert int(op.attrs.ifm2_zero_point) != 0 + assert float(op.attrs.ofm_scale) != 1.0 + assert int(op.attrs.ofm_zero_point) != 0 if has_reshaped_output: assert list(ext_func.body.checked_type.shape) == out_shape @@ -997,22 +1018,42 @@ def verify(ext_func): ), ] elif operator_type == "MIN": - rewriter = legalize.MinRewriter() + rewriter = [legalize.MinRewriter(), legalize.RequantizeRewriter()] pattern_table = [ + ( + ethosu.MinParams.composite_name, + ethosu.minimum_clip_requantize_pattern(), + lambda pat: ethosu.MinParams(pat).is_valid(), + ), ( ethosu.MinParams.composite_name, ethosu.minimum_pattern(), lambda pat: ethosu.MinParams(pat).is_valid(), ), + ( + ethosu.RequantizeParams.composite_name, + ethosu.requantize_pattern(), + lambda pat: ethosu.RequantizeParams(pat).is_valid(), + ), ] elif operator_type == "MAX": - rewriter = legalize.MaxRewriter() + rewriter = [legalize.MaxRewriter(), legalize.RequantizeRewriter()] pattern_table = [ + ( + ethosu.MaxParams.composite_name, + ethosu.maximum_clip_requantize_pattern(), + lambda pat: ethosu.MaxParams(pat).is_valid(), + ), ( ethosu.MaxParams.composite_name, ethosu.maximum_pattern(), lambda pat: ethosu.MaxParams(pat).is_valid(), ), + ( + ethosu.RequantizeParams.composite_name, + ethosu.requantize_pattern(), + lambda pat: ethosu.RequantizeParams(pat).is_valid(), + ), ] tflite_graph = create_tflite_graph() @@ -1031,6 +1072,12 @@ def verify(ext_func): verify(mod["tvmgen_default_ethos_u_main_0"]) +# This test is for checking the case when requantize cannot be fused with MIN/MAX + CLIP due to hardware constraints. +def test_tflite_max_relu_n1_to_1_legalize(): + ifm_shape = [1, 4, 8, 16] + test_tflite_binary_elemwise_legalize("MAX", ifm_shape, ifm_shape, False, relu_n1_to_1) + + def test_binary_add_from_constant_scalar(): dtype = "uint8" ifm_shape = (1, 4, 4, 8)