From 8c0ea735f0440198665e342c75dca33d213aba25 Mon Sep 17 00:00:00 2001 From: Dhruv Chauhan Date: Tue, 22 Feb 2022 11:26:48 +0000 Subject: [PATCH 1/9] [microNPU] Add support for TFLite FULLY_CONNECTED This is primarily a legalization to an NPU Conv2d operator. The legalization target is Conv2d with 1 1 I O (HWIO) --- .../relay/backend/contrib/ethosu/legalize.py | 83 ++++++++++++++ .../tvm/relay/backend/contrib/ethosu/util.py | 14 +++ python/tvm/relay/op/contrib/ethosu.py | 104 ++++++++++++++++++ 3 files changed, 201 insertions(+) diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 704168eb34a5..40b91fe0eef5 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -1577,6 +1577,88 @@ def __call__(self, *args, **kwargs): pass +class FullyConnectedRewriter(DFPatternCallback): + """Legalize Fully Connected (with bias and clip) to an EthosU operator""" + + def __init__(self): + super().__init__(require_type=True) + self.pattern = ( + wildcard().has_attr({"Composite": ethosu_patterns.FullyConnectedParams.composite_name}) + )(wildcard()) + + def callback(self, pre, post, node_map): + params = ethosu_patterns.FullyConnectedParams(post.op.body) + params.ifm.tensor = post.args[0] + activation_map = {"clip": "CLIP"} + + # IFM reshapes + ifm = post.args[0] + if len(params.ifm.shape) != 4 or not params.ifm.shape[1] == params.ifm.shape[2] == 1: + ifm = relay.reshape(ifm, (-1, 1, 1, params.ifm.shape[-1])) + + # Weight transformations + weights_values = params.weights.values + weights_values_ohwi = np.expand_dims(weights_values, axis=(1, 2)) + if params.activation: + activation = activation_map[params.activation.op.name] + clip_min = int(params.activation.attrs.a_min) + clip_max = int(params.activation.attrs.a_max) + else: + activation = "NONE" + clip_min = 0 + clip_max = 0 + scale_bias = vela_api.pack_biases( + biases=params.biases.tensor.data.asnumpy(), + ifm_scale=params.ifm.q_params.scale_f32, + ifm_dtype=np.dtype(params.ifm.dtype), + weight_scales=params.weights.q_params.scale_f32, + ofm_scale=params.ofm.q_params.scale_f32, + is_activation_tanh_or_sigmoid=False, + ) + ethosu_fc = ethosu_ops.ethosu_conv2d( + ifm=ifm, + weight=relay.const(weights_values_ohwi, params.weights.values.dtype), + scale_bias=relay.const(scale_bias, "uint8"), + lut=relay.const([], dtype="int8"), + ifm_scale=float(params.ifm.q_params.scale_f32), + ifm_zero_point=int(params.ifm.q_params.zero_point), + weight_zero_point=int(params.weights.q_params.zero_point), + ofm_scale=float(params.ofm.q_params.scale_f32), + ofm_zero_point=int(params.ofm.q_params.zero_point), + kernel_shape=[1, 1], + ofm_channels=params.weights.shape[0], + strides=(1, 1), + padding=(0, 0, 0, 0), + dilation=(1, 1), + activation=activation, + clip_min=clip_min, + clip_max=clip_max, + upscale="NONE", + ifm_layout="NHWC", + ofm_layout="NHWC", + ) + + if len(params.ofm.shape) != 4 or not params.ofm.shape[1] == params.ofm.shape[2] == 1: + ethosu_fc = relay.reshape(ethosu_fc, params.ofm.shape) + return ethosu_fc + + +@ir.transform.module_pass(opt_level=1) +class LegalizeFullyConnected: + """This is the pass that wraps the AddRewriter""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(FullyConnectedRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + @ir.transform.module_pass(opt_level=1) class LegalizeEthosU: """This is the pass to call graph-rewrites to perform graph transformation @@ -1615,6 +1697,7 @@ def transform_module( mod = LegalizeReshape()(mod) mod = LegalizeStridedSlice()(mod) mod = LegalizeNoOps()(mod) + mod = LegalizeFullyConnected()(mod) return mod def __call__(self, *args, **kwargs): diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index 8234dc047fd8..dffc237e791c 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -129,6 +129,20 @@ class DequantizeArgs(Enum): IFM_ZERO_POINT = 2 +class QDenseArgs(Enum): + """ + This is a helper enum to access the correct index of + qnn.dense arguments + """ + + IFM = 0 + WEIGHTS = 1 + IFM_ZERO_POINT = 2 + WEIGHTS_ZERO_POINT = 3 + IFM_SCALE = 4 + WEIGHTS_SCALE = 5 + + def is_composite_func(func: relay.Function, name: str) -> bool: """ This method checks whether the call is to diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index d83c5d38f429..0a14c95b30d0 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -1537,6 +1537,105 @@ def squeeze_pattern(): return is_op("squeeze")(wildcard()) +class FullyConnectedParams: + """ + This class will parse a call to an ethos-u.fully_connected composite + function and extract the parameter information. + """ + + composite_name = "ethosu.fully_connected" + activation_map = {"clip": "CLIP"} + + @requires_vela + def __init__(self, func_body): + from tvm.relay.backend.contrib.ethosu.util import QDenseArgs # type: ignore + from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs + from tvm.relay.backend.contrib.ethosu.util import RequantArgs + + activation = None + if str(func_body.op) in self.activation_map.keys(): + activation = func_body + requantize_op = activation.args[0] + else: + requantize_op = func_body + + bias_add = requantize_op.args[0] + qnn_dense = bias_add.args[0] + + # We consider the weights & biases as params as they should be constant + self.weights = TensorParams( + qnn_dense.args[QDenseArgs.weights.value], + "OI", + qnn_dense.args[QDenseArgs.weights_scale.value], + qnn_dense.args[QDenseArgs.weights_zero_point.value], + ) + self.biases = TensorParams( + bias_add.args[BiasAddArgs.BIASES.value], + None, + requantize_op.args[RequantArgs.IFM_SCALE.value], + requantize_op.args[RequantArgs.IFM_ZERO_POINT.value], + ) + self.ifm = TensorParams( + qnn_dense.args[QDenseArgs.ifm.value], + None, + qnn_dense.args[QDenseArgs.ifm_scale.value], + qnn_dense.args[QDenseArgs.ifm_zero_point.value], + ) + self.ofm = TensorParams( + func_body, + None, + requantize_op.args[RequantArgs.OFM_SCALE.value], + requantize_op.args[RequantArgs.OFM_ZERO_POINT.value], + ) + + self.activation = activation + + def is_valid(self): + """ + Checks whether Fully Connected has compatible attributes with HW + """ + + def check_weights_fc(weights): + """Checks whether weight tensor is compatible with HW""" + weights_limit = 127 * 65536 + # A saturation upper bound check for accumulators + weights.values = weights.values - weights.q_params.zero_point + axis = 1 + sum_weights = np.amax(np.sum(np.absolute(weights.values), axis=axis)) + if not sum_weights <= weights_limit: + return False + return True + + if not check_valid_dtypes([self.input, self.output], supported_dtypes=[np.int8]): + return False + if not check_weights_fc(self.weights): + return False + if not check_bias(self.biases): + return False + if not check_batch_size(self.ifm): + return False + # Check input shape + if len(self.ifm.shape) < 2: + return False + if not np.all(np.array(self.ifm.shape[:-1]) == 1): + # As we reshape the ifm from + # [n0, n1, ... , n_m] to [n0 * n1 * ... * n_{m-1}, n_m] + # all except the last dims need to be 1. + return False + return True + +def qnn_fc_pattern(): + dense = is_op("qnn.dense")( + wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant() + ) + bias_add = is_op("nn.bias_add")(dense, is_constant()) + req = is_op("qnn.requantize")( + dense | bias_add, is_constant(), is_constant(), is_constant(), is_constant() + ) + optional_clip = req.optional(is_op("clip")) + return optional_clip + + @register_pattern_table("ethos-u") def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]: return [ @@ -1652,6 +1751,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal squeeze_pattern(), lambda pat: SqueezeParams(pat).is_valid(), ), + ( + FullyConnectedParams.composite_name, + qnn_fc_pattern(), + lambda pat: FullyConnectedParams(pat).is_valid(), + ), ] From dab5c5e8f626170c3ac347f5bf84907b8a3a9f09 Mon Sep 17 00:00:00 2001 From: Dhruv Chauhan Date: Tue, 22 Feb 2022 12:17:14 +0000 Subject: [PATCH 2/9] [microNPU] Add support for TFLite FULLY_CONNECTED Test TVM runtime against TFLite for codegen and operator legalization. --- .../contrib/test_ethosu/test_codegen.py | 19 +++++ .../contrib/test_ethosu/test_legalize.py | 79 +++++++++++++++++++ 2 files changed, 98 insertions(+) diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 8ac9bafcdb09..324cc6c687ec 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -1167,5 +1167,24 @@ def leaky_relu_func(x): _compare_tvm_with_tflite(leaky_relu_func, [ifm_shape], accel_type) +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize("units", [32, 64]) +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("activation_function", ["RELU", "NONE"]) +def test_tflite_fully_connected( + accel_type, + units, + use_bias, + activation_function, +): + @tf.function + def fully_connected(): + return tf.keras.layers.Dense( + units=units, activation=activation_function, use_bias=use_bias, + ) + + _compare_tvm_with_tflite(fully_connected, (1, 3, units, 1), accel_type) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 83165d7d54fb..da88edc0722b 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -2346,5 +2346,84 @@ def verify(ext_func): verify(mod["tvmgen_default_ethos_u_main_0"]) +@pytest.mark.parametrize("units", [32, 64]) +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("activation_function", ["RELU", "NONE"]) +def test_tflite_fully_connected( + units, + use_bias, + activation_function, +): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def fully_connected(self, x): + return tf.keras.layers.Dense( + units=units, activation=activation_function, use_bias=use_bias, + )(x) + + model = Model() + concrete_func = model.fully_connected.get_concrete_function( + tf.TensorSpec([1, 3, units, 1], dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple([1, 3, units, 1])) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + def verify(ext_func): + op = ext_func.body + ofm_channels = op.attrs.ofm_channels + + # check IFM + ifm = op.args[0].checked_type + assert list([1, 3, units, 1]) == list([1, 3, units, 1]) + assert str(ifm.dtype) == dtype + assert ifm.shape[3] == ofm_channels + + # Check that scale_bias matches weight tensor + assert list(op.args[2].checked_type.shape)[0] == ofm_channels + + if activation_function == "RELU": + assert str(op.attrs.activation) == "CLIP" + + dense_pattern_table = [ + ( + ethosu.FullyConnectedParams.composite_name, + ethosu.qnn_fc_pattern(), + lambda pat: ethosu.FullyConnectedParams(pat).is_valid(), + ) + ] + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": [1, 3, units, 1]}, + dtype_dict={"input": dtype}, + ) + + mod["main"] = bind_params_by_name(mod["main"], params) + mod = partition_ethosu_by_table(mod, dense_pattern_table) + + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.FullyConnectedRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + verify(mod["tvmgen_default_ethos_u_main_0"]) + if __name__ == "__main__": pytest.main([__file__]) From ef1d576a4e2ad78126b06fd3ac07bdc992bd9160 Mon Sep 17 00:00:00 2001 From: Dhruv Chauhan Date: Tue, 22 Feb 2022 13:10:02 +0000 Subject: [PATCH 3/9] [microNPU] Add support for TFLite FULLY_CONNECTED Fix linting --- python/tvm/relay/op/contrib/ethosu.py | 3 ++- tests/python/contrib/test_ethosu/test_codegen.py | 4 +++- tests/python/contrib/test_ethosu/test_legalize.py | 7 +++++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 0a14c95b30d0..d43280680197 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -1561,7 +1561,7 @@ def __init__(self, func_body): bias_add = requantize_op.args[0] qnn_dense = bias_add.args[0] - + # We consider the weights & biases as params as they should be constant self.weights = TensorParams( qnn_dense.args[QDenseArgs.weights.value], @@ -1624,6 +1624,7 @@ def check_weights_fc(weights): return False return True + def qnn_fc_pattern(): dense = is_op("qnn.dense")( wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant() diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 324cc6c687ec..b2f9b0fecc99 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -1180,7 +1180,9 @@ def test_tflite_fully_connected( @tf.function def fully_connected(): return tf.keras.layers.Dense( - units=units, activation=activation_function, use_bias=use_bias, + units=units, + activation=activation_function, + use_bias=use_bias, ) _compare_tvm_with_tflite(fully_connected, (1, 3, units, 1), accel_type) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index da88edc0722b..88562afe1b82 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -2361,8 +2361,10 @@ class Model(tf.Module): @tf.function def fully_connected(self, x): return tf.keras.layers.Dense( - units=units, activation=activation_function, use_bias=use_bias, - )(x) + units=units, + activation=activation_function, + use_bias=use_bias, + )(x) model = Model() concrete_func = model.fully_connected.get_concrete_function( @@ -2425,5 +2427,6 @@ def verify(ext_func): ) verify(mod["tvmgen_default_ethos_u_main_0"]) + if __name__ == "__main__": pytest.main([__file__]) From a4741b96158d7e62ab4bc348e2bb7f40ca5e8685 Mon Sep 17 00:00:00 2001 From: Dhruv Chauhan Date: Tue, 1 Mar 2022 11:57:19 +0000 Subject: [PATCH 4/9] [microNPU] Add support for TFLite FULLY_CONNECTED Address comments, update codegen test, fix linting. --- .../relay/backend/contrib/ethosu/legalize.py | 31 +++++++----- python/tvm/relay/op/contrib/ethosu.py | 50 +++++++++++-------- .../contrib/test_ethosu/test_codegen.py | 29 ++++++++--- 3 files changed, 70 insertions(+), 40 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 40b91fe0eef5..0a2dd1a22c77 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -1589,26 +1589,33 @@ def __init__(self): def callback(self, pre, post, node_map): params = ethosu_patterns.FullyConnectedParams(post.op.body) params.ifm.tensor = post.args[0] - activation_map = {"clip": "CLIP"} + activation = None # IFM reshapes ifm = post.args[0] if len(params.ifm.shape) != 4 or not params.ifm.shape[1] == params.ifm.shape[2] == 1: - ifm = relay.reshape(ifm, (-1, 1, 1, params.ifm.shape[-1])) + ifm = relay.reshape(ifm, (1, 1, 1, params.ifm.shape[-1])) # Weight transformations weights_values = params.weights.values weights_values_ohwi = np.expand_dims(weights_values, axis=(1, 2)) if params.activation: - activation = activation_map[params.activation.op.name] - clip_min = int(params.activation.attrs.a_min) - clip_max = int(params.activation.attrs.a_max) + activation = ethosu_patterns.FullyConnectedParams.activation_map[ + params.activation.op.name + ] + if params.activation: + clip_min = int(params.activation.attrs.a_min) + clip_max = int(params.activation.attrs.a_max) else: - activation = "NONE" clip_min = 0 clip_max = 0 + bias_values = ( + params.biases.tensor.data.asnumpy() + if params.biases + else np.zeros((params.ifm.shape[-1])) + ) scale_bias = vela_api.pack_biases( - biases=params.biases.tensor.data.asnumpy(), + biases=bias_values, ifm_scale=params.ifm.q_params.scale_f32, ifm_dtype=np.dtype(params.ifm.dtype), weight_scales=params.weights.q_params.scale_f32, @@ -1627,9 +1634,9 @@ def callback(self, pre, post, node_map): ofm_zero_point=int(params.ofm.q_params.zero_point), kernel_shape=[1, 1], ofm_channels=params.weights.shape[0], - strides=(1, 1), - padding=(0, 0, 0, 0), - dilation=(1, 1), + strides=params.strides, + padding=params.padding, + dilation=params.dilation, activation=activation, clip_min=clip_min, clip_max=clip_max, @@ -1645,7 +1652,7 @@ def callback(self, pre, post, node_map): @ir.transform.module_pass(opt_level=1) class LegalizeFullyConnected: - """This is the pass that wraps the AddRewriter""" + """This is the pass that wraps the FullyConnectedRewriter""" def transform_module( self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext @@ -1696,8 +1703,8 @@ def transform_module( mod = LegalizeSqueeze()(mod) mod = LegalizeReshape()(mod) mod = LegalizeStridedSlice()(mod) - mod = LegalizeNoOps()(mod) mod = LegalizeFullyConnected()(mod) + mod = LegalizeNoOps()(mod) return mod def __call__(self, *args, **kwargs): diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index d43280680197..c668fa8a0643 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -1543,7 +1543,7 @@ class FullyConnectedParams: function and extract the parameter information. """ - composite_name = "ethosu.fully_connected" + composite_name = "ethos-u.fully_connected" activation_map = {"clip": "CLIP"} @requires_vela @@ -1552,34 +1552,42 @@ def __init__(self, func_body): from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs from tvm.relay.backend.contrib.ethosu.util import RequantArgs - activation = None + self.activation = None if str(func_body.op) in self.activation_map.keys(): activation = func_body requantize_op = activation.args[0] else: requantize_op = func_body - bias_add = requantize_op.args[0] - qnn_dense = bias_add.args[0] + call = func_body.args[0] + if str(requantize_op.op) == "nn.bias_add": + bias_add = call + else: + bias_add = None + qnn_dense = call - # We consider the weights & biases as params as they should be constant + # weights & biases are params as they should be constant self.weights = TensorParams( - qnn_dense.args[QDenseArgs.weights.value], - "OI", - qnn_dense.args[QDenseArgs.weights_scale.value], - qnn_dense.args[QDenseArgs.weights_zero_point.value], - ) - self.biases = TensorParams( - bias_add.args[BiasAddArgs.BIASES.value], + qnn_dense.args[QDenseArgs.WEIGHTS.value], None, - requantize_op.args[RequantArgs.IFM_SCALE.value], - requantize_op.args[RequantArgs.IFM_ZERO_POINT.value], + qnn_dense.args[QDenseArgs.WEIGHTS_SCALE.value], + qnn_dense.args[QDenseArgs.WEIGHTS_ZERO_POINT.value], + ) + self.biases = ( + TensorParams( + bias_add.args[BiasAddArgs.BIASES.value], + None, + requantize_op.args[RequantArgs.IFM_SCALE.value], + requantize_op.args[RequantArgs.IFM_ZERO_POINT.value], + ) + if bias_add + else None ) self.ifm = TensorParams( - qnn_dense.args[QDenseArgs.ifm.value], + qnn_dense.args[QDenseArgs.IFM.value], None, - qnn_dense.args[QDenseArgs.ifm_scale.value], - qnn_dense.args[QDenseArgs.ifm_zero_point.value], + qnn_dense.args[QDenseArgs.IFM_SCALE.value], + qnn_dense.args[QDenseArgs.IFM_ZERO_POINT.value], ) self.ofm = TensorParams( func_body, @@ -1588,9 +1596,11 @@ def __init__(self, func_body): requantize_op.args[RequantArgs.OFM_ZERO_POINT.value], ) - self.activation = activation + self.strides = (1, 1) + self.dilation = (1, 1) + self.padding = (0, 0, 0, 0) - def is_valid(self): + def is_valid(self) -> bool: """ Checks whether Fully Connected has compatible attributes with HW """ @@ -1606,7 +1616,7 @@ def check_weights_fc(weights): return False return True - if not check_valid_dtypes([self.input, self.output], supported_dtypes=[np.int8]): + if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8]): return False if not check_weights_fc(self.weights): return False diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index b2f9b0fecc99..67a7be570dec 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -1168,24 +1168,37 @@ def leaky_relu_func(x): @pytest.mark.parametrize("accel_type", ACCEL_TYPES) -@pytest.mark.parametrize("units", [32, 64]) +@pytest.mark.parametrize("ifm_shape", [(1, 14), (1, 151)]) +@pytest.mark.parametrize("ofm_channels", [32, 64]) @pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("activation_function", ["RELU", "NONE"]) def test_tflite_fully_connected( accel_type, - units, + ifm_shape, + ofm_channels, use_bias, activation_function, ): @tf.function - def fully_connected(): - return tf.keras.layers.Dense( - units=units, - activation=activation_function, - use_bias=use_bias, + def fully_connected(x): + bias_shape = ofm_channels + bias = tf.constant(np.random.uniform(size=bias_shape), dtype=tf.float32) + w = tf.constant( + np.random.uniform(size=[ifm_shape[1], ofm_channels]), + dtype=tf.float32, ) + if use_bias: + w = tf.nn.bias_add(w, bias) + if activation_function: + w = tf.nn.relu(w) + return tf.matmul(x, w) - _compare_tvm_with_tflite(fully_connected, (1, 3, units, 1), accel_type) + # TODO(dchauhan-arm) For now output is not bit exact with TFLite. + # This is because TFLite reference kernels are not being used. + # For this, TFLite will need upgrading to 2.6. + _compare_tvm_with_tflite( + fully_connected, [(ofm_channels, ifm_shape[1])], accel_type, output_tolerance=1 + ) if __name__ == "__main__": From ae6827c807df5bf00b9a36177767ae8e7a5a39d8 Mon Sep 17 00:00:00 2001 From: Dhruv Chauhan Date: Wed, 2 Mar 2022 17:37:06 +0000 Subject: [PATCH 5/9] [microNPU] Add support for TFLite FULLY_CONNECTED Address more comments, ensure qnn.dense is lowered to NPU, fix linting --- .../relay/backend/contrib/ethosu/legalize.py | 19 +++---- python/tvm/relay/op/contrib/ethosu.py | 40 +++++++-------- .../contrib/test_ethosu/test_codegen.py | 9 ++-- .../contrib/test_ethosu/test_legalize.py | 50 +++++++++---------- 4 files changed, 56 insertions(+), 62 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 0a2dd1a22c77..18271c16a071 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -1589,7 +1589,6 @@ def __init__(self): def callback(self, pre, post, node_map): params = ethosu_patterns.FullyConnectedParams(post.op.body) params.ifm.tensor = post.args[0] - activation = None # IFM reshapes ifm = post.args[0] @@ -1600,19 +1599,15 @@ def callback(self, pre, post, node_map): weights_values = params.weights.values weights_values_ohwi = np.expand_dims(weights_values, axis=(1, 2)) if params.activation: - activation = ethosu_patterns.FullyConnectedParams.activation_map[ - params.activation.op.name - ] - if params.activation: - clip_min = int(params.activation.attrs.a_min) - clip_max = int(params.activation.attrs.a_max) + clip_min = int(params.activation.attrs.a_min) + clip_max = int(params.activation.attrs.a_max) else: clip_min = 0 clip_max = 0 bias_values = ( params.biases.tensor.data.asnumpy() if params.biases - else np.zeros((params.ifm.shape[-1])) + else np.zeros((params.ofm.shape[-1])) ) scale_bias = vela_api.pack_biases( biases=bias_values, @@ -1634,10 +1629,10 @@ def callback(self, pre, post, node_map): ofm_zero_point=int(params.ofm.q_params.zero_point), kernel_shape=[1, 1], ofm_channels=params.weights.shape[0], - strides=params.strides, - padding=params.padding, - dilation=params.dilation, - activation=activation, + strides=(1, 1), + padding=(0, 0, 0, 0), + dilation=(1, 1), + activation=params.activation, clip_min=clip_min, clip_max=clip_max, upscale="NONE", diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index c668fa8a0643..d75c5aed9a9c 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -23,7 +23,8 @@ import tvm # type: ignore from tvm import relay -from tvm.relay.expr import Constant, Call # type: ignore +from tvm.relay.expr import Constant, Call +from tvm.relay.op.contrib.arm_compute_lib import qnn_dense # type: ignore from tvm.relay.op.contrib.register import register_pattern_table # type: ignore from tvm.relay.dataflow_pattern import wildcard, is_op, is_constant, is_tuple # type: ignore from tvm.relay.build_module import bind_params_by_name # type: ignore @@ -1103,7 +1104,10 @@ def is_valid(self): """ if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8]): return False - return True + return True # optional_bias_add = ( + + # is_op("nn.bias_add")(dense, is_constant()) | dense + # ) class TanhParams(LutActivationParams): @@ -1544,7 +1548,6 @@ class FullyConnectedParams: """ composite_name = "ethos-u.fully_connected" - activation_map = {"clip": "CLIP"} @requires_vela def __init__(self, func_body): @@ -1553,18 +1556,19 @@ def __init__(self, func_body): from tvm.relay.backend.contrib.ethosu.util import RequantArgs self.activation = None - if str(func_body.op) in self.activation_map.keys(): - activation = func_body - requantize_op = activation.args[0] + if str(func_body.op) == "clip": + self.activation = func_body + requantize_op = self.activation.args[0] else: requantize_op = func_body - call = func_body.args[0] - if str(requantize_op.op) == "nn.bias_add": + call = requantize_op.args[0] + if str(requantize_op.args[0].op) == "nn.bias_add": bias_add = call + qnn_dense = call else: bias_add = None - qnn_dense = call + qnn_dense = call # weights & biases are params as they should be constant self.weights = TensorParams( @@ -1596,10 +1600,6 @@ def __init__(self, func_body): requantize_op.args[RequantArgs.OFM_ZERO_POINT.value], ) - self.strides = (1, 1) - self.dilation = (1, 1) - self.padding = (0, 0, 0, 0) - def is_valid(self) -> bool: """ Checks whether Fully Connected has compatible attributes with HW @@ -1639,9 +1639,9 @@ def qnn_fc_pattern(): dense = is_op("qnn.dense")( wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant() ) - bias_add = is_op("nn.bias_add")(dense, is_constant()) + optional_bias_add = is_op("nn.bias_add")(dense, is_constant()) | dense req = is_op("qnn.requantize")( - dense | bias_add, is_constant(), is_constant(), is_constant(), is_constant() + dense | optional_bias_add, is_constant(), is_constant(), is_constant(), is_constant() ) optional_clip = req.optional(is_op("clip")) return optional_clip @@ -1665,6 +1665,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal qnn_conv2d_transpose_pattern(), lambda pat: QnnConv2DTransposeParams(pat).is_valid(), ), + ( + FullyConnectedParams.composite_name, + qnn_fc_pattern(), + lambda pat: FullyConnectedParams(pat).is_valid(), + ), ( MaxPool2DParams.composite_name, qnn_maxpool2d_pattern(), @@ -1762,11 +1767,6 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal squeeze_pattern(), lambda pat: SqueezeParams(pat).is_valid(), ), - ( - FullyConnectedParams.composite_name, - qnn_fc_pattern(), - lambda pat: FullyConnectedParams(pat).is_valid(), - ), ] diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 67a7be570dec..a4e6a50af6df 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -1181,17 +1181,18 @@ def test_tflite_fully_connected( ): @tf.function def fully_connected(x): - bias_shape = ofm_channels + bias_shape = ifm_shape[1] bias = tf.constant(np.random.uniform(size=bias_shape), dtype=tf.float32) w = tf.constant( np.random.uniform(size=[ifm_shape[1], ofm_channels]), dtype=tf.float32, ) if use_bias: - w = tf.nn.bias_add(w, bias) + x = tf.nn.bias_add(x, bias) if activation_function: - w = tf.nn.relu(w) - return tf.matmul(x, w) + x = tf.nn.relu(x) + x = tf.matmul(x, w) + return x # TODO(dchauhan-arm) For now output is not bit exact with TFLite. # This is because TFLite reference kernels are not being used. diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 88562afe1b82..9fc2fe3bab6f 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -2346,11 +2346,13 @@ def verify(ext_func): verify(mod["tvmgen_default_ethos_u_main_0"]) -@pytest.mark.parametrize("units", [32, 64]) +@pytest.mark.parametrize("ifm_shape", [(1, 14), (1, 151)]) +@pytest.mark.parametrize("ofm_channels", [32, 64]) @pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("activation_function", ["RELU", "NONE"]) def test_tflite_fully_connected( - units, + ifm_shape, + ofm_channels, use_bias, activation_function, ): @@ -2360,21 +2362,27 @@ def create_tflite_graph(): class Model(tf.Module): @tf.function def fully_connected(self, x): - return tf.keras.layers.Dense( - units=units, - activation=activation_function, - use_bias=use_bias, - )(x) + bias_shape = ofm_channels + bias = tf.constant(np.random.uniform(size=bias_shape), dtype=tf.float32) + w = tf.constant( + np.random.uniform(size=[ifm_shape[1], ofm_channels]), + dtype=tf.float32, + ) + x = tf.matmul(x, w) + if use_bias: + x = tf.nn.bias_add(x, bias) + if activation_function: + x = tf.nn.relu(x) + return x model = Model() concrete_func = model.fully_connected.get_concrete_function( - tf.TensorSpec([1, 3, units, 1], dtype=tf.float32) + tf.TensorSpec(ifm_shape, dtype=tf.float32) ) - # Convert the model def representative_dataset(): for _ in range(100): - data = np.random.rand(*tuple([1, 3, units, 1])) + data = np.random.rand(*tuple(ifm_shape)) yield [data.astype(np.float32)] converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) @@ -2388,21 +2396,10 @@ def representative_dataset(): def verify(ext_func): op = ext_func.body - ofm_channels = op.attrs.ofm_channels - - # check IFM - ifm = op.args[0].checked_type - assert list([1, 3, units, 1]) == list([1, 3, units, 1]) - assert str(ifm.dtype) == dtype - assert ifm.shape[3] == ofm_channels - - # Check that scale_bias matches weight tensor - assert list(op.args[2].checked_type.shape)[0] == ofm_channels - if activation_function == "RELU": assert str(op.attrs.activation) == "CLIP" - dense_pattern_table = [ + fc_pattern_table = [ ( ethosu.FullyConnectedParams.composite_name, ethosu.qnn_fc_pattern(), @@ -2413,18 +2410,19 @@ def verify(ext_func): tflite_graph = create_tflite_graph() tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) - mod, params = relay.frontend.from_tflite( + mod, fc_params = relay.frontend.from_tflite( tflite_model, - shape_dict={"input": [1, 3, units, 1]}, + shape_dict={"input": ifm_shape}, dtype_dict={"input": dtype}, ) - mod["main"] = bind_params_by_name(mod["main"], params) - mod = partition_ethosu_by_table(mod, dense_pattern_table) + mod["main"] = bind_params_by_name(mod["main"], fc_params) + mod = partition_ethosu_by_table(mod, fc_pattern_table) mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( legalize.FullyConnectedRewriter(), mod["tvmgen_default_ethos_u_main_0"] ) + verify(mod["tvmgen_default_ethos_u_main_0"]) From 18bd546821ab0fde0d27c8d42ed08bd27215d14f Mon Sep 17 00:00:00 2001 From: Dhruv Chauhan Date: Thu, 3 Mar 2022 17:33:50 +0000 Subject: [PATCH 6/9] [microNPU] Add support for TFLite FULLY_CONNECTED Fix linting, update legalization test and codegen test for completeness. --- .../relay/backend/contrib/ethosu/legalize.py | 4 +- python/tvm/relay/op/contrib/ethosu.py | 3 +- .../contrib/test_ethosu/test_codegen.py | 11 ++--- .../contrib/test_ethosu/test_legalize.py | 40 ++++++++++++++++++- 4 files changed, 46 insertions(+), 12 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 18271c16a071..3cef1fcd9ecf 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -1599,9 +1599,11 @@ def callback(self, pre, post, node_map): weights_values = params.weights.values weights_values_ohwi = np.expand_dims(weights_values, axis=(1, 2)) if params.activation: + activation = "CLIP" clip_min = int(params.activation.attrs.a_min) clip_max = int(params.activation.attrs.a_max) else: + activation = "NONE" clip_min = 0 clip_max = 0 bias_values = ( @@ -1632,7 +1634,7 @@ def callback(self, pre, post, node_map): strides=(1, 1), padding=(0, 0, 0, 0), dilation=(1, 1), - activation=params.activation, + activation=activation, clip_min=clip_min, clip_max=clip_max, upscale="NONE", diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index d75c5aed9a9c..5805928f5a82 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -24,7 +24,6 @@ import tvm # type: ignore from tvm import relay from tvm.relay.expr import Constant, Call -from tvm.relay.op.contrib.arm_compute_lib import qnn_dense # type: ignore from tvm.relay.op.contrib.register import register_pattern_table # type: ignore from tvm.relay.dataflow_pattern import wildcard, is_op, is_constant, is_tuple # type: ignore from tvm.relay.build_module import bind_params_by_name # type: ignore @@ -1565,7 +1564,7 @@ def __init__(self, func_body): call = requantize_op.args[0] if str(requantize_op.args[0].op) == "nn.bias_add": bias_add = call - qnn_dense = call + qnn_dense = call.args[0] else: bias_add = None qnn_dense = call diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index a4e6a50af6df..7bdb1ff1af3d 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -1181,25 +1181,20 @@ def test_tflite_fully_connected( ): @tf.function def fully_connected(x): - bias_shape = ifm_shape[1] + bias_shape = ofm_channels bias = tf.constant(np.random.uniform(size=bias_shape), dtype=tf.float32) w = tf.constant( np.random.uniform(size=[ifm_shape[1], ofm_channels]), dtype=tf.float32, ) + x = tf.matmul(x, w) if use_bias: x = tf.nn.bias_add(x, bias) if activation_function: x = tf.nn.relu(x) - x = tf.matmul(x, w) return x - # TODO(dchauhan-arm) For now output is not bit exact with TFLite. - # This is because TFLite reference kernels are not being used. - # For this, TFLite will need upgrading to 2.6. - _compare_tvm_with_tflite( - fully_connected, [(ofm_channels, ifm_shape[1])], accel_type, output_tolerance=1 - ) + _compare_tvm_with_tflite(fully_connected, [ifm_shape], accel_type) if __name__ == "__main__": diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 9fc2fe3bab6f..bfe7feaa8947 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -18,6 +18,8 @@ import pytest +from tvm.relay.analysis.analysis import extract_fused_functions + pytest.importorskip("ethosu.vela") import math @@ -2395,7 +2397,43 @@ def representative_dataset(): return tflite_model def verify(ext_func): - op = ext_func.body + op = ext_func.body.args[0] + ofm_channels = op.attrs.ofm_channels + + # check IFM + ifm = op.args[0].checked_type + assert [ifm.shape[2], ifm.shape[3]] == list(ifm_shape) + assert str(ifm.dtype) == dtype + + # check OFM + ofm = op.checked_type + assert [ofm.shape[2], ofm.shape[3]] == [1, ofm_channels] + # assert list(ofm.shape) == list(expected_ofm_shape) + assert str(ofm.dtype) == dtype + # assert ofm.shape[3] == ofm_channels + + # check weights + weights_ohwi = op.args[1].data.asnumpy() + assert str(weights_ohwi.dtype) == dtype + assert weights_ohwi.shape[0] == ofm_channels + assert weights_ohwi.shape[1] == 1 + assert weights_ohwi.shape[2] == 1 + assert weights_ohwi.shape[3] == ifm_shape[1] + + # Check that scale_bias matches weight tensor + assert list(op.args[2].checked_type.shape)[0] == ofm_channels + + # expected_padding = infra.compute_padding_shape( + # ifm_shape, + # expected_ofm_shape, + # (0, 0, 0, 0), + # (1, 1), + # (1, 1), + # (1, 1), + # ) + assert list(op.attrs.padding) == [0, 0, 0, 0] + assert list(op.attrs.strides) == [1, 1] + assert list(op.attrs.dilation) == [1, 1] if activation_function == "RELU": assert str(op.attrs.activation) == "CLIP" From 06aa4d967c2ebda459e634fbaca1ec006b387fff Mon Sep 17 00:00:00 2001 From: Dhruv Chauhan Date: Fri, 4 Mar 2022 10:45:33 +0000 Subject: [PATCH 7/9] [microNPU] Add support for TFLite FULLY_CONNECTED Address comments, fix linting. Certain legalization test assertions were updated. Co-authored-by: Rishabh Jain --- .../relay/backend/contrib/ethosu/legalize.py | 2 +- python/tvm/relay/op/contrib/ethosu.py | 17 ++++++--------- .../contrib/test_ethosu/test_legalize.py | 21 +++---------------- 3 files changed, 10 insertions(+), 30 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index cb05df3d89e0..3fdcdb6c24b5 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -1585,7 +1585,7 @@ def __call__(self, *args, **kwargs): class FullyConnectedRewriter(DFPatternCallback): - """Legalize Fully Connected (with bias and clip) to an EthosU operator""" + """Legalize Fully Connected (with bias and clip) to an NPU operator""" def __init__(self): super().__init__(require_type=True) diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 5805928f5a82..0893be4bb84a 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -23,7 +23,7 @@ import tvm # type: ignore from tvm import relay -from tvm.relay.expr import Constant, Call +from tvm.relay.expr import Constant, Call # type: ignore from tvm.relay.op.contrib.register import register_pattern_table # type: ignore from tvm.relay.dataflow_pattern import wildcard, is_op, is_constant, is_tuple # type: ignore from tvm.relay.build_module import bind_params_by_name # type: ignore @@ -1103,10 +1103,7 @@ def is_valid(self): """ if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8]): return False - return True # optional_bias_add = ( - - # is_op("nn.bias_add")(dense, is_constant()) | dense - # ) + return True class TanhParams(LutActivationParams): @@ -1624,12 +1621,10 @@ def check_weights_fc(weights): if not check_batch_size(self.ifm): return False # Check input shape - if len(self.ifm.shape) < 2: + if not len(self.ifm.shape) == 2: return False - if not np.all(np.array(self.ifm.shape[:-1]) == 1): - # As we reshape the ifm from - # [n0, n1, ... , n_m] to [n0 * n1 * ... * n_{m-1}, n_m] - # all except the last dims need to be 1. + # Check output shape + if not len(self.ofm.shape) == 2: return False return True @@ -1638,7 +1633,7 @@ def qnn_fc_pattern(): dense = is_op("qnn.dense")( wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant() ) - optional_bias_add = is_op("nn.bias_add")(dense, is_constant()) | dense + optional_bias_add = is_op("nn.bias_add")(dense, is_constant()) req = is_op("qnn.requantize")( dense | optional_bias_add, is_constant(), is_constant(), is_constant(), is_constant() ) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 60ddf2dd1993..d57885d70d29 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -18,8 +18,6 @@ import pytest -from tvm.relay.analysis.analysis import extract_fused_functions - pytest.importorskip("ethosu.vela") import math @@ -2477,35 +2475,22 @@ def verify(ext_func): # check IFM ifm = op.args[0].checked_type - assert [ifm.shape[2], ifm.shape[3]] == list(ifm_shape) + assert ifm.shape == [1, 1] + list(ifm_shape) assert str(ifm.dtype) == dtype # check OFM ofm = op.checked_type - assert [ofm.shape[2], ofm.shape[3]] == [1, ofm_channels] - # assert list(ofm.shape) == list(expected_ofm_shape) + assert ofm.shape == [1, 1] + list(ofm_channels) assert str(ofm.dtype) == dtype - # assert ofm.shape[3] == ofm_channels # check weights weights_ohwi = op.args[1].data.asnumpy() assert str(weights_ohwi.dtype) == dtype - assert weights_ohwi.shape[0] == ofm_channels - assert weights_ohwi.shape[1] == 1 - assert weights_ohwi.shape[2] == 1 - assert weights_ohwi.shape[3] == ifm_shape[1] + assert list(weights_ohwi) == [ofm_channels, 1, 1, ifm_shape[1]] # Check that scale_bias matches weight tensor assert list(op.args[2].checked_type.shape)[0] == ofm_channels - # expected_padding = infra.compute_padding_shape( - # ifm_shape, - # expected_ofm_shape, - # (0, 0, 0, 0), - # (1, 1), - # (1, 1), - # (1, 1), - # ) assert list(op.attrs.padding) == [0, 0, 0, 0] assert list(op.attrs.strides) == [1, 1] assert list(op.attrs.dilation) == [1, 1] From 82dc516da97663359e9d89801ed747786226fa2b Mon Sep 17 00:00:00 2001 From: Dhruv Chauhan Date: Wed, 9 Mar 2022 11:21:56 +0000 Subject: [PATCH 8/9] [microNPU] Add support for TFLite FULLY_CONNECTED Fix assertion in legalization test. --- tests/python/contrib/test_ethosu/test_legalize.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index d57885d70d29..a90f8f919ac2 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -2475,12 +2475,12 @@ def verify(ext_func): # check IFM ifm = op.args[0].checked_type - assert ifm.shape == [1, 1] + list(ifm_shape) + assert [ifm.shape[2], ifm.shape[3]] == list(ifm_shape) assert str(ifm.dtype) == dtype # check OFM ofm = op.checked_type - assert ofm.shape == [1, 1] + list(ofm_channels) + assert [ofm.shape[2], ofm.shape[3]] == [1, ofm_channels] assert str(ofm.dtype) == dtype # check weights From f9240218568d6cd257ea493d39f022cc5f0fc6c5 Mon Sep 17 00:00:00 2001 From: Dhruv Chauhan Date: Wed, 9 Mar 2022 12:50:33 +0000 Subject: [PATCH 9/9] [microNPU] Add support for TFLite FULLY_CONNECTED Address comments, fixing assertion on ifm and ofm shape. --- tests/python/contrib/test_ethosu/test_legalize.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index a90f8f919ac2..710c3e8c8812 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -2475,18 +2475,18 @@ def verify(ext_func): # check IFM ifm = op.args[0].checked_type - assert [ifm.shape[2], ifm.shape[3]] == list(ifm_shape) + assert list(ifm.shape) == [1, 1] + list(ifm_shape) assert str(ifm.dtype) == dtype # check OFM ofm = op.checked_type - assert [ofm.shape[2], ofm.shape[3]] == [1, ofm_channels] + assert list(ofm.shape) == [1, 1, 1, ofm_channels] assert str(ofm.dtype) == dtype # check weights weights_ohwi = op.args[1].data.asnumpy() assert str(weights_ohwi.dtype) == dtype - assert list(weights_ohwi) == [ofm_channels, 1, 1, ifm_shape[1]] + assert list(weights_ohwi.shape) == [ofm_channels, 1, 1, ifm_shape[1]] # Check that scale_bias matches weight tensor assert list(op.args[2].checked_type.shape)[0] == ofm_channels