From 5440cc71bd3bca821d5ad854b3d1978bf94c1138 Mon Sep 17 00:00:00 2001 From: Dhruv Chauhan Date: Mon, 15 Nov 2021 10:53:45 +0000 Subject: [PATCH 1/7] [microNPU] Update Conv2D Tests to Use TF API to Gen Test Cases * Current conv2d tests compare the conv2d operator against tvm's execution of the default schedule of conv2d as defined in TOPI and that is not bitexact with tflite runtime's implemention. Therefore a tolerance of "1" in quantized 8-bit domain is used. * Converts the current conv2d tests to use TensorFlow APIs to create a test cases for conv2D and compare against TFLite runtime. Place pytest import skip above imports to satisfy dependancies Rename activation function to be consistent with other tests fix linting failing on attempted import on removed file --- .../relay/backend/contrib/ethosu/__init__.py | 1 - .../relay/backend/contrib/ethosu/errors.py | 35 --- .../relay/backend/contrib/ethosu/legalize.py | 2 - .../contrib/test_ethosu/relay_ir_builder.py | 295 ------------------ .../contrib/test_ethosu/test_codegen.py | 284 ++++++++++------- .../contrib/test_ethosu/test_legalize.py | 235 +++++++------- 6 files changed, 294 insertions(+), 558 deletions(-) delete mode 100644 python/tvm/relay/backend/contrib/ethosu/errors.py delete mode 100644 tests/python/contrib/test_ethosu/relay_ir_builder.py diff --git a/python/tvm/relay/backend/contrib/ethosu/__init__.py b/python/tvm/relay/backend/contrib/ethosu/__init__.py index ed04c202d8af..c4948d54dc26 100644 --- a/python/tvm/relay/backend/contrib/ethosu/__init__.py +++ b/python/tvm/relay/backend/contrib/ethosu/__init__.py @@ -18,7 +18,6 @@ from . import util from . import legalize from . import preprocess -from . import errors from . import codegen from . import vela_api from . import tir_to_cs_translator diff --git a/python/tvm/relay/backend/contrib/ethosu/errors.py b/python/tvm/relay/backend/contrib/ethosu/errors.py deleted file mode 100644 index 65f3711838be..000000000000 --- a/python/tvm/relay/backend/contrib/ethosu/errors.py +++ /dev/null @@ -1,35 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=super-init-not-called -"""This module defines all error types associated with the Arm(R) Ethos(TM)-U NPU code generator.""" - - -class EthosUCodegenError(Exception): - """Base class for all exceptions related to code generation""" - - def __init__(self, data): - self.message = "EthosUCodegenError:" + data - - def __str__(self): - return self.message - - -class UnsupportedLayout(EthosUCodegenError): - """Raised when unsupported layout is encountered during code generation.""" - - def __init__(self, layout): - super().__init__(f"Unsupported Layout {layout}") diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 5613d613f984..46dbef1ef98e 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -209,8 +209,6 @@ def callback( channels_map = { "NHWC": 3, } - if str(params.ofm.layout) not in channels_map.keys(): - raise UnsupportedLayout(str(params.ofm.layout)) kernel_size_map = { "HWIO": params.weights.shape[0:2], "OHWI": params.weights.shape[1:3], diff --git a/tests/python/contrib/test_ethosu/relay_ir_builder.py b/tests/python/contrib/test_ethosu/relay_ir_builder.py deleted file mode 100644 index 6169a3e46520..000000000000 --- a/tests/python/contrib/test_ethosu/relay_ir_builder.py +++ /dev/null @@ -1,295 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Helper module to build relay operations for testing""" - -from pathlib import Path -import numpy as np -import math - -import tvm -from tvm import relay -from tvm.relay.op.contrib import get_pattern_table -from tvm.relay import qnn -from tvm.relay.backend.contrib.ethosu.util import get_range_for_dtype_str - - -class TensorType: - """A data structure to capture tensor parameters""" - - def __init__(self): - self.shape = None - self.dtype = None - self.zp = None - self.sc = None - self.layout = None - - def get_dim_size(self, dim): - for idx, char in enumerate(self.layout): - if dim == char: - return self.shape[idx] - return None - - def get_dim_index(self, dim): - for idx, char in enumerate(self.layout): - if dim == char: - return idx - return None - - -class QnnConv2DParams: - """A data structure to capture relay.qnn.op.conv2D parameters""" - - def __init__(self, dtype): - self.ifm = TensorType() - self.ofm = TensorType() - self.kernel = TensorType() - - # default values - self.ifm.dtype = dtype - self.ifm.layout = "NHWC" - ifm_min, ifm_max = get_range_for_dtype_str(self.ifm.dtype) - self.ifm.zp = relay.const(np.random.randint(ifm_min, ifm_max), "int32") - self.ifm.sc = relay.const(np.random.random() * 2, "float32") - self.kernel.dtype = dtype - self.kernel.layout = "HWIO" - kernel_min, kernel_max = get_range_for_dtype_str(self.kernel.dtype) - self.kernel.zp = relay.const(np.random.randint(kernel_min, kernel_max), "int32") - self.kernel.sc = relay.const(np.random.random() * 2, "float32") - self.ofm.layout = "NHWC" - self.ofm.dtype = dtype - ofm_min, ofm_max = get_range_for_dtype_str(self.ofm.dtype) - self.ofm.zp = relay.const(np.random.randint(ofm_min, ofm_max), "int32") - self.ofm.sc = relay.const(np.random.random() * 2, "float32") - self.dilation = (1, 1) - - self.strides = None - self.pad = None - self.activation = "NONE" - self.clip_min = 0 - self.clip_max = 0 - - def update_output_qnn_params( - self, input_dtype="uint8", kernel_dtype="uint8", output_dtype="uint8" - ): - _, dtype_max = get_range_for_dtype_str(input_dtype) - input_max = self.ifm.sc.data.asnumpy() * (dtype_max - self.ifm.zp.data.asnumpy()) - input_min = -self.ifm.sc.data.asnumpy() * self.ifm.zp.data.asnumpy() - _, dtype_max = get_range_for_dtype_str(kernel_dtype) - kernel_max = np.max( - self.kernel.sc.data.asnumpy() * (dtype_max - self.kernel.zp.data.asnumpy()) - ) - kernel_min = np.min(-self.kernel.sc.data.asnumpy() * self.kernel.zp.data.asnumpy()) - kernel_h = self.kernel.get_dim_size("H") - kernel_w = self.kernel.get_dim_size("W") - channels = self.kernel.get_dim_size("I") - output_limits = [ - kernel_max * kernel_h * kernel_w * channels * input_max, - kernel_min * kernel_h * kernel_w * channels * input_max, - kernel_min * kernel_h * kernel_w * channels * input_min, - kernel_max * kernel_h * kernel_w * channels * input_min, - ] - output_max = max(output_limits) - output_min = min(output_limits) - dtype_min, dtype_max = get_range_for_dtype_str(input_dtype) - self.ofm.sc = relay.const((output_max - output_min) / (dtype_max - dtype_min), "float32") - self.ofm.zp = relay.const(-int(output_min / self.ofm.sc.data.asnumpy()), "int32") - - -class PoolingParams: - """A data structure to capture relay.op.max_pool2d / - relay.op.avg_pool2d parameters - """ - - def __init__(self, dtype): - self.type = None - self.size = None - self.strides = None - self.pad = None - self.layout = None - self.ifm = TensorType() - self.ofm = TensorType() - - # default values - self.ifm.dtype = dtype - self.ifm.layout = "NHWC" - self.ifm.zp = relay.const(np.random.randint(0, 255), "int32") - self.ifm.sc = relay.const(np.random.random() * 2, "float32") - self.ofm.zp = relay.const(np.random.randint(0, 255), "int32") - self.ofm.sc = relay.const(np.random.random() * 2, "float32") - self.ofm.dtype = dtype - self.dilation = (1, 1) - - -class AddParams: - """A data structure to capture relay.qnn.op.add parameters""" - - def __init__(self, dtype): - self.ifm0 = TensorType() - self.ifm1 = TensorType() - self.ofm = TensorType() - - # default values - self.ifm0.dtype = dtype - self.ifm0.zp = relay.const(np.random.randint(0, 255), "int32") - self.ifm0.sc = relay.const(np.random.random() * 2, "float32") - self.ifm1.dtype = dtype - self.ifm1.zp = relay.const(np.random.randint(0, 255), "int32") - self.ifm1.sc = relay.const(np.random.random() * 2, "float32") - self.update_output_qnn_params() - self.ofm.dtype = dtype - - def update_output_qnn_params(self): - ti = np.iinfo(self.ifm0.dtype) - dtype_min, dtype_max = int(ti.min), int(ti.max) - input1_max = self.ifm0.sc.data.asnumpy() * (dtype_max - self.ifm0.zp.data.asnumpy()) - input1_min = (dtype_min - self.ifm0.sc.data.asnumpy()) * self.ifm0.zp.data.asnumpy() - input2_max = self.ifm1.sc.data.asnumpy() * (dtype_max - self.ifm1.zp.data.asnumpy()) - input2_min = (dtype_min - self.ifm1.sc.data.asnumpy()) * self.ifm1.zp.data.asnumpy() - output_max = input1_max + input2_max - output_min = input1_min + input2_min - self.ofm.sc = relay.const((output_max - output_min) / dtype_max, "float32") - self.ofm.zp = relay.const( - (dtype_min - int(output_min / self.ofm.sc.data.asnumpy())), "int32" - ) - - -def get_pad_value(data, kernel, stride): - """Get the pad tuple of value for SAME padding""" - - out = int(math.ceil(float(data) / float(stride))) - pad = max(0, (out - 1) * stride + kernel - data) - pad_before = pad // 2 - pad_after = pad - pad_before - return pad_before, pad_after - - -def create_qnn_conv2d(qnn_conv2d_params, ifm_expr): - """Create a relay.Expr of relay.qnn.conv2D given the parameters""" - v_params = list() - params = { - "kernel_size": [ - qnn_conv2d_params.kernel.get_dim_size("H"), - qnn_conv2d_params.kernel.get_dim_size("W"), - ], - "strides": [qnn_conv2d_params.strides[0], qnn_conv2d_params.strides[1]], - "dilation": [qnn_conv2d_params.dilation[0], qnn_conv2d_params.dilation[1]], - "padding": [0, 0, 0, 0], - "data_layout": qnn_conv2d_params.ifm.layout, - } - dilated_kernel_h = ( - qnn_conv2d_params.dilation[0] * (qnn_conv2d_params.kernel.get_dim_size("H") - 1) + 1 - ) - dilated_kernel_w = ( - qnn_conv2d_params.dilation[1] * (qnn_conv2d_params.kernel.get_dim_size("W") - 1) + 1 - ) - if qnn_conv2d_params.pad == "SAME": - pad_top, pad_bottom = get_pad_value( - qnn_conv2d_params.ifm.get_dim_size("H"), dilated_kernel_h, qnn_conv2d_params.strides[0] - ) - pad_left, pad_right = get_pad_value( - qnn_conv2d_params.ifm.get_dim_size("W"), dilated_kernel_w, qnn_conv2d_params.strides[1] - ) - do_pad = not (pad_top == 0 and pad_bottom == 0 and pad_left == 0 and pad_right == 0) - if do_pad: - params["padding"] = [pad_top, pad_left, pad_bottom, pad_right] - qnn_conv2d_params.pad = params["padding"] - params["input_zero_point"] = qnn_conv2d_params.ifm.zp - params["kernel_zero_point"] = qnn_conv2d_params.kernel.zp - params["out_dtype"] = "int32" - params["input_scale"] = qnn_conv2d_params.ifm.sc - params["kernel_scale"] = qnn_conv2d_params.kernel.sc - params["channels"] = int(qnn_conv2d_params.kernel.get_dim_size("O")) - params["kernel_layout"] = qnn_conv2d_params.kernel.layout - k_shape = qnn_conv2d_params.kernel.shape - k_dtype = qnn_conv2d_params.kernel.dtype - w = tvm.nd.array( - np.random.randint( - np.iinfo(k_dtype).min, high=np.iinfo(k_dtype).max, size=k_shape, dtype=k_dtype - ) - ) - weight_expr = relay.const(w, k_dtype) - v_params.append(w) - qnn_conv2d_expr = qnn.op.conv2d(ifm_expr, weight_expr, **params) - b = tvm.nd.array( - np.random.randint( - 0, high=10, size=(qnn_conv2d_params.kernel.get_dim_size("O")), dtype="int32" - ) - ) - v_params.append(b) - bias_expr = relay.const(b, "int32") - bias = relay.nn.bias_add( - qnn_conv2d_expr, bias_expr, axis=qnn_conv2d_params.ifm.get_dim_index("C") - ) - bias_scale = relay.const( - qnn_conv2d_params.ifm.sc.data.asnumpy() * qnn_conv2d_params.kernel.sc.data.asnumpy(), - "float32", - ) - req_expr = relay.qnn.op.requantize( - bias, - bias_scale, # input zero scale - relay.const(0, "int32"), # input zero point - qnn_conv2d_params.ofm.sc, # output zero scale - qnn_conv2d_params.ofm.zp, # output zero point - out_dtype=qnn_conv2d_params.ofm.dtype, - ) - if qnn_conv2d_params.activation != "NONE": - assert qnn_conv2d_params.activation == "CLIP" - clip_expr = relay.clip(req_expr, qnn_conv2d_params.clip_min, qnn_conv2d_params.clip_max) - return clip_expr, v_params - - return req_expr, v_params - - -def create_pool2d(pooling_params, ifm_expr): - """Create a relay pooling operation""" - assert pooling_params.ifm.layout == "NHWC" - params = { - "pool_size": (pooling_params.size[0], pooling_params.size[1]), - "strides": (pooling_params.strides[0], pooling_params.strides[1]), - "padding": [0, 0], - "layout": "NHWC", - } - if pooling_params.pad == "SAME": - pad_top, pad_bottom = get_pad_value( - pooling_params.ifm.shape[1], pooling_params.size[0], pooling_params.strides[0] - ) - pad_left, pad_right = get_pad_value( - pooling_params.ifm.shape[2], pooling_params.size[1], pooling_params.strides[1] - ) - params["padding"] = [pad_top, pad_left, pad_bottom, pad_right] - if pooling_params.type == "MAX": - out = relay.op.nn.max_pool2d(ifm_expr, **params) - else: - assert pooling_params.type == "AVG" - out = relay.op.cast(ifm_expr, dtype="int32") - out = relay.op.nn.avg_pool2d(out, **params) - out = relay.op.cast(out, dtype=pooling_params.ofm.dtype) - return out - - -def create_qnn_add(ifm0_expr, ifm1_expr, add_params): - add = relay.qnn.op.add( - lhs=ifm0_expr, - rhs=ifm1_expr, - lhs_scale=add_params.ifm0.sc, - lhs_zero_point=add_params.ifm0.zp, - rhs_scale=add_params.ifm1.sc, - rhs_zero_point=add_params.ifm1.zp, - output_scale=add_params.ofm.sc, - output_zero_point=add_params.ofm.zp, - ) - return add diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index e20ab41cb576..b3b3a205f791 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -18,19 +18,18 @@ import pytest pytest.importorskip("ethosu.vela") + import numpy as np import tflite.Model import tvm import tensorflow as tf from tvm import relay -from tvm.relay.backend.contrib.ethosu import util from tvm.relay.op.contrib.ethosu import partition_for_ethosu -from tests.python.relay.aot.aot_test_utils import generate_ref_data -from . import relay_ir_builder from . import infra + ACCEL_TYPES = ["ethos-u55-256", "ethos-u55-128", "ethos-u55-64", "ethos-u55-32"] @@ -48,122 +47,189 @@ def get_shape_expr(in_expr, out_expr): return shape -@pytest.mark.parametrize( - "accel_type", - ACCEL_TYPES, -) -def test_ethosu_conv2d(accel_type): - def create_graph_single(input_tensor_name, input_tensor_shape, input_tensor_dtype): - c1_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype) - c1_params.ifm.shape = input_tensor_shape - c1_params.kernel.shape = (3, 3, c1_params.ifm.shape[3], 32) - c1_params.kernel.sc = relay.const(np.random.rand(32) * 2, "float32") - c1_params.strides = (1, 1) - c1_params.pad = "VALID" - c1_params.update_output_qnn_params( - input_tensor_dtype, input_tensor_dtype, input_tensor_dtype - ) - input0 = relay.var(input_tensor_name, shape=c1_params.ifm.shape, dtype=c1_params.ifm.dtype) - c1, new_params = relay_ir_builder.create_qnn_conv2d(c1_params, input0) - c1_params.ofm.shape = get_shape_expr(input0, c1) +@pytest.mark.parametrize("ifm_shape", [(1, 299, 299, 3), (1, 55, 55, 3)]) +@pytest.mark.parametrize("kernel_shape", [(3, 2), (1, 3)]) +@pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))]) +@pytest.mark.parametrize("padding", ["SAME", "VALID"]) +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize("activation", None, "RELU") +def test_ethosu_conv2d_single( + ifm_shape, + kernel_shape, + strides, + dilation, + padding, + accel_type, + activation, +): + dtype = "int8" - f = relay.Function([input0], c1) - mod = tvm.IRModule() - mod["main"] = f - return mod, [c1_params] + def create_tflite_graph_single(): + class Model(tf.Module): + @tf.function + def tf_function(self, x): + # Use tf.nn API to create the model + op = tf.nn.conv2d( + x, + filters=tf.constant( + np.random.uniform(size=(kernel_shape[0], kernel_shape[1], 3, 3)), + dtype=tf.float32, + ), + strides=strides, + padding=padding, + data_format="NHWC", + dilations=dilation, + ) + if activation: + op = tf.nn.relu(op) + return op - def create_graph_double(input_tensor_name, input_tensor_shape, input_tensor_dtype): - c1_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype) - c1_params.ifm.shape = input_tensor_shape - c1_params.kernel.shape = (7, 7, c1_params.ifm.shape[3], 8) - c1_params.strides = (2, 2) - c1_params.pad = "VALID" - c1_params.update_output_qnn_params( - input_tensor_dtype, input_tensor_dtype, input_tensor_dtype + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) ) - input0 = relay.var(input_tensor_name, shape=c1_params.ifm.shape, dtype=c1_params.ifm.dtype) - c1, new_params = relay_ir_builder.create_qnn_conv2d(c1_params, input0) - c1_params.ofm.shape = get_shape_expr(input0, c1) - c2_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype) - c2_params.ifm.shape = c1_params.ofm.shape - c2_params.kernel.shape = (5, 5, c2_params.ifm.shape[3], 16) - c2_params.strides = (1, 1) - c2_params.pad = "SAME" - c2_params.update_output_qnn_params() - c2, new_params = relay_ir_builder.create_qnn_conv2d(c2_params, c1) - c2_params.ofm.shape = get_shape_expr(input0, c2) + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] - f = relay.Function([input0], c2) - mod = tvm.IRModule() - mod["main"] = f - return mod, [c2_params, c1_params] + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model - def create_graph_activation(input_tensor_name, input_tensor_shape, input_tensor_dtype): - c1_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype) - c1_params.ifm.shape = input_tensor_shape - c1_params.kernel.shape = (7, 7, c1_params.ifm.shape[3], 8) - c1_params.strides = (2, 2) - c1_params.pad = "VALID" - c1_params.activation = "CLIP" - c1_params.clip_min = 90 - c1_params.clip_max = 110 - c1_params.update_output_qnn_params( - input_tensor_dtype, input_tensor_dtype, input_tensor_dtype - ) - input0 = relay.var(input_tensor_name, shape=c1_params.ifm.shape, dtype=c1_params.ifm.dtype) - c1, new_params = relay_ir_builder.create_qnn_conv2d(c1_params, input0) - c1_params.ofm.shape = get_shape_expr(input0, c1) + tflite_graph_single = create_tflite_graph_single() + tflite_model_single = tflite.Model.Model.GetRootAsModel(tflite_graph_single, 0) + + relay_module_single, params_single = relay.frontend.from_tflite( + tflite_model_single, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + mod = partition_for_ethosu(relay_module_single, params_single) - c2_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype) - c2_params.ifm.shape = c1_params.ofm.shape - c2_params.kernel.shape = (5, 5, c2_params.ifm.shape[3], 16) - c2_params.strides = (1, 1) - c2_params.pad = "SAME" - c2_params.update_output_qnn_params() - c2, new_params = relay_ir_builder.create_qnn_conv2d(c2_params, c1) - c2_params.ofm.shape = get_shape_expr(input0, c2) + # Generate reference data + input_data_single, output_data_single = infra.generate_ref_data_tflite(tflite_graph_single) - f = relay.Function([input0], c2) - mod = tvm.IRModule() - mod["main"] = f - return mod, [c2_params, c1_params] - - test_cases = [ - (create_graph_single, ["input", (1, 300, 300, 3), "int8"]), - (create_graph_double, ["input", (1, 128, 256, 4), "int8"]), - (create_graph_activation, ["input", (1, 64, 100, 4), "int8"]), - ] - np.random.seed(42) - for test_case in test_cases: - relay_module, conv_params = test_case[0](*test_case[1]) - input_tensor, input_shape, input_dtype = test_case[1] - mod = partition_for_ethosu(relay_module) - - # Generate reference data - in_min, in_max = util.get_range_for_dtype_str(input_dtype) - input_data = { - input_tensor: np.random.randint( - in_min, high=in_max, size=input_shape, dtype=input_dtype - ) - } - output_data = generate_ref_data(relay_module, input_data) - - compiled_models = infra.build_source( - mod, input_data, output_data, accel_type, output_tolerance=1 + compiled_models_single = infra.build_source( + mod, input_data_single, output_data_single, accel_type + ) + + # Single offload module + imported_modules_single = compiled_models_single[0].executor_factory.lib.imported_modules + assert len(imported_modules_single) == 2 + ethosu_module_single = imported_modules_single[0] + + # Verify C source generated + get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs") + cmms = bytes.fromhex(get_cs(ethosu_module_single)) + + infra.print_payload(cmms) + infra.verify_source(compiled_models_single, accel_type) + + +@pytest.mark.parametrize("ifm_shape", [(1, 214, 227, 3), (1, 27, 42, 3)]) +@pytest.mark.parametrize("kernel_shape", [(3, 2), (1, 3)]) +@pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))]) +@pytest.mark.parametrize("padding", ["SAME", "VALID"]) +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize("activation", None, "RELU") +def test_ethosu_conv2d_double( + ifm_shape, + kernel_shape, + strides, + dilation, + padding, + accel_type, + activation, +): + dtype = "int8" + + def create_tflite_graph_double(): + class Model(tf.Module): + @tf.function + def tf_function_double(self, x): + # Use tf.nn API to create the model with two convolutions + op = tf.nn.conv2d( + x, + filters=tf.constant( + np.random.uniform(size=(kernel_shape[0], kernel_shape[1], 3, 3)), + dtype=tf.float32, + ), + strides=strides, + padding=padding, + data_format="NHWC", + dilations=dilation, + ) + # Second convolution + op2 = tf.nn.conv2d( + op, + filters=tf.constant( + np.random.uniform(size=(kernel_shape[0], kernel_shape[1], 3, 3)), + dtype=tf.float32, + ), + strides=strides, + padding=padding, + data_format="NHWC", + dilations=dilation, + ) + if activation: + op2 = tf.nn.relu(op2) + return op2 + + model = Model() + concrete_func = model.tf_function_double.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) ) - # Assumes only two runtime.Modules are created -- i.e. single offload module - imported_modules = compiled_models[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] - - # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) - infra.print_payload(cmms) - infra.verify_source(compiled_models, accel_type) + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + tflite_graph_double = create_tflite_graph_double() + tflite_model_double = tflite.Model.Model.GetRootAsModel(tflite_graph_double, 0) + + relay_module_double, params_double = relay.frontend.from_tflite( + tflite_model_double, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + mod_double = partition_for_ethosu(relay_module_double, params_double) + + # Generate reference data + input_data_double, output_data_double = infra.generate_ref_data_tflite(tflite_graph_double) + compiled_models_double = infra.build_source( + mod_double, input_data_double, output_data_double, accel_type + ) + + # Double offload module + imported_modules_double = compiled_models_double[0].executor_factory.lib.imported_modules + assert len(imported_modules_double) == 2 + ethosu_module_double = imported_modules_double[0] + + # Verify C source generated + get_cs_double = tvm._ffi.get_global_func("runtime.module.ethosu.getcs") + cmms_double = bytes.fromhex(get_cs_double(ethosu_module_double)) + + infra.print_payload(cmms_double) + infra.verify_source(compiled_models_double, accel_type) @pytest.mark.parametrize("accel_type", ACCEL_TYPES) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 64bdae5c1b8b..8c353d1f7890 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -16,9 +16,11 @@ # under the License. # pylint: disable=invalid-name, unused-argument +import math import pytest pytest.importorskip("ethosu.vela") + import numpy as np import tensorflow as tf import tflite.Model @@ -30,7 +32,6 @@ from tvm.relay.op.contrib import ethosu from tvm.relay.build_module import bind_params_by_name -from . import relay_ir_builder from . import infra @@ -221,6 +222,17 @@ def get_shape_expr(in_expr, out_expr): return shape +def compute_ofm_shape(ifm_shape, padding, kernel_shape, strides, dilation): + if padding.lower() == "valid": + h = math.ceil((ifm_shape[1] - (kernel_shape[0] - 1) * dilation[0]) / strides[0]) + w = math.ceil((ifm_shape[2] - (kernel_shape[1] - 1) * dilation[1]) / strides[1]) + if padding.lower() == "same": + h = math.ceil(ifm_shape[1] / strides[0]) + w = math.ceil(ifm_shape[2] / strides[1]) + ofm_shape = [ifm_shape[0], h, w, kernel_shape[3]] + return ofm_shape + + INVERSE_LAYOUT_TRANSFORM_OHWI_MAP = { "HWIO": [1, 2, 3, 0], "HWOI": [1, 2, 0, 3], @@ -228,128 +240,119 @@ def get_shape_expr(in_expr, out_expr): } -def test_ethosu_conv2d_legalize(): - def create_graph_single(input_tensor_name, input_tensor_shape, input_tensor_dtype): - c1_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype) - c1_params.ifm.shape = input_tensor_shape - c1_params.kernel.shape = (3, 3, c1_params.ifm.shape[3], 32) - c1_params.strides = (1, 1) - c1_params.pad = "VALID" - c1_params.activation = "CLIP" - c1_params.clip_min = 23 - c1_params.clip_max = 180 - input0 = relay.var(input_tensor_name, shape=c1_params.ifm.shape, dtype=c1_params.ifm.dtype) - c1, new_params = relay_ir_builder.create_qnn_conv2d(c1_params, input0) - c1_params.ofm.shape = get_shape_expr(input0, c1) - - f = relay.Function([input0], c1) - mod = tvm.IRModule() - mod["main"] = f - return mod, [c1_params] - - def create_graph_double(input_tensor_name, input_tensor_shape, input_tensor_dtype): - c1_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype) - c1_params.ifm.shape = input_tensor_shape - c1_params.kernel.shape = (7, 7, c1_params.ifm.shape[3], 8) - c1_params.strides = (2, 2) - c1_params.pad = "VALID" - c1_params.activation = "CLIP" - c1_params.clip_min = 10 - c1_params.clip_max = 240 - input0 = relay.var(input_tensor_name, shape=c1_params.ifm.shape, dtype=c1_params.ifm.dtype) - c1, new_params = relay_ir_builder.create_qnn_conv2d(c1_params, input0) - c1_params.ofm.shape = get_shape_expr(input0, c1) - - c2_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype) - c2_params.ifm.shape = c1_params.ofm.shape - c2_params.kernel.shape = (5, 5, c2_params.ifm.shape[3], 16) - c2_params.strides = (1, 1) - c2_params.pad = "SAME" - c2, new_params = relay_ir_builder.create_qnn_conv2d(c2_params, c1) - c2_params.ofm.shape = get_shape_expr(input0, c2) - - f = relay.Function([input0], c2) - mod = tvm.IRModule() - mod["main"] = f - return mod, [c2_params, c1_params] +@pytest.mark.parametrize("ifm_shape", [(1, 299, 299, 3), (1, 55, 55, 3)]) +@pytest.mark.parametrize("kernel_shape", [(3, 2), (1, 3)]) +@pytest.mark.parametrize("padding", ["SAME", "VALID"]) +@pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))]) +@pytest.mark.parametrize("activation", [None, "RELU"]) +def test_tflite_conv_2d_legalize(ifm_shape, kernel_shape, padding, strides, dilation, activation): + dtype = "int8" - def verify_tensor(tensor_type, expr): - assert list(tensor_type.shape) == list(expr.checked_type.shape) - assert str(tensor_type.dtype) == str(expr.checked_type.dtype) + def create_tflite_graph_single(): + class Model(tf.Module): + @tf.function + def tf_function(self, input_shape): + op = tf.nn.conv2d( + input_shape, + filters=tf.constant( + np.random.uniform(size=(kernel_shape[0], kernel_shape[1], 3, 3)), + dtype=tf.float32, + ), + strides=strides, + padding=padding, + data_format="NHWC", + dilations=dilation, + ) + if activation: + op = tf.nn.relu(op) + return op - def verify_linear(ext_func, conv2d_params): + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + def verify(ext_func): op = ext_func.body - for param in conv2d_params: - verify_tensor(param.ifm, op.args[0]) - verify_tensor(param.ofm, op) - - # This will be in OHWI layout - weights_ohwi = op.args[1].data.asnumpy() - weights_layout = str(param.kernel.layout) - weights = np.transpose(weights_ohwi, INVERSE_LAYOUT_TRANSFORM_OHWI_MAP[weights_layout]) - assert weights.shape == param.kernel.shape - assert weights.dtype == param.kernel.dtype - - assert list(op.args[2].checked_type.shape)[0] == weights_ohwi.shape[0] - - assert float(op.attrs.ifm_scale) == float(param.ifm.sc.data.asnumpy()) - assert int(op.attrs.ifm_zero_point) == int(param.ifm.zp.data.asnumpy()) - assert int(op.attrs.weight_zero_point) == int(param.kernel.zp.data.asnumpy()) - assert float(op.attrs.ofm_scale) == float(param.ofm.sc.data.asnumpy()) - assert int(op.attrs.ofm_zero_point) == int(param.ofm.zp.data.asnumpy()) - assert int(op.attrs.ofm_channels) == int(weights_ohwi.shape[0]) - assert list(op.attrs.padding) == list(param.pad) - assert list(op.attrs.strides) == list(param.strides) - assert list(op.attrs.dilation) == list(param.dilation) - assert str(op.attrs.activation) == str(param.activation) - assert int(op.attrs.clip_min) == int(param.clip_min) - assert int(op.attrs.clip_max) == int(param.clip_max) - op = op.args[0] + ofm_channels = op.attrs.ofm_channels - test_cases = [ - (create_graph_single, ["input", (1, 299, 299, 3), "uint8"]), - (create_graph_double, ["input", (1, 128, 256, 4), "uint8"]), - ] - for test_case in test_cases: - mod, conv_params = test_case[0](*test_case[1]) - mod = ethosu.partition_for_ethosu(mod) - mod = legalize.LegalizeConv2D()(mod) - verify_linear(mod["tvmgen_default_ethos_u_main_0"], conv_params) - - -def test_ethosu_conv2d_legalize_errors(): - def create_graph_single_unsupported_ifm_layout( - input_tensor_name, input_tensor_shape, input_tensor_dtype - ): - c1_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype) - c1_params.ifm.shape = input_tensor_shape - c1_params.ifm.layout = "NCHW" - c1_params.kernel.shape = (3, 3, c1_params.ifm.shape[1], 32) - c1_params.strides = (1, 1) - c1_params.pad = "VALID" - c1_params.activation = "CLIP" - c1_params.clip_min = 23 - c1_params.clip_max = 180 - input0 = relay.var(input_tensor_name, shape=c1_params.ifm.shape, dtype=c1_params.ifm.dtype) - c1, new_params = relay_ir_builder.create_qnn_conv2d(c1_params, input0) - c1_params.ofm.shape = get_shape_expr(input0, c1) - - f = relay.Function([input0], c1) - mod = tvm.IRModule() - mod["main"] = f - return mod, [c1_params] + # check IFM + ifm = op.args[0].checked_type + assert list(ifm.shape) == list(ifm_shape) + assert str(ifm.dtype) == dtype + assert ifm.shape[3] == ofm_channels + + # check OFM + ofm = op.checked_type + expected_ofm_shape = compute_ofm_shape(ifm_shape, padding, kernel_shape, strides, dilation) + assert list(ofm.shape) == list(expected_ofm_shape) + assert str(ofm.dtype) == dtype + assert ofm.shape[3] == ofm_channels + + # check weights + weights_ohwi = op.args[1].data.asnumpy() + assert str(weights_ohwi.dtype) == dtype + assert weights_ohwi.shape[0] == ofm_channels + assert weights_ohwi.shape[1] == kernel_shape[0] + assert weights_ohwi.shape[2] == kernel_shape[1] + assert weights_ohwi.shape[3] == 3 + + # Check that scale_bias matches weight tensor + assert list(op.args[2].checked_type.shape)[0] == ofm_channels - test_cases = [ - (create_graph_single_unsupported_ifm_layout, ["input", (1, 3, 299, 299), "uint8"]), + expected_padding = infra.compute_padding_shape( + ifm_shape, + expected_ofm_shape, + padding, + (kernel_shape[0], kernel_shape[1]), + strides, + dilation, + ) + assert list(op.attrs.padding) == list(expected_padding) + assert list(op.attrs.strides) == list(strides) + assert list(op.attrs.dilation) == list(dilation) + if activation == "RELU": + assert str(op.attrs.activation) == "CLIP" + + conv2d_pattern_table = [ + ( + ethosu.QnnConv2DParams.composite_name, + ethosu.qnn_conv2d_pattern(), + lambda pat: ethosu.QnnConv2DParams(pat).is_valid(), + ) ] - for test_case in test_cases: - mod, conv_params = test_case[0](*test_case[1]) - mod = ethosu.partition_for_ethosu(mod) - with pytest.raises( - tvm._ffi.base.TVMError, match="EthosUCodegenError: Unsupported Layout NCHW" - ): - mod = legalize.LegalizeConv2D()(mod) + tflite_graph = create_tflite_graph_single() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, conv_params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + + mod["main"] = bind_params_by_name(mod["main"], conv_params) + mod = partition_ethosu_by_table(mod, conv2d_pattern_table) + + mod["tvmgen_default_ethosu_main_0"] = dataflow_pattern.rewrite( + legalize.EthosUConv2DRewriter(), mod["tvmgen_default_ethosu_main_0"] + ) + + verify(mod["tvmgen_default_ethosu_main_0"]) @pytest.mark.parametrize("ifm_shape", [(1, 299, 299, 3), (1, 123, 17, 7)]) From 9e69c19e235244f4a9d948475e2ae003961bf46b Mon Sep 17 00:00:00 2001 From: Dhruv Chauhan Date: Wed, 1 Dec 2021 12:59:26 +0000 Subject: [PATCH 2/7] [microNPU] Update Conv2D Tests to Use TF API to Gen Test Cases Update ordering of imports --- tests/python/contrib/test_ethosu/test_legalize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 8c353d1f7890..42fcc3b651d3 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -16,11 +16,11 @@ # under the License. # pylint: disable=invalid-name, unused-argument -import math import pytest pytest.importorskip("ethosu.vela") +import math import numpy as np import tensorflow as tf import tflite.Model From 8d1894e2ca79f8ad4e9a319e87d282f106fa03d8 Mon Sep 17 00:00:00 2001 From: Dhruv Chauhan Date: Wed, 1 Dec 2021 13:53:57 +0000 Subject: [PATCH 3/7] [microNPU] Update Conv2D Tests to Use TF API to Gen Test Cases Remove unused import --- python/tvm/relay/backend/contrib/ethosu/legalize.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 46dbef1ef98e..bb34a4ba0f13 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -30,7 +30,6 @@ from tvm.relay.dataflow_pattern import rewrite from tvm.relay.dataflow_pattern import CallPattern from tvm.relay.backend.contrib.ethosu import op as ethosu_ops # type: ignore -from tvm.relay.backend.contrib.ethosu.errors import UnsupportedLayout # type: ignore from tvm.relay.backend.contrib.ethosu import vela_api from tvm.relay.backend.contrib.ethosu import util from tvm.relay.op.contrib import ethosu as ethosu_patterns # type: ignore From 9bd23d52c4e62b11175b3d06771cf2362668de62 Mon Sep 17 00:00:00 2001 From: Dhruv Chauhan Date: Thu, 2 Dec 2021 14:55:02 +0000 Subject: [PATCH 4/7] [microNPU] Update Conv2D Tests to Use TF API to Gen Test Cases` Address comments, use infra.py function to compute ofm_shape --- .../python/contrib/test_ethosu/test_legalize.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 42fcc3b651d3..0c535a26990d 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -222,17 +222,6 @@ def get_shape_expr(in_expr, out_expr): return shape -def compute_ofm_shape(ifm_shape, padding, kernel_shape, strides, dilation): - if padding.lower() == "valid": - h = math.ceil((ifm_shape[1] - (kernel_shape[0] - 1) * dilation[0]) / strides[0]) - w = math.ceil((ifm_shape[2] - (kernel_shape[1] - 1) * dilation[1]) / strides[1]) - if padding.lower() == "same": - h = math.ceil(ifm_shape[1] / strides[0]) - w = math.ceil(ifm_shape[2] / strides[1]) - ofm_shape = [ifm_shape[0], h, w, kernel_shape[3]] - return ofm_shape - - INVERSE_LAYOUT_TRANSFORM_OHWI_MAP = { "HWIO": [1, 2, 3, 0], "HWOI": [1, 2, 0, 3], @@ -245,7 +234,7 @@ def compute_ofm_shape(ifm_shape, padding, kernel_shape, strides, dilation): @pytest.mark.parametrize("padding", ["SAME", "VALID"]) @pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))]) @pytest.mark.parametrize("activation", [None, "RELU"]) -def test_tflite_conv_2d_legalize(ifm_shape, kernel_shape, padding, strides, dilation, activation): +def test_tflite_conv2d_legalize(ifm_shape, kernel_shape, padding, strides, dilation, activation): dtype = "int8" def create_tflite_graph_single(): @@ -298,7 +287,8 @@ def verify(ext_func): # check OFM ofm = op.checked_type - expected_ofm_shape = compute_ofm_shape(ifm_shape, padding, kernel_shape, strides, dilation) + expected_ofm_shape = infra.compute_ofm_shape( + ifm_shape, padding, kernel_shape, strides, dilation) assert list(ofm.shape) == list(expected_ofm_shape) assert str(ofm.dtype) == dtype assert ofm.shape[3] == ofm_channels From 5e23089ae1480b5ea6e005936fccb24d72f18572 Mon Sep 17 00:00:00 2001 From: Dhruv Chauhan Date: Thu, 2 Dec 2021 14:58:04 +0000 Subject: [PATCH 5/7] [microNPU] Update Conv2D Tests to Use TF API to Gen Test Cases Linting --- tests/python/contrib/test_ethosu/test_legalize.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 0c535a26990d..292676f2b996 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -288,7 +288,8 @@ def verify(ext_func): # check OFM ofm = op.checked_type expected_ofm_shape = infra.compute_ofm_shape( - ifm_shape, padding, kernel_shape, strides, dilation) + ifm_shape, padding, kernel_shape, strides, dilation + ) assert list(ofm.shape) == list(expected_ofm_shape) assert str(ofm.dtype) == dtype assert ofm.shape[3] == ofm_channels From 67062482c4b8dc44b0ec85941049c8b969f65212 Mon Sep 17 00:00:00 2001 From: Dhruv Chauhan Date: Wed, 8 Dec 2021 14:29:50 +0000 Subject: [PATCH 6/7] [microNPU] Update Conv2D Tests to Use TF API to Gen Test Cases Missing square brackets in parametrization, missing underscores. --- .../contrib/test_ethosu/test_codegen.py | 85 ++++++++++--------- .../contrib/test_ethosu/test_legalize.py | 6 +- 2 files changed, 47 insertions(+), 44 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 420320def5f3..2e3e53631c7c 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -52,7 +52,7 @@ def get_shape_expr(in_expr, out_expr): @pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))]) @pytest.mark.parametrize("padding", ["SAME", "VALID"]) @pytest.mark.parametrize("accel_type", ACCEL_TYPES) -@pytest.mark.parametrize("activation", None, "RELU") +@pytest.mark.parametrize("activation", ["NONE", "RELU"]) def test_ethosu_conv2d_single( ifm_shape, kernel_shape, @@ -69,15 +69,15 @@ class Model(tf.Module): @tf.function def tf_function(self, x): # Use tf.nn API to create the model + tf_strides = [1, strides[0], strides[1], 1] op = tf.nn.conv2d( x, filters=tf.constant( - np.random.uniform(size=(kernel_shape[0], kernel_shape[1], 3, 3)), + np.random.uniform(size=[kernel_shape[0], kernel_shape[1], 3, 3]), dtype=tf.float32, ), - strides=strides, + strides=tf_strides, padding=padding, - data_format="NHWC", dilations=dilation, ) if activation: @@ -104,34 +104,35 @@ def representative_dataset(): tflite_model = converter.convert() return tflite_model - tflite_graph_single = create_tflite_graph_single() - tflite_model_single = tflite.Model.Model.GetRootAsModel(tflite_graph_single, 0) + tflite_graph = create_tflite_graph_single() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) - relay_module_single, params_single = relay.frontend.from_tflite( - tflite_model_single, + relay_module, params = relay.frontend.from_tflite( + tflite_model, shape_dict={"input": ifm_shape}, dtype_dict={"input": dtype}, ) - mod = partition_for_ethosu(relay_module_single, params_single) + mod = partition_for_ethosu(relay_module, params) # Generate reference data - input_data_single, output_data_single = infra.generate_ref_data_tflite(tflite_graph_single) + input_data, output_data = infra.generate_ref_data_tflite(tflite_graph) - compiled_models_single = infra.build_source( - mod, input_data_single, output_data_single, accel_type + compiled_models = infra.build_source( + mod, + input_data, + output_data, + accel_type, ) - # Single offload module - imported_modules_single = compiled_models_single[0].executor_factory.lib.imported_modules - assert len(imported_modules_single) == 2 - ethosu_module_single = imported_modules_single[0] - - # Verify C source generated - get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs") - cmms = bytes.fromhex(get_cs(ethosu_module_single)) + # Assumes only two runtime.Modules are created -- i.e. single offload module + ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] + # Verify generated C source + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) - infra.verify_source(compiled_models_single, accel_type) + infra.verify_source(compiled_models, accel_type) @pytest.mark.parametrize("ifm_shape", [(1, 214, 227, 3), (1, 27, 42, 3)]) @@ -139,7 +140,7 @@ def representative_dataset(): @pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))]) @pytest.mark.parametrize("padding", ["SAME", "VALID"]) @pytest.mark.parametrize("accel_type", ACCEL_TYPES) -@pytest.mark.parametrize("activation", None, "RELU") +@pytest.mark.parametrize("activation", ["NONE", "RELU"]) def test_ethosu_conv2d_double( ifm_shape, kernel_shape, @@ -159,7 +160,7 @@ def tf_function_double(self, x): op = tf.nn.conv2d( x, filters=tf.constant( - np.random.uniform(size=(kernel_shape[0], kernel_shape[1], 3, 3)), + np.random.uniform(size=[kernel_shape[0], kernel_shape[1], 3, 3]), dtype=tf.float32, ), strides=strides, @@ -203,33 +204,35 @@ def representative_dataset(): tflite_model = converter.convert() return tflite_model - tflite_graph_double = create_tflite_graph_double() - tflite_model_double = tflite.Model.Model.GetRootAsModel(tflite_graph_double, 0) + tflite_graph = create_tflite_graph_double() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) - relay_module_double, params_double = relay.frontend.from_tflite( - tflite_model_double, + relay_module, params = relay.frontend.from_tflite( + tflite_model, shape_dict={"input": ifm_shape}, dtype_dict={"input": dtype}, ) - mod_double = partition_for_ethosu(relay_module_double, params_double) + mod = partition_for_ethosu(relay_module, params) # Generate reference data - input_data_double, output_data_double = infra.generate_ref_data_tflite(tflite_graph_double) - compiled_models_double = infra.build_source( - mod_double, input_data_double, output_data_double, accel_type - ) + input_data, output_data = infra.generate_ref_data_tflite(tflite_graph) - # Double offload module - imported_modules_double = compiled_models_double[0].executor_factory.lib.imported_modules - assert len(imported_modules_double) == 2 - ethosu_module_double = imported_modules_double[0] + compiled_models = infra.build_source( + mod, + input_data, + output_data, + accel_type, + ) - # Verify C source generated - get_cs_double = tvm._ffi.get_global_func("runtime.module.ethosu.getcs") - cmms_double = bytes.fromhex(get_cs_double(ethosu_module_double)) + # Assumes only two runtime.Modules are created -- i.e. single offload module + ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] - infra.print_payload(cmms_double) - infra.verify_source(compiled_models_double, accel_type) + # Verify generated C source + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) + infra.print_payload(cmms) + infra.verify_source(compiled_models, accel_type) @pytest.mark.parametrize("accel_type", ACCEL_TYPES) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index d9bc1b56dfda..e8a38718aaf2 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -340,11 +340,11 @@ def verify(ext_func): mod["main"] = bind_params_by_name(mod["main"], conv_params) mod = partition_ethosu_by_table(mod, conv2d_pattern_table) - mod["tvmgen_default_ethosu_main_0"] = dataflow_pattern.rewrite( - legalize.EthosUConv2DRewriter(), mod["tvmgen_default_ethosu_main_0"] + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.Conv2DRewriter(), mod["tvmgen_default_ethos_u_main_0"] ) - verify(mod["tvmgen_default_ethosu_main_0"]) + verify(mod["tvmgen_default_ethos_u_main_0"]) @pytest.mark.parametrize("ifm_shape", [(1, 299, 299, 3), (1, 123, 17, 7)]) From a950e356be7e6eda20a53286e47690986d1c4337 Mon Sep 17 00:00:00 2001 From: Dhruv Chauhan Date: Wed, 8 Dec 2021 19:14:42 +0000 Subject: [PATCH 7/7] [microNPU] Update Conv2D Tests to Use TF API to Gen Test Cases Fix imports --- tests/python/contrib/test_ethosu/test_codegen.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 4e05856be91e..0707ec27ca27 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -32,6 +32,7 @@ from tvm.relay.backend.contrib.ethosu import preprocess from tvm.relay.op.contrib.ethosu import partition_for_ethosu +from tests.python.relay.aot.aot_test_utils import generate_ref_data from . import infra