diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 6be03a6883fa..64c6cefb8b58 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -1445,6 +1445,70 @@ def __call__(self, *args, **kwargs): pass +class ExpandDimsRewriter(DFPatternCallback): + """Legalize expand dims to a reshape operator.""" + + def __init__(self): + super().__init__(require_type=True, rewrite_once=True) + self.pattern = ( + wildcard().has_attr({"Composite": ethosu_patterns.ExpandDimsParams.composite_name}) + )(None) + + def callback( + self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map + ) -> tvm.relay.Expr: + params = ethosu_patterns.ExpandDimsParams(post.op.body) + return relay.op.reshape(post.args[0], newshape=params.output.shape) + + +@ir.transform.module_pass(opt_level=1) +class LegalizeExpandDims: + """This is the pass that wraps ExpandDimsRewriter.""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(ExpandDimsRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + +class SqueezeRewriter(DFPatternCallback): + """Legalize squeeze to a reshape operator.""" + + def __init__(self): + super().__init__(require_type=True, rewrite_once=True) + self.pattern = ( + wildcard().has_attr({"Composite": ethosu_patterns.SqueezeParams.composite_name}) + )(None) + + def callback( + self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map + ) -> tvm.relay.Expr: + params = ethosu_patterns.SqueezeParams(post.op.body) + return relay.op.reshape(post.args[0], newshape=params.output.shape) + + +@ir.transform.module_pass(opt_level=1) +class LegalizeSqueeze: + """This is the pass that wraps SqueezeRewriter.""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(SqueezeRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + @ir.transform.module_pass(opt_level=1) class LegalizeEthosU: """This is the pass to call graph-rewrites to perform graph transformation @@ -1477,6 +1541,8 @@ def transform_module( mod = LegalizeSigmoid()(mod) mod = LegalizeRequantize()(mod) mod = LegalizeResize2d()(mod) + mod = LegalizeExpandDims()(mod) + mod = LegalizeSqueeze()(mod) mod = LegalizeReshape()(mod) mod = LegalizeStridedSlice()(mod) mod = LegalizeNoOps()(mod) diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 6df4611acffa..fa11fca88390 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -1214,19 +1214,22 @@ class ConcatParams: def __init__(self, func_body): self.concat = func_body + self.is_qnn_variant = self.concat.op.name == "qnn.concatenate" self.input_tensors = [TensorParams(tensor) for tensor in list(func_body.args[0])] - self.input_scales = [s.data.asnumpy() for s in list(func_body.args[1])] - self.input_zero_points = [zp.data.asnumpy() for zp in list(func_body.args[2])] self.axis = func_body.attrs.axis + if self.is_qnn_variant: + self.input_scales = [s.data.asnumpy() for s in list(func_body.args[1])] + self.input_zero_points = [zp.data.asnumpy() for zp in list(func_body.args[2])] + def is_valid(self): """Checks whether Concatenate has compatible attributes with the hardware""" if not check_valid_dtypes(self.input_tensors, supported_dtypes=[np.int8]): return False # Check that the scales and zero points of input tensors are the same - if not all(self.input_scales == self.input_scales[0]): + if self.is_qnn_variant and not all(self.input_scales == self.input_scales[0]): return False - if not all(self.input_zero_points == self.input_zero_points[0]): + if self.is_qnn_variant and not all(self.input_zero_points == self.input_zero_points[0]): return False input_dim = len(self.input_tensors[0].shape) @@ -1244,6 +1247,8 @@ def is_valid(self): output_shape = self.concat.checked_type.shape if len(output_shape) != input_dim: return False + if len(output_shape) > 3 and output_shape[0] != 1: + return False return True @@ -1252,8 +1257,11 @@ def concat_pattern(): tensors = is_tuple(None) scales = is_tuple(None) zero_points = is_tuple(None) - concat = is_op("qnn.concatenate")(tensors, scales, zero_points, is_constant(), is_constant()) - return concat + qnn_concat = is_op("qnn.concatenate")( + tensors, scales, zero_points, is_constant(), is_constant() + ) + concat = is_op("concatenate")(tensors) + return concat | qnn_concat class SplitParams: @@ -1433,6 +1441,60 @@ def resize2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern: return quant | is_op("image.resize2d")(wildcard()).has_attr({"method": "nearest_neighbor"}) +class ExpandDimsParams: + """ + This class will parse a call to a ethos-u.expand_dims composite function + and extract the parameter information. + """ + + composite_name = "ethos-u.expand_dims" + + def __init__(self, func_body): + self.expand_dims = func_body + self.input = TensorParams(func_body.args[0]) + self.output = TensorParams(func_body) + + def is_valid(self): + """Checks whether expand_dims has compatible attributes with the hardware.""" + if not check_dimensions(self.input) or not check_dimensions(self.output): + return False + if not check_valid_dtypes([self.input, self.output], supported_dtypes=[np.int8]): + return False + return True + + +def expand_dims_pattern(): + """Create the pattern for expand_dims.""" + return is_op("expand_dims")(wildcard()) + + +class SqueezeParams: + """ + This class will parse a call to a ethos-u.squeeze composite function + and extract the parameter information. + """ + + composite_name = "ethos-u.squeeze" + + def __init__(self, func_body): + self.squeeze = func_body + self.input = TensorParams(func_body.args[0]) + self.output = TensorParams(func_body) + + def is_valid(self): + """Checks whether squeeze has compatible attributes with the hardware.""" + if not check_dimensions(self.output): + return False + if not check_valid_dtypes([self.input, self.output], supported_dtypes=[np.int8]): + return False + return True + + +def squeeze_pattern(): + """Create the pattern for squeeze.""" + return is_op("squeeze")(wildcard()) + + @register_pattern_table("ethos-u") def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]: return [ @@ -1533,6 +1595,16 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal resize2d_pattern(), lambda pat: Resize2dParams(pat).is_valid(), ), + ( + ExpandDimsParams.composite_name, + expand_dims_pattern(), + lambda pat: ExpandDimsParams(pat).is_valid(), + ), + ( + SqueezeParams.composite_name, + squeeze_pattern(), + lambda pat: SqueezeParams(pat).is_valid(), + ), ] diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index ebcd1a0ba1fc..27fdc17d4ea2 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -1020,6 +1020,28 @@ def create_model(): _compare_ethosu_with_reference(ethosu_mod, input_data, output_data, accel_type) +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize("ifm_shape,axis", [((2,), 0), ((1, 3, 3), 2)]) +def test_tflite_expand_dims(accel_type, ifm_shape, axis): + @tf.function + def expand_dims_func(x): + return tf.expand_dims(x, axis=axis) + + _compare_tvm_with_tflite(expand_dims_func, [ifm_shape], accel_type) + + +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize( + "ifm_shape,axis", [((1, 1, 2, 1), 0), ((1, 3, 3, 1), 3), ((1, 1, 2, 1), None)] +) +def test_tflite_squeeze(accel_type, ifm_shape, axis): + @tf.function + def squeeze_func(x): + return tf.squeeze(x, axis=axis) + + _compare_tvm_with_tflite(squeeze_func, [ifm_shape], accel_type) + + @pytest.mark.parametrize("accel_type", ACCEL_TYPES) @pytest.mark.parametrize( "ifm_shape,size", @@ -1100,5 +1122,39 @@ def conv2d_transpose(x): _compare_tvm_with_tflite(conv2d_transpose, [ifm_shape], accel_type=accel_type) +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize( + "ifm_shapes,axis", + [ + ([(1, 2, 2), (1, 2, 2), (1, 2, 2)], 2), + ([(5, 4), (5, 4)], 1), + ([(1,), (1,)], 0), + ([(3, 1), (3, 1), (3, 1), (3, 1)], 0), + ], +) +def test_tflite_pack(accel_type, ifm_shapes, axis): + @tf.function + def pack_func(*inputs): + return tf.stack(inputs, axis=axis) + + # TODO(lhutton1) For now output is not bit exact with TFLite. + # This is because TFLite reference kernels are not being used. + # For this, TFLite will need upgrading to 2.6. + _compare_tvm_with_tflite(pack_func, ifm_shapes, accel_type, output_tolerance=1) + + +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize( + "ifm_shape,axis", + [[(1, 2, 3, 4), 1], [(2, 3), 1], [(5, 6, 7), 2]], +) +def test_tflite_unpack(accel_type, ifm_shape, axis): + @tf.function + def unpack_func(x): + return tf.unstack(x, axis=axis) + + _compare_tvm_with_tflite(unpack_func, [ifm_shape], accel_type) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index ab304a7b0c2b..8af342a00dfd 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -1604,6 +1604,147 @@ def verify(ext_func): verify(mod["tvmgen_default_ethos_u_main_0"]) +@pytest.mark.parametrize("ifm_shape,axis", [((2,), 0), ((1, 3, 3), 2)]) +def test_tflite_expand_dims(ifm_shape, axis): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x): + return tf.expand_dims(x, axis=axis) + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, tf.float32) + ) + + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + + return tflite_model + + def verify(ext_func): + op = ext_func.body + expected_shape = list(ifm_shape) + expected_shape.insert(axis, 1) + + # Check IFM + assert list(op.args[0].checked_type.shape) == list(ifm_shape) + assert op.args[0].checked_type.dtype == dtype + + # Check OFM + assert list(op.checked_type.shape) == expected_shape + assert op.checked_type.dtype == dtype + + # Check op + assert op.op.name == "reshape" + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, _ = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + mod = ethosu.partition_for_ethosu(mod) + + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.ExpandDimsRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.ReshapeRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + mod["tvmgen_default_ethos_u_main_0"] = relay.transform.InferType()(mod)[ + "tvmgen_default_ethos_u_main_0" + ] + verify(mod["tvmgen_default_ethos_u_main_0"]) + + +@pytest.mark.parametrize( + "ifm_shape,axis", [((1, 1, 2, 1), 0), ((1, 3, 3, 1), 3), ((1, 1, 2, 1), None)] +) +def test_tflite_squeeze(ifm_shape, axis): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x): + return tf.squeeze(x, axis=axis) + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, tf.float32) + ) + + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + + return tflite_model + + def verify(ext_func): + op = ext_func.body + expected_shape = list(ifm_shape) + if isinstance(axis, int): + expected_shape = ifm_shape[:axis] + ifm_shape[axis + 1 :] + else: + expected_shape = list(filter(lambda a: a != 1, expected_shape)) + + # Check IFM + assert list(op.args[0].checked_type.shape) == list(ifm_shape) + assert op.args[0].checked_type.dtype == dtype + + # Check OFM + assert list(op.checked_type.shape) == list(expected_shape) + assert op.checked_type.dtype == dtype + + # Check op + assert op.op.name == "reshape" + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, _ = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + mod = ethosu.partition_for_ethosu(mod) + + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.SqueezeRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.ReshapeRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + mod["tvmgen_default_ethos_u_main_0"] = relay.transform.InferType()(mod)[ + "tvmgen_default_ethos_u_main_0" + ] + verify(mod["tvmgen_default_ethos_u_main_0"]) + + @pytest.mark.parametrize( "ifm_shape,size", [ @@ -1903,5 +2044,221 @@ def verify(ext_func): verify(mod["tvmgen_default_ethos_u_main_0"]) +@pytest.mark.parametrize( + "ifm_shapes,axis", + [ + ([(1, 2, 2), (1, 2, 2), (1, 2, 2)], 2), + ([(5, 4), (5, 4)], 1), + ([(1,), (1,)], 0), + ([(3, 1), (3, 1), (3, 1), (3, 1)], 0), + ], +) +def test_tflite_pack(ifm_shapes, axis): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, inputs, axis): + return tf.stack(inputs, axis=axis) + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + [tf.TensorSpec(shape, tf.float32) for shape in ifm_shapes], axis + ) + + def representative_dataset(): + for _ in range(100): + datas = [np.random.rand(*shape) for shape in ifm_shapes] + yield [data.astype(np.float32) for data in datas] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + + return tflite_model + + def verify(ext_func): + new_pack_axis = len(ifm_shapes) + ifm_shape = list(ifm_shapes[0]) + op = ext_func.body + + after_reshape = ifm_shape[:axis] + [1] + ifm_shape[axis:] + out_shape = ifm_shape[:axis] + [new_pack_axis] + ifm_shape[axis:] + + assert op.op.name == "concatenate" + + # Check shapes after expand_dims (legalized as reshape) + for i in range(len(ifm_shapes)): + assert list(op.args[0][i].checked_type.shape) == after_reshape + assert op.args[0][i].checked_type.dtype == dtype + + # Check output + assert list(op.checked_type.shape) == out_shape + assert op.checked_type.dtype == dtype + + pack_pattern_table = [ + ( + ethosu.ConcatParams.composite_name, + ethosu.concat_pattern(), + lambda pat: ethosu.ConcatParams(pat).is_valid(), + ), + ( + ethosu.ExpandDimsParams.composite_name, + ethosu.expand_dims_pattern(), + lambda pat: ethosu.ExpandDimsParams(pat).is_valid(), + ), + ( + ethosu.ReshapeParams.composite_name, + ethosu.reshape_pattern(), + lambda pat: ethosu.ReshapeParams(pat).is_valid(), + ), + ] + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + relay_module, _ = relay.frontend.from_tflite( + tflite_model, + shape_dict={("ifm" + str(i)): shape for i, shape in enumerate(ifm_shapes)}, + dtype_dict={("ifm" + str(i)): dtype for i, _ in enumerate(ifm_shapes)}, + ) + mod = partition_ethosu_by_table(relay_module, pack_pattern_table) + + seq = [ + legalize.ConcatRewriter(), + legalize.ExpandDimsRewriter(), + legalize.ReshapeRewriter(), + legalize.NoOpRewriter(), + ] + for legalizer in seq: + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalizer, mod["tvmgen_default_ethos_u_main_0"] + ) + mod["tvmgen_default_ethos_u_main_0"] = relay.transform.InferType()(mod)[ + "tvmgen_default_ethos_u_main_0" + ] + verify(mod["tvmgen_default_ethos_u_main_0"]) + + +@pytest.mark.parametrize( + "ifm_shape,axis", + [[(1, 2, 3, 4), 1], [(2, 3), 1], [(5, 6, 7), 2]], +) +def test_tflite_unpack(ifm_shape, axis): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x, axis): + return tf.unstack(x, axis=axis) + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, tf.float32), axis + ) + + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + + return tflite_model + + def verify(ext_func): + outputs = ext_func.body.args[0].fields + shape = list(ifm_shape) + unpacked_shape = shape[:axis] + shape[axis + 1 :] + split_shape = shape[:axis] + [1] + shape[axis + 1 :] + + assert len(outputs) == shape[axis] + + for i, output in enumerate(outputs): + expr = output.args[0].args[0] + expr = expr.tuple_value[expr.index] + expr = expr.args[0] + + # Checking expected unpacked output shape. + # Squeeze is legalized to a reshape. + assert expr.op.name == "reshape" + assert list(expr.checked_type.shape) == unpacked_shape + assert output.checked_type.dtype == dtype + + expr = expr.args[0] + expr = expr.tuple_value[expr.index] + expr = expr.args[0] + + # Check input is split correctly + assert list(expr.args[0].checked_type.shape) == shape + assert list(expr.checked_type.shape) == split_shape + assert expr.checked_type.dtype == dtype + + # Check split attrs + begin_shape = [0] * len(ifm_shape) + begin_shape[axis] = i + assert list(expr.attrs.begin) == begin_shape + end_shape = shape[:axis] + [i + 1] + shape[axis + 1 :] + assert list(expr.attrs.end) == end_shape + assert list(expr.attrs.strides) == [1] + + pack_pattern_table = [ + ( + ethosu.SplitParams.composite_name, + ethosu.split_pattern(), + lambda pat: ethosu.SplitParams(pat).is_valid(), + ), + ( + ethosu.SqueezeParams.composite_name, + ethosu.squeeze_pattern(), + lambda pat: ethosu.SqueezeParams(pat).is_valid(), + ), + ( + ethosu.ReshapeParams.composite_name, + ethosu.reshape_pattern(), + lambda pat: ethosu.ReshapeParams(pat).is_valid(), + ), + ] + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, _ = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + mod = partition_ethosu_by_table(mod, pack_pattern_table) + + seq = [ + legalize.PartitionedSplitRewriter(), + legalize.SplitRewriter(), + legalize.SqueezeRewriter(), + legalize.ReshapeRewriter(), + legalize.NoOpRewriter(), + ] + for legalizer in seq: + + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalizer, mod["tvmgen_default_ethos_u_main_0"] + ) + mod["tvmgen_default_ethos_u_main_0"] = relay.transform.InferType()(mod)[ + "tvmgen_default_ethos_u_main_0" + ] + verify(mod["tvmgen_default_ethos_u_main_0"]) + + if __name__ == "__main__": pytest.main([__file__])