diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 3e69b409a3a9..b5fa9b4e961b 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -1447,6 +1447,82 @@ def callback( ) +class ChannelPadRewriter(DFPatternCallback): + """Convert ethos-u.channel-pad composite function to the Relay concatenate operation""" + + def __init__(self): + super().__init__(require_type=True) + self.pattern = ( + wildcard().has_attr({"Composite": ethosu_patterns.ChannelPadParams.composite_name}) + )(wildcard()) + + def callback( + self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map + ) -> tvm.relay.Expr: + params = ethosu_patterns.ChannelPadParams(post.op.body) + params.ifm.tensor = post.args[0] + + concat_args = list() + lut = relay.const([], dtype="int8") + # pad channels before + if params.ch_padding[0] > 0: + shape1 = list(params.ifm.shape) + shape1[3] = params.ch_padding[0].value + pad_channels = relay.Constant( + tvm.nd.array( + np.full( + shape=shape1, + fill_value=int(params.ifm.q_params.zero_point), + dtype=params.ifm.dtype, + ) + ) + ) + identity1 = ethosu_ops.ethosu_identity( + ifm=pad_channels, + lut=lut, + ifm_scale=float(params.ifm.q_params.scale_f32), + ifm_zero_point=int(params.ifm.q_params.zero_point), + ofm_scale=float(params.ofm.q_params.scale_f32), + ofm_zero_point=int(params.ofm.q_params.zero_point), + ) + concat_args.append(identity1) + + identity2 = ethosu_ops.ethosu_identity( + ifm=params.ifm.tensor, + lut=lut, + ifm_scale=float(params.ifm.q_params.scale_f32), + ifm_zero_point=int(params.ifm.q_params.zero_point), + ofm_scale=float(params.ofm.q_params.scale_f32), + ofm_zero_point=int(params.ofm.q_params.zero_point), + ) + concat_args.append(identity2) + + # pad channels after + if params.ch_padding[1] > 0: + shape3 = list(params.ifm.shape) + shape3[3] = params.ch_padding[1].value + pad_channels3 = relay.Constant( + tvm.nd.array( + np.full( + shape=shape3, + fill_value=int(params.ifm.q_params.zero_point), + dtype=params.ifm.dtype, + ) + ) + ) + identity3 = ethosu_ops.ethosu_identity( + ifm=pad_channels3, + lut=lut, + ifm_scale=float(params.ifm.q_params.scale_f32), + ifm_zero_point=int(params.ifm.q_params.zero_point), + ofm_scale=float(params.ofm.q_params.scale_f32), + ofm_zero_point=int(params.ofm.q_params.zero_point), + ) + concat_args.append(identity3) + + return relay.op.concatenate(relay.Tuple(concat_args), axis=3) + + @util.create_npu_function_pass(opt_level=1) class LegalizeEthosU: """This is the pass to call graph-rewrites to perform graph transformation @@ -1461,6 +1537,7 @@ def transform_npu_function(self, _, func: relay.Function) -> relay.Function: rewriters = [ PartitionedSplitRewriter(), SplitRewriter(), + ChannelPadRewriter(), Conv2DRewriter(), Conv2DTransposeRewriter(), DepthwiseConv2DRewriter(), diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index 70ec1c12eb3d..e9c38d3d7775 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -143,7 +143,7 @@ class QDenseArgs(Enum): WEIGHTS_SCALE = 5 -class QPad2DArgs(Enum): +class QPadArgs(Enum): """ This is a helper enum to obtain the correct index of nn.pad arguments. diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 744c15987bbe..fb7e8398c393 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -1940,15 +1940,15 @@ class PadParams: padding_bounds = [31, 31, 32, 32] def __init__(self, func_body: Call): - from tvm.relay.backend.contrib.ethosu.util import QPad2DArgs + from tvm.relay.backend.contrib.ethosu.util import QPadArgs # there is no 'layout' attribute in nn.pad layout = "NHWC" self.ifm = TensorParams( - tensor=func_body.args[QPad2DArgs.IFM.value], + tensor=func_body.args[QPadArgs.IFM.value], layout=layout, scale=tvm.relay.Constant(tvm.nd.array(np.array(1.0, dtype="float32"))), - zero_point=func_body.args[QPad2DArgs.IFM_ZERO_POINT.value], + zero_point=func_body.args[QPadArgs.IFM_ZERO_POINT.value], ) self.padding = self.extract_padding(func_body) @@ -1956,7 +1956,7 @@ def __init__(self, func_body: Call): tensor=func_body, layout=layout, scale=tvm.relay.Constant(tvm.nd.array(np.array(1.0, dtype="float32"))), - zero_point=func_body.args[QPad2DArgs.IFM_ZERO_POINT.value], + zero_point=func_body.args[QPadArgs.IFM_ZERO_POINT.value], ) @staticmethod @@ -1964,8 +1964,8 @@ def extract_padding( padding: relay.Call, ) -> Optional[Tuple[int, int, int, int]]: """ - Here we check whether a separate padding operation can be rewritten - as NPU depthwise convolution. If the padding specified by the + Here we check whether a separate spatial-dimension padding operation can be + rewritten as NPU depthwise convolution. If the padding specified by the separate nn.pad operation is not supported by NPU depthwise convolution, None will be returned. This will cause the nn.pad not to be offloaded to NPU. """ @@ -2000,6 +2000,79 @@ def is_valid(self): return True +class ChannelPadParams: + """ + This class will parse a call to a ethos-u.channel-pad composite function + and extract the parameter information. + """ + + composite_name = "ethos-u.channel-pad" + # The ethos-u.channel-pad composite function will be transformed + # to the Relay concatenate operation. + + def __init__(self, func_body: Call): + from tvm.relay.backend.contrib.ethosu.util import QPadArgs + + # there is no 'layout' attribute in nn.pad + layout = "NHWC" + self.ifm = TensorParams( + tensor=func_body.args[QPadArgs.IFM.value], + layout=layout, + scale=tvm.relay.Constant(tvm.nd.array(np.array(1.0, dtype="float32"))), + zero_point=func_body.args[QPadArgs.IFM_ZERO_POINT.value], + ) + + self.ch_padding = self.extract_ch_padding(func_body) + self.ofm = TensorParams( + tensor=func_body, + layout=layout, + scale=tvm.relay.Constant(tvm.nd.array(np.array(1.0, dtype="float32"))), + zero_point=func_body.args[QPadArgs.IFM_ZERO_POINT.value], + ) + + @staticmethod + def extract_ch_padding( + padding: relay.Call, + ) -> Optional[Tuple[int, int]]: + """ + Here we check whether a separate channel-dimension padding operation can be + rewritten as Relay concatenate operation. If the padding specified by the + separate nn.pad operation is not supported by NPU, None will be returned. + This will cause the nn.pad not to be offloaded to NPU. + """ + pad_width = padding.attrs["pad_width"] + if len(pad_width) != 4: + return None + if ( + list(pad_width[0]) != [0, 0] + or list(pad_width[1]) != [0, 0] + or list(pad_width[2]) != [0, 0] + ): + return None + return [ + pad_width[3][0], + pad_width[3][1], + ] + + def is_valid(self): + """ + This function checks whether pad has compatible attributes + with the Relay concatenate operation + """ + tensor_params = [self.ifm, self.ofm] + if not check_valid_dtypes(tensor_params, supported_dtypes=[np.uint8, np.int8]): + return False + if self.ifm.dtype != self.ofm.dtype: + return False + if not check_batch_size(self.ifm): + return False + if not self.ch_padding: + return False + if not check_dimensions(self.ifm) or not check_dimensions(self.ofm): + return False + return True + + def pad_pattern(): """Create pattern for pad""" pattern = is_op("nn.pad")(wildcard(), is_constant()) @@ -2066,6 +2139,11 @@ def softmax_pattern() -> tvm.relay.dataflow_pattern.DFPattern: @register_pattern_table("ethos-u") def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]: return [ + ( + ChannelPadParams.composite_name, + pad_pattern(), + lambda pat: ChannelPadParams(pat).is_valid(), + ), ( QnnConv2DParams.composite_name, qnn_conv2d_pattern(), diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index c621155827a9..71e7e029c148 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -475,7 +475,9 @@ def get_convolutional_args(call, include_buffers=False, remove_constants=False): return conv_args -def compute_ofm_shape(ifm_shape, padding, kernel_shape, strides, dilation=[1, 1]): +def compute_ofm_shape( + ifm_shape, padding, kernel_shape, strides, dilation=[1, 1], channel_padding=[0, 0] +): assert len(strides) == 2 assert len(dilation) == 2 assert len(kernel_shape) == 2 @@ -492,7 +494,7 @@ def compute_ofm_shape(ifm_shape, padding, kernel_shape, strides, dilation=[1, 1] elif padding.lower() == "same": h = math.ceil(ifm_shape[1] / strides[0]) w = math.ceil(ifm_shape[2] / strides[1]) - ofm_shape = [ifm_shape[0], h, w, ifm_shape[3]] + ofm_shape = [ifm_shape[0], h, w, ifm_shape[3] + channel_padding[0] + channel_padding[1]] return ofm_shape diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 86f64d648320..c496ae21b3a0 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -281,6 +281,29 @@ def pad2d(x): infra.compare_tvm_with_tflite(pad2d, [ifm_shape], "ethos-u55-256") +@pytest.mark.parametrize("ifm_shape", [(1, 55, 55, 3), (1, 23, 32, 7)]) +@pytest.mark.parametrize("channel_padding", [(0, 1), (1, 1), (5, 2)]) +@pytest.mark.parametrize("const_value", [0, 5, 125, -5]) +def test_tflite_separate_channel_pad( + ifm_shape, + channel_padding, + const_value, +): + np.random.seed(0) + + @tf.function + def concat_func(x): + x = tf.pad( + x, + [[0, 0], [0, 0], [0, 0], [channel_padding[0], channel_padding[1]]], + "CONSTANT", + const_value, + ) + return x + + infra.compare_tvm_with_tflite(concat_func, [ifm_shape], "ethos-u55-256", enable_cascader=False) + + @pytest.mark.parametrize( "accel_type", ACCEL_TYPES, diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index d1d0befcee70..f87b2da98312 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -34,6 +34,7 @@ from tvm.relay.backend.contrib.ethosu import util from tvm.relay.build_module import bind_params_by_name from tvm.relay.frontend.tflite import get_pad_value +from tvm.relay.expr_functor import ExprVisitor from . import infra @@ -462,6 +463,118 @@ def verify(ext_func): verify(mod["tvmgen_default_ethos_u_main_0"]) +def test_tflite_conv2d_with_separate_channel_padding_legalize(): + dtype = "int8" + ifm_shape = (1, 55, 34, 3) + kernel_shape = (3, 2) + strides = (1, 1) + dilation = (2, 1) + padding_ch = (1, 1) + + class ArePadOnGraph(ExprVisitor): + """ + Visits the Graph recursively and checks if it contains 'nn.pad' op + """ + + def __init__(self): + ExprVisitor.__init__(self) + self.on_graph = False + + def visit_call(self, call): + if isinstance(call.op, tvm.ir.Op): + if str(call.op.name) == "nn.pad": + self.on_graph = True + + return super().visit_call(call) + + def are_pad_on_graph(self, subgraph) -> bool: + """ + This function recursively visits the graph and checks if 'nn.pad' op is on graph + """ + self.visit(subgraph) + return self.on_graph + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x): + tf_strides = [1, strides[0], strides[1], 1] + op = tf.pad( + x, + [[0, 0], [0, 0], [0, 0], [padding_ch[0], padding_ch[1]]], + "CONSTANT", + ) + # HWIO + weight_shape = [ + kernel_shape[0], + kernel_shape[1], + ifm_shape[3] + padding_ch[0] + padding_ch[1], + 3, + ] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + return tf.nn.conv2d( + op, + weight, + strides=tf_strides, + padding="VALID", + dilations=dilation, + ) + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + def verify(ext_func): + + assert ArePadOnGraph().are_pad_on_graph(ext_func.body) == True + + conv2d_pattern_table = [ + ( + ethosu.ChannelPadParams.composite_name, + ethosu.pad_pattern(), + lambda pat: ethosu.ChannelPadParams(pat).is_valid(), + ), + ( + ethosu.QnnConv2DParams.composite_name, + ethosu.qnn_conv2d_pattern(), + lambda pat: ethosu.QnnConv2DParams(pat).is_valid(), + ), + ] + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, conv_params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + + mod["main"] = bind_params_by_name(mod["main"], conv_params) + mod = partition_ethosu_by_table(mod, conv2d_pattern_table) + + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.Conv2DRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + + verify(mod["tvmgen_default_ethos_u_main_0"]) + + @pytest.mark.parametrize("ifm_shape", [(1, 299, 299, 3), (1, 123, 17, 7)]) @pytest.mark.parametrize("kernel_shape", [(7, 3), (22, 5)]) @pytest.mark.parametrize("padding", ["SAME", "VALID"]) @@ -760,7 +873,7 @@ def verify(ext_func): ethosu.PadParams.composite_name, ethosu.pad_pattern(), lambda pat: ethosu.PadParams(pat).is_valid(), - ) + ), ] tflite_graph = create_tflite_graph() @@ -781,6 +894,132 @@ def verify(ext_func): verify(mod["tvmgen_default_ethos_u_main_0"]) +@pytest.mark.parametrize("ifm_shape", [(1, 55, 55, 3), (1, 23, 32, 7)]) +@pytest.mark.parametrize("channel_padding", [(0, 1), (1, 1), (5, 2)]) +@pytest.mark.parametrize("const_value", [0, 5, 125, -5]) +def test_tflite_separate_channel_padding_legalize(ifm_shape, channel_padding, const_value): + dtype = "int8" + padding = (0, 0, 0, 0) + + class AreConcatenateOnGraph(ExprVisitor): + """ + Visits the Graph recursively and checks if it contains 'concatenate' op + """ + + def __init__(self): + ExprVisitor.__init__(self) + self.on_graph = False + + def visit_call(self, call): + if isinstance(call.op, tvm.ir.Op): + if str(call.op.name) == "concatenate": + self.on_graph = True + + return super().visit_call(call) + + def are_concatenate_on_graph(self, subgraph) -> bool: + """ + This function recursively visits the graph and checks if 'concatenate' op is on graph + """ + self.visit(subgraph) + return self.on_graph + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x): + return tf.pad( + x, + [ + [0, 0], + [padding[0], padding[2]], + [padding[1], padding[3]], + [channel_padding[0], channel_padding[1]], + ], + "CONSTANT", + const_value, + ) + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + def verify(ext_func, channel_padding): + + op = ext_func.body + + pad_before = 0 + pad_after = 0 + if channel_padding[0] == 0 and channel_padding[1] > 0: + pad_after = ext_func.body.args[0][1].args[0].checked_type.shape[3] + ifm = ext_func.body.args[0][0].args[0].checked_type + if channel_padding[0] > 0 and channel_padding[1] == 0: + pad_before = ext_func.body.args[0][0].args[0].checked_type.shape[3] + ifm = ext_func.body.args[0][1].args[0].checked_type + if channel_padding[0] > 0 and channel_padding[1] > 0: + pad_before = ext_func.body.args[0][0].args[0].checked_type.shape[3] + ifm = ext_func.body.args[0][1].args[0].checked_type + pad_after = ext_func.body.args[0][2].args[0].checked_type.shape[3] + + # check IFM + assert list(ifm.shape) == list(ifm_shape) + assert str(ifm.dtype) == dtype + assert ifm.shape[3] == ifm_shape[3] + + # check OFM + ofm = op.checked_type + expected_ofm_shape = list(ifm_shape) + expected_ofm_shape[3] = channel_padding[0] + ifm_shape[3] + channel_padding[1] + assert list(ofm.shape) == expected_ofm_shape + assert str(ofm.dtype) == dtype + + # check padding + assert [pad_before, pad_after] == list(channel_padding) + + # check if relay contains 'concatenate' op + assert AreConcatenateOnGraph().are_concatenate_on_graph(ext_func.body) == True + + pad_pattern_table = [ + ( + ethosu.ChannelPadParams.composite_name, + ethosu.pad_pattern(), + lambda pat: ethosu.ChannelPadParams(pat).is_valid(), + ), + ] + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + + mod["main"] = bind_params_by_name(mod["main"], params) + mod = partition_ethosu_by_table(mod, pad_pattern_table) + + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.ChannelPadRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + verify(mod["tvmgen_default_ethos_u_main_0"], channel_padding) + + @pytest.mark.parametrize("pooling_type", ["MAX", "AVG"]) @pytest.mark.parametrize("ifm_shape", [[1, 3, 4, 3], [1, 4, 5, 2]]) @pytest.mark.parametrize(