diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index b16374028b1a..eb029dac2cbc 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -1269,6 +1269,101 @@ def __call__(self, *args, **kwargs): pass +class Resize2dRewriter(DFPatternCallback): + """ + Convert ethos-u.resize2d composite function to an equivalent operation that + performs the relevant upsampling operation. + + Case 1: No upsampling (upscale factor of 1): + Identity. + Case 1: Nearest neighbor upsampling: + 1x1 pooling with 2x2 nearest neighbor upsampling. + Case 2: Bilinear upsampling: + 2x2 average pool with 2x2 nearest neighbor upsampling. + """ + + def __init__(self): + super().__init__(require_type=True) + self.pattern = ( + wildcard().has_attr({"Composite": ethosu_patterns.Resize2dParams.composite_name}) + )(wildcard()) + + def callback( + self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map + ) -> tvm.relay.Expr: + params = ethosu_patterns.Resize2dParams(post.op.body) + params.ifm.tensor = post.args[0] + + lut = relay.const([], "int8") + ifm_shape = params.ifm.shape + in_channels = ifm_shape[-1] + reduced_op = params.ifm.tensor + current_size = np.array(ifm_shape[1:3]) + output_size = np.array(params.size) + + if (current_size == output_size).all(): + return ethosu_ops.ethosu_identity( + reduced_op, + lut, + ifm_scale=float(params.ifm.q_params.scale_f32), + ifm_zero_point=int(params.ifm.q_params.zero_point), + ofm_scale=float(params.ofm.q_params.scale_f32), + ofm_zero_point=int(params.ofm.q_params.zero_point), + ) + + padding = [0, 0, 0, 0] + rounding_mode = "TFL" + pool_shape = [1, 1] + if params.method == "linear": + pool_shape = [2, 2] + rounding_mode = "NATURAL" + if params.coordinate_transformation_mode == "asymmetric": + # Use SAME padding. + ypad = Resize2dRewriter.get_required_padding(ifm_shape[1]) + xpad = Resize2dRewriter.get_required_padding(ifm_shape[2]) + padding = [ypad // 2, xpad // 2, (ypad + 1) // 2, (xpad + 1) // 2] + + return ethosu_ops.ethosu_pooling( + ifm=reduced_op, + lut=lut, + pooling_type="AVG", + ifm_scale=float(params.ifm.q_params.scale_f32), + ifm_zero_point=int(params.ifm.q_params.zero_point), + ofm_scale=float(params.ofm.q_params.scale_f32), + ofm_zero_point=int(params.ofm.q_params.zero_point), + pool_shape=pool_shape, + ofm_channels=in_channels, + strides=[1, 1], + padding=padding, + upscale="NEAREST", + rounding_mode=rounding_mode, + ) + + @staticmethod + def get_required_padding(input_size: int, pool_size: int = 2) -> int: + """Gets the amount of padding required needed to achieve + 'SAME' padding for a given axis.""" + needed_input = (input_size - 1) + pool_size + total_padding = max(0, needed_input - input_size) + return total_padding + + +@ir.transform.module_pass(opt_level=1) +class LegalizeResize2d: + """This is the pass that wraps Resize2dRewriter""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(Resize2dRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + @ir.transform.module_pass(opt_level=1) class LegalizeEthosU: """This is the pass to call graph-rewrites to perform graph transformation @@ -1299,6 +1394,7 @@ def transform_module( mod = LegalizeConcat()(mod) mod = LegalizeSigmoid()(mod) mod = LegalizeRequantize()(mod) + mod = LegalizeResize2d()(mod) mod = LegalizeReshape()(mod) mod = LegalizeStridedSlice()(mod) mod = LegalizeNoOps()(mod) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py index 8446b0c2e4ad..958125630324 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py @@ -288,7 +288,10 @@ def match_ethosu_binary_elementwise(output_tensor, device_config): pad = binary_elementwise.op.input_tensors[0] if pad.op.name != "ethosu_pad": return None - convert_to_nhwc = pad.op.input_tensors[0] + upscale = pad.op.input_tensors[0] + if upscale.op.name != "ethosu_upscale": + return None + convert_to_nhwc = upscale.op.input_tensors[0] if convert_to_nhwc.op.name != "ethosu_convert_to_nhwc": return None read = convert_to_nhwc.op.input_tensors[0] @@ -297,7 +300,10 @@ def match_ethosu_binary_elementwise(output_tensor, device_config): pad2 = binary_elementwise.op.input_tensors[1] if pad2.op.name != "ethosu_pad": return None - convert_to_nhwc2 = pad2.op.input_tensors[0] + upscale2 = pad2.op.input_tensors[0] + if upscale2.op.name != "ethosu_upscale": + return None + convert_to_nhwc2 = upscale2.op.input_tensors[0] if convert_to_nhwc2.op.name != "ethosu_convert_to_nhwc": return None read2 = convert_to_nhwc2.op.input_tensors[0] diff --git a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py index ea2290ef1e5f..040d1e26fba9 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py @@ -267,7 +267,10 @@ def match_ethosu_conv2d(output_tensor, device_config): pad = conv2d.op.input_tensors[0] if pad.op.name != "ethosu_pad": return None - convert_to_nhwc = pad.op.input_tensors[0] + upscale = pad.op.input_tensors[0] + if upscale.op.name != "ethosu_upscale": + return None + convert_to_nhwc = upscale.op.input_tensors[0] if convert_to_nhwc.op.name != "ethosu_convert_to_nhwc": return None read = convert_to_nhwc.op.input_tensors[0] diff --git a/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py index ff09662cc14a..79d4f05f9cf2 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py @@ -267,7 +267,10 @@ def match_ethosu_depthwise_conv2d(output_tensor, device_config): pad = depthwise2d.op.input_tensors[0] if pad.op.name != "ethosu_pad": return None - convert_to_nhwc = pad.op.input_tensors[0] + upscale = pad.op.input_tensors[0] + if upscale.op.name != "ethosu_upscale": + return None + convert_to_nhwc = upscale.op.input_tensors[0] if convert_to_nhwc.op.name != "ethosu_convert_to_nhwc": return None read = convert_to_nhwc.op.input_tensors[0] diff --git a/python/tvm/relay/backend/contrib/ethosu/te/dma.py b/python/tvm/relay/backend/contrib/ethosu/te/dma.py index 14aa67bb37d3..9d9eaf0ed444 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/dma.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/dma.py @@ -277,6 +277,38 @@ def pad_compute(tensor: te.Tensor, padding: tuple) -> te.Tensor: ) +def upscale_compute(tensor: te.Tensor, upscale_factor: int) -> te.Tensor: + """Apply upscaling to an NHWC tensor. + + Parameters + ---------- + tensor : te.Tensor + The tensor to pad. + upscale_factor : int + The factor by which to apply upscaling. + + Returns + ------- + te.Tensor + The upscaled tensor. + + """ + shape = tensor.shape + + reason = f"The compiler only supports 2x2 upscaling, but factor was {upscale_factor}." + assert upscale_factor in (1, 2), reason + new_shape = (shape[0], shape[1] * upscale_factor, shape[2] * upscale_factor, shape[3]) + + upscale_attrs = {"op": "ethosu_upscale"} + + return te.compute( + new_shape, + lambda nn, hh, ww, cc: tensor(nn, hh // upscale_factor, ww // upscale_factor, cc), + name="ethosu_upscale", + attrs=upscale_attrs, + ) + + def dma_ifm_compute( ifm: te.Tensor, layout: str, @@ -284,6 +316,7 @@ def dma_ifm_compute( scale: float, channels: int, padding: Tuple[int, int, int, int], + upscale_factor: Optional[int] = 1, ) -> te.Tensor: """A sequence of compute operators representing the DMA capabilities for an IFM. @@ -301,6 +334,8 @@ def dma_ifm_compute( The number of valid channels for the data. padding : tuple The 4 dimensional padding as (pad_top, pad_left, pad_bottom, pad_right). + upscale_factor : Optional[int] + The factor by which to apply upscaling. By default there will be no upscaling. Returns ------- @@ -310,7 +345,8 @@ def dma_ifm_compute( """ read_ifm = read_compute(ifm, zero_point, scale, layout=layout) convert_to_nhwc_ifm = convert_to_nhwc_compute(read_ifm, layout, channels) - return pad_compute(convert_to_nhwc_ifm, padding) + upscale_ifm = upscale_compute(convert_to_nhwc_ifm, upscale_factor) + return pad_compute(upscale_ifm, padding) def dma_ofm_compute( diff --git a/python/tvm/relay/backend/contrib/ethosu/te/pooling.py b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py index aaf79e8a8c8d..f1b065cbcf17 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/pooling.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py @@ -109,9 +109,12 @@ def pooling_compute( padding = [int(v) for v in padding] stride_h, stride_w = [int(v) for v in strides] pool_shape_h, pool_shape_w = [int(v) for v in pool_shape] + upscale_factor = 2 if upscale != "NONE" else 1 # Compute operation for the IFM DMA pipeline - dmaed_ifm = dma_ifm_compute(ifm, ifm_layout, ifm_zero_point, ifm_scale, ofm_channels, padding) + dmaed_ifm = dma_ifm_compute( + ifm, ifm_layout, ifm_zero_point, ifm_scale, ofm_channels, padding, upscale_factor + ) # Pooling compute operation ofm_height = (dmaed_ifm.shape[1] - pool_shape_h) // stride_h + 1 @@ -228,7 +231,10 @@ def match_ethosu_pooling(output_tensor, device_config): pad = pool2d.op.input_tensors[0] if pad.op.name != "ethosu_pad": return None - convert_to_nhwc = pad.op.input_tensors[0] + upscale = pad.op.input_tensors[0] + if upscale.op.name != "ethosu_upscale": + return None + convert_to_nhwc = upscale.op.input_tensors[0] if convert_to_nhwc.op.name != "ethosu_convert_to_nhwc": return None read = convert_to_nhwc.op.input_tensors[0] diff --git a/python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py index 68d1c603ad98..69f06be955cb 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py @@ -205,7 +205,10 @@ def match_ethosu_unary_elementwise(output_tensor, device_config): pad = unary_elementwise.op.input_tensors[0] if pad.op.name != "ethosu_pad": return None - convert_to_nhwc = pad.op.input_tensors[0] + upscale = pad.op.input_tensors[0] + if upscale.op.name != "ethosu_upscale": + return None + convert_to_nhwc = upscale.op.input_tensors[0] if convert_to_nhwc.op.name != "ethosu_convert_to_nhwc": return None read = convert_to_nhwc.op.input_tensors[0] diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py index 7670c5d2f7b6..9f82d7478265 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py @@ -76,6 +76,31 @@ def _visit(expr): ) +def get_upscale_params(stmt): + """Get the upscale parameters from a loop nest. + + Parameters + ---------- + stmt : tvm.tir.AttrStmt + The outermost attribute statement of an upscale loop nest. + + Returns + ------- + input_pointer : tvm.tir.Var + The pointer consumed by the operation. + output_pointer : tvm.tir.Var + The pointer produced by the operation. + """ + _, body = get_op_attrs(stmt) + _, _, _, _, _, inner = get_outer_loops(body, "NHWC") + if isinstance(inner.value, tvm.tir.Call): + input_pointer = inner.value.args[1].buffer_var + else: + input_pointer = inner.value.buffer_var + output_pointer = inner.buffer_var + return (input_pointer, output_pointer) + + def get_convert_to_nhwc_params(stmt): """Get the true number of channels from a convert_to_nhwc loop nest. @@ -264,6 +289,8 @@ def get_ifm_params(pointer, producers): """ pad = producers[pointer] serial_padding, input_pointer, _ = get_pad_params(pad) + upscale = producers[input_pointer] + input_pointer, _ = get_upscale_params(upscale) convert_to_nhwc = producers[input_pointer] in_channels, input_pointer, _ = get_convert_to_nhwc_params(convert_to_nhwc) read = producers[input_pointer] diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py b/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py index 010ae1b86448..e929caa2409b 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py @@ -89,7 +89,7 @@ def get_pooling_params( padding=serial_padding, activation=serial_activation, rounding_mode=attrs["rounding_mode"], - upscale="NONE", + upscale=attrs["upscale"], ), output_pointer, replace_pointer, diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py index 20a8ff85ee2f..dc458484ec16 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py @@ -304,7 +304,8 @@ def from_output(cls, out): convert_to_nhcwb16 = write.op.input_tensors[0] conv2d = convert_to_nhcwb16.op.input_tensors[0] pad = conv2d.op.input_tensors[0] - convert_to_nhwc = pad.op.input_tensors[0] + upscale = pad.op.input_tensors[0] + convert_to_nhwc = upscale.op.input_tensors[0] read = convert_to_nhwc.op.input_tensors[0] return cls(read, convert_to_nhwc, pad, conv2d, convert_to_nhcwb16, write) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index 77fbc3e8628d..c68210429285 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -443,7 +443,7 @@ def _create_npu_op_conv2d( _convert_clip_bounds(npu_conv2d_op) npu_conv2d_op.rounding_mode = _create_npu_rounding_mode(serial_2d_convolution.rounding_mode) - npu_conv2d_op.upscale = _create_npu_resampling_mode(serial_2d_convolution.upscale) + npu_conv2d_op.ifm_upscale = _create_npu_resampling_mode(serial_2d_convolution.upscale) accel_config = vela_api.get_accelerator_config() weights_shape_ohwi = [ npu_conv2d_op.ofm.shape.depth, @@ -506,7 +506,7 @@ def _create_npu_op_depthwise_conv2d(serial_2d_depthwise): npu_depthwise_conv2d_op.rounding_mode = _create_npu_rounding_mode( serial_2d_depthwise.rounding_mode ) - npu_depthwise_conv2d_op.upscale = _create_npu_resampling_mode(serial_2d_depthwise.upscale) + npu_depthwise_conv2d_op.ifm_upscale = _create_npu_resampling_mode(serial_2d_depthwise.upscale) target_accel_config = vela_api.get_accelerator_config() block_config = vela_api.get_optimal_block_config(npu_depthwise_conv2d_op, target_accel_config) npu_depthwise_conv2d_op.block_config = block_config @@ -656,7 +656,7 @@ def _create_npu_resampling_mode( mode_map = { "NONE": vapi.NpuResamplingMode.NONE, "NEAREST": vapi.NpuResamplingMode.NEAREST, - "TRANSPOSE": vapi.NpuResamplingMode.TRANSPOSE, + "ZEROS": vapi.NpuResamplingMode.TRANSPOSE, } mode = str(mode.value) assert mode in mode_map.keys() @@ -737,7 +737,7 @@ def _create_npu_op_pooling(serial_pooling: spec.SerialPooling): _convert_clip_bounds(npu_pooling_op) npu_pooling_op.rounding_mode = _create_npu_rounding_mode(serial_pooling.rounding_mode) - npu_pooling_op.upscale = _create_npu_resampling_mode(serial_pooling.upscale) + npu_pooling_op.ifm_upscale = _create_npu_resampling_mode(serial_pooling.upscale) target_accel_config = vela_api.get_accelerator_config() block_config = vela_api.get_optimal_block_config(npu_pooling_op, target_accel_config) diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 9ea1e2bb1fc3..72c83605ff04 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -17,8 +17,8 @@ # pylint: disable=ungrouped-imports, import-outside-toplevel """Arm(R) Ethos(TM)-U NPU supported operators.""" import functools - from typing import Dict, List, Tuple, Callable, Optional + import numpy as np # type: ignore import tvm # type: ignore @@ -83,9 +83,10 @@ def __init__(self, tensor, layout=None, scale=None, zero_point=None): self.q_params = vapi.NpuQuantization(1.0, 0) -def check_strides(strides: List[int]) -> bool: +def check_strides(strides: List[int], stride_range=None) -> bool: """This function checks whether strides are within the limits supported by the NPU""" - stride_range = (1, 3) + if stride_range is None: + stride_range = (1, 3) smin, smax = stride_range if not smax >= strides[0] >= smin: return False @@ -146,9 +147,10 @@ def check_batch_size(ifm: TensorParams): return ifm.shape[0] == 1 -def check_dilation(dilation: List[int]): +def check_dilation(dilation: List[int], dilation_range=None): """This function checks whether dilation is within the limits supported by the NPU""" - dilation_range = (1, 2) + if dilation_range is None: + dilation_range = (1, 2) dmin, dmax = dilation_range if not dmin <= dilation[0] <= dmax: return False @@ -1199,6 +1201,91 @@ def requantize_pattern() -> tvm.relay.dataflow_pattern.DFPattern: ) +class Resize2dParams: + """ + This class will parse a call to ethos-u.resize2d composite function + and extract the parameter information. + """ + + composite_name = "ethos-u.resize2d" + + def __init__(self, func_body: Call): + layout = "NHWC" + + resize_2d = func_body + in_var = func_body.args[0] + if ( + isinstance(resize_2d, tvm.relay.expr.Call) + and isinstance(resize_2d.op, tvm.ir.Op) + and resize_2d.op.name == "qnn.quantize" + ): + resize_2d = resize_2d.args[0] + in_var = in_var.args[0].args[0] + out_var = func_body + + self.ifm = TensorParams(in_var, layout=layout) + self.ofm = TensorParams(out_var, layout=layout) + + attrs = resize_2d.attrs + self.size = attrs.size + self.method = attrs.method + self.roi = attrs.roi + self.coordinate_transformation_mode = attrs.coordinate_transformation_mode + self.rounding_method = attrs.rounding_method + self.out_dtype = attrs.out_dtype + + def is_valid(self) -> bool: + """ + Checks whether image.resize2d has compatible attributes with HW. + """ + + def check_compatible_size(mode, method, upscale_size, ifm_size): + """Checking the provided upscale_size is compatible with the NPU. The NPU only + supports upsampling when the upsampling size is 2 * input_size, or when there is + no upsampling to be done, so check that this is the case. In the special case of + resize_bilinear with align_corners=True, the NPU only supports an upsampling + size of 2 * input_size - 1.""" + delta = 1 if mode == "align_corners" and method == "linear" else 0 + upscale_size = np.array(upscale_size) + ifm_size = np.array(ifm_size) + ifm_upscaled = ifm_size * 2 - delta + return (ifm_upscaled == upscale_size).all() or (ifm_size == upscale_size).all() + + tensor_params = [self.ifm, self.ofm] + if not check_valid_dtypes(tensor_params, supported_dtypes=[np.int8]): + return False + if len(self.ifm.shape) != 4 or len(self.ofm.shape) != 4: + return False + if list(float(x) for x in self.roi) != [0.0] * 4: + return False + if self.method not in ("nearest_neighbor", "linear"): + return False + if self.coordinate_transformation_mode not in ("asymmetric", "align_corners"): + return False + if not check_compatible_size( + self.coordinate_transformation_mode, + self.method, + self.size, + self.ifm.shape[1:3], + ): + return False + if self.rounding_method != "": + return False + if self.out_dtype and self.out_dtype != "int8": + return False + return True + + +def resize2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """ + This function creates the pattern for image.resize2d. + """ + dequant = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant()) + resize_2d = is_op("image.resize2d")(dequant).has_attr({"method": "linear"}) + quant = is_op("qnn.quantize")(resize_2d, is_constant(), is_constant()) + return quant | is_op("image.resize2d")(wildcard()).has_attr({"method": "nearest_neighbor"}) + + @register_pattern_table("ethos-u") def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]: return [ @@ -1289,6 +1376,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal requantize_pattern(), lambda pat: RequantizeParams(pat).is_valid(), ), + ( + Resize2dParams.composite_name, + resize2d_pattern(), + lambda pat: Resize2dParams(pat).is_valid(), + ), ] diff --git a/src/relay/op/contrib/ethosu/common.cc b/src/relay/op/contrib/ethosu/common.cc index 817575cc8d0d..eac576257721 100644 --- a/src/relay/op/contrib/ethosu/common.cc +++ b/src/relay/op/contrib/ethosu/common.cc @@ -75,6 +75,23 @@ Array EthosuInferKernelOutput(Array ifm_shape, String ifm_ return output_shape; } +Array EthosuInferUpscaledInput(Array ifm_shape, String ifm_layout) { + if (ifm_layout == "NHCWB16") { + ifm_shape = {ifm_shape[0], ifm_shape[1], ifm_shape[3], ifm_shape[2] * 16}; + } + + const int scale_factor = 2; + Array new_ifm_shape = {ifm_shape[0], ifm_shape[1] * scale_factor, + ifm_shape[2] * scale_factor, ifm_shape[3]}; + + if (ifm_layout == "NHCWB16") { + int channel_bricks = 1 + (new_ifm_shape[3].as()->value - 1) / 16; + new_ifm_shape = {new_ifm_shape[0], new_ifm_shape[1], channel_bricks, new_ifm_shape[2], 16}; + } + + return new_ifm_shape; +} + } // namespace ethosu } // namespace contrib } // namespace op diff --git a/src/relay/op/contrib/ethosu/common.h b/src/relay/op/contrib/ethosu/common.h index cc489de6a49a..001b596c0949 100644 --- a/src/relay/op/contrib/ethosu/common.h +++ b/src/relay/op/contrib/ethosu/common.h @@ -59,6 +59,12 @@ Array EthosuInferKernelOutput(Array ifm_shape, String ifm_ IndexExpr ofm_channels, Array dilation, Array strides, Array padding); +/*! \brief Infer the Output Feature Map shape for operations that use upscaling. + * \param ifm_shape The shape of the Input Feature Map. + * \param ifm_layout The layout of the Input Feature Map. + */ +Array EthosuInferUpscaledInput(Array ifm_shape, String ifm_layout); + } // namespace ethosu } // namespace contrib } // namespace op diff --git a/src/relay/op/contrib/ethosu/pooling.cc b/src/relay/op/contrib/ethosu/pooling.cc index ca765a1581c4..dc16c072ebe2 100644 --- a/src/relay/op/contrib/ethosu/pooling.cc +++ b/src/relay/op/contrib/ethosu/pooling.cc @@ -139,9 +139,23 @@ bool EthosuPoolingRel(const Array& types, int num_inputs, const Attrs& att return false; } - // Assign ofm type + const std::unordered_set upscale_methods = {"NONE", "ZEROS", "NEAREST"}; + if (upscale_methods.find(param->upscale) == upscale_methods.end()) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: Expected upsample method to be 'NONE', " + "'ZEROS' or 'NEAREST' but got " + << param->upscale); + return false; + } + + Array ifm_shape = ifm->shape; + if (param->upscale != "NONE") { + ifm_shape = EthosuInferUpscaledInput(ifm_shape, param->ifm_layout); + } + + // Assign ofm shape auto ofm_shape = EthosuInferKernelOutput( - ifm->shape, param->ifm_layout, param->ofm_layout, param->pool_shape, param->ofm_channels, + ifm_shape, param->ifm_layout, param->ofm_layout, param->pool_shape, param->ofm_channels, Array({1, 1}), param->strides, param->padding); reporter->Assign(types[result_index], TensorType(ofm_shape, ifm->dtype)); return true; diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index 0b058a94fb60..52bc8ef69435 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -532,6 +532,7 @@ def make_ethosu_pooling( ifm_layout="NHWC", ofm_layout="NHWC", rounding_mode="TFL", + upscale="NONE", ): pooling = ethosu_ops.ethosu_pooling( ifm, @@ -549,7 +550,7 @@ def make_ethosu_pooling( clip_min=10 if activation == "CLIP" else 0, clip_max=100 if activation == "CLIP" else 0, rounding_mode=rounding_mode, - upscale="NONE", + upscale=upscale, ifm_layout=ifm_layout, ofm_layout=ofm_layout, ) diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 1af8a60158fb..455be799a822 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -1020,5 +1020,46 @@ def create_model(): _compare_ethosu_with_reference(ethosu_mod, input_data, output_data, accel_type) +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize( + "ifm_shape,size", + [[(1, 2, 2, 1), (4, 4)], [(1, 4, 7, 3), (8, 14)], [(1, 3, 5, 3), (3, 5)]], +) +def test_tflite_resize2d_nearest_neighbor(accel_type, ifm_shape, size): + align_corners = False + + @tf.function + def resize_model(x): + return tf.compat.v1.image.resize_nearest_neighbor( + x, size, align_corners=align_corners, half_pixel_centers=False + ) + + _compare_tvm_with_tflite(resize_model, [ifm_shape], accel_type) + + +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize( + "ifm_shape,size,align_corners", + [ + [(1, 2, 2, 1), (4, 4), False], + [(1, 4, 7, 3), (8, 14), False], + [(1, 2, 2, 1), (3, 3), True], + [(1, 4, 7, 3), (7, 13), True], + [(1, 3, 5, 3), (3, 5), False], + ], +) +def test_tflite_resize2d_bilinear(accel_type, ifm_shape, size, align_corners): + @tf.function + def resize_model(x): + return tf.compat.v1.image.resize_bilinear( + x, size, align_corners=align_corners, half_pixel_centers=False + ) + + # TODO(lhutton1) For now output is not bit exact with TFLite. + # This is because TFLite reference kernels are not being used. + # For this, TFLite will need upgrading to 2.6. + _compare_tvm_with_tflite(resize_model, [ifm_shape], accel_type, output_tolerance=1) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index f05fec9d124b..f77e15c9334e 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -1290,6 +1290,7 @@ def verify(ext_func): mod["tvmgen_default_ethos_u_main_0"] = relay.transform.InferType()(mod)[ "tvmgen_default_ethos_u_main_0" ] + verify(mod["tvmgen_default_ethos_u_main_0"]) def test_tflite_sigmoid_legalize(): @@ -1602,5 +1603,176 @@ def verify(ext_func): verify(mod["tvmgen_default_ethos_u_main_0"]) +@pytest.mark.parametrize( + "ifm_shape,size", + [ + [(1, 2, 2, 1), (4, 4)], + [(1, 4, 7, 3), (8, 14)], + [(1, 3, 5, 3), (3, 5)], + ], +) +def test_tflite_resize2d_nearest_neighbor(ifm_shape, size): + align_corners = False + dtype = "int8" + + def create_tflite_graph(): + @tf.function + def resize_model(x): + return tf.compat.v1.image.resize_nearest_neighbor( + x, size, align_corners=align_corners, half_pixel_centers=False + ) + + concrete_func = resize_model.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model, 0) + + mod, _ = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + return mod + + def verify(ext_func): + op = ext_func.body + in_var = op.args[0] + + # check IFM + assert tuple(in_var.checked_type.shape) == ifm_shape + assert in_var.checked_type.dtype == dtype + + # check OFM + attrs = dict(op.attrs) + out_shape = (ifm_shape[0], size[0], size[1], ifm_shape[3]) + assert tuple(op.checked_type.shape) == out_shape + assert op.checked_type.dtype == dtype + + # Check Op attributes + if size[0] == ifm_shape[1] and size[1] == ifm_shape[2]: + assert op.op.name == "contrib.ethosu.identity" + else: + assert attrs["pooling_type"] == "AVG" + assert attrs["upscale"] == "NEAREST" + + rewriter = legalize.Resize2dRewriter() + pattern_table = [ + ( + ethosu.Resize2dParams.composite_name, + ethosu.resize2d_pattern(), + lambda pat: ethosu.Resize2dParams(pat).is_valid(), + ), + ] + + mod = create_tflite_graph() + mod = partition_ethosu_by_table(mod, pattern_table) + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + rewriter, mod["tvmgen_default_ethos_u_main_0"] + ) + verify(mod["tvmgen_default_ethos_u_main_0"]) + + +@pytest.mark.parametrize( + "ifm_shape,size,align_corners", + [ + [(1, 2, 2, 1), (4, 4), False], + [(1, 4, 7, 3), (8, 14), False], + [(1, 2, 2, 1), (3, 3), True], + [(1, 4, 7, 3), (7, 13), True], + [(1, 3, 5, 3), (3, 5), False], + ], +) +def test_tflite_resize2d_bilinear(ifm_shape, size, align_corners): + dtype = "int8" + + def create_tflite_graph(): + @tf.function + def resize_model(x): + return tf.compat.v1.image.resize_bilinear( + x, size, align_corners=align_corners, half_pixel_centers=False + ) + + concrete_func = resize_model.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model, 0) + + mod, _ = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + return mod + + def verify(ext_func): + op = ext_func.body + in_var = op.args[0] + + # check IFM + assert tuple(in_var.checked_type.shape) == ifm_shape + assert in_var.checked_type.dtype == dtype + + # check OFM + attrs = dict(op.attrs) + out_shape = (ifm_shape[0], size[0], size[1], ifm_shape[3]) + assert tuple(op.checked_type.shape) == out_shape + assert op.checked_type.dtype == dtype + + # Check Op attributes + if size[0] == ifm_shape[1] and size[1] == ifm_shape[2]: + assert op.op.name == "contrib.ethosu.identity" + else: + assert attrs["pooling_type"] == "AVG" + assert attrs["upscale"] == "NEAREST" + + # Check padding + if align_corners: + assert list(attrs["padding"]) == [0, 0, 0, 0] + else: + assert list(attrs["padding"]) == [0, 0, 1, 1] + + rewriter = legalize.Resize2dRewriter() + pattern_table = [ + ( + ethosu.Resize2dParams.composite_name, + ethosu.resize2d_pattern(), + lambda pat: ethosu.Resize2dParams(pat).is_valid(), + ), + ] + + mod = create_tflite_graph() + mod = partition_ethosu_by_table(mod, pattern_table) + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + rewriter, mod["tvmgen_default_ethos_u_main_0"] + ) + verify(mod["tvmgen_default_ethos_u_main_0"]) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_replace_pooling.py b/tests/python/contrib/test_ethosu/test_replace_pooling.py index ee72ffa4cb99..c535498ee04d 100644 --- a/tests/python/contrib/test_ethosu/test_replace_pooling.py +++ b/tests/python/contrib/test_ethosu/test_replace_pooling.py @@ -37,19 +37,29 @@ def _create_serial_pooling( padding, activation="NONE", rounding_mode="TFL", + upscale="NONE", ): + upscale_factor = 2 if upscale != "NONE" else 1 if ifm_layout == "NHWC": ifm_stride_c = 1 ifm_stride_w = ifm_shape[3] ifm_stride_h = ifm_shape[2] * ifm_shape[3] - ofm_height = (ifm_shape[1] - pool_shape[0] + padding[0] + padding[0]) // strides[0] + 1 - ofm_width = (ifm_shape[2] - pool_shape[1] + padding[1] + padding[1]) // strides[1] + 1 + ofm_height = ( + ifm_shape[1] * upscale_factor - pool_shape[0] + padding[0] + padding[2] + ) // strides[0] + 1 + ofm_width = ( + ifm_shape[2] * upscale_factor - pool_shape[1] + padding[1] + padding[3] + ) // strides[1] + 1 else: ifm_stride_w = 16 ifm_stride_c = 16 * ifm_shape[3] if ofm_channels >= 16 else 1 ifm_stride_h = 16 * ifm_shape[2] * ifm_shape[3] - ofm_height = (ifm_shape[1] - pool_shape[0] + padding[0] + padding[0]) // strides[0] + 1 - ofm_width = (ifm_shape[3] - pool_shape[1] + padding[1] + padding[1]) // strides[1] + 1 + ofm_height = ( + ifm_shape[1] * upscale_factor - pool_shape[0] + padding[0] + padding[2] + ) // strides[0] + 1 + ofm_width = ( + ifm_shape[3] * upscale_factor - pool_shape[1] + padding[1] + padding[3] + ) // strides[1] + 1 if ofm_layout == "NHWC": ofm_stride_c = 1 @@ -117,18 +127,19 @@ def _create_serial_pooling( clip_max=100 if activation == "CLIP" else 0, ), rounding_mode=rounding_mode, - upscale="NONE", + upscale=upscale, ) @pytest.mark.parametrize( - "ifm_shape, ofm_channels, ifm_layout, ofm_layout, rounding_mode", + "ifm_shape, ofm_channels, ifm_layout, ofm_layout, rounding_mode, upscale", [ - ((1, 5, 9, 3), 3, "NHWC", "NHWC", "TFL"), - ((1, 8, 3, 9, 16), 40, "NHCWB16", "NHCWB16", "NATURAL"), - ((1, 8, 3, 9, 16), 40, "NHCWB16", "NHWC", "TRUNCATE"), - ((1, 8, 9, 40), 40, "NHWC", "NHCWB16", "TFL"), - ((1, 8, 9, 8), 8, "NHWC", "NHCWB16", "TFL"), + ((1, 5, 9, 3), 3, "NHWC", "NHWC", "TFL", "NONE"), + ((1, 8, 3, 9, 16), 40, "NHCWB16", "NHCWB16", "NATURAL", "NONE"), + ((1, 8, 3, 9, 16), 40, "NHCWB16", "NHWC", "TRUNCATE", "ZEROS"), + ((1, 8, 9, 40), 40, "NHWC", "NHCWB16", "TFL", "ZEROS"), + ((1, 8, 9, 8), 8, "NHWC", "NHCWB16", "TFL", "NEAREST"), + ((1, 5, 9, 3), 3, "NHWC", "NHWC", "TFL", "NEAREST"), ], ) @pytest.mark.parametrize("pooling_type", ["AVG", "MAX"]) @@ -141,10 +152,19 @@ def test_pooling_single( pooling_type, activation, rounding_mode, + upscale, ): pool_shape = (3, 2) strides = (1, 2) - padding = (1, 1, 1, 0) + + # When strides are not (1, 1) it is possible to create invalid + # padding configurations. It is possible to construct a pooling + # operation with invalid padding, but the compiler will account + # for this and adjust the padding accordingly, leading to a + # mismatch between the expected and actual result. Therefore, + # hardcoded padding values are used for each case. + padding = (1, 1, 1, 0) if upscale == "NONE" else (0, 0, 0, 0) + ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") pooling = make_ethosu_pooling( ifm, @@ -157,6 +177,7 @@ def test_pooling_single( ifm_layout, ofm_layout, rounding_mode, + upscale, ) func = relay.Function(relay.analysis.free_vars(pooling), pooling) func = run_opt_pass(func, relay.transform.InferType()) @@ -180,6 +201,7 @@ def _visit(stmt): padding, activation, rounding_mode, + upscale, ) assert data[0] == ["ethosu_pooling"] + list(serial_pooling) diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index 20595465e32e..d81aff4c706d 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -131,11 +131,11 @@ def test_copy_constants(): sch = te.create_schedule([cached_func.outputs[0].op]) planner = copy_constants() planner(cached_func, const_dict, sch) - assert len(sch.stages) == 21 - assert ".global" in sch.stages[5].op.name - assert ".global" in sch.stages[7].op.name - assert ".global" in sch.stages[15].op.name + assert len(sch.stages) == 23 + assert ".global" in sch.stages[6].op.name + assert ".global" in sch.stages[8].op.name assert ".global" in sch.stages[17].op.name + assert ".global" in sch.stages[19].op.name # This test makes sure that constants and LUTs have a correct storage scope @@ -156,10 +156,10 @@ def test_copy_luts(): sch = te.create_schedule([te_graph.outputs[0].op]) copy_constants()(te_graph, const_dict, sch) copy_luts()(te_graph, const_dict, sch) - assert len(sch.stages) == 16 - assert ".global" in sch.stages[5].op.name - assert ".global" in sch.stages[7].op.name - assert ".local" in sch.stages[9].op.name + assert len(sch.stages) == 17 + assert ".global" in sch.stages[6].op.name + assert ".global" in sch.stages[8].op.name + assert ".local" in sch.stages[10].op.name def test_schedule_cache_reads(): diff --git a/tests/python/contrib/test_ethosu/test_type_inference.py b/tests/python/contrib/test_ethosu/test_type_inference.py index 9b606562c5c0..b25f92edc274 100644 --- a/tests/python/contrib/test_ethosu/test_type_inference.py +++ b/tests/python/contrib/test_ethosu/test_type_inference.py @@ -228,6 +228,23 @@ def test_ethosu_pooling_invalid_dtype(): run_opt_pass(func, relay.transform.InferType()) +def test_ethosu_pooling_invalid_upscale_method(): + invalid_upscale_method = "FOO" + ifm = relay.var("ifm", shape=[1, 56, 72, 55], dtype="int8") + pooling = make_ethosu_pooling( + ifm, + "MAX", + (3, 2), + 55, + (1, 2), + (0, 1, 2, 3), + upscale=invalid_upscale_method, + ) + func = relay.Function([ifm], pooling) + with pytest.raises(TVMError): + run_opt_pass(func, relay.transform.InferType()) + + @pytest.mark.parametrize( "ifm_shape, ifm_layout", [((1, 4, 5, 33), "NHWC"), ((1, 4, 3, 5, 16), "NHCWB16")] )