From 226cd6900ff36026f1ea63b7d0c5760034d99373 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Tue, 21 Sep 2021 20:06:16 +0000 Subject: [PATCH 1/3] [microNPU] Support binary elementwise with non-4D inputs Reshapes non-4D inputs to become 4D, then reshapes the output back to the non-4D input shape. Change-Id: I680ac06841aa1323435bcd09b6996fc57117cd84 --- .../relay/backend/contrib/ethosu/legalize.py | 67 +++++++++++++++++-- python/tvm/relay/op/contrib/ethosu.py | 18 +++-- .../contrib/test_ethosu/test_codegen.py | 1 + .../contrib/test_ethosu/test_legalize.py | 20 +++++- 4 files changed, 92 insertions(+), 14 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 8095cb184f5b..415ae516204d 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -426,6 +426,58 @@ def __init__( self.params_class = params_class self.pattern = pattern + @staticmethod + def reshape_input(inputs: List[ethosu_patterns.TensorParams]) -> List[tvm.relay.Expr]: + """Reshape the inputs so that the following binary elementwise + operator receives 4-dimensional inputs. + + Parameters + ---------- + inputs: List[ethosu_patterns.TensorParams] + The inputs to reshape. + + Returns + ------- + reshaped_inputs: List[tvm.relay.Expr] + The new reshaped inputs. + """ + reshaped_inputs = [] + for i in inputs: + in_shape = i.shape + if len(in_shape) < 4: + pad_size = 4 - len(in_shape) + new_shape = ([1] * pad_size) + in_shape + new_call = relay.reshape(i.tensor, new_shape) + reshaped_inputs.append(new_call) + else: + reshaped_inputs.append(i.tensor) + return reshaped_inputs + + @staticmethod + def reshape_output(output: tvm.relay.Expr, ifm_input_shape: List[int]) -> tvm.relay.Expr: + """Reshape the output back to the original dimensionality. + Since the NPU must have the brodcastable tensor as the + second operand, the original shape of the first ifm must + be the output shape. + + Parameters + ---------- + output: tvm.relay.Expr + The output to reshape. + + ifm_input_shape: List[int] + The shape of the non-reshaped ifm tensor. + + Returns + ------- + reshaped_output: tvm.relay.Expr + The reshaped output expression. + """ + if len(ifm_input_shape) == 4: + return output + reshaped_output = relay.reshape(output, ifm_input_shape) + return reshaped_output + def callback( self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map ) -> tvm.relay.Expr: @@ -451,9 +503,12 @@ def callback( # We don't yet support activation functions that need to get legalized to LUTs. lut = relay.const([], dtype="int8") - return ethosu_ops.ethosu_binary_elementwise( - ifm=params.ifm.tensor, - ifm2=params.ifm2.tensor, + inputs = [params.ifm, params.ifm2] + inputs = self.reshape_input(inputs) + + ethosu_binary_elementwise = ethosu_ops.ethosu_binary_elementwise( + ifm=inputs[0], + ifm2=inputs[1], lut=lut, operator_type=params.operator_type, ifm_scale=float(params.ifm.q_params.scale_f32), @@ -462,8 +517,8 @@ def callback( ifm2_zero_point=int(params.ifm2.q_params.zero_point), ofm_scale=float(params.ofm.q_params.scale_f32), ofm_zero_point=int(params.ofm.q_params.zero_point), - ifm_channels=params.ifm.shape[3], - ifm2_channels=params.ifm2.shape[3], + ifm_channels=params.ifm.shape[-1], + ifm2_channels=params.ifm2.shape[-1], reversed_operands=params.reversed_operands, ofm_dtype=params.ofm.dtype, activation=activation, @@ -473,6 +528,8 @@ def callback( ifm2_layout=str(params.ifm2.layout), ofm_layout=str(params.ofm.layout), ) + output = self.reshape_output(ethosu_binary_elementwise, params.ifm.shape) + return output class AddRewriter(BinaryElementwiseRewriter): diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 8b4ee21d2892..03d2fee36c1e 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -514,11 +514,12 @@ def __init__(self, func_body: Call, operator_type: str, has_quantization_paramet self.activation = clip self.operator_type = operator_type - def can_broadcast(x, y): - for i in range(1, 4): - if x.shape[i] == y.shape[i] or y.shape[i] == 1: - continue + def can_broadcast(ifm, ifm2): + if len(ifm.shape) < len(ifm2.shape): return False + for m, n in zip(ifm.shape[::-1], ifm2.shape[::-1]): + if m != n and m == 1: + return False return True if can_broadcast(self.ifm, self.ifm2): @@ -537,9 +538,14 @@ def is_valid(self): """ if np.dtype(self.ofm) == np.int32 and self.activation is not None: return False - if len(self.ifm.shape) != 4 or len(self.ifm2.shape) != 4: + # Due to identity operator requiring ifm != int32 for now + if np.dtype(self.ifm) == np.int32 and len(self.ifm.shape) < 4 or len(self.ifm2.shape) < 4: return False - if self.ifm.shape[0] != 1 or self.ifm2.shape[0] != 1: + if len(self.ifm.shape) > 4 or len(self.ifm2.shape) > 4: + return False + if len(self.ifm.shape) == 4 and self.ifm.shape[0] != 1: + return False + if len(self.ifm2.shape) == 4 and self.ifm2.shape[0] != 1: return False if not self.valid_broadcast: return False diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 93af66da8194..6b27d468cb64 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -349,6 +349,7 @@ def representative_dataset(): ([1, 2, 3, 4], [1, 2, 3, 4]), ([1, 2, 3, 4], [1, 1, 1, 1]), ([1, 1, 1, 1], [1, 2, 3, 4]), + ([1, 4, 4], [4, 1]), ], ) @pytest.mark.parametrize("activation_function", ["NONE", "RELU"]) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 8c3e4e31c1ca..7ab3962358fa 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -565,6 +565,9 @@ def verify(ext_func): ([1, 2, 3, 4], [1, 2, 3, 4], False), ([1, 2, 3, 4], [1, 1, 3, 1], False), ([1, 1, 3, 1], [1, 2, 3, 4], True), + ([4], [4], False), + ([4], [1, 2, 3, 4], True), + ([1, 4, 4], [4, 1], False), ], ) @pytest.mark.parametrize("activation_function", ["NONE", "RELU"]) @@ -621,16 +624,27 @@ def verify(ext_func): shapes = [ifm_shape, ifm2_shape] ifm_index, ifm2_index = (1, 0) if reversed_operands else (0, 1) op = ext_func.body - assert list(op.args[0].checked_type.shape) == shapes[ifm_index] - assert list(op.args[1].checked_type.shape) == shapes[ifm2_index] + + has_reshaped_output = False + shapes_padded = [[1] * (4 - len(s)) + s for s in shapes] + out_padded = [1] * (4 - len(out_shape)) + out_shape + if op.op.name != "contrib.ethosu.binary_elementwise": + has_reshaped_output = True + op = op.args[0] + + assert list(op.args[0].checked_type.shape) == shapes_padded[ifm_index] + assert list(op.args[1].checked_type.shape) == shapes_padded[ifm2_index] assert op.args[0].checked_type.dtype == dtype - assert list(op.checked_type.shape) == out_shape + assert list(op.checked_type.shape) == out_padded assert op.checked_type.dtype == dtype assert op.attrs.operator_type == operator_type assert op.attrs.reversed_operands == reversed_operands if activation_function == "RELU": assert str(op.attrs.activation) == "CLIP" + if has_reshaped_output: + assert list(ext_func.body.checked_type.shape) == out_shape + if operator_type == "ADD": rewriter = legalize.AddRewriter() pattern_table = [ From 2f89ed43535649e0eb8f6475b018c1254b765010 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Thu, 18 Nov 2021 09:22:51 +0000 Subject: [PATCH 2/3] fix type hint Change-Id: I5bae1bd11fa0c82e3ffd882fddfcb925dff259cb --- python/tvm/relay/backend/contrib/ethosu/legalize.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 415ae516204d..75484773702c 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -427,13 +427,15 @@ def __init__( self.pattern = pattern @staticmethod - def reshape_input(inputs: List[ethosu_patterns.TensorParams]) -> List[tvm.relay.Expr]: + def reshape_input( + inputs: List["TensorParams"], + ) -> List[tvm.relay.Expr]: """Reshape the inputs so that the following binary elementwise operator receives 4-dimensional inputs. Parameters ---------- - inputs: List[ethosu_patterns.TensorParams] + inputs: List[TensorParams] The inputs to reshape. Returns From 8e4a1522a445dd4043e72954648d724b9c9b58ad Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Sun, 21 Nov 2021 21:33:11 +0000 Subject: [PATCH 3/3] address comments Change-Id: I6167cf73b2722902212717c5243cd19edc3489b7 --- python/tvm/relay/op/contrib/ethosu.py | 4 +- .../contrib/test_ethosu/test_codegen.py | 78 +++++++++++++++++++ .../contrib/test_ethosu/test_legalize.py | 1 + 3 files changed, 81 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 03d2fee36c1e..a2916e46dbb9 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -538,8 +538,8 @@ def is_valid(self): """ if np.dtype(self.ofm) == np.int32 and self.activation is not None: return False - # Due to identity operator requiring ifm != int32 for now - if np.dtype(self.ifm) == np.int32 and len(self.ifm.shape) < 4 or len(self.ifm2.shape) < 4: + # Due to identity operator requiring ofm != int32 for now + if np.dtype(self.ofm) == np.int32 and len(self.ofm.shape) < 4: return False if len(self.ifm.shape) > 4 or len(self.ifm2.shape) > 4: return False diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 6b27d468cb64..81bcbe6b7c5c 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -436,6 +436,84 @@ def representative_dataset(): infra.verify_source(compiled_models, accel_type) +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize( + "ifm_shape, ifm2_shape", + [ + ([4], [4]), + ([4], [1, 2, 3, 4]), + ([1, 4, 4], [4, 1]), + ], +) +def test_binary_add_with_non_4d_shapes( + accel_type, + ifm_shape, + ifm2_shape, +): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, lhs, rhs): + return tf.math.add(lhs, rhs) + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32), tf.TensorSpec(ifm2_shape, dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + data2 = np.random.rand(*tuple(ifm2_shape)) * 2 + yield [data.astype(np.float32), data2.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"ifm": ifm_shape, "ifm2": ifm2_shape}, + dtype_dict={"ifm": dtype, "ifm2": dtype}, + ) + mod = partition_for_ethosu(mod, params) + + # Generate reference data + input_data, output_data = infra.generate_ref_data_tflite(tflite_graph) + + compiled_models = infra.build_source( + mod, + input_data, + output_data, + accel_type, + output_tolerance=0, + ) + + # Assumes only two runtime.Modules are created -- i.e. single offload module + imported_modules = compiled_models[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + + infra.print_payload(cmms) + infra.verify_source(compiled_models, accel_type) + + @pytest.mark.parametrize("accel_type", ACCEL_TYPES) def test_binary_add_from_constant_scalar(accel_type): dtype = "uint8" diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 7ab3962358fa..8612b90adbe3 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -565,6 +565,7 @@ def verify(ext_func): ([1, 2, 3, 4], [1, 2, 3, 4], False), ([1, 2, 3, 4], [1, 1, 3, 1], False), ([1, 1, 3, 1], [1, 2, 3, 4], True), + ([1, 4, 4], [4, 1], False), ([4], [4], False), ([4], [1, 2, 3, 4], True), ([1, 4, 4], [4, 1], False),