diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 43a0f89d95c1..5f591f1d89ad 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -467,6 +467,13 @@ TVM_DLL Pass RemoveUnusedFunctions(Array entry_functions); */ TVM_DLL Pass SimplifyExpr(); +/*! + * \brief Stripped down version of SimplifyExpr which is run after AlterOpLayout. + * + * \return The pass. + */ +TVM_DLL Pass SimplifyExprPostAlterOp(); + /*! * \brief Run any custom passes registered under "RelayToTIR" attributes on TargetKinds. * diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index e956c82828c1..53aec11e5816 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -877,23 +877,6 @@ def convert_deformable_conv2d(attrs, inputs, tinfos, desired_layouts): return relay.nn.deformable_conv2d(data, offset, weight, **new_attrs) -# QNN ops -@reg.register_alter_op_layout("add") -def alter_op_layout_add(attrs, inputs, tinfos, out_type): - """Alter the layout of a add op. - - Useful for fusing the bias constant with an input zero point constant in a previous quantized - op. Only used when previous op is a quantized op, which is why it lives in topi.nn.qnn. - """ - return topi.nn.qnn.qnn_add_alter_layout(attrs, inputs, tinfos, out_type) - - -@reg.register_alter_op_layout("qnn.requantize") -def alter_op_layout_qnn_requantize(attrs, inputs, tinfos, out_type): - """Alter the layout of a requantization op.""" - return topi.nn.qnn.qnn_requantize_alter_layout(attrs, inputs, tinfos, out_type) - - # bitpack @reg.register_compute("nn.bitpack") def compute_bitpack(attrs, inputs, out_dtype): diff --git a/python/tvm/relay/qnn/op/_qnn.py b/python/tvm/relay/qnn/op/_qnn.py index e2157a051abb..278ce7ee23c8 100644 --- a/python/tvm/relay/qnn/op/_qnn.py +++ b/python/tvm/relay/qnn/op/_qnn.py @@ -17,12 +17,20 @@ # pylint: disable=invalid-name, unused-argument, len-as-condition """QNN operator feature registration""" +import numpy as np + from tvm import topi from .. import strategy from ...op.op import register_compute from ...op.op import register_injective_schedule -from ...op.op import register_strategy, register_pattern, register_alter_op_layout, OpPattern +from ...op.op import ( + OpPattern, + register_alter_op_layout, + register_legalize, + register_pattern, + register_strategy, +) @register_compute("qnn.simulated_quantize") @@ -85,12 +93,60 @@ def simulated_dequantize_compute(attrs, inputs, output_type): register_strategy("qnn.conv2d", strategy.qnn_conv2d_strategy) +@register_legalize("clip") +def legalize_clip(attrs, inputs, tinfos): + """Removes clip operators with bounds matching the defaults for their dtype. + + This is already done after alter_op by TVM's simplification passes, but certain QNN operator + implementations (like Cortex-M) need it to be done earlier in legalization. + """ + + if hasattr(inputs[0], "op") and inputs[0].op.name == "qnn.requantize": + dtype_info = np.iinfo(tinfos[0].dtype) + if dtype_info.min == attrs.a_min and dtype_info.max == attrs.a_max: + return inputs[0] + + return None + + +@register_legalize("nn.bias_add") +def legalize_bias_add(attrs, inputs, tinfos): + """Legalize a bias add operator. + + May be used to "fold in" unused channels from quantized convolution operators. This should + be done before layout rewrites occur to minimize the amount of "extra" overhead operators + like "cast" and "layout_transform". + """ + return topi.nn.bias_add_legalize(attrs, inputs, tinfos) + + @register_alter_op_layout("qnn.conv2d") def alter_op_layout_qnn_conv2d(attrs, inputs, tinfos, out_type): - """Alternate the layout of qnn.conv2d""" + """Alter the layout of a qnn conv2d op. + + May be used to alter the current QNN Conv2D op, but can also be used to alter previous ops to + better match the current op. For example, Arm Cortex-M uses this to set the out_layout of + previous ops to the input layout preferred by future layouts. + """ return topi.nn.qnn_conv2d_alter_layout(attrs, inputs, tinfos, out_type) +@register_alter_op_layout("add") +def alter_op_layout_add(attrs, inputs, tinfos, out_type): + """Alter the layout of a add op. + + Useful for fusing the bias constant with an input zero point constant in a previous quantized + op. Only used when previous op is a quantized op, which is why it lives in topi.nn.qnn. + """ + return topi.nn.add_alter_layout(attrs, inputs, tinfos, out_type) + + +@register_alter_op_layout("qnn.requantize") +def alter_op_layout_qnn_requantize(attrs, inputs, tinfos, out_type): + """Alter the layout of a requantization op.""" + return topi.nn.qnn_requantize_alter_layout(attrs, inputs, tinfos, out_type) + + # qnn.dense register_strategy("qnn.dense", strategy.qnn_dense_strategy) diff --git a/python/tvm/relay/qnn/strategy/arm_cpu.py b/python/tvm/relay/qnn/strategy/arm_cpu.py index f8653817835e..bddfd7de3a56 100644 --- a/python/tvm/relay/qnn/strategy/arm_cpu.py +++ b/python/tvm/relay/qnn/strategy/arm_cpu.py @@ -21,9 +21,55 @@ regular/depthwise conv2d is supported, but qnn_dense will be added eventually.""" from tvm import topi, TVMError -from .generic import qnn_conv2d_strategy +from tvm.topi.utils import get_const_tuple from ... import op as _op from ...op.strategy.generic import is_depthwise_conv2d +from .generic import ( + qnn_conv2d_strategy, + qnn_dense_strategy, + qnn_dequantize_strategy, + qnn_quantize_strategy, + wrap_compute_dequantize, + wrap_compute_quantize, + wrap_topi_qnn_dense, + wrap_topi_schedule, +) + + +@qnn_quantize_strategy.register("arm_cpu") +def qnn_quantize_strategy_arm_cpu(_attrs, _inputs, _out_type, _target): + """qnn.quantize strategy for arm_cpu""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_quantize(topi.hexagon.qnn_quantize), + wrap_topi_schedule(topi.hexagon.schedule_qnn_quantize), + name="qnn_quantize.arm_cpu", + ) + return strategy + + +@qnn_dequantize_strategy.register("arm_cpu") +def qnn_dequantize_strategy_arm_cpu(_attrs, _inputs, _out_type, _target): + """qnn.dequantize strategy for arm_cpu""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_dequantize(topi.hexagon.qnn_dequantize), + wrap_topi_schedule(topi.hexagon.schedule_qnn_dequantize), + name="qnn_dequantize.arm_cpu", + ) + return strategy + + +@qnn_dense_strategy.register("arm_cpu") +def qnn_dense_strategy_arm_cpu(_attrs, _inputs, _out_type, _target): + """qnn.dense strategy for arm_cpu""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_topi_qnn_dense(topi.hexagon.qnn_dense), + wrap_topi_schedule(topi.hexagon.schedule_qnn_dense), + name="qnn_dense.arm_cpu", + ) + return strategy @qnn_conv2d_strategy.register("arm_cpu") @@ -59,13 +105,28 @@ def qnn_conv2d_strategy_arm_cpu(attrs, inputs, _out_type, target): topi.arm_cpu.schedule_qnn_conv2d, name="qnn_conv2d.arm_cpu", ) + else: + raise TVMError("QNN regular Conv2D for Arm Cortex-M DSP got incorrect input layout!") elif is_depthwise_conv2d(data.shape, data_layout, kernel.shape, kernel_layout, groups): if data_layout == "NCHW" and kernel_layout == "IOHW": - strategy.add_implementation( - topi.arm_cpu.qnn_depthwise_conv2d, - topi.arm_cpu.schedule_qnn_depthwise_conv2d, - name="qnn_depthwise_conv2d.arm_cpu", - ) + height, width = data.shape[2:] + y_stride, x_stride = get_const_tuple(attrs.strides) + if height * width * y_stride % 2 == 0: + strategy.add_implementation( + topi.arm_cpu.qnn_depthwise_conv2d, + topi.arm_cpu.schedule_qnn_depthwise_conv2d, + name="qnn_depthwise_conv2d.arm_cpu", + ) + elif y_stride == x_stride == 1: + strategy.add_implementation( + topi.arm_cpu.qnn_unrolled_depthwise_conv2d, + topi.arm_cpu.schedule_qnn_unrolled_depthwise_conv2d, + name="qnn_unrolled_depthwise_conv2d.arm_cpu", + ) + else: + raise TVMError("No QNN depthwise Conv2D Cortex-M schedule supports these params!") + else: + raise TVMError("QNN depthwise Conv2D for Arm Cortex-M DSP got incorrect input layout!") else: raise TVMError("No Arm Cortex-M DSP strategy exists for generic group qnn.conv2d") diff --git a/python/tvm/topi/arm_cpu/__init__.py b/python/tvm/topi/arm_cpu/__init__.py index eba102662bc4..054103f43bef 100644 --- a/python/tvm/topi/arm_cpu/__init__.py +++ b/python/tvm/topi/arm_cpu/__init__.py @@ -23,7 +23,6 @@ from .conv2d_transpose import * from .conv2d_int8 import * from . import conv2d_alter_op -from . import qnn_alter_op from .bitserial_conv2d import * from .bitserial_dense import * from .injective import * @@ -31,3 +30,5 @@ from .pooling import * from .dense import * from .qnn import * +from . import qnn_alter_op +from . import qnn_legalize diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py index 1d36e1dd1e9c..d2a8f1ef6905 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py @@ -286,7 +286,14 @@ def _write_sums_to_memory(num_outputs, offset, stride) -> Iterator[str]: num_packed = (num_outputs - offset) // 2 for i in range(num_packed): index = 2 * i + offset - yield f"int32_t packed_res_{i} = requant_{index} + (requant_{index + 1} << 16);" + # We must explicitly call asm inline to use the PKHBT instruction. It is not part of + # ACLE and has no __builtin. Writing it using masks and bitshifts does not work either: + # Arm GCC 12 with -O3 does not compile these efficiently. + yield f"int packed_res_{i};" + yield ( + f'__asm__ ("pkhbt %0, %1, %2, lsl #16" : "=r" (packed_res_{i}) : ' + f'"r" (requant_{index}), "r" (requant_{index + 1}));' + ) if offset == 1: yield "((int16_t*) output)[1] = (int16_t) requant_0;" diff --git a/python/tvm/topi/arm_cpu/qnn.py b/python/tvm/topi/arm_cpu/qnn.py index fad64cc09bb8..bfd37847f3e0 100644 --- a/python/tvm/topi/arm_cpu/qnn.py +++ b/python/tvm/topi/arm_cpu/qnn.py @@ -17,25 +17,40 @@ """Contains TVMScript implementations of some QNN operators for Arm. Currently, the only ops with compute functions are fused regular and depthwise convolutions for -Arm Cortex-M with DSP. +Arm Cortex-M with DSP. Additionally, these functions explicitly do not support padding - it +must be done in a separate Relay op for memory reasons. """ -from typing import Tuple +from typing import Callable, Dict, Tuple import tvm -from tvm import te -from tvm.tir import const +from tvm import te, tir, TVMError from tvm.script import tir as T +from tvm.tir import const + from ..utils import get_const_tuple from .mprofile.dsp.micro_kernel import tensordot -def int_ceil_division(x, y): +def _int_ceil_division(x, y): return -(x // -y) def _compute_output_dim(data_length, kernel_length, stride): - return int_ceil_division(data_length + 1 - kernel_length, stride) + return _int_ceil_division(data_length + 1 - kernel_length, stride) + + +def _pick_num_outputs(out_width): + """Guess a good value for num_outputs.""" + + assert out_width > 1 + + # num_outputs is capped at 8 + for i in range(2, min(out_width + 1, 8)): + if out_width % i == 0: + return i + + raise TVMError(f"Cannot pick a good num_outputs value for out_width = {out_width}!") def _pick_tensordot_impl(attrs, inputs, num_outputs=2, is_depthwise=False): @@ -118,38 +133,89 @@ def _make_tscript_ptr(buffer, offset, length, dtype="int16"): ) +def _bias_ptr(bias, c): + return _make_tscript_ptr(bias, c, 1, dtype="int32") + + +def _scale_ptr(scale, c): + return _make_tscript_ptr(scale, c, 1, dtype="int32") + + def _make_tscript_call(func_name, *args): return T.evaluate(T.call_extern(func_name, *args, dtype="int32")) def _make_conv2d_primfunc( - call_dimensions: Tuple, - buffer_shapes: Tuple[Tuple, Tuple, Tuple, Tuple, Tuple], + output_dimensions: Tuple[int, int, int, int], + buffer_shapes: Tuple, aligned_func: Tuple[str, str], offset_func: Tuple[str, str], - ptr_gens: Tuple, -): - height, width, out_channels = call_dimensions + ptr_gens: Tuple[Callable, Callable], + output_layout: str = "NHWC", +) -> tir.function.PrimFunc: + """Makes a TIR PrimFunc computing Conv2D using a call to tensordot. + + Can be used to generate regular, depthwise, and grouped Conv2D operators by passing different + arguments and ptr_gen functions. However, it only works for Conv2D operators where the height + stride of the tensor is divisible by two. + + Parameters + ---------- + output_dimensions : Tuple[int, int, int, int] + A tuple containing the out_height, out_width, out_channels, and desired num_outputs values + in that order. + + buffer_shapes: Tuple[tvm.ir.container.Array] + The shapes of the data, kernel, bias, scale, and output tensors, in that order. Each shape + should be a TVM Array. + + aligned_func: Tuple[str, str] + A tuple containing the (name, C implementation) of a word-aligned tensordot operator. + + offset_func: Tuple[str, str] + A tuple containing the (name, C implementation) of a word-unaligned tensordot operator. Can + be a tuple of empty strings if the Conv2D in question does not need an unaligned operator. + + ptr_gens: Tuple[Callable, Callable] + A tuple of two functions to generate data and kernel access pointers. They should take as + inputs the buffer, (y, x, c) indices, and an alignment offset. They should return a + T.tvm_access_ptr object which can be used in T.call_extern. + + output_layout: str + The tensor layout that will be prosued by the generated PrimFunc. Should be NHWC or NCHW. + """ + + out_height, out_width, out_channels, num_outputs = output_dimensions data_shape, kernel_shape, bias_shape, scale_shape, output_shape = buffer_shapes aligned_func_name, aligned_func_code = aligned_func offset_func_name, offset_func_code = offset_func - output_ptr, data_ptr, kernel_ptr = ptr_gens + data_ptr, kernel_ptr = ptr_gens # If the functions are identical, we can skip the second loop if aligned_func_name == offset_func_name: aligned_channels = out_channels - offset_channels = tvm.tir.const(0) - c_step = tvm.tir.const(1) + offset_channels = 0 + c_step = const(1) else: aligned_channels = out_channels // 2 offset_channels = out_channels // 2 - c_step = tvm.tir.const(2) - - def bias_ptr(bias, c): - return _make_tscript_ptr(bias, c, 1, dtype="int32") - - def scale_ptr(scale, c): - return _make_tscript_ptr(scale, c, 1, dtype="int32") + c_step = const(2) + + def output_ptr(output, y, x, c): + if output_layout == "NHWC": + return _make_tscript_ptr( + output, + y * const(out_width * out_channels) + x * const(out_channels * num_outputs) + c, + 1, + ) + elif output_layout == "NCHW": + return _make_tscript_ptr( + output, + c * const(out_height * out_width) + y * const(out_width) + x * const(num_outputs), + 1, + ) + else: + raise TVMError(f"Unsupported out_layout '{output_layout}'!") @T.prim_func def biased_quantized_conv2d( @@ -181,30 +247,36 @@ def biased_quantized_conv2d( __4 = scale[0] # pylint: enable=unused-variable - for c_ax, y_ax, x_ax in T.grid(aligned_channels, height, width): + for c_ax, y_ax, x_ax in T.grid( + const(aligned_channels), const(out_height), const(out_width // num_outputs) + ): with T.block("conv2d_aligned"): T.block_attr({"pragma_import_c": aligned_func_code}) - y, x, c = T.axis.remap("SSS", [y_ax, x_ax, c_ax]) + y, x, c_interval = T.axis.remap("SSS", [y_ax, x_ax, c_ax]) + c = c_interval * c_step _make_tscript_call( aligned_func_name, - output_ptr(output, y, x, c * c_step), - data_ptr(data, y, x, c * c_step), - kernel_ptr(kernel, c * c_step), - bias_ptr(bias, c * c_step), - scale_ptr(scale, c * c_step), + output_ptr(output, y, x, c), + data_ptr(data, y, x, c), + kernel_ptr(kernel, c), + _bias_ptr(bias, c), + _scale_ptr(scale, c), ) - for c_ax, y_ax, x_ax in T.grid(offset_channels, height, width): + for c_ax, y_ax, x_ax in T.grid( + const(offset_channels), const(out_height), const(out_width // num_outputs) + ): with T.block("conv2d_offset"): T.block_attr({"pragma_import_c": offset_func_code}) - y, x, c = T.axis.remap("SSS", [y_ax, x_ax, c_ax]) + y, x, c_interval = T.axis.remap("SSS", [y_ax, x_ax, c_ax]) + c = c_interval * c_step + 1 _make_tscript_call( offset_func_name, - output_ptr(output, y, x, c * c_step + 1), - data_ptr(data, y, x, c * c_step + 1, offset=1), - kernel_ptr(kernel, c * c_step + 1, offset=1), - bias_ptr(bias, c * c_step + 1), - scale_ptr(scale, c * c_step + 1), + output_ptr(output, y, x, c), + data_ptr(data, y, x, c, offset=1), + kernel_ptr(kernel, c, offset=1), + _bias_ptr(bias, c), + _scale_ptr(scale, c), ) return biased_quantized_conv2d @@ -221,23 +293,21 @@ def qnn_conv2d(attrs, inputs, out_type): # Make a few checks to unpack the function arguments and ensure it was called with the right # arguments. Note that unlike most schedules, qnn_conv2d does not use a wrapper. assert len(inputs) == 11 - data, kernel, _izp, _kzp, _iscale, _kscale, bias, scale = inputs[0:8] - output_layout = attrs.out_layout - assert output_layout == "NHWC" + assert not any(get_const_tuple(attrs.padding)) + data, kernel, _izp, _kzp, _iscale, _kscale, bias, scale = inputs[0:8] _, height, width, in_channels = get_const_tuple(data.shape) out_channels, kernel_h, kernel_w, _ = get_const_tuple(kernel.shape) - y_stride, x_stride = get_const_tuple(attrs.strides) + y_stride, x_stride = get_const_tuple(attrs.strides) out_height = _compute_output_dim(height, kernel_h, y_stride) out_width = _compute_output_dim(width, kernel_w, x_stride) # Decide how many sums our function should have running at the same time. Doing # this lets us do "more work" for each memory load, but doing too many of them causes us to run - # out of registers. Currently this is set to either 1 or 2, but autotuning this value would - # improve performance a lot. Tracked by https://github.com/apache/tvm/issues/13528. - - num_outputs = 2 + # out of registers. Currently this is set to the smallest value greater than one that divides + # the output width, but autotuning this value would improve performance a lot. + num_outputs = _pick_num_outputs(out_width) # Next, decide whether whether we need "parity alternation". For example, if we have an # 8x3x3x3 kernel (8 output channels, height 3, width 3, input channels 3) in the OHWI layout, @@ -253,14 +323,6 @@ def qnn_conv2d(attrs, inputs, out_type): aligned_func, offset_func = _pick_tensordot_impl(attrs, inputs, num_outputs, False) - # Helper functions to make pointers - def output_ptr(buffer, y, x, c): - return _make_tscript_ptr( - buffer, - y * const(out_width * out_channels) + x * const(out_channels * num_outputs) + c, - 1, - ) - # We need to disable pylint's unused argument checker, as the kwarg offset is unused but must # be present for compatibility. We cannot add an underscore as we normally would, as this makes # the keyword not match. @@ -284,11 +346,12 @@ def kernel_ptr(buffer, c, offset=0): ) prim_func = _make_conv2d_primfunc( - (const(out_height), const(out_width // num_outputs), const(out_channels)), + (out_height, out_width, out_channels, num_outputs), (data.shape, kernel.shape, bias.shape, scale.shape, out_type.shape), aligned_func, offset_func, - (output_ptr, data_ptr, kernel_ptr), + (data_ptr, kernel_ptr), + output_layout=attrs.out_layout, ) output = te.extern_primfunc([data, kernel, bias, scale], prim_func, name="tir", dtype="int16") @@ -307,30 +370,19 @@ def qnn_depthwise_conv2d(attrs, inputs, out_type): """ assert len(inputs) == 11 + assert not any(get_const_tuple(attrs.padding)) data, kernel, _izp, _kzp, _iscale, _kscale, bias, scale = inputs[0:8] - output_layout = attrs.out_layout - assert output_layout == "NHWC" - _, _, height, width = get_const_tuple(data.shape) _, out_channels, kernel_h, kernel_w = get_const_tuple(kernel.shape) - _, out_height, out_width, _ = get_const_tuple(out_type.shape) - y_stride, x_stride = get_const_tuple(attrs.strides) + y_stride, x_stride = get_const_tuple(attrs.strides) out_height = _compute_output_dim(height, kernel_h, y_stride) out_width = _compute_output_dim(width, kernel_w, x_stride) - num_outputs = 2 + num_outputs = _pick_num_outputs(out_width) aligned_func, offset_func = _pick_tensordot_impl(attrs, inputs, num_outputs, True) - # Helper functions for making pointers. - def output_ptr(buffer, y, x, c): - return _make_tscript_ptr( - buffer, - y * const(out_width * out_channels) + x * const(out_channels * num_outputs) + c, - 1, - ) - def data_ptr(buffer, y, x, c, offset=0): if height * width % 2 == 1: x_ptr_offset = tvm.tir.const(-1) @@ -354,11 +406,12 @@ def kernel_ptr(buffer, c, offset=0): ) prim_func = _make_conv2d_primfunc( - (const(out_height), const(out_width // num_outputs), const(out_channels)), + (out_height, out_width, out_channels, num_outputs), (data.shape, kernel.shape, bias.shape, scale.shape, out_type.shape), aligned_func, offset_func, - (output_ptr, data_ptr, kernel_ptr), + (data_ptr, kernel_ptr), + output_layout=attrs.out_layout, ) output = te.extern_primfunc([data, kernel, bias, scale], prim_func, name="tir", dtype="int16") @@ -368,3 +421,170 @@ def kernel_ptr(buffer, c, offset=0): def schedule_qnn_depthwise_conv2d(_attrs, _outs, _target): """Schedule function for qnn.depthwise_conv2d.""" return None + + +def _make_unrolled_conv2d_primfunc( + output_dimensions: Tuple[int, int, int], + buffer_shapes: Tuple[Tuple, Tuple, Tuple, Tuple, Tuple], + function_names: Dict[Tuple, str], + function_code: str, + ptr_gens: Tuple[Callable, Callable], + output_layout: str = "NHWC", +) -> tir.function.PrimFunc: + """Makes a TIR PrimFunc computing Conv2D using a call to tensordot. + + Can be used to generate regular, depthwise, and grouped Conv2D operators by passing different + arguments and ptr_gen functions. Takes some of the same arguments as _make_conv2d_primfunc, but + requires the tensordot function variations to be passed differently. The generated PrimFunc is + simlar to the one produced by _make_conv2d_primfunc, but unrolls the height and width loops + over the input tensor. This results in longer code, but unlike _make_conv2d_primfunc this + function does not require the height stride be an even number of words. + + This is required to compute layer 25 in MobileNetV1 models, among other things. + + Parameters + ---------- + output_dimensions : Tuple[int, int, int, int] + A tuple containing the out_height, out_width, out_channels, and desired num_outputs values + in that order. + + buffer_shapes: Tuple[tvm.ir.container.Array] + The shapes of the data, kernel, bias, scale, and output tensors, in that order. Each shape + should be a TVM Array. + + function_names: Dict[Tuple, str] + A dictionary mapping a tuple of (data, kernel, output) alignments to the name of the + appropriate tensordot function. + + function_code: str + A string containing all verions of tensordot function our PrimFunc needs. This will usually + be a string of 4+ function variations concatenated together. + + ptr_gens: Tuple[Callable, Callable] + A tuple of two functions to generate data and kernel access pointers. They should take as + inputs the buffer, (y, x, c) indices, and an alignment offset. They should return a + T.tvm_access_ptr object which can be used in T.call_extern. + + output_layout: str + The tensor layout that will be prosued by the generated PrimFunc. Should be NHWC or NCHW. + """ + + out_height, out_width, out_channels = output_dimensions + data_shape, kernel_shape, bias_shape, scale_shape, output_shape = buffer_shapes + data_ptr, kernel_ptr = ptr_gens + + def output_ptr(output, y, c): + if output_layout == "NHWC": + return _make_tscript_ptr(output, y * const(out_width * out_channels) + c, 1) + elif output_layout == "NCHW": + return _make_tscript_ptr( + output, c * const(out_height * out_width) + y * const(out_width), 1 + ) + else: + raise TVMError(f"Unsupported out_layout '{output_layout}'!") + + def make_row_calls(buffers, c_var, out_height): + output, data, kernel, bias, scale = buffers + for y in range(out_height): + for c in range(2): + _make_tscript_call( + function_names[(y + c) % 2, c % 2, 0], + output_ptr(output, y, c_var + c), + data_ptr(data, y, c_var + c, offset=(y + c) % 2), + kernel_ptr(kernel, c_var + c, offset=c), + _bias_ptr(bias, c_var + c), + _scale_ptr(scale, c_var + c), + ) + + @T.prim_func + def biased_quantized_conv2d( + data_handle: T.handle, + kernel_handle: T.handle, + bias_handle: T.handle, + scale_handle: T.handle, + output_handle: T.handle, + ) -> None: + # Same setup is used as in _make_conv2d_primfunc + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + data = T.match_buffer(data_handle, data_shape, dtype="int16") + kernel = T.match_buffer(kernel_handle, kernel_shape, dtype="int16") + bias = T.match_buffer(bias_handle, bias_shape, dtype="int32") + scale = T.match_buffer(scale_handle, scale_shape) + output = T.match_buffer(output_handle, output_shape, dtype="int16") + + # pylint: disable=unused-variable + output[0, 0, 0, 0] = 0 + __1 = data[0, 0, 0, 0] + __2 = kernel[0, 0, 0, 0] + __3 = bias[0, 0, 0, 0] + __4 = scale[0] + # pylint: enable=unused-variable + + for c_ax in T.grid(out_channels // 2): + with T.block("conv2ds"): + T.block_attr({"pragma_import_c": function_code}) + c = T.axis.remap("S", [c_ax]) * 2 + make_row_calls((output, data, kernel, bias, scale), c, out_height) + + return biased_quantized_conv2d + + +def qnn_unrolled_depthwise_conv2d(attrs, inputs, out_type): + """Compute for qnn.depthwise_conv2d with NCHW layout for convolutions with small width, height. + + Behaves similarly to qnn_depthwise_conv2d, but does not iterate over the output width and height + and instead calls these functions explicitly. This gives a tiny performance boost in exchange + for larger code size, but more importantly does not require out_width * out_height + * y_stride % 2 == 0. This does, however, require y_stride == x_stride == 1. + """ + + assert len(inputs) == 11 + assert not any(get_const_tuple(attrs.padding)) + y_stride, x_stride = get_const_tuple(attrs.strides) + assert y_stride == x_stride == 1 + + data, kernel, _izp, _kzp, _iscale, _kscale, bias, scale = inputs[0:8] + _, _, height, width = get_const_tuple(data.shape) + _, out_channels, kernel_h, kernel_w = get_const_tuple(kernel.shape) + + y_stride, x_stride = get_const_tuple(attrs.strides) + out_height = _compute_output_dim(height, kernel_h, y_stride) + out_width = _compute_output_dim(width, kernel_w, x_stride) + + rq_output_zero_point_const = inputs[10] + assert len(rq_output_zero_point_const.op.body) == 1 + output_zero_point = rq_output_zero_point_const.op.body[0] + + dimensions = (width, kernel_h, kernel_w) + x_strides = (1, out_channels) + + func_names = {} + impls = [] + for alignment in ((0, 0, 0), (0, 1, 0), (1, 0, 0), (1, 1, 0)): + func_name, impl = tensordot.tensordot_int16_impl( + out_width, dimensions, alignment, x_strides, output_zero_point=output_zero_point + ) + func_names[alignment] = func_name + impls.append(impl) + + def data_ptr(buffer, y, c, offset=0): + return _make_tscript_ptr(buffer, c * const(width * height) + y * const(width) - offset, 1) + + def kernel_ptr(buffer, c, offset=0): + return _make_tscript_ptr(buffer, c * const(kernel_h * kernel_w) - offset, 1) + + prim_func = _make_unrolled_conv2d_primfunc( + (out_height, out_width, out_channels), + (data.shape, kernel.shape, bias.shape, scale.shape, out_type.shape), + func_names, + "\n".join(impls), + (data_ptr, kernel_ptr), + output_layout=attrs.out_layout, + ) + output = te.extern_primfunc([data, kernel, bias, scale], prim_func, name="tir", dtype="int16") + return [output] + + +def schedule_qnn_unrolled_depthwise_conv2d(_attrs, _outs, _target): + """Schedule function for qnn.depthwise_conv2d.""" + return None diff --git a/python/tvm/topi/arm_cpu/qnn_alter_op.py b/python/tvm/topi/arm_cpu/qnn_alter_op.py index 00225493db96..31782d69d032 100644 --- a/python/tvm/topi/arm_cpu/qnn_alter_op.py +++ b/python/tvm/topi/arm_cpu/qnn_alter_op.py @@ -16,73 +16,129 @@ # under the License. """Arm Cortex-M specific optimizations for quantized operators.""" +from typing import Iterable + import numpy as np from tvm import nd, relay, target -from ..nn import qnn_requantize_alter_layout, qnn_add_alter_layout - +from ..utils import get_const_tuple +from ..nn import qnn_conv2d_alter_layout, add_alter_layout, qnn_requantize_alter_layout -@qnn_requantize_alter_layout.register(["arm_cpu"]) -def alter_requantize_layout(attrs, inputs, _tinfos, _out_type): - """Changes a floating point requantize op to use int64 multiply + shift for microTVM. - Usually, this is done by QNN legalization. However, microTVM wants to manually choose the - integer rounding constants in order to: - (a) Have int32, not int64 constants - (b) Use a constant rounding shift to skip a memory load. +def prev_ops_match(curr_op: relay.expr.Call, pattern: Iterable[str]): + """Checks if the names of nested Relay operators match a pattern. - Ideally, we would pick these constants in the requantize (or fused) schedule. Unfortunately that - is not currently possible, so we pick them with `alter_layout` as a hack. This will only work if - the requantize schedule "plays along" with this hack. + Note this function considers `curr_op` as a linear stack of operators, only considering args[0] + when traversing backwards. `pattern` should be an Iterable of operator names, written backwards + from last to first. """ + prev_op = curr_op + for op_name in pattern: + if (not hasattr(prev_op, "op")) or prev_op.op.name != op_name: + return False + prev_op = prev_op.args[0] + return True - # Only microTVM Cortex-M boards with DSP use the relevant schedules - current_target = target.Target.current(allow_none=False) - if not (current_target.features.has_dsp and "cortex-m" in current_target.mcpu): - return None - _, in_scale, _, out_scale, _ = inputs - in_scale_numpy = in_scale.data.numpy().astype("float64") - out_scale_scalar = out_scale.data.numpy().item() +def edit_attrs(attrs, **kwargs): + return {**attrs, **kwargs} - # Shifting by 33 and rounding means shifting by 32, adding 1, and shifting by 1 again. This is - # useful, because shifting a multiplication product by 32 can be done for "free" with SMMUL - scales = ((in_scale_numpy / out_scale_scalar) * 2**33).astype("int32") - # Requantize ops in Relay do not support int32 scales - if we try to use one, requantize.cc will - # raise an error. As a hacky work-around, we change the scale dtype to float32, without changing - # underlying data. This works, as our compute function knows to interpret the scale as an int32. +def change_numpy_layout(arr, src_layout, dst_layout): + assert src_layout.isalpha() and dst_layout.isalpha() + axis_order = [src_layout.index(c) for c in dst_layout] + return np.transpose(arr, axis_order) - # This is only a work-around - a better long-term solution would be adding a new integer - # requantize op, which takes integer scales, shifts, and rounding behavior. - fake_float_scales = scales.view("float32") - scale_constant = relay.Constant(nd.array(fake_float_scales)) - return relay.qnn.op.requantize(inputs[0], scale_constant, *inputs[2:], **attrs) +def _squash_transformations(expr): + if isinstance(expr, relay.expr.Constant): + return expr.data.numpy() + assert isinstance(expr, relay.expr.Call) + assert len(expr.args) == 1 + + prev_kernel = _squash_transformations(expr.args[0]) + attrs = expr.attrs + + if expr.op.name == "layout_transform": + return change_numpy_layout(prev_kernel, attrs.src_layout, attrs.dst_layout) + elif expr.op.name == "cast": + return prev_kernel.astype(attrs.dtype) + elif kernel.op.name == "expand_dims": + new_axes = range(attrs.axis, attrs.axis + attrs.num_newaxis) + return np.expand_dims(prev_kernel, tuple(new_axes)) + else: + raise RuntimeError(f"Invalid kernel transformation '{expr}'!") + + +def _alter_depthwise_conv2d_layout(depthwise_conv2d): + cast_op = depthwise_conv2d.args[0] + requantize_op = cast_op.args[0] + add_op = requantize_op.args[0] + prev_conv2d_op = add_op.args[0] + + return relay.qnn.op.conv2d( + relay.layout_transform( + relay.cast( + relay.qnn.op.requantize( + relay.op.add( + relay.qnn.op.conv2d( + *prev_conv2d_op.args, + **edit_attrs(prev_conv2d_op.attrs, out_layout="NCHW"), + ), + relay.layout_transform( + add_op.args[1], + src_layout="NHWC", + dst_layout="NCHW", + ), + ), + *requantize_op.args[1:], + **edit_attrs(requantize_op.attrs, axis=1), + ), + dtype="int16", + ), + src_layout="NCHW", + dst_layout="NHWC", + ), + *depthwise_conv2d.args[1:], + **edit_attrs(depthwise_conv2d.attrs, data_layout="NCHW"), + ) + +@qnn_conv2d_alter_layout.register(["arm_cpu"]) +def alter_conv2d_layout(attrs, inputs, _tinfos, _out_type): + """Adjust a qnn.conv2d and preceeding ops to better fit on Cortex-M.""" + current_target = target.Target.current(allow_none=False) + if not "cortex-m" in current_target.mcpu: + return None -def _is_qnn_op_depthwise_conv2d(qnn_conv2d_op): - return relay.op.strategy.generic.is_depthwise_conv2d( - qnn_conv2d_op.args[0].type_annotation.shape, - qnn_conv2d_op.attrs.data_layout, - qnn_conv2d_op.args[1].data.shape, - qnn_conv2d_op.attrs.kernel_layout, - qnn_conv2d_op.attrs.groups, + # Always cast to int16 and pick a our desired kernel layout - this won't affect anything + data_expr, kernel_expr = inputs[:2] + is_depthwise = attrs.groups > 1 + new_kernel_layout = "IOHW" if is_depthwise else "OHWI" + + op = relay.qnn.op.conv2d( + relay.cast(data_expr, dtype="int16"), + relay.cast(kernel_expr, dtype="int16"), + *inputs[2:], + **edit_attrs(attrs, kernel_layout=new_kernel_layout, out_layout="NHWC"), ) + # If possible, modify depthwise ops to take as input NCHW instead. + if is_depthwise and prev_ops_match(op.args[0], ("cast", "qnn.requantize", "add", "qnn.conv2d")): + op = _alter_depthwise_conv2d_layout(op) + + return op + -@qnn_add_alter_layout.register(["arm_cpu"]) +@add_alter_layout.register(["arm_cpu"]) def alter_add_layout(_attrs, inputs, _tinfos, _out_type): """Fuses the zero point for a previous quantized operator with this add operation. Currently only supports qnn.conv2d, but qnn.dense support should be added. Note that this optimization means we must pad tensors with the input zero point, and NOT with zero. """ - - prev_op, biases = inputs - if not hasattr(prev_op, "op"): - return None - if prev_op.op.name != "qnn.conv2d": + prev_op, biases_data_op = inputs + if not prev_ops_match(inputs[0], ("qnn.conv2d",)): return None # We should not perform this alteration if the target has a uint * int SIMD MAC operation (since @@ -93,9 +149,9 @@ def alter_add_layout(_attrs, inputs, _tinfos, _out_type): return None conv_input_zp = prev_op.args[2].data.numpy().item() - kernel = prev_op.args[1].data.numpy() + kernel = _squash_transformations(prev_op.args[1]) - if _is_qnn_op_depthwise_conv2d(prev_op): + if prev_op.attrs.groups == prev_op.attrs.channels: axes_to_sum = "HW" elif prev_op.attrs.groups == 1: axes_to_sum = "HWI" @@ -108,6 +164,13 @@ def alter_add_layout(_attrs, inputs, _tinfos, _out_type): # The zero point is subtracted from the input elements, so we need a "-" sign here zp_shifted_sums = element_sums * (-conv_input_zp) + # The bias values may or may not be wrapped in an expand_dims op + if isinstance(biases_data_op, relay.expr.Call): + biases = biases_data_op.args[0] + else: + biases = biases_data_op + assert isinstance(biases, relay.expr.Constant) + # We want to make sure new_biases is representable as an int32. It's tempting to just check # whether arr.dtype == "int32" (since Numpy will automatically increase dtype in some cases) # but this leads to weird wrapping behavior and doesn't work. We must do it manually. @@ -115,8 +178,77 @@ def alter_add_layout(_attrs, inputs, _tinfos, _out_type): if new_biases.min() < -(2**31) or new_biases.max() > 2**31 - 1: return None + current_target = target.Target.current(allow_none=False) new_input_zp = relay.Constant(nd.array(np.int32(0))) - new_conv_args = (*prev_op.args[:2], new_input_zp, *prev_op.args[3:]) - new_conv_op = relay.qnn.op.conv2d(*new_conv_args, **prev_op.attrs) + new_conv_args = [*prev_op.args[:2], new_input_zp, *prev_op.args[3:]] bias_constant = relay.Constant(nd.array(new_biases.astype("int32"))) - return relay.add(new_conv_op, bias_constant) + + # We should handle padding separately from convolution, so the original tensor can be + # de-allocated immediately. This may also help with fusing padding onto a previous + # operator. However, only do this if we're working with Cortex-M devices. + padding = get_const_tuple(prev_op.attrs.padding) + if "cortex-m" in current_target.mcpu and any(padding): + data_layout = prev_op.attrs.data_layout + assert data_layout.isupper() + + pad_up, pad_left, pad_down, pad_right = padding + pad_op_arg = [(0, 0)] * len(data_layout) + pad_op_arg[data_layout.index("H")] = (pad_up, pad_down) + pad_op_arg[data_layout.index("W")] = (pad_left, pad_right) + new_conv_args[0] = relay.nn.pad(new_conv_args[0], tuple(pad_op_arg), conv_input_zp) + + new_conv_op = relay.qnn.op.conv2d( + *new_conv_args, + **edit_attrs(prev_op.attrs, padding=(0, 0, 0, 0)), + ) + # If biases was wrapped in an expand_dims op, we must re-wrap it + if isinstance(biases_data_op, relay.expr.Call): + new_biases_op = relay.expand_dims(bias_constant, **biases_data_op.attrs) + else: + new_biases_op = bias_constant + + return relay.add(new_conv_op, new_biases_op) + + +@qnn_requantize_alter_layout.register(["arm_cpu"]) +def alter_requantize_layout(attrs, inputs, _tinfos, _out_type): + """Changes a floating point requantize op to use int64 multiply + shift for microTVM. + + Usually, this is done by QNN legalization. However, microTVM wants to manually choose the + integer rounding constants in order to: + (a) Have int32, not int64 constants + (b) Use a constant rounding shift to skip a memory load. + + Ideally, we would pick these constants in the requantize (or fused) schedule. Unfortunately that + is not currently possible, so we pick them with `alter_layout` as a hack. This will only work if + the requantize schedule "plays along" with this hack. + """ + + # Only microTVM Cortex-M boards with DSP use the relevant schedules + current_target = target.Target.current(allow_none=False) + if not (current_target.features.has_dsp and "cortex-m" in current_target.mcpu): + return None + + if not prev_ops_match(inputs[0], ("add", "qnn.conv2d")): + return None + + _, in_scale, _, out_scale, _ = inputs + in_scale_numpy = in_scale.data.numpy().astype("float64") + out_scale_scalar = out_scale.data.numpy().item() + + # Shifting by 33 and rounding means shifting by 32, adding 1, and shifting by 1 again. This is + # useful, because shifting a multiplication product by 32 can be done for "free" with SMMUL + scales = ((in_scale_numpy / out_scale_scalar) * 2**33).astype("int32") + + # Requantize ops in Relay do not support int32 scales - if we try to use one, requantize.cc will + # raise an error. As a hacky work-around, we change the scale dtype to float32, without changing + # underlying data. This works, as our compute function knows to interpret the scale as an int32. + + # This is only a work-around - a better long-term solution would be adding a new integer + # requantize op, which takes integer scales, shifts, and rounding behavior. + fake_float_scales = scales.view("float32") + + scale_constant = relay.Constant(nd.array(fake_float_scales)) + new_attrs = {k: attrs[k] for k in attrs.keys()} + new_attrs["out_dtype"] = "int16" + return relay.qnn.op.requantize(inputs[0], scale_constant, *inputs[2:], **new_attrs) diff --git a/python/tvm/topi/arm_cpu/qnn_legalize.py b/python/tvm/topi/arm_cpu/qnn_legalize.py new file mode 100644 index 000000000000..2833fbce26f1 --- /dev/null +++ b/python/tvm/topi/arm_cpu/qnn_legalize.py @@ -0,0 +1,382 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""QNN legalization transforms that help eliminate sparse channels. + +Some models (like MobileNetV1 when fine-tuned) have output channels in their kernels which are +completely full of zeros. Sometimes these can be optimized away by the C compiler, but this does not +happen when complex schedules (like the ACLE tensordot convolutions) are used. + +Instead, we will remove these channels by replacing blocks of operators with equivalent "denser" +ones during legalization. This is harder than it looks - while the outputs of channels with all-zero +kernels do not depend on the input data, they are usually not zero. We work around this by computing +how these constant values affect subsequent operators, and "folding" these effects into a bias_add. + +It would eventually be nice to have a generalized, cross-target solution for removing zero channels, +as there is no downside. This may be possible with Relax, but I'm unsure. +""" + +import numpy as np +from scipy.signal import convolve2d +from tvm.topi.utils import get_const_tuple +from tvm import nd, relay +from .qnn_alter_op import prev_ops_match, edit_attrs +from ..nn import bias_add_legalize + + +def _compute_fixed_conv2d_outputs(requantize_op): + """Compute all conv2d output values that do not depend on the layer input. + + Parameters + ---------- + requantize_op : relay.expr.Call + A qnn.requantize Relay operator, which must be preceeded by a nn.bias_add op and a + qnn.conv2d operator. The qnn.conv2d operator must have groups==1. All arguments to all three + operators, besides the main tensor, must be constants. + + Returns + ------- + fixed_outputs : Dict[int, int] + A dictionary showing which of the conv2d -> bias_add -> requantize output channels are + "fixed" - i.e. those that do not depend on the input tensor. Each key in the dictionary is + an output channel index, and each value is the value that all entries in that output channel + will have. If the block has no fixed output channels, this dictionary will be empty. + """ + + bias_add_op = requantize_op.args[0] + conv2d_op = bias_add_op.args[0] + + assert conv2d_op.attrs.kernel_layout.isalpha() + assert conv2d_op.attrs.groups == 1 + kernel = conv2d_op.args[1].data.numpy() + oc_axis = conv2d_op.attrs.kernel_layout.index("O") + + num_channels = kernel.shape[oc_axis] + rq_input_scale = requantize_op.args[1].data.numpy() + rq_output_scale = requantize_op.args[3].data.numpy().item() + rq_output_zero_point = requantize_op.args[4].data.numpy().item() + bias_data = bias_add_op.args[1].data.numpy() + + fixed_outputs = {} + + for i in range(num_channels): + if np.any(np.take(kernel, i, axis=oc_axis)): + continue + scale = rq_input_scale[i] / rq_output_scale + channel_constant = round(bias_data[i] * scale + rq_output_zero_point) + clipped = min(127, max(-128, channel_constant)) + fixed_outputs[i] = clipped + + return fixed_outputs + + +def _compute_fixed_depthwise_outputs(requantize_op, fixed_channel_inputs): + """Compute all depthwise conv2d output values that do not depend on the PREVIOUS layer input. + + We take as input a requantize operator, and a dictionary of which inputs to our depthwise + operator are fixed and what values they are fixed to. However, a fixed input to one channel + of our depthwise operator does NOT guarantee we can remove the output, because of padding. + This function checks if the padding makes a difference in the outputs, and if not, removes + the channels from the depthwise_conv2d. + + Parameters + ---------- + requantize_op : relay.expr.Call + A qnn.requantize Relay operator, which must be preceeded by a nn.bias_add op and a + qnn.conv2d operator. The qnn.conv2d operator must be depthwise. All arguments to all three + operators, besides the main tensor, must be constants. + + fixed_channel_inputs : Dict[int, int] + A dictionary showing which input channels to the qnn.conv2d operator have fixed values, and + what those values are fixed to. Can be empty. Usually, this will be generated by + _compute_fixed_conv2d_outputs. + + Returns + ------- + fixed_outputs : Dict[int, int] + A dictionary showing which of the conv2d -> bias_add -> requantize output channels are + "fixed" - i.e. those that do not depend on the input tensor. Each key in the dictionary is + an output channel index, and each value is the value that all entries in that output channel + will have. If the block has no fixed output channels, this dictionary will be empty. + """ + + bias_add_op = requantize_op.args[0] + depthwise_op = bias_add_op.args[0] + + assert depthwise_op.attrs.kernel_layout.isalpha() + assert depthwise_op.attrs.groups > 1 + kernel = depthwise_op.args[1].data.numpy() + oc_axis = depthwise_op.attrs.kernel_layout.index("O") + + conv_input_zero_point = depthwise_op.args[2].data.numpy().item() + rq_input_scale = requantize_op.args[1].data.numpy() + rq_output_scale = requantize_op.args[3].data.numpy().item() + rq_output_zero_point = requantize_op.args[4].data.numpy().item() + bias_data = bias_add_op.args[1].data.numpy() + + kernel_size = get_const_tuple(depthwise_op.attrs.kernel_size) + fixed_outputs = {} + + for i, fixed_input in fixed_channel_inputs.items(): + input_array = np.full(kernel_size, fixed_input, dtype="int32") - conv_input_zero_point + kernel_channel = np.take(kernel, i, axis=oc_axis).reshape(kernel_size) + scale = rq_input_scale[i] / rq_output_scale + + convolved = convolve2d(input_array, kernel_channel, mode="same") + rounded = np.around((convolved + bias_data[i]) * scale).astype("int32") + clipped = np.clip(rounded + rq_output_zero_point, -128, 127) + + # We require the ENTIRE padded convolution to all have the same clipped value before we do + # a replacement. This is excessive - we only have to check for the padding that will + # actually be performed on the depthwise convolution, which is often less. If we felt even + # more ambitious, we could do the replacement for "close enough" looking convolution + # outputs, which in theory could reduce accuracy but in practice does not. Doing this would + # yield a ~0.5% speed gain on MobileNetV1, and nothing on other models. + + if np.all(clipped == clipped[0, 0]): + fixed_outputs[i] = clipped[0, 0] + + # TODO @guberti look for all-zero entries in the depthwise kernel. I don't think these really + # occur in practice, but it would be nice for theoretical completeness. + + return fixed_outputs + + +def _excise_conv2d_channels(empty_channels, input_op, requantize_op, is_depthwise=False): + bias_add_op = requantize_op.args[0] + conv2d_op = bias_add_op.args[0] + axis = conv2d_op.attrs.kernel_layout.index("O") + + kernel_data = np.delete(conv2d_op.args[1].data.numpy(), empty_channels, axis=axis) + bias_data = np.delete(bias_add_op.args[1].data.numpy(), empty_channels) + in_scale_data = np.delete(conv2d_op.args[5].data.numpy(), empty_channels) + out_scale_data = np.delete(requantize_op.args[1].data.numpy(), empty_channels) + num_channels = kernel_data.shape[axis] + if is_depthwise: + num_groups = num_channels + else: + num_groups = 1 + + return relay.qnn.op.requantize( + relay.nn.bias_add( + relay.qnn.op.conv2d( + input_op, + relay.Constant(nd.array(kernel_data)), + *conv2d_op.args[2:5], + relay.Constant(nd.array(in_scale_data)), + **edit_attrs(conv2d_op.attrs, channels=num_channels, groups=num_groups), + ), + relay.Constant(nd.array(bias_data)), + **bias_add_op.attrs, + ), + relay.Constant(nd.array(out_scale_data)), + *requantize_op.args[2:], + **requantize_op.attrs, + ) + + +def _excise_avg_pool_channels(empty_channels, input_op, first_reshape_op, axis=1): + outer_cast = first_reshape_op.args[0].args[0] + avg_pool = outer_cast.args[0] + inner_cast = avg_pool.args[0] + + new_shape = list(get_const_tuple(first_reshape_op.attrs.newshape)) + new_shape[axis] -= len(empty_channels) + + return relay.reshape( + relay.cast( + relay.nn.avg_pool2d(relay.cast(input_op, **inner_cast.attrs), **avg_pool.attrs), + **outer_cast.attrs, + ), + **edit_attrs(first_reshape_op.attrs, newshape=new_shape), + ) + + +def _fold_into_conv_bias(fixed_inputs, conv2d_op, input_op): + assert not any(get_const_tuple(conv2d_op.attrs.padding)) + in_axis = conv2d_op.attrs.kernel_layout.index("I") + out_axis = conv2d_op.attrs.kernel_layout.index("O") + + kernel = conv2d_op.args[1].data.numpy() + zero_point = conv2d_op.args[2].data.numpy().item() + + extra_bias = np.zeros((kernel.shape[out_axis],), dtype="int32") + + # For every output channel + for i in range(kernel.shape[out_axis]): + out_kernel_slice = np.expand_dims(np.take(kernel, i, axis=out_axis), axis=out_axis) + + # For every input channel that is being removed: + for j, val in fixed_inputs.items(): + kernel_slice = np.take(out_kernel_slice, j, axis=in_axis) + accumulator = np.sum(kernel_slice * (val - zero_point)) + extra_bias[i] += accumulator + + stripped_kernel = np.delete(kernel, tuple(fixed_inputs.keys()), axis=in_axis) + new_conv = relay.qnn.op.conv2d( + input_op, + relay.Constant(nd.array(stripped_kernel)), + *conv2d_op.args[2:], + **conv2d_op.attrs, + ) + + return new_conv, extra_bias + + +def _fold_into_dense_bias(fixed_inputs, dense_op, input_op, channel_axis=1): + weights = dense_op.args[1].data.numpy() + assert channel_axis < 2 + assert len(weights.shape) == 2 + zero_point = dense_op.args[2].data.numpy().item() + + extra_bias = np.zeros((weights.shape[1 - channel_axis],), dtype="int32") + + # For every output channel + for i in range(weights.shape[1 - channel_axis]): + out_weights_slice = np.take(weights, i, axis=1 - channel_axis) + + # For every input channel that is being removed: + for j, val in fixed_inputs.items(): + weight = out_weights_slice[j] + extra_bias[i] += (val - zero_point) * weight + + stripped_weights = np.delete(weights, tuple(fixed_inputs.keys()), axis=channel_axis) + new_dense = relay.qnn.op.dense( + input_op, + relay.Constant(nd.array(stripped_weights)), + *dense_op.args[2:], + **dense_op.attrs, + ) + + return new_dense, extra_bias + + +def _densify_conv_depthwise_conv_pattern(attrs, inputs): + """Rewrites a regular -> depthwise -> regular convolution pattern to excise empty out channels. + + Should be called as part of legalization (before dtypes and layouts are rewritten) and with the + BIAS ADD OPERATOR'S (the one we'll use to "fold in" our constants) `attrs` and `inputs`. The + last regular conv2d operator must be unpadded. + """ + current_conv = inputs[0] + depthwise_requantize = current_conv.args[0] + top_requantize = depthwise_requantize.args[0].args[0].args[0] + top_conv2d = top_requantize.args[0].args[0] + + fixed_conv2d_outputs = _compute_fixed_conv2d_outputs(top_requantize) + fixed_dw_outputs = _compute_fixed_depthwise_outputs(depthwise_requantize, fixed_conv2d_outputs) + + # Ensure number of channels is divisible by two + if len(fixed_dw_outputs) % 2 > 0: + fixed_dw_outputs.popitem() + + if not fixed_dw_outputs: + return None + + unneeded_channels = tuple(fixed_dw_outputs.keys()) + new_top_conv2d = _excise_conv2d_channels(unneeded_channels, top_conv2d.args[0], top_requantize) + new_dw_conv2d = _excise_conv2d_channels( + unneeded_channels, new_top_conv2d, depthwise_requantize, is_depthwise=True + ) + new_conv, extra_bias = _fold_into_conv_bias(fixed_dw_outputs, current_conv, new_dw_conv2d) + + new_bias = inputs[1].data.numpy() + extra_bias + new_op = relay.nn.bias_add(new_conv, relay.Constant(nd.array(new_bias)), **attrs) + return new_op + + +def _densify_conv_pool_dense_pattern(attrs, inputs): + """Rewrites a regular conv -> pool -> dense pattern to excise empty out channels from the conv. + + Should be called as part of legalization (before dtypes and layouts are rewritten) and with the + BIAS ADD operator's `attrs` and `inputs` (the one we'll use to "fold in" our constants). The + average pool operator must reduce the height and width dimensions to 1x1. + """ + first_reshape = inputs[0].args[0] + top_requantize = first_reshape.args[0].args[0].args[0].args[0].args[0] + top_conv2d = top_requantize.args[0].args[0] + + fixed_conv2d_outputs = _compute_fixed_conv2d_outputs(top_requantize) + + # Ensure number of channels is divisible by two + if len(fixed_conv2d_outputs) % 2 > 0: + fixed_dw_outputs.popitem() + + if not fixed_conv2d_outputs: + return None + + unneeded_channels = tuple(fixed_conv2d_outputs.keys()) + new_top_conv2d = _excise_conv2d_channels(unneeded_channels, top_conv2d.args[0], top_requantize) + new_avg_pool = _excise_avg_pool_channels(unneeded_channels, new_top_conv2d, first_reshape) + new_conv, extra_bias = _fold_into_dense_bias(fixed_conv2d_outputs, inputs[0], new_avg_pool) + + new_bias = inputs[1].data.numpy() + extra_bias + new_op = relay.nn.bias_add(new_conv, relay.Constant(nd.array(new_bias)), **attrs) + return new_op + + +@bias_add_legalize.register(["arm_cpu"]) +def legalize_bias_add(attrs, inputs, _tinfos): + """Remove empty convolution channels when possible, and "fold" them into the bias add. + + TODO @guberti: these rewrites are always beneficial and will improve performance cross-platform, + should we enable them for all platforms, not just arm_cpu? + """ + + if prev_ops_match( + inputs[0], + ( + "qnn.conv2d", + "qnn.requantize", + "nn.bias_add", + "qnn.conv2d", + "qnn.requantize", + "nn.bias_add", + "qnn.conv2d", + ), + ): + current_conv = inputs[0] + depthwise_conv2d = current_conv.args[0].args[0].args[0] + top_conv2d = depthwise_conv2d.args[0].args[0].args[0] + if ( + not any(get_const_tuple(current_conv.attrs.padding)) + and current_conv.attrs.groups == 1 + and depthwise_conv2d.attrs.groups > 1 + and top_conv2d.attrs.groups == 1 + ): + return _densify_conv_depthwise_conv_pattern(attrs, inputs) + + if prev_ops_match( + inputs[0], + ( + "qnn.dense", + "reshape", + "reshape", + "cast", + "nn.avg_pool2d", + "cast", + "qnn.requantize", + "nn.bias_add", + "qnn.conv2d", + ), + ): + avg_pool = inputs[0].args[0].args[0].args[0].args[0] + top_requantize = avg_pool.args[0].args[0] + top_conv2d = top_requantize.args[0].args[0] + if top_conv2d.attrs.groups == 1: + return _densify_conv_pool_dense_pattern(attrs, inputs) + + return None diff --git a/python/tvm/topi/hexagon/qnn/nn.py b/python/tvm/topi/hexagon/qnn/nn.py index 5702be2e1a33..e60314b82757 100644 --- a/python/tvm/topi/hexagon/qnn/nn.py +++ b/python/tvm/topi/hexagon/qnn/nn.py @@ -874,7 +874,7 @@ def qnn_dense( # Add bias if bias is not None: - out = te.compute(out.shape, lambda n, c: out[n, c] + bias[c]) + out = te.compute(out.shape, lambda n, c: out[n, c] + bias[0, c]) # Requantize output of dense # Q_output = zp_output + round((scale_input)/(scale_output) * (Q_input - zp_input)) diff --git a/python/tvm/topi/nn/qnn.py b/python/tvm/topi/nn/qnn.py index 9aaa452a7392..98bbb7ebe50f 100644 --- a/python/tvm/topi/nn/qnn.py +++ b/python/tvm/topi/nn/qnn.py @@ -191,8 +191,8 @@ def _dispatch_sim_dequantize(value): @tvm.target.generic_func -def qnn_requantize_alter_layout(_attrs, _inputs, _tinfos, _out_type): - """Change requantize layout. +def qnn_conv2d_alter_layout(_attrs, _inputs, _tinfos, _out_type): + """Change qnn.conv2d layout. Parameters ---------- @@ -213,7 +213,27 @@ def qnn_requantize_alter_layout(_attrs, _inputs, _tinfos, _out_type): @tvm.target.generic_func -def qnn_add_alter_layout(_attrs, _inputs, _tinfos, _out_type): +def bias_add_legalize(_attrs, _inputs, _tinfos): + """Legalize bias_add layout. + + Bias add is not a QNN-specific function, but this generic exists so that empty channels can + be excised from quantized conv2d operators and folded into bias adds. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current convolution + inputs : tvm.relay.Expr + Grouped input symbols + tinfos : list + Input shape and dtype + + """ + return None + + +@tvm.target.generic_func +def add_alter_layout(_attrs, _inputs, _tinfos, _out_type): """Change add layout. Add is not a QNN-specific function, but this generic exists so that bias add operations can be @@ -239,9 +259,8 @@ def qnn_add_alter_layout(_attrs, _inputs, _tinfos, _out_type): @tvm.target.generic_func -def qnn_conv2d_alter_layout(_attrs, _inputs, _tinfos, _out_type): - """Change qnn.conv2D layout. - Not to change by default +def qnn_requantize_alter_layout(_attrs, _inputs, _tinfos, _out_type): + """Change requantize layout. Parameters ---------- @@ -253,6 +272,10 @@ def qnn_conv2d_alter_layout(_attrs, _inputs, _tinfos, _out_type): Input shape and dtype out_type: type The output type + + Note + ---- + Unlike other TOPI functions, this function operates on both graph level and operator level. """ return None diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 4ff8a59b349e..f009bda9cd98 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -274,6 +274,7 @@ Array GetPassPrefix(bool is_homogeneous, bool is_vm) { pass_seqs.push_back(transform::InferType()); } pass_seqs.push_back(transform::AlterOpLayout()); + pass_seqs.push_back(transform::SimplifyExprPostAlterOp()); } // Fast math optimizations. diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index c64957b5b62a..a9b7390c0374 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -969,6 +969,19 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { return RewritePatterns(composer.MakeCallbacks(), expr, mod); } +Expr SimplifyExprPostAlterOp(const Expr& expr, const IRModule& mod) { + // stripped-down version of AlterOp that cleans up some patterns + // often left by the AlterOpLayout pass. + DFPatternRewriteComposer composer; + composer.AddRewrite(); + composer.AddRewrite(); + composer.AddRewrite(); + composer.AddRewrite(); + composer.AddRewrite(); + composer.AddRewrite(); + return RewritePatterns(composer.MakeCallbacks(), expr, mod); +} + namespace transform { Pass SimplifyExpr() { @@ -979,7 +992,17 @@ Pass SimplifyExpr() { return CreateFunctionPass(pass_func, 0, "SimplifyExpr", {"InferType"}); } +Pass SimplifyExprPostAlterOp() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(SimplifyExprPostAlterOp(f, m)); + }; + return CreateFunctionPass(pass_func, 0, "SimplifyExprPostAlterOp", {"InferType"}); +} + TVM_REGISTER_GLOBAL("relay._transform.SimplifyExpr").set_body_typed(SimplifyExpr); +TVM_REGISTER_GLOBAL("relay._transform.SimplifyExprPostAlterOp") + .set_body_typed(SimplifyExprPostAlterOp); } // namespace transform diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 6bf14424bf38..514ec8395821 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -23,6 +23,7 @@ #include "codegen_c.h" #include +#include #include #include @@ -632,7 +633,8 @@ void CodeGenC::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, } void CodeGenC::VisitStmt_(const AllocateConstNode* op) { - std::string symbol_name = op->buffer_var->name_hint; + std::string symbol_name = AllocVarID(op->buffer_var.get()); + int64_t num_elements = 1; const auto& data = op->data.value(); diff --git a/tests/python/relay/qnn/test_clip_legalization.py b/tests/python/relay/qnn/test_clip_legalization.py new file mode 100644 index 000000000000..d1a9c5901a2d --- /dev/null +++ b/tests/python/relay/qnn/test_clip_legalization.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test that do-nothing requantize -> clip operators are removed during legalization.""" + +import numpy as np +import pytest + +import tvm +from tvm import nd, relay +from tvm.relay import transform + + +def run_opt_pass(expr, passes): + passes = passes if isinstance(passes, list) else [passes] + mod = tvm.IRModule.from_expr(expr) + seq = tvm.transform.Sequential(passes) + with tvm.transform.PassContext(opt_level=3): + mod = seq(mod) + entry = mod["main"] + return entry if isinstance(expr, relay.Function) else entry.body + + +def tvm_const(obj): + return relay.Constant(nd.array(obj)) + + +@pytest.mark.parametrize( + "dtype,min_val,max_val,is_redundant", + [ + ("int8", -128, 127, True), + ("int8", -127, 127, False), + ("int16", -128, 127, False), + ("int32", -2147483648, 2147483647, True), + ], +) +def test_removes_redundant_requantize_clip_ops(dtype, min_val, max_val, is_redundant): + """Test that qnn.requantize -> clip sequences are removed during legalization if the bounds of + the clip operator match the min and max values of the data type.""" + + input_var = relay.var("input", shape=(1, 3, 3, 4), dtype="int32") + out = relay.qnn.op.requantize( + input_var, + tvm_const(np.float32(1.0)), + tvm_const(np.int32(0)), + tvm_const(np.float32(1.0)), + tvm_const(np.int32(-128)), + axis=3, + out_dtype=dtype, + ) + out = relay.clip(out, a_min=min_val, a_max=max_val) + func = relay.Function([input_var], out) + unmodified = run_opt_pass(func, transform.InferType()) + legalized = run_opt_pass(func, transform.Legalize()) + + # Check that the clip op was removed if and only if `is_redundant` is True. + if is_redundant: + assert legalized.body.op.name == "qnn.requantize" + assert not tvm.ir.structural_equal(unmodified, legalized) + else: + assert legalized.body.op.name == "clip" + tvm.ir.assert_structural_equal(unmodified, legalized) + + +def test_ignores_standalone_clip_ops(): + """The legalization pass should only affect qnn.requantize -> clip sequences, and should leave + standalone clip operators untouched.""" + + input_var = relay.var("x", shape=(1, 3, 3, 4), dtype="int8") + out = relay.clip(input_var, a_min=-128, a_max=127) + func = relay.Function([input_var], out) + unmodified = run_opt_pass(func, transform.InferType()) + legalized = run_opt_pass(func, transform.Legalize()) + tvm.ir.assert_structural_equal(unmodified, legalized) diff --git a/tests/python/relay/qnn/test_qnn_channel_stripping.py b/tests/python/relay/qnn/test_qnn_channel_stripping.py new file mode 100644 index 000000000000..25197ca84c54 --- /dev/null +++ b/tests/python/relay/qnn/test_qnn_channel_stripping.py @@ -0,0 +1,299 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test QNN channel stripping legalization pass.""" + +import numpy as np +import tvm +from tvm import nd, relay + +from tvm.relay import transform +from tvm.relay.testing.temp_op_attr import TempOpAttr +from tvm.testing.aot import generate_ref_data + +from tvm.topi.arm_cpu.qnn_legalize import legalize_bias_add + + +def run_opt_pass(expr, passes): + passes = passes if isinstance(passes, list) else [passes] + mod = tvm.IRModule.from_expr(expr) + seq = tvm.transform.Sequential(passes) + with tvm.transform.PassContext(opt_level=3): + mod = seq(mod) + entry = mod["main"] + return entry if isinstance(expr, relay.Function) else entry.body + + +def execute_relay_func(relay_func, in_data): + ref_module = tvm.IRModule.from_expr(relay_func) + return generate_ref_data(ref_module, {"input": in_data})["output"] + + +def tvm_const(obj): + return relay.Constant(nd.array(obj)) + + +def make_test_conv_depthwise_conv(): + """Generates a convolution -> depthwise_convolution -> convolution pattern that can have + channels stripped. The structure here mirrors MobileNetV1's layers 8-10.""" + + input_var = relay.var("input", shape=(1, 12, 12, 4), dtype="int8") + + kernel_1 = np.array( + [[0, 1, 0, -2], [0, 3, 0, 5], [0, 5, 0, -9], [0, 2, 0, 21]], dtype="int8" + ).reshape((1, 1, 4, 4)) + input_scale_1 = np.float32(0.5) + output_scale_1 = np.array([0.5, 2.0, 0.25, 4.0], dtype="float32") + + out = relay.qnn.op.conv2d( + input_var, + tvm_const(kernel_1), + tvm_const(np.int32(-128)), + tvm_const(np.int32(0)), + tvm_const(input_scale_1), + tvm_const(output_scale_1), + channels=4, + kernel_size=(1, 1), + padding=(0, 0), + data_layout="NHWC", + kernel_layout="HWIO", + ) + + bias_1 = np.array([198, -2, 19, 10], dtype="int32") + out = relay.nn.bias_add( + out, + tvm_const(bias_1), + axis=3, + ) + + input_scale_2 = np.float32(0.25) + out = relay.qnn.op.requantize( + out, + tvm_const(input_scale_1 * output_scale_1), + tvm_const(np.int32(0)), + tvm_const(input_scale_2), + tvm_const(np.int32(-128)), + axis=3, + out_dtype="int8", + ) + # Outputs here will be fixed to {0: 70, 2: -118} + + kernel_2 = np.array( + [ + [0, 6, 4, 2], + [8, 6, -3, -1], + [-2, -5, 3, -8], + [-7, 5, 1, 9], + [-4, -9, -8, -2], + [-1, 4, -5, 3], + [-4, -9, 2, 6], + [9, -6, 0, 5], + [-3, 8, 1, -7], + ], + dtype="int8", + ).reshape((3, 3, 4, 1)) + output_scale_2 = np.array([0.25, 0.125, 2.0, 0.125], dtype="float32") + out = relay.qnn.op.conv2d( + out, + tvm_const(kernel_2), + tvm_const(np.int32(-128)), + tvm_const(np.int32(0)), + tvm_const(input_scale_2), + tvm_const(output_scale_2), + channels=4, + groups=4, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWOI", + ) + + bias_2 = np.array([4582, 4, -12, 15], dtype="int32") + out = relay.nn.bias_add( + out, + tvm_const(bias_2), + axis=3, + ) + + input_scale_3 = np.float32(0.125) + out = relay.qnn.op.requantize( + out, + tvm_const(input_scale_2 * output_scale_2), + tvm_const(np.int32(0)), + tvm_const(input_scale_3), + tvm_const(np.int32(-128)), + axis=3, + out_dtype="int8", + ) + # Outputs here will be fixed to {0: 127, 2: -128} + + kernel_3 = np.array( + [[4, -2, 9, 9], [0, 0, 0, 0], [0, 0, 0, 0], [-1, 1, -1, 1]], dtype="int8" + ).reshape((1, 1, 4, 4)) + output_scale_3 = np.array([0.25, 0.125, 1.0, 0.5], dtype="float32") + + out = relay.qnn.op.conv2d( + out, + tvm_const(kernel_3), + tvm_const(np.int32(-128)), + tvm_const(np.int32(0)), + tvm_const(input_scale_3), + tvm_const(output_scale_3), + channels=4, + kernel_size=(1, 1), + padding=(0, 0), + data_layout="NHWC", + kernel_layout="HWIO", + ) + + bias_3 = np.array([1, -1, 4, 6], dtype="int32") + out = relay.nn.bias_add( + out, + tvm_const(bias_3), + axis=3, + ) + + return relay.Function([input_var], out) + + +def make_test_conv_pool_dense(): + """Generates a convolution -> pool -> dense pattern that can have channels stripped. The + structure here mirrors MobileNetV1's final few layers.""" + + input_var = relay.var("input", shape=(1, 3, 3, 4), dtype="int8") + + kernel = np.array( + [[0, 1, 0, -2], [0, 3, 0, 5], [0, 5, 0, -9], [0, 2, 0, 21]], dtype="int8" + ).reshape((1, 1, 4, 4)) + input_scale = np.float32(0.029626124) + output_scale = np.array([0.5, 2.0, 0.25, 4.0], dtype="float32") + + out = relay.qnn.op.conv2d( + input_var, + tvm_const(kernel), + tvm_const(np.int32(-128)), + tvm_const(np.int32(0)), + tvm_const(input_scale), + tvm_const(output_scale), + channels=4, + kernel_size=(1, 1), + padding=(0, 0), + data_layout="NHWC", + kernel_layout="HWIO", + ) + + bias_1 = np.array([198, -2, 19, 10], dtype="int32") + out = relay.nn.bias_add( + out, + tvm_const(bias_1), + axis=3, + ) + + out = relay.qnn.op.requantize( + out, + tvm_const(input_scale * output_scale), + tvm_const(np.int32(0)), + tvm_const(np.float32(0.015656913)), + tvm_const(np.int32(-128)), + axis=3, + out_dtype="int8", + ) + + out = relay.cast(out, dtype="int32") + out = relay.nn.avg_pool2d( + out, + pool_size=[3, 3], + strides=[3, 3], + layout="NHWC", + ) + + out = relay.cast(out, dtype="int8") + # The channel stripping logic expects two reshape operators + out = relay.reshape(out, newshape=[-1, 4]) + out = relay.reshape(out, newshape=[-1, 4]) + + dense_weights = np.array([[15, -2, -3, 11], [12, -10, 13, -10]], dtype="int8") + out = relay.qnn.op.dense( + out, + tvm_const(dense_weights), + tvm_const(np.int32(-128)), + tvm_const(np.int32(0)), + tvm_const(np.float32(0.015656913)), + tvm_const(np.float32(0.0047202893)), + units=2, + out_dtype="int32", + ) + + dense_bias = np.array([1463, -1463], dtype="int32") + out = relay.nn.bias_add( + out, + tvm_const(dense_bias), + axis=1, + ) + + return relay.Function([input_var], out) + + +def test_conv_depthwise_conv(): + """Make sure that qnn_legalize.py is able to detect and remove empty output channels from a + convolution -> depthwise convolution -> convolution pattern by folding into a bias_add op.""" + + original = make_test_conv_depthwise_conv() + + with TempOpAttr("nn.bias_add", "FTVMLegalize", legalize_bias_add): + unoptimized = run_opt_pass(original, transform.InferType()) + optimized = run_opt_pass(original, transform.Legalize()) + + # Inputs and outputs should be unmodified by channel stripping + assert unoptimized.checked_type == optimized.checked_type + + # Make sure 2/4 channels were removed by channel stripping + assert tuple(unoptimized.body.args[0].args[0].checked_type.shape) == (1, 12, 12, 4) + assert tuple(optimized.body.args[0].args[0].checked_type.shape) == (1, 12, 12, 2) + + # Make sure optimized and unoptimized versions behave identically + np.random.seed(12402) # Fix seed for repeatability + input_data = np.random.randint(-128, 128, size=(1, 12, 12, 4), dtype="int8") + + unoptimized_output = execute_relay_func(unoptimized, np.copy(input_data)) + optimized_output = execute_relay_func(optimized, np.copy(input_data)) + np.testing.assert_array_equal(unoptimized_output, optimized_output) + + +def test_conv_pool_dense(): + """Make sure that qnn_legalize.py is able to detect and remove empty output channels from a + convolution -> avg_pool2d -> dense pattern by folding them into a bias_add op.""" + + original = make_test_conv_pool_dense() + + with TempOpAttr("nn.bias_add", "FTVMLegalize", legalize_bias_add): + unoptimized = run_opt_pass(original, transform.InferType()) + optimized = run_opt_pass(original, transform.Legalize()) + + # Inputs and outputs should be unmodified by channel stripping + assert unoptimized.checked_type == optimized.checked_type + + # Make sure 2/4 channels were removed by channel stripping + assert tuple(unoptimized.body.args[0].args[0].checked_type.shape) == (1, 4) + assert tuple(optimized.body.args[0].args[0].checked_type.shape) == (1, 2) + + # Make sure optimized and unoptimized versions behave identically + np.random.seed(12402) # Fix seed for repeatability + input_data = np.random.randint(-128, 128, size=(1, 3, 3, 4), dtype="int8") + + unoptimized_output = execute_relay_func(unoptimized, np.copy(input_data)) + optimized_output = execute_relay_func(optimized, np.copy(input_data)) + np.testing.assert_array_equal(unoptimized_output, optimized_output) diff --git a/tests/python/relay/strategy/arm_cpu/test_quantized_convolution.py b/tests/python/relay/strategy/arm_cpu/test_quantized_convolution.py index 573231f9632c..8af49ca08f7f 100644 --- a/tests/python/relay/strategy/arm_cpu/test_quantized_convolution.py +++ b/tests/python/relay/strategy/arm_cpu/test_quantized_convolution.py @@ -43,6 +43,7 @@ SAMPLE_URL = ( "https://github.com/dmlc/web-data/raw/main/tensorflow/models/InceptionV1/elephant-299.jpg" ) +MODEL_NUM_CONVS = 27 @pytest.fixture(scope="module") @@ -95,6 +96,54 @@ def _get_mobilenet_v1_layer_attributes(layer_num): return ((1, 1, 1, 1), (1, 1), True) +@pytest.mark.parametrize("layer", range(2, 27, 2)) +@tvm.testing.requires_package("tensorflow") +def test_empty_channel_detection(interpreter, layer): + """Some models (mainly MobileNetV1) have kernels with many output channels full entirely of + zeroes. The VWW model is one of these. This test confirms that the outputs of these channels, + as computed by TensorFlow, are indeed not dependent upon the input values. + """ + + _, kernel, bias, output = _load_tflite_layer(interpreter, layer) + kernel_data, _ = kernel + bias_data, bias_quant = bias + output_data, output_quant = output + is_depthwise = _get_mobilenet_v1_layer_attributes(layer)[2] + assert not is_depthwise + assert kernel_data.shape[1] == kernel_data.shape[2] == 1 + + out_channels = kernel_data.shape[3] + fixed_channels = {} + + out_zero_point = output_quant["zero_points"][0] + assert out_zero_point == -128 + + for i in range(out_channels): + # Skip over output channels with data + if np.any(kernel_data[i, 0, 0, :]): + continue + + scale = bias_quant["scales"][i] / output_quant["scales"][0] + channel_constant = round(bias_data[i] * scale + out_zero_point) + clipped = min(127, max(-128, channel_constant)) + + out_channel_values = output_data[0, :, :, i].flatten() + assert all(x == clipped for x in out_channel_values) + fixed_channels[i] = clipped + + # Check if we are on the final convolution and skip the next test if so + if layer + 1 >= MODEL_NUM_CONVS: + return + + # We now need to compute values for the following depthwise layer + depthwise_output = _load_tflite_layer(interpreter, layer + 1)[3][0] + is_depthwise = _get_mobilenet_v1_layer_attributes(layer + 1)[2] + assert is_depthwise + + for i in fixed_channels: + assert np.all(depthwise_output[:, :, :, i] == depthwise_output[0, 0, 0, i]) + + def _get_relu_activation_prefix(layer_num): if layer_num == 0: return "model/activation/Relu;" @@ -242,14 +291,8 @@ def _make_aot_model(params, hyperparams, layouts, is_depthwise=False): data, kernel, bias, output = tensors data_quant, kernel_quant, bias_quant, output_quant = quantizations - dtype, padding, _strides = hyperparams + dtype, _padding, _strides = hyperparams data_layout, _, output_layout = layouts - - if any(padding): - pad_const = int(data_quant["zero_points"][0]) - pad_before = (0, padding[0], padding[1], 0) - pad_after = (0, padding[2], padding[3], 0) - data = np.pad(data, tuple(zip(pad_before, pad_after)), constant_values=pad_const) data_ndarr = _change_layout(data, "NHWC", data_layout, dtype) output_ndarr = _change_layout(output, "NHWC", output_layout, dtype) @@ -284,9 +327,10 @@ def _make_executor(): ) -@pytest.mark.parametrize("layer", range(23)) +@pytest.mark.parametrize("output_layout", ["NHWC", "NCHW"]) +@pytest.mark.parametrize("layer", range(27)) @tvm.testing.requires_corstone300 -def test_qnn_conv2d_mobilenetv1_layer(interpreter, layer): +def test_qnn_conv2d_mobilenetv1_layer(interpreter, layer, output_layout): """Checks microTVM output against TFLite for one MobileNetV1 layer. Loads the input, kernel, bias, expected output, and quantization parameters from the specified @@ -294,7 +338,7 @@ def test_qnn_conv2d_mobilenetv1_layer(interpreter, layer): same structure. The Function is run using microTVM and AOTTestModel, and we verify microTVM's output is the same as the TFLite ground truth. - This function only cross-checks the first 23 layers in MobileNetV1, which are regular and + This function only cross-checks the first 27 layers in MobileNetV1, which are regular and depthwise 2D convolutions (this function only works for 2D convolutions). We do not test the average pool, dense, or softmax layers at the end of the model. @@ -309,6 +353,9 @@ def test_qnn_conv2d_mobilenetv1_layer(interpreter, layer): layer: int The index of the layer to check against TensorFlow's ground truth values. + + output_layout: str + The output_layout for microTVM to use. Does not have to match the TensorFlow layout. """ dtype = "int16" @@ -316,9 +363,9 @@ def test_qnn_conv2d_mobilenetv1_layer(interpreter, layer): padding, strides, is_depthwise = _get_mobilenet_v1_layer_attributes(layer) if is_depthwise: - data_layout, kernel_layout, output_layout = "NCHW", "OIHW", "NHWC" + data_layout, kernel_layout = "NCHW", "OIHW" else: - data_layout, kernel_layout, output_layout = "NHWC", "OHWI", "NHWC" + data_layout, kernel_layout = "NHWC", "OHWI" test_model = _make_aot_model( (tensor, kernel, bias, output), diff --git a/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py b/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py index 46d2797ba394..7bea7577b6bf 100644 --- a/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py +++ b/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py @@ -308,8 +308,10 @@ def test_1x1x8_convolution_code(): requant_3 = (requant_3 + 1) >> 1; requant_3 = __ssat(requant_3 + -128, 8); - int32_t packed_res_0 = requant_0 + (requant_1 << 16); - int32_t packed_res_1 = requant_2 + (requant_3 << 16); + int packed_res_0; + __asm__ ("pkhbt %0, %1, %2, lsl #16" : "=r" (packed_res_0) : "r" (requant_0), "r" (requant_1)); + int packed_res_1; + __asm__ ("pkhbt %0, %1, %2, lsl #16" : "=r" (packed_res_1) : "r" (requant_2), "r" (requant_3)); output[0] = packed_res_0; output[1] = packed_res_1; return 0;