From a6cfffaab5856fb4f12a20fd6d1d4d48013c4867 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Thu, 26 May 2022 11:32:06 +0000 Subject: [PATCH] [microNPU] Optimize separate padding operation for conv2d Optimizes a case where padding appears as a separate nn.pad operation followed by a qnn.conv2d. If possible, the nn.pad will be partitioned and offloaded together with the qnn.conv2d operation, as opposed to separately. As a fallback, both operations will be considered separately. Change-Id: I9125195386abdcc1d17ec612dc0a4cd6474d637a --- python/tvm/relay/op/contrib/ethosu.py | 66 +++++- tests/python/contrib/test_ethosu/infra.py | 11 +- .../contrib/test_ethosu/test_codegen.py | 68 +++++- .../contrib/test_ethosu/test_legalize.py | 216 ++++++++++++++++++ 4 files changed, 349 insertions(+), 12 deletions(-) diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index dfdc0c82fb1e..806bf6dce2e8 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -201,6 +201,8 @@ def __init__(self, func_body: tvm.relay.Function): from tvm.relay.backend.contrib.ethosu.util import RequantArgs activation = None + separate_padding = None + if str(func_body.op) in self.activation_map.keys(): activation = func_body requantize_op = activation.args[0] @@ -208,8 +210,11 @@ def __init__(self, func_body: tvm.relay.Function): requantize_op = func_body bias_add = requantize_op.args[0] qnn_conv2d = bias_add.args[0] + if isinstance(qnn_conv2d.args[0], relay.Call) and str(qnn_conv2d.args[0].op) == "nn.pad": + separate_padding = qnn_conv2d.args[0] data_layout = qnn_conv2d.attrs.data_layout self.kernel_layout = qnn_conv2d.attrs.kernel_layout + # We consider the weights & biases as params as it should be a Constant self.weights = TensorParams( qnn_conv2d.args[QConv2DArgs.WEIGHTS.value], @@ -224,8 +229,11 @@ def __init__(self, func_body: tvm.relay.Function): requantize_op.args[RequantArgs.IFM_SCALE.value], requantize_op.args[RequantArgs.IFM_ZERO_POINT.value], ) + ifm_tensor = ( + separate_padding.args[0] if separate_padding else qnn_conv2d.args[QConv2DArgs.IFM.value] + ) self.ifm = TensorParams( - qnn_conv2d.args[QConv2DArgs.IFM.value], + ifm_tensor, data_layout, qnn_conv2d.args[QConv2DArgs.IFM_SCALE.value], qnn_conv2d.args[QConv2DArgs.IFM_ZERO_POINT.value], @@ -237,7 +245,10 @@ def __init__(self, func_body: tvm.relay.Function): requantize_op.args[RequantArgs.OFM_ZERO_POINT.value], ) attrs = qnn_conv2d.attrs - self.padding = attrs.padding + + pad_value = int(qnn_conv2d.args[QConv2DArgs.IFM_ZERO_POINT.value].data.asnumpy()) + self.padding = self.extract_padding(attrs.padding, separate_padding, pad_value) + self.strides = attrs.strides self.dilation = attrs.dilation self.activation = activation @@ -250,6 +261,37 @@ def __init__(self, func_body: tvm.relay.Function): if self.groups == self.weights.shape[channels_axis[self.kernel_layout]]: self.is_depthwise = True + @staticmethod + def extract_padding( + operator_padding: Tuple[int, int, int, int], + separate_padding: relay.Call, + pad_value: int, + ) -> Optional[Tuple[int, int, int, int]]: + """ + Convolution operations can sometimes have padding represented as a separate + padding operation before the convolution operation itself. Here we can check + whether these representations can be combined into a single padding attribute + as part of the NPU convolution itself. If the padding specified by the separate + nn.pad operation is not supported, None will be returned. This will cause the + nn.pad to be offloaded separately. + """ + if separate_padding is None: + return operator_padding + if pad_value != int(separate_padding.args[1].data.asnumpy()): + return None + pad_width = separate_padding.attrs["pad_width"] + if len(pad_width) != 4: + return None + if list(pad_width[0]) != [0, 0] or list(pad_width[3]) != [0, 0]: + return None + top, left, bottom, right = operator_padding + return [ + top + pad_width[1][0], + left + pad_width[2][0], + bottom + pad_width[1][1], + right + pad_width[2][1], + ] + def is_valid(self) -> bool: """ This function checks whether QnnConv2D has compatible attributes with the NPU @@ -267,7 +309,7 @@ def is_valid(self) -> bool: return False if not check_dilation(self.dilation): return False - if not check_padding(self.padding, self.padding_bounds): + if not self.padding or not check_padding(self.padding, self.padding_bounds): return False legal_groups = [1, self.ofm.shape[3]] if self.groups not in legal_groups: @@ -437,7 +479,7 @@ def is_valid(self): return False if not check_dilation(self.dilation): return False - if not check_padding(self.padding, self.padding_bounds): + if not self.padding or not check_padding(self.padding, self.padding_bounds): return False if self.weights.layout != "HWOI": return False @@ -453,8 +495,14 @@ def qnn_conv2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern: """ This function creates the pattern for qnn.conv2D with optional fused RELU activation. """ + optional_pad = is_op("nn.pad")(wildcard(), is_constant()) qnn_conv2d = is_op("qnn.conv2d")( - wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant() + optional_pad | wildcard(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), ).has_attr({"kernel_layout": "HWIO"}) bias_add = is_op("nn.bias_add")(qnn_conv2d, is_constant()) req = is_op("qnn.requantize")( @@ -468,8 +516,14 @@ def qnn_depthwise_conv2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern: """ This function creates the pattern for depthwise qnn.conv2D with optional fused RELU activation. """ + optional_pad = is_op("nn.pad")(wildcard(), is_constant()) qnn_conv2d = is_op("qnn.conv2d")( - wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant() + optional_pad | wildcard(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), ).has_attr({"kernel_layout": "HWOI"}) bias_add = is_op("nn.bias_add")(qnn_conv2d, is_constant()) req = is_op("qnn.requantize")( diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index 20bd12945f8f..580c14af07c8 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -419,10 +419,17 @@ def compute_ofm_shape(ifm_shape, padding, kernel_shape, strides, dilation=[1, 1] assert len(strides) == 2 assert len(dilation) == 2 assert len(kernel_shape) == 2 - if padding.lower() == "valid": + if isinstance(padding, tuple): + h = ( + ifm_shape[1] - (kernel_shape[0] - 1) * dilation[0] + padding[0] + padding[2] + ) // strides[0] + w = ( + ifm_shape[2] - (kernel_shape[1] - 1) * dilation[1] + padding[1] + padding[3] + ) // strides[1] + elif padding.lower() == "valid": h = math.ceil((ifm_shape[1] - (kernel_shape[0] - 1) * dilation[0]) / strides[0]) w = math.ceil((ifm_shape[2] - (kernel_shape[1] - 1) * dilation[1]) / strides[1]) - if padding.lower() == "same": + elif padding.lower() == "same": h = math.ceil(ifm_shape[1] / strides[0]) w = math.ceil(ifm_shape[2] / strides[1]) ofm_shape = [ifm_shape[0], h, w, ifm_shape[3]] diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 1e8d307b33ea..9256b32e5566 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -68,13 +68,43 @@ def conv2d(x): padding=padding, dilations=dilation, ) - if activation: + if activation == "RELU": op = tf.nn.relu(op) return op infra.compare_tvm_with_tflite(conv2d, [ifm_shape], accel_type) +def test_tflite_conv2d_with_separate_pad(): + np.random.seed(0) + + ifm_shape = (1, 55, 34, 3) + kernel_shape = (3, 2) + strides = (1, 1) + dilation = (2, 1) + padding = (0, 0, 1, 1) + + @tf.function + def conv2d(x): + tf_strides = [1, strides[0], strides[1], 1] + op = tf.pad( + x, + [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]], + "CONSTANT", + ) + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + return tf.nn.conv2d( + op, + weight, + strides=tf_strides, + padding="VALID", + dilations=dilation, + ) + + infra.compare_tvm_with_tflite(conv2d, [ifm_shape], "ethos-u55-256") + + @pytest.mark.parametrize("ifm_shape", [(1, 214, 227, 2), (1, 27, 42, 3)]) @pytest.mark.parametrize("kernel_shape", [(3, 2), (1, 3)]) @pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))]) @@ -116,7 +146,7 @@ def conv2d_double(x): padding=padding, dilations=dilation, ) - if activation: + if activation == "RELU": op2 = tf.nn.relu(op2) return op2 @@ -152,7 +182,7 @@ def conv_invalid_scale(x): padding=padding, dilations=dilation, ) - if activation: + if activation == "RELU": op = tf.nn.relu(op) return op @@ -187,13 +217,43 @@ def depthwise_conv2d(x): op = tf.nn.depthwise_conv2d( x, weight, strides=tf_strides, padding=padding, dilations=dilation ) - if activation_function: + if activation_function == "RELU": op = tf.nn.relu(op) return op infra.compare_tvm_with_tflite(depthwise_conv2d, [ifm_shape], accel_type) +def test_tflite_depthwise_conv2d_with_separate_pad(): + np.random.seed(0) + + ifm_shape = (1, 23, 32, 7) + kernel_shape = (1, 2) + strides = (3, 2) + dilation = (1, 1) + padding = (0, 0, 1, 1) + + @tf.function + def depthwise_conv2d(x): + tf_strides = [1, strides[0], strides[1], 1] + op = tf.pad( + x, + [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]], + "CONSTANT", + ) + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + return tf.nn.depthwise_conv2d( + op, + weight, + strides=tf_strides, + padding="VALID", + dilations=dilation, + ) + + infra.compare_tvm_with_tflite(depthwise_conv2d, [ifm_shape], "ethos-u55-256") + + @pytest.mark.parametrize( "accel_type", ACCEL_TYPES, diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 2dd5eff91373..3f8b5f7d5b58 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -347,6 +347,114 @@ def verify(ext_func): verify(mod["tvmgen_default_ethos_u_main_0"]) +def test_tflite_conv2d_with_separate_padding_legalize(): + dtype = "int8" + ifm_shape = (1, 55, 34, 3) + kernel_shape = (3, 2) + strides = (1, 1) + dilation = (2, 1) + padding = (0, 0, 1, 1) + + def create_tflite_graph_single(): + class Model(tf.Module): + @tf.function + def tf_function(self, x): + tf_strides = [1, strides[0], strides[1], 1] + op = tf.pad( + x, + [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]], + "CONSTANT", + ) + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + return tf.nn.conv2d( + op, + weight, + strides=tf_strides, + padding="VALID", + dilations=dilation, + ) + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + def verify(ext_func): + op = ext_func.body + ofm_channels = op.attrs.ofm_channels + + # check IFM + ifm = op.args[0].checked_type + assert list(ifm.shape) == list(ifm_shape) + assert str(ifm.dtype) == dtype + assert ifm.shape[3] == ofm_channels + + # check OFM + ofm = op.checked_type + expected_ofm_shape = infra.compute_ofm_shape( + ifm_shape, padding, kernel_shape, strides, dilation + ) + assert list(ofm.shape) == list(expected_ofm_shape) + assert str(ofm.dtype) == dtype + assert ofm.shape[3] == ofm_channels + + # check weights + weights_ohwi = op.args[1].data.asnumpy() + assert str(weights_ohwi.dtype) == dtype + assert weights_ohwi.shape[0] == ofm_channels + assert weights_ohwi.shape[1] == kernel_shape[0] + assert weights_ohwi.shape[2] == kernel_shape[1] + assert weights_ohwi.shape[3] == 3 + + # Check that scale_bias matches weight tensor + assert list(op.args[2].checked_type.shape)[0] == ofm_channels + + assert list(op.attrs.padding) == list(padding) + assert list(op.attrs.strides) == list(strides) + assert list(op.attrs.dilation) == list(dilation) + + conv2d_pattern_table = [ + ( + ethosu.QnnConv2DParams.composite_name, + ethosu.qnn_conv2d_pattern(), + lambda pat: ethosu.QnnConv2DParams(pat).is_valid(), + ) + ] + + tflite_graph = create_tflite_graph_single() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, conv_params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + + mod["main"] = bind_params_by_name(mod["main"], conv_params) + mod = partition_ethosu_by_table(mod, conv2d_pattern_table) + + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.Conv2DRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + + verify(mod["tvmgen_default_ethos_u_main_0"]) + + @pytest.mark.parametrize("ifm_shape", [(1, 299, 299, 3), (1, 123, 17, 7)]) @pytest.mark.parametrize("kernel_shape", [(7, 3), (22, 5)]) @pytest.mark.parametrize("padding", ["SAME", "VALID"]) @@ -458,6 +566,114 @@ def verify(ext_func): verify(mod["tvmgen_default_ethos_u_main_0"]) +def test_tflite_depthwise_conv2d_with_separate_padding_legalize(): + dtype = "int8" + ifm_shape = (1, 23, 32, 7) + kernel_shape = (1, 2) + strides = (3, 2) + dilation = (1, 1) + padding = (0, 0, 1, 1) + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x): + tf_strides = [1, strides[0], strides[1], 1] + op = tf.pad( + x, + [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]], + "CONSTANT", + ) + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + return tf.nn.depthwise_conv2d( + op, + weight, + strides=tf_strides, + padding="VALID", + dilations=dilation, + ) + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + def verify(ext_func): + op = ext_func.body + ofm_channels = op.attrs.ofm_channels + + # check IFM + ifm = op.args[0].checked_type + assert list(ifm.shape) == list(ifm_shape) + assert str(ifm.dtype) == dtype + assert ifm.shape[3] == ofm_channels + + # check OFM + ofm = op.checked_type + expected_ofm_shape = infra.compute_ofm_shape( + ifm_shape, padding, kernel_shape, strides, dilation + ) + assert list(ofm.shape) == list(expected_ofm_shape) + assert str(ofm.dtype) == dtype + assert ofm.shape[3] == ofm_channels + + # check weights + weights_ohwi = op.args[1].data.asnumpy() + assert str(weights_ohwi.dtype) == dtype + assert weights_ohwi.shape[0] == ofm_channels + assert weights_ohwi.shape[1] == kernel_shape[0] + assert weights_ohwi.shape[2] == kernel_shape[1] + assert weights_ohwi.shape[3] == 1 # only depth multiplier 1 is supported + + # Check that scale_bias matches weight tensor + assert list(op.args[2].checked_type.shape)[0] == ofm_channels + + assert list(op.attrs.padding) == list(padding) + assert op.attrs.ofm_channels == ofm_channels + assert list(op.attrs.strides) == list(strides) + assert list(op.attrs.dilation) == list(dilation) + + depthwise_pattern_table = [ + ( + ethosu.QnnDepthwiseConv2DParams.composite_name, + ethosu.qnn_depthwise_conv2d_pattern(), + lambda pat: ethosu.QnnDepthwiseConv2DParams(pat).is_valid(), + ) + ] + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + + mod["main"] = bind_params_by_name(mod["main"], params) + mod = partition_ethosu_by_table(mod, depthwise_pattern_table) + + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.DepthwiseConv2DRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + verify(mod["tvmgen_default_ethos_u_main_0"]) + + @pytest.mark.parametrize("pooling_type", ["MAX", "AVG"]) @pytest.mark.parametrize("ifm_shape", [[1, 3, 4, 3], [1, 4, 5, 2]]) @pytest.mark.parametrize(