From d95decabca5871972161a8eda586a496c92758c3 Mon Sep 17 00:00:00 2001 From: Andrei Hutu Date: Mon, 20 Nov 2023 16:48:20 +0000 Subject: [PATCH 1/6] [TOPI][Relay] Add conv2d NHWC fp32 hybrid schedule for `arm_cpu` Implemented an `arm_cpu` conv2d NHWC schedule for fp32 using a hybrid GeMM approach, effectively breaking down the matrix multiplication into a macro-kernel (partitioning into fixed-sized, tile-level subproblems) and a micro-kernel (independently dealing with each subproblem). After the im2col transformation, the input matrix is handled natively (not interleaved), while the weights matrix is tiled and interleaved at compile time. The micro-kernel uses 16 registers to accumulate the results of each 4x16 output tile, cycling through the operands needed to compute them (from the input and weight matrices) in the remaining registers. There are now two ways to transform the weights matrix for conv2d, which are detailed in `convolution.cc`: * for fp32: tile, interleave * for int8: tile, interleave, transpose To maintain naming consistency across both of these implementations (transposed vs not transposed), all mentions of `tile_rows_B` or `tile_cols_B` have been changed to `tile_N` and `tile_K` respectively to denote the tiling size along each axis of the flattened B matrix. As usual, `N = out_channels` and `K = kernel_width * kernel_height * in_channels`. I have also added a new conv2d NHWC fp32 test for both the `conv2d_nhwc_spatial_pack` and `conv2d_NHWC_fp32_hybrid` schedules. --- include/tvm/relay/attrs/nn.h | 10 +- python/tvm/relay/op/nn/_nn.py | 2 +- python/tvm/relay/op/nn/nn.py | 12 +- python/tvm/relay/op/strategy/arm_cpu.py | 17 +- python/tvm/topi/arm_cpu/arm_utils.py | 105 +++--- python/tvm/topi/arm_cpu/conv2d.py | 111 ++++++ python/tvm/topi/arm_cpu/conv2d_alter_op.py | 57 +-- python/tvm/topi/arm_cpu/conv2d_gemm.py | 338 +++++++++++------- python/tvm/topi/arm_cpu/conv2d_int8.py | 96 +---- python/tvm/topi/nn/conv2d.py | 30 +- src/relay/op/nn/convolution.cc | 70 +++- tests/python/integration/test_arm_aprofile.py | 1 + tests/python/topi/test_topi_conv2d_nhwc.py | 33 ++ 13 files changed, 542 insertions(+), 340 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index e58c73dc7354..58edb9df8b97 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -197,12 +197,14 @@ struct ConvWinogradWeightTransformAttrs : public tvm::AttrsNode { - int tile_rows; - int tile_cols; + int tile_N; + int tile_K; TVM_DECLARE_ATTRS(ConvGemmWeightTransformAttrs, "relay.attrs.ConvGemmWeightTransformAttrs") { - TVM_ATTR_FIELD(tile_rows).describe("Tile rows of the weight transformation for ConvGemm."); - TVM_ATTR_FIELD(tile_cols).describe("Tile columns of the weight transformation for ConvGemm."); + TVM_ATTR_FIELD(tile_N).describe( + "Tile size across N axis of the weight transformation for ConvGemm. (N = OC)"); + TVM_ATTR_FIELD(tile_K).describe( + "Tile size across K axis of the weight transformation for ConvGemm. (K = KW * KH * IC)"); } }; diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 6acaf43fe7d2..a03907f071fd 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -798,7 +798,7 @@ def mirror_pad_func(attrs, inputs, _): @reg.register_compute("nn.contrib_conv2d_gemm_weight_transform") def compute_contrib_conv2d_gemm_weight_transform(attrs, inputs, out_dtype): """Compute definition of contrib_conv2d_gemm_weight_transform""" - out = topi.nn.conv2d_gemm_weight_transform(inputs[0], attrs.tile_rows, attrs.tile_cols) + out = topi.nn.conv2d_gemm_weight_transform(inputs[0], attrs.tile_N, attrs.tile_K) return [out] diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 89953eb1dfb3..8cb66ecaa9a2 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -2741,7 +2741,7 @@ def contrib_conv2d_winograd_weight_transform(weight, tile_size): return _make.contrib_conv2d_winograd_weight_transform(weight, tile_size) -def contrib_conv2d_gemm_weight_transform(weights, tile_rows, tile_cols): +def contrib_conv2d_gemm_weight_transform(weights, tile_N, tile_K): r"""Weight Transformation part for 2D convolution with gemm algorithm. We separate this as a single op to enable pre-compute for inference. @@ -2751,17 +2751,17 @@ def contrib_conv2d_gemm_weight_transform(weights, tile_rows, tile_cols): ---------- weights : tvm.relay.Expr The weight expressions. - tile_rows: int - Tile rows of the weight transformation for ConvGemm. - tile_cols: int - Tile columns of the weight transformation for ConvGemm. + tile_N: int + Tile size across N axis of the weight transformation for ConvGemm. (N = OC) + tile_K: int + Tile size across K axis of the weight transformation for ConvGemm. (K = KW * KH * IC) Returns ------- result : tvm.relay.Expr The computed result. """ - return _make.contrib_conv2d_gemm_weight_transform(weights, tile_rows, tile_cols) + return _make.contrib_conv2d_gemm_weight_transform(weights, tile_N, tile_K) def contrib_conv3d_winograd_weight_transform(weight, tile_size): diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index a23ccf8f6932..36afe6957ff0 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -242,7 +242,13 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): ), name="conv2d_NHWC_quantized_interleaved.arm_cpu", ) - if (not is_aarch64) or (data.dtype not in ["int8", "uint8"]): + if is_aarch64 and data.dtype not in ["int8", "uint8"]: + strategy.add_implementation( + wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_fp32_hybrid), + wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_fp32_hybrid), + name="conv2d_NHWC_fp32_hybrid.arm_cpu", + ) + else: # TODO(@giuseros) # This strategy errors out for quantized data types when tuning. # Let's use this only for non-aarch64 or non-quantized cases @@ -517,9 +523,12 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu", ) else: - raise RuntimeError( - f"Unsupported conv2d_NHWC_quantized_without_transform layout {layout}" - f"with datatype {data.dtype}" + strategy.add_implementation( + wrap_compute_conv2d_gemm( + topi.arm_cpu.compute_conv2d_NHWC_fp32_hybrid_without_transform + ), + wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_fp32_hybrid_without_transform), + name="conv2d_NHWC_fp32_hybrid_without_transform.arm_cpu", ) return strategy diff --git a/python/tvm/topi/arm_cpu/arm_utils.py b/python/tvm/topi/arm_cpu/arm_utils.py index 9c519cbb936c..50f570f17f47 100644 --- a/python/tvm/topi/arm_cpu/arm_utils.py +++ b/python/tvm/topi/arm_cpu/arm_utils.py @@ -20,9 +20,9 @@ from tvm.target import Target -def get_tiling_B_interleaved_t(interleave_A): +def get_tiling_B_transformed(interleave_A, in_dtype): """Compute the tiling information for matrix B', where B' - is the transposed and interleaved version of matrix B in C=A*B. + is the tiled, interleaved (and transposed) version of matrix B in C=A*B. The tiling information is chosen to maximize register usage during the tile computation. @@ -36,59 +36,68 @@ def get_tiling_B_interleaved_t(interleave_A): Parameters ---------- - interleave_A: bool - determines if A is expected to be interleaved + interleave_A : bool + determines if A is expected to be interleaved + in_dtype : str + input datatype + Returns ---------- - tile_rows_B: the output tile rows of B' - tile_cols_B: the output tile columns of B' + tile_N: the output tile size of B' on N axis (N = OC) + tile_K: the output tile size of B' on K axis (K = KW * KH * IC) """ target = Target.current(allow_none=False) - - if target.features.has_matmul_i8: - # If smmla/ummla is available, A must be interleaved. - # Each load from B' will contain 8 elements - # and we are loading 12 rows of B' (i.e., 12 columns of B) - tile_rows_B = 12 - tile_cols_B = 8 - elif target.features.has_dotprod: - # The number of tile rows of B' vary depending on the - # strategy: - # * If we are interleaving A, then we select 12 columns from B'(i.e., - # 12 rows from B). - # * If we are not interleaving A, then we select 16 columns from B'(i.e., - # 16 rows from B). - tile_rows_B = 12 if interleave_A else 16 - - # Dot product instruction groups 2 (u)int16x8 vectors in - # groups of 4 and compute the dot product among those groups - # This means that the number of columns in a tile of B' (i.e., the - # rows of the original matrix B) need to be 4. - tile_cols_B = 4 + if in_dtype in ["int8", "uint8"]: + if target.features.has_matmul_i8: + # If smmla/ummla is available, A must be interleaved. + # Each load from B' will contain 8 elements + # and we are loading 12 rows of B' (i.e., 12 columns of B) + tile_N = 12 + tile_K = 8 + elif target.features.has_dotprod: + # The number of tile rows of B' vary depending on the + # strategy: + # * If we are interleaving A, then we select 12 columns from B'(i.e., + # 12 rows from B). + # * If we are not interleaving A, then we select 16 columns from B'(i.e., + # 16 rows from B). + tile_N = 12 if interleave_A else 16 + + # Dot product instruction groups 2 (u)int16x8 vectors in + # groups of 4 and compute the dot product among those groups + # This means that the number of columns in a tile of B' (i.e., the + # rows of the original matrix B) need to be 4. + tile_K = 4 + else: + # If no acceleration is available, A must be interleaved. In this case + # we load 4 rows of B' (i.e., 4 columns of B). Each of them will contain 16 elements + tile_N = 4 + tile_K = 16 else: - # If no acceleration is available, A must be interleaved. In this case - # we load 4 rows of B' (i.e., 4 columns of B). Each of them will contain 16 elements - tile_rows_B = 4 - tile_cols_B = 16 + # In non-quantized cases, A is not interleaved. + # Each load from B' contains 16 elements (i.e. 16 columns from B) + # We are loading 4 rows from B', in the dimension of reduction (i.e. 4 rows from B) + tile_N = 16 + tile_K = 4 - return tile_rows_B, tile_cols_B + return tile_N, tile_K -def get_conv2d_weights_padding(N, K, tile_rows, tile_cols): +def get_conv2d_weights_padding(N, K, tile_N, tile_K): """Compute the necessary padding for matrix B', where B' - is the transposed and interleaved version of matrix B in C=A*B. + is the transformed version of matrix B in C=A*B. Parameters ---------- N : int - Number of rows in B' = OC + Number of columns in B = OC K : int - Number of columns in B' = KW * KH * IC - tile_rows : int - tile rows of B' - tile_cols : int - tile columns of B' + Number of rows in B = KW * KH * IC + tile_N : int + tile size of B' on N axis + tile_K : int + tile size of B' on K axis Returns ---------- @@ -98,16 +107,16 @@ def get_conv2d_weights_padding(N, K, tile_rows, tile_cols): pad_N = 0 pad_K = 0 - if N % tile_rows != 0: - pad_N = tile_rows - (N % tile_rows) + if N % tile_N != 0: + pad_N = tile_N - (N % tile_N) - # Tensorize will later make use of 4 tiles at once across the columns so make sure we pad such - # that the columns is multiple of 4 - column_multiplier = 4 - tile_cols_multiplied = tile_cols * column_multiplier - K_misalignment = K % tile_cols_multiplied + # Tensorize will later make use of 4 tiles at once across the K axis so make sure we pad such + # that K is multiple of 4 + K_multiplier = 4 + tile_K_multiplied = tile_K * K_multiplier + K_misalignment = K % tile_K_multiplied if K_misalignment != 0: - pad_K = tile_cols_multiplied - K_misalignment + pad_K = tile_K_multiplied - K_misalignment return pad_N, pad_K diff --git a/python/tvm/topi/arm_cpu/conv2d.py b/python/tvm/topi/arm_cpu/conv2d.py index a478818084d5..67b2ed8d86a6 100644 --- a/python/tvm/topi/arm_cpu/conv2d.py +++ b/python/tvm/topi/arm_cpu/conv2d.py @@ -27,12 +27,18 @@ from .. import nn from ..nn.utils import get_const_int, get_pad_tuple from ..nn.winograd_util import winograd_transform_matrices +from .arm_utils import get_tiling_B_transformed from .conv2d_spatial_pack import ( conv2d_spatial_pack_nchw, conv2d_spatial_pack_nhwc, schedule_conv2d_spatial_pack_nchw, schedule_conv2d_spatial_pack_nhwc, ) +from .conv2d_gemm import ( + compute_conv2d_gemm_without_weight_transform, + schedule_conv2d_gemm_interleaved, + schedule_conv2d_gemm_native, +) from .mprofile.dsp.conv2d import conv2d_nhwc_dsp_compute, conv2d_nhwc_dsp_schedule @@ -509,3 +515,108 @@ def conv2d_nhwc_dsp(cfg, data, kernel, strides, padding, dilation, out_dtype): def schedule_conv2d_nhwc_dsp(cfg, outs): """Create schedule for conv2d_nhwc_dsp""" return conv2d_nhwc_dsp_schedule(cfg, outs) + + +def compute_conv2d_NHWC(cfg, data, kernel, strides, padding, dilation, out_dtype, interleave_A): + N, IH, IW, IC = get_const_tuple(data.shape) + KH, KW, _, OC = get_const_tuple(kernel.shape) + tile_N, tile_K = get_tiling_B_transformed(interleave_A, data.dtype) + + kernel = nn.conv2d_gemm_weight_transform(kernel, tile_N, tile_K) + return compute_conv2d_gemm_without_weight_transform( + cfg, data, kernel, strides, padding, dilation, out_dtype, (KH, KW), OC, interleave_A + ) + + +def compute_conv2d_NHWC_without_transform( + cfg, + data, + B, + strides, + padding, + dilation, + out_dtype, + kernel_size=None, + output_channels=None, + interleave_A=False, +): + """Compute conv2d NHWC without weight transform""" + return compute_conv2d_gemm_without_weight_transform( + cfg, + data, + B, + strides, + padding, + dilation, + out_dtype, + kernel_size, + output_channels, + interleave_A, + ) + + +def schedule_conv2d_NHWC(cfg, outs, interleave_A): + """Create schedule for tensors""" + s = te.create_schedule([x.op for x in outs]) + # Vectorize the output and then inline all the rest + out = outs[0] + n, h, w, c = out.op.axis + n_h_fused = s[out].fuse(n, h) + _, inner = s[out].split(c, 4) + s[out].vectorize(inner) + s[out].parallel(n_h_fused) + + def _callback(op): + """Traverse operators from computation graph""" + if op.name == "conv2d_gemm_output": + conv_out = op.output(0) + if interleave_A: + schedule_conv2d_gemm_interleaved(cfg, s, conv_out, out) + else: + schedule_conv2d_gemm_native(cfg, s, conv_out, out) + if out != conv_out: + s[conv_out].compute_at(s[out], inner) + else: + C = conv_out.op.input_tensors[0] + if interleave_A: + s[C].compute_at(s[out], inner) + + traverse_inline(s, outs[0].op, _callback) + return s + + +@autotvm.register_topi_compute("conv2d_NHWC_fp32_hybrid.arm_cpu") +def compute_conv2d_NHWC_fp32_hybrid(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Interface for hybrid compute_conv2d_NHWC_fp32_hybrid""" + return compute_conv2d_NHWC(cfg, data, kernel, strides, padding, dilation, out_dtype, False) + + +@autotvm.register_topi_compute("conv2d_NHWC_fp32_hybrid_without_transform.arm_cpu") +def compute_conv2d_NHWC_fp32_hybrid_without_transform( + cfg, data, kernel, strides, padding, dilation, out_dtype, kernel_size, output_channels +): + """Interface for hybrid compute_conv2d_NHWC_fp32_hybrid_without_transform""" + return compute_conv2d_NHWC_without_transform( + cfg, + data, + kernel, + strides, + padding, + dilation, + out_dtype, + kernel_size, + output_channels, + False, + ) + + +@autotvm.register_topi_schedule("conv2d_NHWC_fp32_hybrid.arm_cpu") +def schedule_conv2d_NHWC_fp32_hybrid(cfg, outs): + """Interface for hybrid schedule_conv2d_NHWC_fp32_hybrid""" + return schedule_conv2d_NHWC(cfg, outs, False) + + +@autotvm.register_topi_schedule("conv2d_NHWC_fp32_hybrid_without_transform.arm_cpu") +def schedule_conv2d_NHWC_fp32_hybrid_without_transform(cfg, outs): + """Interface for hybrid schedule_conv2d_NHWC_fp32_hybrid""" + return schedule_conv2d_NHWC(cfg, outs, False) diff --git a/python/tvm/topi/arm_cpu/conv2d_alter_op.py b/python/tvm/topi/arm_cpu/conv2d_alter_op.py index 1c30e1f3b650..8984d1aafa69 100644 --- a/python/tvm/topi/arm_cpu/conv2d_alter_op.py +++ b/python/tvm/topi/arm_cpu/conv2d_alter_op.py @@ -32,15 +32,15 @@ from ..x86.conv2d import _get_default_config as _get_x86_default_config from ..x86.conv2d_int8 import _get_default_config_int8 from .conv2d_int8 import is_int8_hw_support -from .arm_utils import get_tiling_B_interleaved_t, get_conv2d_weights_padding +from .arm_utils import get_tiling_B_transformed, get_conv2d_weights_padding from ..generic.conv2d import conv2d_alter_int8_common from .mprofile.dsp.micro_kernel.common import num_simd_lanes_per_word logger = logging.getLogger("topi") -def interleave_transpose_weights(inputs, data, kernel, interleave_A): - """Transform the weight matrix by reshaping, interleaving and transposing it +def transform_weights(inputs, data, kernel, interleave_A): + """Transform the weight matrix by tiling, interleaving (and transposing it) Parameters ---------- @@ -59,29 +59,28 @@ def interleave_transpose_weights(inputs, data, kernel, interleave_A): new_kernel_expr : tvm.relay.Expr The relay expression of the weights """ - assert ( - data.dtype == "int8" - and kernel.dtype == "int8" - or data.dtype == "uint8" - and kernel.dtype == "uint8" - ) KH, KW, IC, OC = get_const_tuple(kernel.shape) K = KH * KW * IC N = OC - # Get tiling information for the interleaved transposed version of B - tile_rows_B, tile_cols_B = get_tiling_B_interleaved_t(interleave_A) - pad_N, pad_K = get_conv2d_weights_padding(N, K, tile_rows_B, tile_cols_B) + # Get tiling information for the transformed version of B + tile_N, tile_K = get_tiling_B_transformed(interleave_A, data.dtype) + pad_N, pad_K = get_conv2d_weights_padding(N, K, tile_N, tile_K) N_padded = N + pad_N K_padded = K + pad_K - new_kernel_expr = relay.nn.contrib_conv2d_gemm_weight_transform( - inputs[1], tile_rows_B, tile_cols_B - ) - new_kernel = te.placeholder( - (N_padded // tile_rows_B, K_padded // tile_cols_B, tile_rows_B, tile_cols_B), kernel.dtype - ) + new_kernel_expr = relay.nn.contrib_conv2d_gemm_weight_transform(inputs[1], tile_N, tile_K) + if data.dtype in ["int8", "uint8"]: + new_kernel = te.placeholder( + (N_padded // tile_N, K_padded // tile_K, tile_N, tile_K), + kernel.dtype, + ) + else: + new_kernel = te.placeholder( + (N_padded // tile_N, K_padded // tile_K, tile_K, tile_N), + kernel.dtype, + ) return new_kernel, new_kernel_expr @@ -149,6 +148,20 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): inputs[0], relay.Constant(tvm.nd.array(reshaped_new_kernel)), **new_attrs ) + if topi_tmpl == "conv2d_NHWC_fp32_hybrid.arm_cpu": + assert data_layout == "NHWC" and kernel_layout == "HWIO" + KH, KW, _, OC = get_const_tuple(kernel.shape) + new_workload_name = "conv2d_NHWC_fp32_hybrid_without_transform.arm_cpu" + new_kernel, new_kernel_expr = transform_weights(inputs, data, kernel, interleave_A=False) + new_workload = autotvm.task.args_to_workload( + [data, new_kernel, strides, padding, dilation, out_dtype, (KH, KW), OC], + new_workload_name, + ) + dispatch_ctx.update(target, new_workload, cfg) + return relay.nn.contrib_conv2d_gemm_without_weight_transform( + inputs[0], new_kernel_expr, **new_attrs + ) + # Only microTVM does layout alteration for NHWC layout with real data types if data_layout == "NHWC" and data_dtype not in ["uint8", "int8"]: return None @@ -431,9 +444,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): assert data_layout == "NHWC" and kernel_layout == "HWIO" KH, KW, _, OC = get_const_tuple(kernel.shape) new_workload_name = "conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu" - new_kernel, new_kernel_expr = interleave_transpose_weights( - inputs, data, kernel, interleave_A=True - ) + new_kernel, new_kernel_expr = transform_weights(inputs, data, kernel, interleave_A=True) new_workload = autotvm.task.args_to_workload( [data, new_kernel, strides, padding, dilation, out_dtype, (KH, KW), OC], new_workload_name, @@ -447,9 +458,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): assert data_layout == "NHWC" and kernel_layout == "HWIO" KH, KW, _, OC = get_const_tuple(kernel.shape) new_workload_name = "conv2d_NHWC_quantized_native_without_transform.arm_cpu" - new_kernel, new_kernel_expr = interleave_transpose_weights( - inputs, data, kernel, interleave_A=False - ) + new_kernel, new_kernel_expr = transform_weights(inputs, data, kernel, interleave_A=False) new_workload = autotvm.task.args_to_workload( [data, new_kernel, strides, padding, dilation, out_dtype, (KH, KW), OC], new_workload_name, diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index 90e02c5ab043..649b6d99aef5 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -70,6 +70,7 @@ def compute_conv2d_gemm_without_weight_transform( """Compute conv2d by transforming the input, executing GEMM and transforming the output back""" batches, IH, IW, IC = get_const_tuple(data.shape) + in_dtype = data.dtype KH, KW = get_const_tuple(kernel_size) OC = get_const_int(output_channels) @@ -90,7 +91,7 @@ def compute_conv2d_gemm_without_weight_transform( OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1 OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1 - if pad_top or pad_left: + if pad_top or pad_left or pad_down or pad_right: data_pad = nn.pad( data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0], name="data_pad" ) @@ -119,8 +120,7 @@ def compute_conv2d_gemm_without_weight_transform( # Pad if necessary N_transformed = B_interleaved_t.shape[0] - tile_rows_B = B_interleaved_t.shape[2] - tile_cols_B = B_interleaved_t.shape[3] + tile_N = B_interleaved_t.shape[2] if in_dtype in ["int8", "uint8"] else B_interleaved_t.shape[3] # Select the tiling strategy for A. # The tiling information is chosen to maximize register usage during @@ -134,34 +134,41 @@ def compute_conv2d_gemm_without_weight_transform( # In order to have more information # target = Target.current(allow_none=False) - if target.features.has_matmul_i8: - # If smmla/ummla is enabled, we are loading 8 rows from A. Each row - # will contain 8 elements - tile_rows_A = 8 - tile_cols_A = 8 - elif target.features.has_dotprod and interleave_A: - # If dot product has been enabled, and we are interleaving A - # tile size should be 8x4 - tile_rows_A = 8 - tile_cols_A = 4 + if in_dtype in ["int8", "uint8"]: + if target.features.has_matmul_i8: + # If smmla/ummla is enabled, we are loading 8 rows from A. Each row + # will contain 8 elements + tile_M = 8 + tile_K = 8 + elif target.features.has_dotprod and interleave_A: + # If dot product has been enabled, and we are interleaving A + # tile size should be 8x4 + tile_M = 8 + tile_K = 4 + else: + # If either there is no dot product or if we are using a native strategy + # tile size should be 4x16 + tile_M = 4 + tile_K = 16 else: - # If either there is no dot product or if we are using a native strategy - # tile size should be 4x16 - tile_rows_A = 4 - tile_cols_A = 16 + # In non-quantized cases, A is not interleaved. + # We are loading 4 rows from A. + # Each row will contain 4 elements, along the dimension of reduction + tile_M = 4 + tile_K = 4 pad_M = 0 pad_K = 0 - if M % tile_rows_A != 0: - pad_M = tile_rows_A - (M % tile_rows_A) + if M % tile_M != 0: + pad_M = tile_M - (M % tile_M) - if K % tile_cols_A != 0: - pad_K = tile_cols_A - (K % tile_cols_A) + if K % tile_K != 0: + pad_K = tile_K - (K % tile_K) M_padded = M + pad_M K_padded = K + pad_K - N_padded = N_transformed * tile_rows_B + N_padded = N_transformed * tile_N pad_before = (0, 0, 0) pad_after = (0, pad_M, pad_K) @@ -174,131 +181,158 @@ def compute_conv2d_gemm_without_weight_transform( idxm = tvm.tir.indexmod k = te.reduce_axis((0, K_padded), "k") - if interleave_A: - # Configuration space - configure_knobs(cfg, M_padded, K_padded, target) + if in_dtype in ["int8", "uint8"]: + if interleave_A: + # Configuration space + configure_knobs(cfg, M_padded, K_padded, target) - # Pack the input data - A_interleaved = te.compute( - (batches, M_padded // tile_rows_A, K_padded // tile_cols_A, tile_rows_A, tile_cols_A), - lambda b, x, y, z, w: A[b, z + tile_rows_A * x, w + tile_cols_A * y], - name="A_interleaved", - ) - target = Target.current(allow_none=False) - if target.features.has_matmul_i8: - # Execute GEMM. In the case of mmla, we need to enforce the tiling - # from the compute. This is because mmla is doing a tiled computation - # as well. So we have a big 8x12 tile, with small 2x2 sub-tiles - # generated by mmla. In theory we could make the tile 2x2 and - # fuse and split during scheduling, but this would not work - # because of possible padding - C_interleaved = te.compute( + # Pack the input data + A_interleaved = te.compute( ( batches, - M_padded // tile_rows_A, - N_transformed, - tile_rows_A // 2, - tile_rows_B // 2, - 2, - 2, + M_padded // tile_M, + K_padded // tile_K, + tile_M, + tile_K, ), - lambda b, x, y, w, z, s, t: te.sum( - A_interleaved[b, x, k // tile_cols_A, 2 * w + s, idxm(k, tile_cols_A)].astype( - "int32" - ) - * B_interleaved_t[y, k // tile_cols_B, 2 * z + t, idxm(k, tile_cols_B)].astype( - "int32" - ), - axis=k, - ), - name="C_interleaved", - ) - # Ensure the padding needed for tensorize does not get removed during tir passes - # by adding a dummy reference to the specific padded area of the result - zero = ( - tvm.tir.const(1, C_interleaved.dtype) - * C_interleaved[ - batches - 1, - M // tile_rows_A, - N_transformed - 1, - idxm(M, tile_rows_A) // 2, - tile_rows_B // 2 - 1, - 1, - 1, - ] - - tvm.tir.const(1, C_interleaved.dtype) - * C_interleaved[ - batches - 1, - M // tile_rows_A, - N_transformed - 1, - idxm(M, tile_rows_A) // 2, - tile_rows_B // 2 - 1, - 1, - 1, - ] + lambda b, x, y, z, w: A[b, z + tile_M * x, w + tile_K * y], + name="A_interleaved", ) - # Unpack the result - C = te.compute( - (batches, M, N), - lambda b, x, y: ( - C_interleaved[ - b, - x // tile_rows_A, - y // tile_rows_B, - idxm(x, tile_rows_A) // 2, - idxm(y, tile_rows_B) // 2, - idxm(idxm(x, tile_rows_A), 2), - idxm(idxm(y, tile_rows_B), 2), + target = Target.current(allow_none=False) + if target.features.has_matmul_i8: + # Execute GEMM. In the case of mmla, we need to enforce the tiling + # from the compute. This is because mmla is doing a tiled computation + # as well. So we have a big 8x12 tile, with small 2x2 sub-tiles + # generated by mmla. In theory we could make the tile 2x2 and + # fuse and split during scheduling, but this would not work + # because of possible padding + C_interleaved = te.compute( + ( + batches, + M_padded // tile_M, + N_transformed, + tile_M // 2, + tile_N // 2, + 2, + 2, + ), + lambda b, x, y, w, z, s, t: te.sum( + A_interleaved[b, x, k // tile_K, 2 * w + s, idxm(k, tile_K)].astype("int32") + * B_interleaved_t[y, k // tile_K, 2 * z + t, idxm(k, tile_K)].astype( + "int32" + ), + axis=k, + ), + name="C_interleaved", + ) + # Ensure the padding needed for tensorize does not get removed during tir passes + # by adding a dummy reference to the specific padded area of the result + zero = ( + tvm.tir.const(1, C_interleaved.dtype) + * C_interleaved[ + batches - 1, + M // tile_M, + N_transformed - 1, + idxm(M, tile_M) // 2, + tile_N // 2 - 1, + 1, + 1, ] - + zero - ).astype(out_dtype), - name="C", - ) + - tvm.tir.const(1, C_interleaved.dtype) + * C_interleaved[ + batches - 1, + M // tile_M, + N_transformed - 1, + idxm(M, tile_M) // 2, + tile_N // 2 - 1, + 1, + 1, + ] + ) + # Unpack the result + C = te.compute( + (batches, M, N), + lambda b, x, y: ( + C_interleaved[ + b, + x // tile_M, + y // tile_N, + idxm(x, tile_M) // 2, + idxm(y, tile_N) // 2, + idxm(idxm(x, tile_M), 2), + idxm(idxm(y, tile_N), 2), + ] + + zero + ).astype(out_dtype), + name="C", + ) + else: + # Execute GEMM + C_interleaved = te.compute( + (batches, M_padded // tile_M, N_transformed, tile_M, tile_N), + lambda b, x, y, w, z: te.sum( + A_interleaved[b, x, k // tile_K, w, idxm(k, tile_K)].astype("int32") + * B_interleaved_t[y, k // tile_K, z, idxm(k, tile_K)].astype("int32"), + axis=k, + ), + name="C_interleaved", + ) + # Unpack the result + C = te.compute( + (batches, M, N), + lambda b, x, y: C_interleaved[ + b, + x // tile_M, + y // tile_N, + idxm(x, tile_M), + idxm(y, tile_N), + ].astype(out_dtype), + name="C", + ) + zero = tvm.tir.const(0) else: - # Execute GEMM - C_interleaved = te.compute( - (batches, M_padded // tile_rows_A, N_transformed, tile_rows_A, tile_rows_B), - lambda b, x, y, w, z: te.sum( - A_interleaved[b, x, k // tile_cols_A, w, idxm(k, tile_cols_A)].astype("int32") - * B_interleaved_t[y, k // tile_cols_B, z, idxm(k, tile_cols_B)].astype("int32"), + # No need to pack/unpack, execute GEMM directly + C = te.compute( + (batches, M_padded, N_padded), + lambda b, x, y: te.sum( + A[b, x, k].astype("int32") + * B_interleaved_t[ + y // tile_N, + k // tile_K, + idxm(y, tile_N), + idxm(k, tile_K), + ].astype("int32"), axis=k, ), - name="C_interleaved", - ) - # Unpack the result - C = te.compute( - (batches, M, N), - lambda b, x, y: C_interleaved[ - b, - x // tile_rows_A, - y // tile_rows_B, - idxm(x, tile_rows_A), - idxm(y, tile_rows_B), - ].astype(out_dtype), name="C", ) - zero = tvm.tir.const(0) + + # We need to ensure that infer bound pass does not remove the padding + # which is necessary for the tensorizations to work. So we need to + # add a dummy reference to the padding area of the result + zero = ( + tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1] + - tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1] + ) else: - # No need to pack/unpack, execute GEMM directly + # Configuration space + configure_knobs(cfg, M_padded, K_padded, target) + C = te.compute( (batches, M_padded, N_padded), lambda b, x, y: te.sum( - A[b, x, k].astype("int32") + A[b, x, k].astype("float32") * B_interleaved_t[ - y // tile_rows_B, k // tile_cols_B, idxm(y, tile_rows_B), idxm(k, tile_cols_B) - ].astype("int32"), + y // tile_N, + k // tile_K, + idxm(k, tile_K), + idxm(y, tile_N), + ].astype("float32"), axis=k, ), name="C", ) - - # We need to ensure that infer bound pass does not remove the padding - # which is necessary for the tensorizations to work. So we need to - # add a dummy reference to the padding area of the result - zero = ( - tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1] - - tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1] - ) + zero = tvm.tir.const(0) # Reshape the result into a convolution output out_shape = (batches, OH, OW, OC) @@ -417,14 +451,35 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): # Computation b, x, y = C.op.axis (k,) = C.op.reduce_axis - k_outer, k_inner = s[C].split(k, 16) - y_tile_size = 16 - x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=4, y_factor=y_tile_size) - s[C].reorder(b, x_outer, y_outer, k_outer, x_inner, y_inner, k_inner) - gemm_acc = gemm_acc_nx16_int8_int8_int32(in_type, rows=1) - s[C].unroll(x_inner) - s[C].tensorize(y_inner, gemm_acc) - s[C].parallel(x_outer) + + if in_type in ["int8", "uint8"]: + k_outer, k_inner = s[C].split(k, 16) + y_tile_size = 16 + x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=4, y_factor=y_tile_size) + s[C].reorder(b, x_outer, y_outer, k_outer, x_inner, y_inner, k_inner) + gemm_acc = gemm_acc_nx16_int8_int8_int32(in_type, rows=1) + s[C].unroll(x_inner) + s[C].tensorize(y_inner, gemm_acc) + s[C].parallel(x_outer) + else: + k_outer, k_inner = s[C].split(k, 4) + y_tile_size = 16 + x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=4, y_factor=y_tile_size) + y_inner_outer, y_inner_inner = s[C].split(y_inner, 4) + b_x_outer_fused = s[C].fuse(b, x_outer) + s[C].parallel(b_x_outer_fused) + s[C].reorder( + b_x_outer_fused, + y_outer, + k_outer, + k_inner, + y_inner_outer, + x_inner, + y_inner_inner, + ) + s[C].unroll(y_inner_outer) + s[C].unroll(x_inner) + s[C].vectorize(y_inner_inner) # Input transform if A.op.name == "A_padded_K" or A.op.name == "A_padded_M": @@ -450,7 +505,11 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): split_factor = 16 n_size = data_im2col.shape[2] - if n_size % split_factor != 0: + if n_size % 16 == 0: + split_factor = 16 + elif n_size % 8 == 0: + split_factor = 8 + else: # Split by kernel area (KH * KW) to ensure proper vectorization ic = data_im2col.op.input_tensors[0].shape[3] split_factor = n_size // ic @@ -466,6 +525,13 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): else: s[data_im2col].compute_at(s[C], x_inner) + A_pad = data_im2col.op.input_tensors[0] + if A_pad.op.name == "data_pad": + n, h, w, c = A_pad.op.axis + n_h_fused = s[A_pad].fuse(n, h) + s[A_pad].parallel(n_h_fused) + s[A_pad].vectorize(c) + # Output transform if out != final_out: n, h, w, c = out.op.axis diff --git a/python/tvm/topi/arm_cpu/conv2d_int8.py b/python/tvm/topi/arm_cpu/conv2d_int8.py index 6b2c9527a400..721385c189e7 100644 --- a/python/tvm/topi/arm_cpu/conv2d_int8.py +++ b/python/tvm/topi/arm_cpu/conv2d_int8.py @@ -25,12 +25,7 @@ from ..x86.conv2d_int8 import _pack_data from ..nn.utils import get_pad_tuple from .tensor_intrin import dot_int8_int8_int32_neon_82, dot_int8_int8_int32_neon -from .conv2d_gemm import ( - compute_conv2d_gemm_without_weight_transform, - schedule_conv2d_gemm_interleaved, - schedule_conv2d_gemm_native, -) -from .arm_utils import get_tiling_B_interleaved_t +from .conv2d import compute_conv2d_NHWC, compute_conv2d_NHWC_without_transform, schedule_conv2d_NHWC def _get_default_config(cfg, data, kernel, strides, padding, dilation, out_dtype): @@ -208,75 +203,6 @@ def schedule_conv2d_nchw_int8(outs): return schedule_conv2d_NCHWc_int8(outs) -def _compute_conv2d_NHWC_quantized( - cfg, data, kernel, strides, padding, dilation, out_dtype, interleave_A -): - N, IH, IW, IC = get_const_tuple(data.shape) - KH, KW, _, OC = get_const_tuple(kernel.shape) - tile_rows_B, tile_cols_B = get_tiling_B_interleaved_t(interleave_A) - - kernel = nn.conv2d_gemm_weight_transform(kernel, tile_rows_B, tile_cols_B) - return compute_conv2d_gemm_without_weight_transform( - cfg, data, kernel, strides, padding, dilation, out_dtype, (KH, KW), OC, interleave_A - ) - - -def _compute_conv2d_NHWC_quantized_without_transform( - cfg, - data, - B, - strides, - padding, - dilation, - out_dtype, - kernel_size=None, - output_channels=None, - interleave_A=False, -): - return compute_conv2d_gemm_without_weight_transform( - cfg, - data, - B, - strides, - padding, - dilation, - out_dtype, - kernel_size, - output_channels, - interleave_A, - ) - - -def _schedule_conv2d_NHWC_quantized(cfg, outs, interleave_A): - """Create schedule for tensors""" - s = te.create_schedule([x.op for x in outs]) - # Vectorize the output and then inline all the rest - out = outs[0] - n, h, w, c = out.op.axis - n_h_fused = s[out].fuse(n, h) - outer, inner = s[out].split(c, 4) - s[out].vectorize(inner) - s[out].parallel(n_h_fused) - - def _callback(op): - """Traverse operators from computation graph""" - if op.name == "conv2d_gemm_output": - conv_out = op.output(0) - if interleave_A: - schedule_conv2d_gemm_interleaved(cfg, s, conv_out, out) - else: - schedule_conv2d_gemm_native(cfg, s, conv_out, out) - if out != conv_out: - s[conv_out].compute_at(s[out], inner) - else: - C = conv_out.op.input_tensors[0] - if interleave_A: - s[C].compute_at(s[out], inner) - - traverse_inline(s, outs[0].op, _callback) - return s - - # Interleaved schedules: those schedule will interleave the input data. The # weights are interleaved and transposed @autotvm.register_topi_compute("conv2d_NHWC_quantized_interleaved.arm_cpu") @@ -284,9 +210,7 @@ def compute_conv2d_NHWC_quantized_interleaved( cfg, data, kernel, strides, padding, dilation, out_dtype ): """Interface for interleaved compute_conv2d_NHWC_quantized_interleaved""" - return _compute_conv2d_NHWC_quantized( - cfg, data, kernel, strides, padding, dilation, out_dtype, True - ) + return compute_conv2d_NHWC(cfg, data, kernel, strides, padding, dilation, out_dtype, True) @autotvm.register_topi_compute("conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu") @@ -294,7 +218,7 @@ def compute_conv2d_NHWC_quantized_interleaved_without_transform( cfg, data, kernel, strides, padding, dilation, out_dtype, kernel_size, output_channels ): """Interface for interleaved compute_conv2d_NHWC_quantized_interleaved_without_transform""" - return _compute_conv2d_NHWC_quantized_without_transform( + return compute_conv2d_NHWC_without_transform( cfg, data, kernel, strides, padding, dilation, out_dtype, kernel_size, output_channels, True ) @@ -302,13 +226,13 @@ def compute_conv2d_NHWC_quantized_interleaved_without_transform( @autotvm.register_topi_schedule("conv2d_NHWC_quantized_interleaved.arm_cpu") def schedule_conv2d_NHWC_quantized_interleaved(cfg, outs): """Interface for interleaved schedule_conv2d_NHWC_quantized_interleaved""" - return _schedule_conv2d_NHWC_quantized(cfg, outs, True) + return schedule_conv2d_NHWC(cfg, outs, True) @autotvm.register_topi_schedule("conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu") def schedule_conv2d_NHWC_quantized_interleaved_without_transform(cfg, outs): """Interface for interleaved schedule_conv2d_NHWC_quantized_interleaved""" - return _schedule_conv2d_NHWC_quantized(cfg, outs, True) + return schedule_conv2d_NHWC(cfg, outs, True) # Native schedules: those schedule won't interleave A (which is left in its native form). @@ -316,9 +240,7 @@ def schedule_conv2d_NHWC_quantized_interleaved_without_transform(cfg, outs): @autotvm.register_topi_compute("conv2d_NHWC_quantized_native.arm_cpu") def compute_conv2d_NHWC_quantized_native(cfg, data, kernel, strides, padding, dilation, out_dtype): """Interface for native compute_conv2d_NHWC_quantized""" - return _compute_conv2d_NHWC_quantized( - cfg, data, kernel, strides, padding, dilation, out_dtype, False - ) + return compute_conv2d_NHWC(cfg, data, kernel, strides, padding, dilation, out_dtype, False) @autotvm.register_topi_compute("conv2d_NHWC_quantized_native_without_transform.arm_cpu") @@ -326,7 +248,7 @@ def compute_conv2d_NHWC_quantized_native_without_transform( cfg, data, kernel, strides, padding, dilation, out_dtype, kernel_size, output_channels ): """Interface for compute_conv2d_NHWC_quantized_native_without_transform""" - return _compute_conv2d_NHWC_quantized_without_transform( + return compute_conv2d_NHWC_without_transform( cfg, data, kernel, @@ -343,10 +265,10 @@ def compute_conv2d_NHWC_quantized_native_without_transform( @autotvm.register_topi_schedule("conv2d_NHWC_quantized_native.arm_cpu") def schedule_conv2d_NHWC_quantized_native(cfg, outs): """Interface for native schedule_conv2d_NHWC_quantized""" - return _schedule_conv2d_NHWC_quantized(cfg, outs, False) + return schedule_conv2d_NHWC(cfg, outs, False) @autotvm.register_topi_schedule("conv2d_NHWC_quantized_native_without_transform.arm_cpu") def schedule_conv2d_NHWC_quantized_native_without_transform(cfg, outs): """Interface for native schedule_conv2d_NHWC_quantized""" - return _schedule_conv2d_NHWC_quantized(cfg, outs, False) + return schedule_conv2d_NHWC(cfg, outs, False) diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 75f72ee93d4d..7516bff702f4 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -615,17 +615,17 @@ def conv2d_NCHWc_int8( ) -def conv2d_gemm_weight_transform(kernel, tile_rows, tile_cols): +def conv2d_gemm_weight_transform(kernel, tile_N, tile_K): """Weight transformation for winograd Parameters ---------- kernel: Tensor The raw kernel tensor with layout "NHWC". - tile_rows: int - Tile rows of the weight transformation for ConvGemm. - tile_cols: int - Tile columns of the weight transformation for ConvGemm. + tile_N: int + Tile size across N axis of the weight transformation for ConvGemm. (N = OC) + tile_K: int + Tile size across K axis of the weight transformation for ConvGemm. (K = KW * KH * IC) Returns ------- @@ -640,7 +640,7 @@ def conv2d_gemm_weight_transform(kernel, tile_rows, tile_cols): (K, N), lambda x, y: kernel[(x // IC) // KW, (x // IC) % KW, x % IC, y], "weight_flatten" ) - pad_N, pad_K = tvm.topi.arm_cpu.arm_utils.get_conv2d_weights_padding(N, K, tile_rows, tile_cols) + pad_N, pad_K = tvm.topi.arm_cpu.arm_utils.get_conv2d_weights_padding(N, K, tile_N, tile_K) N_padded = N + pad_N K_padded = K + pad_K @@ -650,11 +650,19 @@ def conv2d_gemm_weight_transform(kernel, tile_rows, tile_cols): kernel_flat, pad_before=(0, 0), pad_after=(pad_K, pad_N), name="weight_padding" ) - return te.compute( - (N_padded // tile_rows, K_padded // tile_cols, tile_rows, tile_cols), - lambda x, y, z, w: kernel_flat[w + tile_cols * y, z + tile_rows * x], - name="weight_block_reshape", - ) + if kernel.dtype in ["int8", "uint8"]: + B_inter_t = te.compute( + (N_padded // tile_N, K_padded // tile_K, tile_N, tile_K), + lambda x, y, z, w: kernel_flat[w + tile_K * y, z + tile_N * x], + name="weight_block_reshape", + ) + else: + B_inter_t = te.compute( + (N_padded // tile_N, K_padded // tile_K, tile_K, tile_N), + lambda x, y, z, w: kernel_flat[z + tile_K * y, w + tile_N * x], + name="weight_block_reshape", + ) + return B_inter_t def conv2d_winograd_weight_transform(kernel, tile_size): diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 13c7f74c7ecd..fe28bc50f80e 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -43,10 +43,10 @@ Expr MakeConvWinogradWeightTransform(Expr weight, int tile_size, std::string op_ return Call(op, {weight}, Attrs(attrs), {}); } -Expr MakeConvGemmWeightTransform(Expr weight, int tile_rows, int tile_cols, std::string op_name) { +Expr MakeConvGemmWeightTransform(Expr weight, int tile_N, int tile_K, std::string op_name) { auto attrs = make_object(); - attrs->tile_rows = tile_rows; - attrs->tile_cols = tile_cols; + attrs->tile_N = tile_N; + attrs->tile_K = tile_K; const Op& op = Op::Get(op_name); return Call(op, {weight}, Attrs(attrs), {}); } @@ -1472,13 +1472,14 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_gemm_without_weight_transform") TVM_REGISTER_NODE_TYPE(ConvGemmWeightTransformAttrs); // Gemm convolution shape relations -// In order to run GEMM we need to block-transpose and interleave the K x N weights matrix W. -// The high level idea is to subdivide W in tiles of tile_cols x tile_rows, and transpose and -// interleave them. The final output is a [N//tile_rows, K//tile_cols, tile_rows, tile_cols] +// In order to run GEMM we need to transform the K x N weights matrix W. +// +// For integer datatypes, the high level idea is to subdivide W in tiles of tile_K x tile_N, and +// transpose and interleave them. The final output is a [N//tile_N, K//tile_K, tile_N, tile_K] // matrix that we call W_interleaved_t. // -// In the following picture, we show how the first [tile_cols,tile_rows] block of W is transformed -// for tile_rows = 4 and tile_cols = 16 +// In the following picture, we show how the first [tile_K,tile_N] block of W is transformed +// for tile_N = 4 and tile_K = 16 // // W[0,0,:,:] W_interleaved_t[0,0,:,:] // +-------------------------------+ +----------------------------------- + @@ -1490,9 +1491,31 @@ TVM_REGISTER_NODE_TYPE(ConvGemmWeightTransformAttrs); // |W[15,0] W[15,1] W[15,2] W[15,3]| // +-------------------------------+ // -// Tile columns is usually the direction of the reduction. So, if our target can reduce k elements -// at the time, we should set tile_cols = k. -// Tile rows is connected with the number of registers available for the given target. +// Alternatively, for floating point datatypes, we subdivide W in tiles of tile_K x tile_N size, +// then interleave these tiles, without transposing. The final output is a [N//tile_N, K//tile_K, +// tile_K, tile_N] matrix called W_interleaved. +// +// In the following illustration, we show how the tiles are interleaved. +// Note that the inside of each tile is kept unchanged during this tranformation. +// +// W[:,:,:,:] W_interleaved[:,:,:,:] +// +--------+--------+--------+ +--------+--------+ +// | | | | | | | +// | tile_1 | tile_2 | tile_3 | | tile_1 | tile_4 | +// | | | | --\ | | | +// +--------+--------+--------+ --/ +--------+--------+ +// | | | | | | | +// | tile_4 | tile_5 | tile_6 | | tile_2 | tile_5 | +// | | | | | | | +// +--------+--------+--------+ +--------+--------+ +// | | | +// | tile_3 | tile_6 | +// | | | +// +--------+--------+ +// +// Tile K is the direction of the reduction in both cases. So, if our target can reduce k elements +// at the time, we should set tile_K = k. +// Tile N is connected with the number of registers available for the given target. // bool Conv2DGemmWeightTransformRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { @@ -1502,8 +1525,8 @@ bool Conv2DGemmWeightTransformRel(const Array& types, int num_inputs, cons const ConvGemmWeightTransformAttrs* param = attrs.as(); ICHECK(param != nullptr); - int n = param->tile_rows; - int k = param->tile_cols; + int n = param->tile_N; + int k = param->tile_K; ICHECK_EQ(weight->shape.size(), 4) << "Only support HWIO kernel layout"; @@ -1519,12 +1542,21 @@ bool Conv2DGemmWeightTransformRel(const Array& types, int num_inputs, cons const auto N_padded = N + pad_N; const auto K_padded = K + pad_K; - Array oshape{ - indexdiv(N_padded, n), - indexdiv(K_padded, k), - n, - k, - }; + Array oshape; + if (weight->dtype.is_int() || weight->dtype.is_uint()) + oshape = { + indexdiv(N_padded, n), + indexdiv(K_padded, k), + n, + k, + }; + else + oshape = { + indexdiv(N_padded, n), + indexdiv(K_padded, k), + k, + n, + }; reporter->Assign(types[1], TensorType(oshape, weight->dtype)); return true; diff --git a/tests/python/integration/test_arm_aprofile.py b/tests/python/integration/test_arm_aprofile.py index c38217a1b1c0..006ad5f359f4 100644 --- a/tests/python/integration/test_arm_aprofile.py +++ b/tests/python/integration/test_arm_aprofile.py @@ -49,6 +49,7 @@ def test_conv2d(dtype): invar, weight, kernel_size=kernel_size, + channels=2, strides=(1, 1), padding=(0, 0), dilation=(1, 1), diff --git a/tests/python/topi/test_topi_conv2d_nhwc.py b/tests/python/topi/test_topi_conv2d_nhwc.py index e60cf12aa83e..a8922ad11936 100644 --- a/tests/python/topi/test_topi_conv2d_nhwc.py +++ b/tests/python/topi/test_topi_conv2d_nhwc.py @@ -45,6 +45,19 @@ "hls": (topi.nn.conv2d_nhwc, topi.hls.schedule_conv2d_nhwc), } +device = tvm.testing.parameter( + ( + "llvm --device arm_cpu --mtriple aarch64-linux-gnu", + topi.arm_cpu.conv2d_nhwc_spatial_pack, + topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack, + ), + ( + "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a", + topi.arm_cpu.compute_conv2d_NHWC_fp32_hybrid, + topi.arm_cpu.schedule_conv2d_NHWC_fp32_hybrid, + ), +) + dtype = tvm.testing.parameter("float32") batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation = tvm.testing.parameters( @@ -77,6 +90,26 @@ def ref_data(dtype, batch, in_channel, in_size, num_filter, kernel, stride, padd return a_np, w_np, b_np +def test_conv2d_nhwc_gemm_fp32(device, ref_data, dtype, stride, padding, dilation): + a_np, w_np, b_np = ref_data + + A = te.placeholder(a_np.shape, name="A", dtype=dtype) + W = te.placeholder(w_np.shape, name="W", dtype=dtype) + + target, compute, schedule = device + dev = tvm.device(target, 0) + + with tvm.target.Target(target): + B = compute(A, W, stride, padding, dilation, dtype) + s = schedule([B]) + a = tvm.nd.array(a_np, dev) + w = tvm.nd.array(w_np, dev) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev) + func = tvm.build(s, [A, W, B], target) + func(a, w, b) + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5) + + def test_conv2d_nhwc_hwio(target, dev, ref_data, dtype, stride, padding, dilation): a_np, w_np, b_np = ref_data From 1498fe267531248a01cd450911fb096476ae58ad Mon Sep 17 00:00:00 2001 From: Andrei Hutu Date: Mon, 20 Nov 2023 16:48:20 +0000 Subject: [PATCH 2/6] Fix dotprod native schedule --- python/tvm/topi/arm_cpu/conv2d_gemm.py | 43 +++++++++++++++----------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index 649b6d99aef5..a363ea817e72 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -120,7 +120,12 @@ def compute_conv2d_gemm_without_weight_transform( # Pad if necessary N_transformed = B_interleaved_t.shape[0] - tile_N = B_interleaved_t.shape[2] if in_dtype in ["int8", "uint8"] else B_interleaved_t.shape[3] + if in_dtype in ["int8", "uint8"]: + tile_N = B_interleaved_t.shape[2] + tile_K_B = B_interleaved_t.shape[3] + else: + tile_N = B_interleaved_t.shape[3] + tile_K_B = B_interleaved_t.shape[2] # Select the tiling strategy for A. # The tiling information is chosen to maximize register usage during @@ -139,23 +144,23 @@ def compute_conv2d_gemm_without_weight_transform( # If smmla/ummla is enabled, we are loading 8 rows from A. Each row # will contain 8 elements tile_M = 8 - tile_K = 8 + tile_K_A = 8 elif target.features.has_dotprod and interleave_A: # If dot product has been enabled, and we are interleaving A # tile size should be 8x4 tile_M = 8 - tile_K = 4 + tile_K_A = 4 else: # If either there is no dot product or if we are using a native strategy # tile size should be 4x16 tile_M = 4 - tile_K = 16 + tile_K_A = 16 else: # In non-quantized cases, A is not interleaved. # We are loading 4 rows from A. # Each row will contain 4 elements, along the dimension of reduction tile_M = 4 - tile_K = 4 + tile_K_A = 4 pad_M = 0 pad_K = 0 @@ -163,8 +168,8 @@ def compute_conv2d_gemm_without_weight_transform( if M % tile_M != 0: pad_M = tile_M - (M % tile_M) - if K % tile_K != 0: - pad_K = tile_K - (K % tile_K) + if K % tile_K_A != 0: + pad_K = tile_K_A - (K % tile_K_A) M_padded = M + pad_M K_padded = K + pad_K @@ -191,11 +196,11 @@ def compute_conv2d_gemm_without_weight_transform( ( batches, M_padded // tile_M, - K_padded // tile_K, + K_padded // tile_K_A, tile_M, - tile_K, + tile_K_A, ), - lambda b, x, y, z, w: A[b, z + tile_M * x, w + tile_K * y], + lambda b, x, y, z, w: A[b, z + tile_M * x, w + tile_K_A * y], name="A_interleaved", ) target = Target.current(allow_none=False) @@ -217,8 +222,10 @@ def compute_conv2d_gemm_without_weight_transform( 2, ), lambda b, x, y, w, z, s, t: te.sum( - A_interleaved[b, x, k // tile_K, 2 * w + s, idxm(k, tile_K)].astype("int32") - * B_interleaved_t[y, k // tile_K, 2 * z + t, idxm(k, tile_K)].astype( + A_interleaved[b, x, k // tile_K_A, 2 * w + s, idxm(k, tile_K_A)].astype( + "int32" + ) + * B_interleaved_t[y, k // tile_K_B, 2 * z + t, idxm(k, tile_K_B)].astype( "int32" ), axis=k, @@ -271,8 +278,8 @@ def compute_conv2d_gemm_without_weight_transform( C_interleaved = te.compute( (batches, M_padded // tile_M, N_transformed, tile_M, tile_N), lambda b, x, y, w, z: te.sum( - A_interleaved[b, x, k // tile_K, w, idxm(k, tile_K)].astype("int32") - * B_interleaved_t[y, k // tile_K, z, idxm(k, tile_K)].astype("int32"), + A_interleaved[b, x, k // tile_K_A, w, idxm(k, tile_K_A)].astype("int32") + * B_interleaved_t[y, k // tile_K_B, z, idxm(k, tile_K_B)].astype("int32"), axis=k, ), name="C_interleaved", @@ -298,9 +305,9 @@ def compute_conv2d_gemm_without_weight_transform( A[b, x, k].astype("int32") * B_interleaved_t[ y // tile_N, - k // tile_K, + k // tile_K_B, idxm(y, tile_N), - idxm(k, tile_K), + idxm(k, tile_K_B), ].astype("int32"), axis=k, ), @@ -324,8 +331,8 @@ def compute_conv2d_gemm_without_weight_transform( A[b, x, k].astype("float32") * B_interleaved_t[ y // tile_N, - k // tile_K, - idxm(k, tile_K), + k // tile_K_B, + idxm(k, tile_K_B), idxm(y, tile_N), ].astype("float32"), axis=k, From 2d9d81704ba0520a69aee82442b3777e408ff582 Mon Sep 17 00:00:00 2001 From: Andrei Hutu Date: Mon, 20 Nov 2023 16:48:20 +0000 Subject: [PATCH 3/6] Address code review comments --- python/tvm/relay/op/strategy/arm_cpu.py | 95 +++++++------ python/tvm/topi/arm_cpu/arm_utils.py | 8 +- python/tvm/topi/arm_cpu/conv2d.py | 24 ++-- python/tvm/topi/arm_cpu/conv2d_alter_op.py | 4 +- .../strategy/test_select_implementation.py | 128 ++++++++++++++---- tests/python/topi/test_topi_conv2d_nhwc.py | 10 +- 6 files changed, 182 insertions(+), 87 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 36afe6957ff0..ee80736a2f4b 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -215,7 +215,17 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): has_dot_prod = target.features.has_dotprod has_matmul_i8 = target.features.has_matmul_i8 - if data.dtype in ["int8", "uint8"]: + if not is_aarch64: + # TODO(@giuseros) + # This strategy errors out for quantized data types when tuning. + # Let's use this only for non-aarch64 or non-quantized cases + strategy.add_implementation( + wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_spatial_pack), + wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack), + name="conv2d_nhwc_spatial_pack.arm_cpu", + ) + elif data.dtype in ["int8", "uint8"]: + # Quantized cases if has_matmul_i8: strategy.add_implementation( wrap_compute_conv2d( @@ -232,7 +242,7 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_native), name="conv2d_NHWC_quantized_native.arm_cpu", ) - if is_aarch64 and has_asimd: + if has_asimd: strategy.add_implementation( wrap_compute_conv2d( topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved @@ -242,20 +252,12 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): ), name="conv2d_NHWC_quantized_interleaved.arm_cpu", ) - if is_aarch64 and data.dtype not in ["int8", "uint8"]: - strategy.add_implementation( - wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_fp32_hybrid), - wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_fp32_hybrid), - name="conv2d_NHWC_fp32_hybrid.arm_cpu", - ) else: - # TODO(@giuseros) - # This strategy errors out for quantized data types when tuning. - # Let's use this only for non-aarch64 or non-quantized cases + # Non-quantized cases strategy.add_implementation( - wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_spatial_pack), - wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack), - name="conv2d_nhwc_spatial_pack.arm_cpu", + wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_float_hybrid), + wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_float_hybrid), + name="conv2d_NHWC_float_hybrid.arm_cpu", ) else: raise RuntimeError(f"Unsupported kernel layout {kernel_layout} for conv2d NHWC") @@ -497,38 +499,51 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ interleaved_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved_without_transform native_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_native_without_transform - if layout == "NHWC" and data.dtype in ["int8", "uint8"]: - if has_matmul_i8: - strategy.add_implementation( - wrap_compute_conv2d_gemm(interleaved_compute), - wrap_topi_schedule( - topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform - ), - name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu", - ) - if has_dot_prod: + if layout == "NHWC": + if not is_aarch64: + # Non-AArch64 cases + raise RuntimeError(f"Unsupported non-AArch64 conv2d_NHWC_without_transform") + elif data.dtype in ["int8", "uint8"]: + # Quantized cases + if has_matmul_i8: + strategy.add_implementation( + wrap_compute_conv2d_gemm(interleaved_compute), + wrap_topi_schedule( + topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform + ), + name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu", + ) + if has_dot_prod: + strategy.add_implementation( + wrap_compute_conv2d_gemm(native_compute), + wrap_topi_schedule( + topi.arm_cpu.schedule_conv2d_NHWC_quantized_native_without_transform + ), + name="conv2d_NHWC_quantized_native_without_transform.arm_cpu", + ) + if has_asimd: + strategy.add_implementation( + wrap_compute_conv2d_gemm(interleaved_compute), + wrap_topi_schedule( + topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform + ), + name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu", + ) + else: + # Non-quantized cases strategy.add_implementation( - wrap_compute_conv2d_gemm(native_compute), - wrap_topi_schedule( - topi.arm_cpu.schedule_conv2d_NHWC_quantized_native_without_transform + wrap_compute_conv2d_gemm( + topi.arm_cpu.compute_conv2d_NHWC_float_hybrid_without_transform ), - name="conv2d_NHWC_quantized_native_without_transform.arm_cpu", - ) - if is_aarch64 and has_asimd: - strategy.add_implementation( - wrap_compute_conv2d_gemm(interleaved_compute), wrap_topi_schedule( - topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform + topi.arm_cpu.schedule_conv2d_NHWC_float_hybrid_without_transform ), - name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu", + name="conv2d_NHWC_float_hybrid_without_transform.arm_cpu", ) else: - strategy.add_implementation( - wrap_compute_conv2d_gemm( - topi.arm_cpu.compute_conv2d_NHWC_fp32_hybrid_without_transform - ), - wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_fp32_hybrid_without_transform), - name="conv2d_NHWC_fp32_hybrid_without_transform.arm_cpu", + raise RuntimeError( + f"Unsupported conv2d_NHWC_without_transform layout {layout}" + f"with datatype {data.dtype}" ) return strategy diff --git a/python/tvm/topi/arm_cpu/arm_utils.py b/python/tvm/topi/arm_cpu/arm_utils.py index 50f570f17f47..0a67aa1c6b83 100644 --- a/python/tvm/topi/arm_cpu/arm_utils.py +++ b/python/tvm/topi/arm_cpu/arm_utils.py @@ -36,10 +36,10 @@ def get_tiling_B_transformed(interleave_A, in_dtype): Parameters ---------- - interleave_A : bool - determines if A is expected to be interleaved - in_dtype : str - input datatype + interleave_A : bool + determines if A is expected to be interleaved + in_dtype : str + input datatype Returns diff --git a/python/tvm/topi/arm_cpu/conv2d.py b/python/tvm/topi/arm_cpu/conv2d.py index 67b2ed8d86a6..33ddedb9279f 100644 --- a/python/tvm/topi/arm_cpu/conv2d.py +++ b/python/tvm/topi/arm_cpu/conv2d.py @@ -585,17 +585,17 @@ def _callback(op): return s -@autotvm.register_topi_compute("conv2d_NHWC_fp32_hybrid.arm_cpu") -def compute_conv2d_NHWC_fp32_hybrid(cfg, data, kernel, strides, padding, dilation, out_dtype): - """Interface for hybrid compute_conv2d_NHWC_fp32_hybrid""" +@autotvm.register_topi_compute("conv2d_NHWC_float_hybrid.arm_cpu") +def compute_conv2d_NHWC_float_hybrid(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Interface for hybrid compute_conv2d_NHWC_float_hybrid""" return compute_conv2d_NHWC(cfg, data, kernel, strides, padding, dilation, out_dtype, False) -@autotvm.register_topi_compute("conv2d_NHWC_fp32_hybrid_without_transform.arm_cpu") -def compute_conv2d_NHWC_fp32_hybrid_without_transform( +@autotvm.register_topi_compute("conv2d_NHWC_float_hybrid_without_transform.arm_cpu") +def compute_conv2d_NHWC_float_hybrid_without_transform( cfg, data, kernel, strides, padding, dilation, out_dtype, kernel_size, output_channels ): - """Interface for hybrid compute_conv2d_NHWC_fp32_hybrid_without_transform""" + """Interface for hybrid compute_conv2d_NHWC_float_hybrid_without_transform""" return compute_conv2d_NHWC_without_transform( cfg, data, @@ -610,13 +610,13 @@ def compute_conv2d_NHWC_fp32_hybrid_without_transform( ) -@autotvm.register_topi_schedule("conv2d_NHWC_fp32_hybrid.arm_cpu") -def schedule_conv2d_NHWC_fp32_hybrid(cfg, outs): - """Interface for hybrid schedule_conv2d_NHWC_fp32_hybrid""" +@autotvm.register_topi_schedule("conv2d_NHWC_float_hybrid.arm_cpu") +def schedule_conv2d_NHWC_float_hybrid(cfg, outs): + """Interface for hybrid schedule_conv2d_NHWC_float_hybrid""" return schedule_conv2d_NHWC(cfg, outs, False) -@autotvm.register_topi_schedule("conv2d_NHWC_fp32_hybrid_without_transform.arm_cpu") -def schedule_conv2d_NHWC_fp32_hybrid_without_transform(cfg, outs): - """Interface for hybrid schedule_conv2d_NHWC_fp32_hybrid""" +@autotvm.register_topi_schedule("conv2d_NHWC_float_hybrid_without_transform.arm_cpu") +def schedule_conv2d_NHWC_float_hybrid_without_transform(cfg, outs): + """Interface for hybrid schedule_conv2d_NHWC_float_hybrid""" return schedule_conv2d_NHWC(cfg, outs, False) diff --git a/python/tvm/topi/arm_cpu/conv2d_alter_op.py b/python/tvm/topi/arm_cpu/conv2d_alter_op.py index 8984d1aafa69..b346a05d23a4 100644 --- a/python/tvm/topi/arm_cpu/conv2d_alter_op.py +++ b/python/tvm/topi/arm_cpu/conv2d_alter_op.py @@ -148,10 +148,10 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): inputs[0], relay.Constant(tvm.nd.array(reshaped_new_kernel)), **new_attrs ) - if topi_tmpl == "conv2d_NHWC_fp32_hybrid.arm_cpu": + if topi_tmpl == "conv2d_NHWC_float_hybrid.arm_cpu": assert data_layout == "NHWC" and kernel_layout == "HWIO" KH, KW, _, OC = get_const_tuple(kernel.shape) - new_workload_name = "conv2d_NHWC_fp32_hybrid_without_transform.arm_cpu" + new_workload_name = "conv2d_NHWC_float_hybrid_without_transform.arm_cpu" new_kernel, new_kernel_expr = transform_weights(inputs, data, kernel, interleave_A=False) new_workload = autotvm.task.args_to_workload( [data, new_kernel, strides, padding, dilation, out_dtype, (KH, KW), OC], diff --git a/tests/python/relay/strategy/test_select_implementation.py b/tests/python/relay/strategy/test_select_implementation.py index d7dd0abbc4d7..a98d85d3c58c 100644 --- a/tests/python/relay/strategy/test_select_implementation.py +++ b/tests/python/relay/strategy/test_select_implementation.py @@ -57,6 +57,39 @@ def test_concatenate(target, expected_implementation): assert impl.name == expected_implementation +def _get_conv2d_impl(dtype, target): + """Returns selected conv2d implementation for a given datatype and target""" + data_shape = (1, 1, 1, 4) + weight_shape = (1, 1, 4, 4) + data_layout = "NHWC" + kernel_layout = "HWIO" + channels = 4 + kernel_size = (1, 1) + + out = relay.nn.conv2d( + relay.var("data", shape=data_shape, dtype=dtype), + relay.var("weight", shape=weight_shape, dtype=dtype), + kernel_size=kernel_size, + channels=channels, + data_layout=data_layout, + kernel_layout=kernel_layout, + out_dtype=dtype, + ) + + with target: + out = run_opt_pass(out, relay.transform.AlterOpLayout()) + impl, _ = relay.backend.te_compiler.select_implementation( + out.op, + out.attrs, + [te.placeholder(data_shape, dtype), te.placeholder(weight_shape, dtype)], + out.checked_type, + target, + use_autotvm=False, + ) + + return impl.name + + @pytest.mark.parametrize( "target,expected_impl", [ @@ -93,37 +126,78 @@ def test_concatenate(target, expected_implementation): ) def test_int8_conv2d(target, expected_impl): target = tvm.target.Target(target) - dtype = "int8" - data_shape = (1, 1, 1, 4) - weight_shape = (1, 1, 4, 4) - data_layout = "NHWC" - kernel_layout = "HWIO" - channels = 4 - kernel_size = (1, 1) - out = relay.nn.conv2d( - relay.var("data", shape=data_shape, dtype=dtype), - relay.var("weight", shape=weight_shape, dtype=dtype), - kernel_size=kernel_size, - channels=channels, - data_layout=data_layout, - kernel_layout=kernel_layout, - out_dtype=dtype, - ) + selected_impl = _get_conv2d_impl(dtype, target) + assert selected_impl == expected_impl - with target: - out = run_opt_pass(out, relay.transform.AlterOpLayout()) - impl, _ = relay.backend.te_compiler.select_implementation( - out.op, - out.attrs, - [te.placeholder(data_shape, dtype), te.placeholder(weight_shape, dtype)], - out.checked_type, - target, - use_autotvm=False, - ) - assert impl.name == expected_impl +@pytest.mark.parametrize( + "target,expected_impl", + [ + ("llvm -device=arm_cpu", "conv2d_nhwc_spatial_pack.arm_cpu"), + ( + "llvm -device=arm_cpu -mtriple=armv8l-linux-gnu -mattr=+neon", + "conv2d_nhwc_spatial_pack.arm_cpu", + ), + ( + "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu", + "conv2d_NHWC_float_hybrid_without_transform.arm_cpu", + ), + ( + "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon", + "conv2d_NHWC_float_hybrid_without_transform.arm_cpu", + ), + ( + "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v8.2a,+neon", + "conv2d_NHWC_float_hybrid_without_transform.arm_cpu", + ), + ( + "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v9a", + "conv2d_NHWC_float_hybrid_without_transform.arm_cpu", + ), + ], +) +def test_fp32_conv2d(target, expected_impl): + target = tvm.target.Target(target) + dtype = "float32" + + selected_impl = _get_conv2d_impl(dtype, target) + assert selected_impl == expected_impl + + +@pytest.mark.parametrize( + "target,expected_impl", + [ + ("llvm -device=arm_cpu", "conv2d_nhwc_spatial_pack.arm_cpu"), + ( + "llvm -device=arm_cpu -mtriple=armv8l-linux-gnu -mattr=+neon", + "conv2d_nhwc_spatial_pack.arm_cpu", + ), + ( + "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu", + "conv2d_NHWC_float_hybrid_without_transform.arm_cpu", + ), + ( + "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon", + "conv2d_NHWC_float_hybrid_without_transform.arm_cpu", + ), + ( + "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v8.2a,+neon", + "conv2d_NHWC_float_hybrid_without_transform.arm_cpu", + ), + ( + "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v9a", + "conv2d_NHWC_float_hybrid_without_transform.arm_cpu", + ), + ], +) +def test_fp16_conv2d(target, expected_impl): + target = tvm.target.Target(target) + dtype = "float16" + + selected_impl = _get_conv2d_impl(dtype, target) + assert selected_impl == expected_impl @pytest.mark.parametrize( diff --git a/tests/python/topi/test_topi_conv2d_nhwc.py b/tests/python/topi/test_topi_conv2d_nhwc.py index a8922ad11936..589ae0f7d4ff 100644 --- a/tests/python/topi/test_topi_conv2d_nhwc.py +++ b/tests/python/topi/test_topi_conv2d_nhwc.py @@ -16,6 +16,7 @@ # under the License. """Example code to do convolution.""" import os +import platform import numpy as np import tvm from tvm import te @@ -53,8 +54,8 @@ ), ( "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a", - topi.arm_cpu.compute_conv2d_NHWC_fp32_hybrid, - topi.arm_cpu.schedule_conv2d_NHWC_fp32_hybrid, + topi.arm_cpu.compute_conv2d_NHWC_float_hybrid, + topi.arm_cpu.schedule_conv2d_NHWC_float_hybrid, ), ) @@ -106,6 +107,11 @@ def test_conv2d_nhwc_gemm_fp32(device, ref_data, dtype, stride, padding, dilatio w = tvm.nd.array(w_np, dev) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev) func = tvm.build(s, [A, W, B], target) + + build_only = platform.machine() != "aarch64" + if build_only: + return + func(a, w, b) tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5) From 98f6c11dec3030c56e54416efbd7b8bc6dece745 Mon Sep 17 00:00:00 2001 From: Andrei Hutu Date: Mon, 20 Nov 2023 16:48:20 +0000 Subject: [PATCH 4/6] Rename schedule and restrict usage to fp32 or fp16 --- python/tvm/relay/op/strategy/arm_cpu.py | 54 +++++++++---------- python/tvm/topi/arm_cpu/conv2d.py | 24 ++++----- python/tvm/topi/arm_cpu/conv2d_alter_op.py | 4 +- python/tvm/topi/arm_cpu/conv2d_gemm.py | 4 +- src/relay/op/nn/convolution.cc | 2 +- .../strategy/test_select_implementation.py | 16 +++--- tests/python/topi/test_topi_conv2d_nhwc.py | 4 +- 7 files changed, 52 insertions(+), 56 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index ee80736a2f4b..91059940288b 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -214,18 +214,8 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): has_asimd = target.features.has_asimd has_dot_prod = target.features.has_dotprod has_matmul_i8 = target.features.has_matmul_i8 - - if not is_aarch64: - # TODO(@giuseros) - # This strategy errors out for quantized data types when tuning. - # Let's use this only for non-aarch64 or non-quantized cases - strategy.add_implementation( - wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_spatial_pack), - wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack), - name="conv2d_nhwc_spatial_pack.arm_cpu", - ) - elif data.dtype in ["int8", "uint8"]: - # Quantized cases + # Quantized cases + if is_aarch64 and data.dtype in ["int8", "uint8"]: if has_matmul_i8: strategy.add_implementation( wrap_compute_conv2d( @@ -252,12 +242,21 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): ), name="conv2d_NHWC_quantized_interleaved.arm_cpu", ) - else: - # Non-quantized cases + # Non-quantized cases + if is_aarch64 and data.dtype in ["float32", "float16"]: + strategy.add_implementation( + wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_hybrid), + wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_hybrid), + name="conv2d_NHWC_hybrid.arm_cpu", + ) + if (not is_aarch64) or (data.dtype not in ["int8", "uint8", "float32", "float16"]): + # TODO(@giuseros) + # This strategy errors out for quantized data types when tuning. + # Let's use this only for non-aarch64 or non-quantized cases strategy.add_implementation( - wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_float_hybrid), - wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_float_hybrid), - name="conv2d_NHWC_float_hybrid.arm_cpu", + wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_spatial_pack), + wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack), + name="conv2d_nhwc_spatial_pack.arm_cpu", ) else: raise RuntimeError(f"Unsupported kernel layout {kernel_layout} for conv2d NHWC") @@ -499,11 +498,12 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ interleaved_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved_without_transform native_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_native_without_transform - if layout == "NHWC": + if layout == "NHWC" and data.dtype in ["int8", "uint8", "float32", "float16"]: + # Non-AArch64 cases if not is_aarch64: - # Non-AArch64 cases - raise RuntimeError(f"Unsupported non-AArch64 conv2d_NHWC_without_transform") - elif data.dtype in ["int8", "uint8"]: + raise RuntimeError("Unsupported non-AArch64 conv2d_NHWC_without_transform") + # AArch64 cases + if data.dtype in ["int8", "uint8"]: # Quantized cases if has_matmul_i8: strategy.add_implementation( @@ -529,16 +529,12 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ ), name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu", ) - else: + elif data.dtype in ["float32", "float16"]: # Non-quantized cases strategy.add_implementation( - wrap_compute_conv2d_gemm( - topi.arm_cpu.compute_conv2d_NHWC_float_hybrid_without_transform - ), - wrap_topi_schedule( - topi.arm_cpu.schedule_conv2d_NHWC_float_hybrid_without_transform - ), - name="conv2d_NHWC_float_hybrid_without_transform.arm_cpu", + wrap_compute_conv2d_gemm(topi.arm_cpu.compute_conv2d_NHWC_hybrid_without_transform), + wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_hybrid_without_transform), + name="conv2d_NHWC_hybrid_without_transform.arm_cpu", ) else: raise RuntimeError( diff --git a/python/tvm/topi/arm_cpu/conv2d.py b/python/tvm/topi/arm_cpu/conv2d.py index 33ddedb9279f..90e199f36a03 100644 --- a/python/tvm/topi/arm_cpu/conv2d.py +++ b/python/tvm/topi/arm_cpu/conv2d.py @@ -585,17 +585,17 @@ def _callback(op): return s -@autotvm.register_topi_compute("conv2d_NHWC_float_hybrid.arm_cpu") -def compute_conv2d_NHWC_float_hybrid(cfg, data, kernel, strides, padding, dilation, out_dtype): - """Interface for hybrid compute_conv2d_NHWC_float_hybrid""" +@autotvm.register_topi_compute("conv2d_NHWC_hybrid.arm_cpu") +def compute_conv2d_NHWC_hybrid(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Interface for hybrid compute_conv2d_NHWC_hybrid""" return compute_conv2d_NHWC(cfg, data, kernel, strides, padding, dilation, out_dtype, False) -@autotvm.register_topi_compute("conv2d_NHWC_float_hybrid_without_transform.arm_cpu") -def compute_conv2d_NHWC_float_hybrid_without_transform( +@autotvm.register_topi_compute("conv2d_NHWC_hybrid_without_transform.arm_cpu") +def compute_conv2d_NHWC_hybrid_without_transform( cfg, data, kernel, strides, padding, dilation, out_dtype, kernel_size, output_channels ): - """Interface for hybrid compute_conv2d_NHWC_float_hybrid_without_transform""" + """Interface for hybrid compute_conv2d_NHWC_hybrid_without_transform""" return compute_conv2d_NHWC_without_transform( cfg, data, @@ -610,13 +610,13 @@ def compute_conv2d_NHWC_float_hybrid_without_transform( ) -@autotvm.register_topi_schedule("conv2d_NHWC_float_hybrid.arm_cpu") -def schedule_conv2d_NHWC_float_hybrid(cfg, outs): - """Interface for hybrid schedule_conv2d_NHWC_float_hybrid""" +@autotvm.register_topi_schedule("conv2d_NHWC_hybrid.arm_cpu") +def schedule_conv2d_NHWC_hybrid(cfg, outs): + """Interface for hybrid schedule_conv2d_NHWC_hybrid""" return schedule_conv2d_NHWC(cfg, outs, False) -@autotvm.register_topi_schedule("conv2d_NHWC_float_hybrid_without_transform.arm_cpu") -def schedule_conv2d_NHWC_float_hybrid_without_transform(cfg, outs): - """Interface for hybrid schedule_conv2d_NHWC_float_hybrid""" +@autotvm.register_topi_schedule("conv2d_NHWC_hybrid_without_transform.arm_cpu") +def schedule_conv2d_NHWC_hybrid_without_transform(cfg, outs): + """Interface for hybrid schedule_conv2d_NHWC_hybrid""" return schedule_conv2d_NHWC(cfg, outs, False) diff --git a/python/tvm/topi/arm_cpu/conv2d_alter_op.py b/python/tvm/topi/arm_cpu/conv2d_alter_op.py index b346a05d23a4..fe4569ceb1ad 100644 --- a/python/tvm/topi/arm_cpu/conv2d_alter_op.py +++ b/python/tvm/topi/arm_cpu/conv2d_alter_op.py @@ -148,10 +148,10 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): inputs[0], relay.Constant(tvm.nd.array(reshaped_new_kernel)), **new_attrs ) - if topi_tmpl == "conv2d_NHWC_float_hybrid.arm_cpu": + if topi_tmpl == "conv2d_NHWC_hybrid.arm_cpu": assert data_layout == "NHWC" and kernel_layout == "HWIO" KH, KW, _, OC = get_const_tuple(kernel.shape) - new_workload_name = "conv2d_NHWC_float_hybrid_without_transform.arm_cpu" + new_workload_name = "conv2d_NHWC_hybrid_without_transform.arm_cpu" new_kernel, new_kernel_expr = transform_weights(inputs, data, kernel, interleave_A=False) new_workload = autotvm.task.args_to_workload( [data, new_kernel, strides, padding, dilation, out_dtype, (KH, KW), OC], diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index a363ea817e72..e08775dcf3b5 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -328,13 +328,13 @@ def compute_conv2d_gemm_without_weight_transform( C = te.compute( (batches, M_padded, N_padded), lambda b, x, y: te.sum( - A[b, x, k].astype("float32") + A[b, x, k].astype(in_dtype) * B_interleaved_t[ y // tile_N, k // tile_K_B, idxm(k, tile_K_B), idxm(y, tile_N), - ].astype("float32"), + ].astype(in_dtype), axis=k, ), name="C", diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index fe28bc50f80e..e895c98df2c0 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -1543,7 +1543,7 @@ bool Conv2DGemmWeightTransformRel(const Array& types, int num_inputs, cons const auto K_padded = K + pad_K; Array oshape; - if (weight->dtype.is_int() || weight->dtype.is_uint()) + if (weight->dtype.bits() == 8 && (weight->dtype.is_int() || weight->dtype.is_uint())) oshape = { indexdiv(N_padded, n), indexdiv(K_padded, k), diff --git a/tests/python/relay/strategy/test_select_implementation.py b/tests/python/relay/strategy/test_select_implementation.py index a98d85d3c58c..f9b1a002a8b6 100644 --- a/tests/python/relay/strategy/test_select_implementation.py +++ b/tests/python/relay/strategy/test_select_implementation.py @@ -142,19 +142,19 @@ def test_int8_conv2d(target, expected_impl): ), ( "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu", - "conv2d_NHWC_float_hybrid_without_transform.arm_cpu", + "conv2d_NHWC_hybrid_without_transform.arm_cpu", ), ( "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon", - "conv2d_NHWC_float_hybrid_without_transform.arm_cpu", + "conv2d_NHWC_hybrid_without_transform.arm_cpu", ), ( "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v8.2a,+neon", - "conv2d_NHWC_float_hybrid_without_transform.arm_cpu", + "conv2d_NHWC_hybrid_without_transform.arm_cpu", ), ( "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v9a", - "conv2d_NHWC_float_hybrid_without_transform.arm_cpu", + "conv2d_NHWC_hybrid_without_transform.arm_cpu", ), ], ) @@ -176,19 +176,19 @@ def test_fp32_conv2d(target, expected_impl): ), ( "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu", - "conv2d_NHWC_float_hybrid_without_transform.arm_cpu", + "conv2d_NHWC_hybrid_without_transform.arm_cpu", ), ( "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon", - "conv2d_NHWC_float_hybrid_without_transform.arm_cpu", + "conv2d_NHWC_hybrid_without_transform.arm_cpu", ), ( "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v8.2a,+neon", - "conv2d_NHWC_float_hybrid_without_transform.arm_cpu", + "conv2d_NHWC_hybrid_without_transform.arm_cpu", ), ( "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v9a", - "conv2d_NHWC_float_hybrid_without_transform.arm_cpu", + "conv2d_NHWC_hybrid_without_transform.arm_cpu", ), ], ) diff --git a/tests/python/topi/test_topi_conv2d_nhwc.py b/tests/python/topi/test_topi_conv2d_nhwc.py index 589ae0f7d4ff..05f9cb9c0570 100644 --- a/tests/python/topi/test_topi_conv2d_nhwc.py +++ b/tests/python/topi/test_topi_conv2d_nhwc.py @@ -54,8 +54,8 @@ ), ( "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a", - topi.arm_cpu.compute_conv2d_NHWC_float_hybrid, - topi.arm_cpu.schedule_conv2d_NHWC_float_hybrid, + topi.arm_cpu.compute_conv2d_NHWC_hybrid, + topi.arm_cpu.schedule_conv2d_NHWC_hybrid, ), ) From d56a96bf77e632a30d003a1d9a56064754b0c834 Mon Sep 17 00:00:00 2001 From: Andrei Hutu Date: Mon, 20 Nov 2023 16:48:20 +0000 Subject: [PATCH 5/6] Add `spatial_pack` implementation with low plevel for fp32/fp16 --- python/tvm/relay/op/strategy/arm_cpu.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 91059940288b..f7771bc64dae 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -249,7 +249,7 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_hybrid), name="conv2d_NHWC_hybrid.arm_cpu", ) - if (not is_aarch64) or (data.dtype not in ["int8", "uint8", "float32", "float16"]): + if (not is_aarch64) or (data.dtype not in ["int8", "uint8"]): # TODO(@giuseros) # This strategy errors out for quantized data types when tuning. # Let's use this only for non-aarch64 or non-quantized cases @@ -257,6 +257,7 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_spatial_pack), wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack), name="conv2d_nhwc_spatial_pack.arm_cpu", + plevel=5, ) else: raise RuntimeError(f"Unsupported kernel layout {kernel_layout} for conv2d NHWC") From 0063880a49dc7f84beab2042efea8261b27326bc Mon Sep 17 00:00:00 2001 From: Andrei Hutu Date: Wed, 22 Nov 2023 10:41:06 +0000 Subject: [PATCH 6/6] Rewrite `arm_cpu` conv2d quantized strategy selection --- python/tvm/relay/op/strategy/arm_cpu.py | 74 ++++++++++++++----------- 1 file changed, 43 insertions(+), 31 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index f7771bc64dae..1f9a6fc41e16 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -211,35 +211,41 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): ) elif kernel_layout == "HWIO": is_aarch64 = target.features.is_aarch64 - has_asimd = target.features.has_asimd has_dot_prod = target.features.has_dotprod has_matmul_i8 = target.features.has_matmul_i8 + interleaved_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved + interleaved_schedule = topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved + native_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_native + native_schedule = topi.arm_cpu.schedule_conv2d_NHWC_quantized_native # Quantized cases if is_aarch64 and data.dtype in ["int8", "uint8"]: - if has_matmul_i8: + if has_matmul_i8 and has_dot_prod: strategy.add_implementation( - wrap_compute_conv2d( - topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved - ), - wrap_topi_schedule( - topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved - ), + wrap_compute_conv2d(interleaved_compute), + wrap_topi_schedule(interleaved_schedule), name="conv2d_NHWC_quantized_interleaved.arm_cpu", ) - if has_dot_prod: strategy.add_implementation( - wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_native), - wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_native), + wrap_compute_conv2d(native_compute), + wrap_topi_schedule(native_schedule), name="conv2d_NHWC_quantized_native.arm_cpu", ) - if has_asimd: + elif has_matmul_i8: strategy.add_implementation( - wrap_compute_conv2d( - topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved - ), - wrap_topi_schedule( - topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved - ), + wrap_compute_conv2d(interleaved_compute), + wrap_topi_schedule(interleaved_schedule), + name="conv2d_NHWC_quantized_interleaved.arm_cpu", + ) + elif has_dot_prod: + strategy.add_implementation( + wrap_compute_conv2d(native_compute), + wrap_topi_schedule(native_schedule), + name="conv2d_NHWC_quantized_native.arm_cpu", + ) + else: + strategy.add_implementation( + wrap_compute_conv2d(interleaved_compute), + wrap_topi_schedule(interleaved_schedule), name="conv2d_NHWC_quantized_interleaved.arm_cpu", ) # Non-quantized cases @@ -493,12 +499,13 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ data = inputs[0] strategy = _op.OpStrategy() is_aarch64 = target.features.is_aarch64 - has_asimd = target.features.has_asimd has_dot_prod = target.features.has_dotprod has_matmul_i8 = target.features.has_matmul_i8 interleaved_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved_without_transform + interleaved_schedule = topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform native_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_native_without_transform + native_schedule = topi.arm_cpu.schedule_conv2d_NHWC_quantized_native_without_transform if layout == "NHWC" and data.dtype in ["int8", "uint8", "float32", "float16"]: # Non-AArch64 cases if not is_aarch64: @@ -506,28 +513,33 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ # AArch64 cases if data.dtype in ["int8", "uint8"]: # Quantized cases - if has_matmul_i8: + if has_matmul_i8 and has_dot_prod: strategy.add_implementation( wrap_compute_conv2d_gemm(interleaved_compute), - wrap_topi_schedule( - topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform - ), + wrap_topi_schedule(interleaved_schedule), name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu", ) - if has_dot_prod: strategy.add_implementation( wrap_compute_conv2d_gemm(native_compute), - wrap_topi_schedule( - topi.arm_cpu.schedule_conv2d_NHWC_quantized_native_without_transform - ), + wrap_topi_schedule(native_schedule), name="conv2d_NHWC_quantized_native_without_transform.arm_cpu", ) - if has_asimd: + elif has_matmul_i8: + strategy.add_implementation( + wrap_compute_conv2d_gemm(interleaved_compute), + wrap_topi_schedule(interleaved_schedule), + name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu", + ) + elif has_dot_prod: + strategy.add_implementation( + wrap_compute_conv2d_gemm(native_compute), + wrap_topi_schedule(native_schedule), + name="conv2d_NHWC_quantized_native_without_transform.arm_cpu", + ) + else: strategy.add_implementation( wrap_compute_conv2d_gemm(interleaved_compute), - wrap_topi_schedule( - topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform - ), + wrap_topi_schedule(interleaved_schedule), name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu", ) elif data.dtype in ["float32", "float16"]: