From 1abac94dbfa69bcc8c3e7e6d35e19f25edaa1c0d Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Sat, 27 Nov 2021 18:40:36 +0000 Subject: [PATCH 1/2] [microNPU] Add support for SPLIT and SPLIT_V Both, SPLIT and SPLIT_V get lowered to relay.split and in the legalization the Relay split gets turned into strided slices. This patch adds the pattern and legalizer to enable offloading the TFLite's splits to the NPU. --- .../relay/backend/contrib/ethosu/legalize.py | 20 +++ python/tvm/relay/op/contrib/ethosu.py | 43 +++++ .../contrib/test_ethosu/test_codegen.py | 78 +++++++++ .../contrib/test_ethosu/test_legalize.py | 158 ++++++++++++++++++ 4 files changed, 299 insertions(+) diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index b2264f32611e..60dc14426753 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -109,6 +109,25 @@ def callback( return relay.Tuple(strided_slices) +class PartitionedSplitRewriter(DFPatternCallback): + """This pass brings the split out of the partitioned function""" + + def __init__(self): + super().__init__(require_type=True, rewrite_once=True) + self.pattern = ( + wildcard().has_attr({"Composite": ethosu_patterns.SplitParams.composite_name}) + )(wildcard()) + + def callback( + self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map + ) -> tvm.relay.Expr: + split_input = post.args[0] + split_params = ethosu_patterns.SplitParams(post.op.body) + indices_or_sections = split_params.indices_or_sections + axis = split_params.axis + return relay.op.split(split_input, indices_or_sections, axis=axis).astuple() + + @ir.transform.module_pass(opt_level=1) class LegalizeSplit: """This is the pass that wraps SplitRewriter""" @@ -117,6 +136,7 @@ def transform_module( self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext ) -> tvm.ir.IRModule: for global_var, func in mod.functions.items(): + func = rewrite(PartitionedSplitRewriter(), func) func = rewrite(SplitRewriter(), func) mod.update_func(global_var, func) return mod diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index a7d3da3200b5..73007cffe726 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -1107,6 +1107,44 @@ def concat_pattern(): return concat +class SplitParams: + """ + This class will parse a call to a ethos-u.split composite function + and extract the parameter information. + """ + + composite_name = "ethos-u.split" + + def __init__(self, func_body): + self.split = func_body + self.input = TensorParams(func_body.args[0]) + self.axis = func_body.attrs.axis + self.indices_or_sections = self.convert_indices_or_sections( + func_body.attrs.indices_or_sections + ) + + def convert_indices_or_sections(self, indices_or_sections): + # split_v + if isinstance(indices_or_sections, tvm.ir.container.Array): + values = [i.value for i in indices_or_sections] + # split + else: + values = indices_or_sections.value + return values + + def is_valid(self): + """Checks whether split has compatible attributes with the hardware""" + if not check_valid_dtypes([self.input], supported_dtypes=[np.int8]): + return False + return True + + +def split_pattern(): + "Create the pattern for split" + split = is_op("split")(wildcard()) + return split + + @register_pattern_table("ethos-u") def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]: return [ @@ -1187,6 +1225,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal sigmoid_pattern(), lambda pat: SigmoidParams(pat).is_valid(), ), + ( + SplitParams.composite_name, + split_pattern(), + lambda pat: SplitParams(pat).is_valid(), + ), ] diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 21e86c866512..a33b818a6580 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -856,5 +856,83 @@ def sigmoid_function(x): _compare_tvm_with_tflite(sigmoid_function, [ifm_shape], accel_type) +# This codegen test checks both, split and split_v +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize( + "ifm_shape, num_or_size_splits, axis", + [ + ((1, 4, 6, 8), (1, 3, 4), 3), + ((4, 6, 8), 2, 0), + ((50,), 25, 0), + ((5, 11), 1, 1), + ((13,), (13,), 0), + ((22, 7), (4, -1), 1), + ], +) +def test_tflite_split(accel_type, ifm_shape, num_or_size_splits, axis): + dtype = "int8" + + def get_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x, num_or_size_splits, axis): + op = tf.split(x, num_or_size_splits, axis=axis) + return op + + model = Model() + + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32), num_or_size_splits, axis + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + tflite_graph = get_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + relay_module, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + mod = partition_for_ethosu(relay_module, params) + + # Generate reference data + input_data, output_data = infra.generate_ref_data_tflite(tflite_graph) + + compiled_models = infra.build_source( + mod, + input_data, + output_data, + accel_type, + ) + + # Assumes only two runtime.Modules are created -- i.e. single offload module + imported_modules = compiled_models[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + + infra.print_payload(cmms) + infra.verify_source(compiled_models, accel_type) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 946aa951679b..7edbdc2c360a 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -1350,5 +1350,163 @@ def representative_dataset(): assert tuple(func_body.args[1].checked_type.shape) == (256,) +@pytest.mark.parametrize( + "ifm_shape, num_or_size_splits, axis", + [ + ((1, 4, 6, 8), 3, 2), + ((4, 6, 8), 2, 0), + ((5, 15), 3, 1), + ((3, 7), 1, 1), + ((100,), 25, 0), + ], +) +def test_tflite_split_legalize(ifm_shape, num_or_size_splits, axis): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x, num_or_size_splits, axis): + op = tf.split(x, num_or_size_splits, axis=axis) + return op + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, tf.float32), num_or_size_splits, axis + ) + + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + + return tflite_model + + def verify(ext_func): + # dig out the split + single_output_split = num_or_size_splits == 1 + split = ( + ext_func.body.tuple_value + if single_output_split + else ext_func.body.args[0][0].args[0].tuple_value + ) + assert split.op.name == "split" + + # Split is specified by number of equal chunks + assert split.attrs.indices_or_sections == num_or_size_splits + + assert split.attrs.axis == axis + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, _ = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + mod = ethosu.partition_for_ethosu(mod) + + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.PartitionedSplitRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + + mod["tvmgen_default_ethos_u_main_0"] = relay.transform.InferType()(mod)[ + "tvmgen_default_ethos_u_main_0" + ] + + verify(mod["tvmgen_default_ethos_u_main_0"]) + + +@pytest.mark.parametrize( + "ifm_shape, num_or_size_splits, axis", + [ + ((1, 4, 6, 8), (1, 3, 4), 3), + ((10, 18, 4), (1, 4, 3, 2), 0), + ((22, 7), (4, -1), 1), + ((25,), (25,), 0), + ], +) +def test_tflite_split_v_legalize(ifm_shape, num_or_size_splits, axis): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x, num_or_size_splits, axis): + # TF split gets converted into TFLite's split_v + op = tf.split(x, num_or_size_splits, axis=axis) + return op + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, tf.float32), num_or_size_splits, axis + ) + + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + + return tflite_model + + def verify(ext_func): + # dig out the split + single_output_split = len(num_or_size_splits) == 1 + split = ( + ext_func.body.tuple_value + if single_output_split + else ext_func.body.args[0][0].args[0].tuple_value + ) + assert split.op.name == "split" + + # Split is specified by the size of sections, so converting num_or_size_splits + # into the indices where the tensor is split at since this is how split is represented + # in Relay + split_sections = [] if single_output_split else [num_or_size_splits[0]] + for split_size in num_or_size_splits[1:-1]: + sec = split_sections[-1] + split_size + split_sections.append(sec) + assert list(split.attrs.indices_or_sections) == split_sections + + assert split.attrs.axis == axis + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, _ = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + mod = ethosu.partition_for_ethosu(mod) + + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.PartitionedSplitRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + + mod["tvmgen_default_ethos_u_main_0"] = relay.transform.InferType()(mod)[ + "tvmgen_default_ethos_u_main_0" + ] + + verify(mod["tvmgen_default_ethos_u_main_0"]) + + if __name__ == "__main__": pytest.main([__file__]) From 5c10153df29d1ffec1b629592a0a5c219b2bc834 Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Wed, 8 Dec 2021 15:29:54 +0000 Subject: [PATCH 2/2] Rebase the tests --- .../contrib/test_ethosu/test_codegen.py | 66 ++----------------- 1 file changed, 5 insertions(+), 61 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index a33b818a6580..4db20f9b3d0a 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -870,68 +870,12 @@ def sigmoid_function(x): ], ) def test_tflite_split(accel_type, ifm_shape, num_or_size_splits, axis): - dtype = "int8" - - def get_tflite_graph(): - class Model(tf.Module): - @tf.function - def tf_function(self, x, num_or_size_splits, axis): - op = tf.split(x, num_or_size_splits, axis=axis) - return op - - model = Model() - - concrete_func = model.tf_function.get_concrete_function( - tf.TensorSpec(ifm_shape, dtype=tf.float32), num_or_size_splits, axis - ) - - # Convert the model - def representative_dataset(): - for _ in range(100): - data = np.random.rand(*tuple(ifm_shape)) - yield [data.astype(np.float32)] - - converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.representative_dataset = representative_dataset - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] - converter.inference_input_type = tf.int8 - converter.inference_output_type = tf.int8 - tflite_model = converter.convert() - return tflite_model - - tflite_graph = get_tflite_graph() - tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) - - relay_module, params = relay.frontend.from_tflite( - tflite_model, - shape_dict={"input": ifm_shape}, - dtype_dict={"input": dtype}, - ) - mod = partition_for_ethosu(relay_module, params) - - # Generate reference data - input_data, output_data = infra.generate_ref_data_tflite(tflite_graph) - - compiled_models = infra.build_source( - mod, - input_data, - output_data, - accel_type, - ) - - # Assumes only two runtime.Modules are created -- i.e. single offload module - imported_modules = compiled_models[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] - - # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) + @tf.function + def split_func(x): + op = tf.split(x, num_or_size_splits, axis=axis) + return op - infra.print_payload(cmms) - infra.verify_source(compiled_models, accel_type) + _compare_tvm_with_tflite(split_func, [ifm_shape], accel_type) if __name__ == "__main__":