diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index e58c73dc7354..58edb9df8b97 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -197,12 +197,14 @@ struct ConvWinogradWeightTransformAttrs : public tvm::AttrsNode { - int tile_rows; - int tile_cols; + int tile_N; + int tile_K; TVM_DECLARE_ATTRS(ConvGemmWeightTransformAttrs, "relay.attrs.ConvGemmWeightTransformAttrs") { - TVM_ATTR_FIELD(tile_rows).describe("Tile rows of the weight transformation for ConvGemm."); - TVM_ATTR_FIELD(tile_cols).describe("Tile columns of the weight transformation for ConvGemm."); + TVM_ATTR_FIELD(tile_N).describe( + "Tile size across N axis of the weight transformation for ConvGemm. (N = OC)"); + TVM_ATTR_FIELD(tile_K).describe( + "Tile size across K axis of the weight transformation for ConvGemm. (K = KW * KH * IC)"); } }; diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 6acaf43fe7d2..a03907f071fd 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -798,7 +798,7 @@ def mirror_pad_func(attrs, inputs, _): @reg.register_compute("nn.contrib_conv2d_gemm_weight_transform") def compute_contrib_conv2d_gemm_weight_transform(attrs, inputs, out_dtype): """Compute definition of contrib_conv2d_gemm_weight_transform""" - out = topi.nn.conv2d_gemm_weight_transform(inputs[0], attrs.tile_rows, attrs.tile_cols) + out = topi.nn.conv2d_gemm_weight_transform(inputs[0], attrs.tile_N, attrs.tile_K) return [out] diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 89953eb1dfb3..8cb66ecaa9a2 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -2741,7 +2741,7 @@ def contrib_conv2d_winograd_weight_transform(weight, tile_size): return _make.contrib_conv2d_winograd_weight_transform(weight, tile_size) -def contrib_conv2d_gemm_weight_transform(weights, tile_rows, tile_cols): +def contrib_conv2d_gemm_weight_transform(weights, tile_N, tile_K): r"""Weight Transformation part for 2D convolution with gemm algorithm. We separate this as a single op to enable pre-compute for inference. @@ -2751,17 +2751,17 @@ def contrib_conv2d_gemm_weight_transform(weights, tile_rows, tile_cols): ---------- weights : tvm.relay.Expr The weight expressions. - tile_rows: int - Tile rows of the weight transformation for ConvGemm. - tile_cols: int - Tile columns of the weight transformation for ConvGemm. + tile_N: int + Tile size across N axis of the weight transformation for ConvGemm. (N = OC) + tile_K: int + Tile size across K axis of the weight transformation for ConvGemm. (K = KW * KH * IC) Returns ------- result : tvm.relay.Expr The computed result. """ - return _make.contrib_conv2d_gemm_weight_transform(weights, tile_rows, tile_cols) + return _make.contrib_conv2d_gemm_weight_transform(weights, tile_N, tile_K) def contrib_conv3d_winograd_weight_transform(weight, tile_size): diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index a23ccf8f6932..1f9a6fc41e16 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -211,37 +211,50 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): ) elif kernel_layout == "HWIO": is_aarch64 = target.features.is_aarch64 - has_asimd = target.features.has_asimd has_dot_prod = target.features.has_dotprod has_matmul_i8 = target.features.has_matmul_i8 - - if data.dtype in ["int8", "uint8"]: - if has_matmul_i8: + interleaved_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved + interleaved_schedule = topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved + native_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_native + native_schedule = topi.arm_cpu.schedule_conv2d_NHWC_quantized_native + # Quantized cases + if is_aarch64 and data.dtype in ["int8", "uint8"]: + if has_matmul_i8 and has_dot_prod: + strategy.add_implementation( + wrap_compute_conv2d(interleaved_compute), + wrap_topi_schedule(interleaved_schedule), + name="conv2d_NHWC_quantized_interleaved.arm_cpu", + ) strategy.add_implementation( - wrap_compute_conv2d( - topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved - ), - wrap_topi_schedule( - topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved - ), + wrap_compute_conv2d(native_compute), + wrap_topi_schedule(native_schedule), + name="conv2d_NHWC_quantized_native.arm_cpu", + ) + elif has_matmul_i8: + strategy.add_implementation( + wrap_compute_conv2d(interleaved_compute), + wrap_topi_schedule(interleaved_schedule), name="conv2d_NHWC_quantized_interleaved.arm_cpu", ) - if has_dot_prod: + elif has_dot_prod: strategy.add_implementation( - wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_native), - wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_native), + wrap_compute_conv2d(native_compute), + wrap_topi_schedule(native_schedule), name="conv2d_NHWC_quantized_native.arm_cpu", ) - if is_aarch64 and has_asimd: + else: strategy.add_implementation( - wrap_compute_conv2d( - topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved - ), - wrap_topi_schedule( - topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved - ), + wrap_compute_conv2d(interleaved_compute), + wrap_topi_schedule(interleaved_schedule), name="conv2d_NHWC_quantized_interleaved.arm_cpu", ) + # Non-quantized cases + if is_aarch64 and data.dtype in ["float32", "float16"]: + strategy.add_implementation( + wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_hybrid), + wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_hybrid), + name="conv2d_NHWC_hybrid.arm_cpu", + ) if (not is_aarch64) or (data.dtype not in ["int8", "uint8"]): # TODO(@giuseros) # This strategy errors out for quantized data types when tuning. @@ -250,6 +263,7 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_spatial_pack), wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack), name="conv2d_nhwc_spatial_pack.arm_cpu", + plevel=5, ) else: raise RuntimeError(f"Unsupported kernel layout {kernel_layout} for conv2d NHWC") @@ -485,40 +499,59 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ data = inputs[0] strategy = _op.OpStrategy() is_aarch64 = target.features.is_aarch64 - has_asimd = target.features.has_asimd has_dot_prod = target.features.has_dotprod has_matmul_i8 = target.features.has_matmul_i8 interleaved_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved_without_transform + interleaved_schedule = topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform native_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_native_without_transform - if layout == "NHWC" and data.dtype in ["int8", "uint8"]: - if has_matmul_i8: - strategy.add_implementation( - wrap_compute_conv2d_gemm(interleaved_compute), - wrap_topi_schedule( - topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform - ), - name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu", - ) - if has_dot_prod: - strategy.add_implementation( - wrap_compute_conv2d_gemm(native_compute), - wrap_topi_schedule( - topi.arm_cpu.schedule_conv2d_NHWC_quantized_native_without_transform - ), - name="conv2d_NHWC_quantized_native_without_transform.arm_cpu", - ) - if is_aarch64 and has_asimd: + native_schedule = topi.arm_cpu.schedule_conv2d_NHWC_quantized_native_without_transform + if layout == "NHWC" and data.dtype in ["int8", "uint8", "float32", "float16"]: + # Non-AArch64 cases + if not is_aarch64: + raise RuntimeError("Unsupported non-AArch64 conv2d_NHWC_without_transform") + # AArch64 cases + if data.dtype in ["int8", "uint8"]: + # Quantized cases + if has_matmul_i8 and has_dot_prod: + strategy.add_implementation( + wrap_compute_conv2d_gemm(interleaved_compute), + wrap_topi_schedule(interleaved_schedule), + name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu", + ) + strategy.add_implementation( + wrap_compute_conv2d_gemm(native_compute), + wrap_topi_schedule(native_schedule), + name="conv2d_NHWC_quantized_native_without_transform.arm_cpu", + ) + elif has_matmul_i8: + strategy.add_implementation( + wrap_compute_conv2d_gemm(interleaved_compute), + wrap_topi_schedule(interleaved_schedule), + name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu", + ) + elif has_dot_prod: + strategy.add_implementation( + wrap_compute_conv2d_gemm(native_compute), + wrap_topi_schedule(native_schedule), + name="conv2d_NHWC_quantized_native_without_transform.arm_cpu", + ) + else: + strategy.add_implementation( + wrap_compute_conv2d_gemm(interleaved_compute), + wrap_topi_schedule(interleaved_schedule), + name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu", + ) + elif data.dtype in ["float32", "float16"]: + # Non-quantized cases strategy.add_implementation( - wrap_compute_conv2d_gemm(interleaved_compute), - wrap_topi_schedule( - topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform - ), - name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu", + wrap_compute_conv2d_gemm(topi.arm_cpu.compute_conv2d_NHWC_hybrid_without_transform), + wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_hybrid_without_transform), + name="conv2d_NHWC_hybrid_without_transform.arm_cpu", ) else: raise RuntimeError( - f"Unsupported conv2d_NHWC_quantized_without_transform layout {layout}" + f"Unsupported conv2d_NHWC_without_transform layout {layout}" f"with datatype {data.dtype}" ) return strategy diff --git a/python/tvm/topi/arm_cpu/arm_utils.py b/python/tvm/topi/arm_cpu/arm_utils.py index 9c519cbb936c..0a67aa1c6b83 100644 --- a/python/tvm/topi/arm_cpu/arm_utils.py +++ b/python/tvm/topi/arm_cpu/arm_utils.py @@ -20,9 +20,9 @@ from tvm.target import Target -def get_tiling_B_interleaved_t(interleave_A): +def get_tiling_B_transformed(interleave_A, in_dtype): """Compute the tiling information for matrix B', where B' - is the transposed and interleaved version of matrix B in C=A*B. + is the tiled, interleaved (and transposed) version of matrix B in C=A*B. The tiling information is chosen to maximize register usage during the tile computation. @@ -36,59 +36,68 @@ def get_tiling_B_interleaved_t(interleave_A): Parameters ---------- - interleave_A: bool - determines if A is expected to be interleaved + interleave_A : bool + determines if A is expected to be interleaved + in_dtype : str + input datatype + Returns ---------- - tile_rows_B: the output tile rows of B' - tile_cols_B: the output tile columns of B' + tile_N: the output tile size of B' on N axis (N = OC) + tile_K: the output tile size of B' on K axis (K = KW * KH * IC) """ target = Target.current(allow_none=False) - - if target.features.has_matmul_i8: - # If smmla/ummla is available, A must be interleaved. - # Each load from B' will contain 8 elements - # and we are loading 12 rows of B' (i.e., 12 columns of B) - tile_rows_B = 12 - tile_cols_B = 8 - elif target.features.has_dotprod: - # The number of tile rows of B' vary depending on the - # strategy: - # * If we are interleaving A, then we select 12 columns from B'(i.e., - # 12 rows from B). - # * If we are not interleaving A, then we select 16 columns from B'(i.e., - # 16 rows from B). - tile_rows_B = 12 if interleave_A else 16 - - # Dot product instruction groups 2 (u)int16x8 vectors in - # groups of 4 and compute the dot product among those groups - # This means that the number of columns in a tile of B' (i.e., the - # rows of the original matrix B) need to be 4. - tile_cols_B = 4 + if in_dtype in ["int8", "uint8"]: + if target.features.has_matmul_i8: + # If smmla/ummla is available, A must be interleaved. + # Each load from B' will contain 8 elements + # and we are loading 12 rows of B' (i.e., 12 columns of B) + tile_N = 12 + tile_K = 8 + elif target.features.has_dotprod: + # The number of tile rows of B' vary depending on the + # strategy: + # * If we are interleaving A, then we select 12 columns from B'(i.e., + # 12 rows from B). + # * If we are not interleaving A, then we select 16 columns from B'(i.e., + # 16 rows from B). + tile_N = 12 if interleave_A else 16 + + # Dot product instruction groups 2 (u)int16x8 vectors in + # groups of 4 and compute the dot product among those groups + # This means that the number of columns in a tile of B' (i.e., the + # rows of the original matrix B) need to be 4. + tile_K = 4 + else: + # If no acceleration is available, A must be interleaved. In this case + # we load 4 rows of B' (i.e., 4 columns of B). Each of them will contain 16 elements + tile_N = 4 + tile_K = 16 else: - # If no acceleration is available, A must be interleaved. In this case - # we load 4 rows of B' (i.e., 4 columns of B). Each of them will contain 16 elements - tile_rows_B = 4 - tile_cols_B = 16 + # In non-quantized cases, A is not interleaved. + # Each load from B' contains 16 elements (i.e. 16 columns from B) + # We are loading 4 rows from B', in the dimension of reduction (i.e. 4 rows from B) + tile_N = 16 + tile_K = 4 - return tile_rows_B, tile_cols_B + return tile_N, tile_K -def get_conv2d_weights_padding(N, K, tile_rows, tile_cols): +def get_conv2d_weights_padding(N, K, tile_N, tile_K): """Compute the necessary padding for matrix B', where B' - is the transposed and interleaved version of matrix B in C=A*B. + is the transformed version of matrix B in C=A*B. Parameters ---------- N : int - Number of rows in B' = OC + Number of columns in B = OC K : int - Number of columns in B' = KW * KH * IC - tile_rows : int - tile rows of B' - tile_cols : int - tile columns of B' + Number of rows in B = KW * KH * IC + tile_N : int + tile size of B' on N axis + tile_K : int + tile size of B' on K axis Returns ---------- @@ -98,16 +107,16 @@ def get_conv2d_weights_padding(N, K, tile_rows, tile_cols): pad_N = 0 pad_K = 0 - if N % tile_rows != 0: - pad_N = tile_rows - (N % tile_rows) + if N % tile_N != 0: + pad_N = tile_N - (N % tile_N) - # Tensorize will later make use of 4 tiles at once across the columns so make sure we pad such - # that the columns is multiple of 4 - column_multiplier = 4 - tile_cols_multiplied = tile_cols * column_multiplier - K_misalignment = K % tile_cols_multiplied + # Tensorize will later make use of 4 tiles at once across the K axis so make sure we pad such + # that K is multiple of 4 + K_multiplier = 4 + tile_K_multiplied = tile_K * K_multiplier + K_misalignment = K % tile_K_multiplied if K_misalignment != 0: - pad_K = tile_cols_multiplied - K_misalignment + pad_K = tile_K_multiplied - K_misalignment return pad_N, pad_K diff --git a/python/tvm/topi/arm_cpu/conv2d.py b/python/tvm/topi/arm_cpu/conv2d.py index a478818084d5..90e199f36a03 100644 --- a/python/tvm/topi/arm_cpu/conv2d.py +++ b/python/tvm/topi/arm_cpu/conv2d.py @@ -27,12 +27,18 @@ from .. import nn from ..nn.utils import get_const_int, get_pad_tuple from ..nn.winograd_util import winograd_transform_matrices +from .arm_utils import get_tiling_B_transformed from .conv2d_spatial_pack import ( conv2d_spatial_pack_nchw, conv2d_spatial_pack_nhwc, schedule_conv2d_spatial_pack_nchw, schedule_conv2d_spatial_pack_nhwc, ) +from .conv2d_gemm import ( + compute_conv2d_gemm_without_weight_transform, + schedule_conv2d_gemm_interleaved, + schedule_conv2d_gemm_native, +) from .mprofile.dsp.conv2d import conv2d_nhwc_dsp_compute, conv2d_nhwc_dsp_schedule @@ -509,3 +515,108 @@ def conv2d_nhwc_dsp(cfg, data, kernel, strides, padding, dilation, out_dtype): def schedule_conv2d_nhwc_dsp(cfg, outs): """Create schedule for conv2d_nhwc_dsp""" return conv2d_nhwc_dsp_schedule(cfg, outs) + + +def compute_conv2d_NHWC(cfg, data, kernel, strides, padding, dilation, out_dtype, interleave_A): + N, IH, IW, IC = get_const_tuple(data.shape) + KH, KW, _, OC = get_const_tuple(kernel.shape) + tile_N, tile_K = get_tiling_B_transformed(interleave_A, data.dtype) + + kernel = nn.conv2d_gemm_weight_transform(kernel, tile_N, tile_K) + return compute_conv2d_gemm_without_weight_transform( + cfg, data, kernel, strides, padding, dilation, out_dtype, (KH, KW), OC, interleave_A + ) + + +def compute_conv2d_NHWC_without_transform( + cfg, + data, + B, + strides, + padding, + dilation, + out_dtype, + kernel_size=None, + output_channels=None, + interleave_A=False, +): + """Compute conv2d NHWC without weight transform""" + return compute_conv2d_gemm_without_weight_transform( + cfg, + data, + B, + strides, + padding, + dilation, + out_dtype, + kernel_size, + output_channels, + interleave_A, + ) + + +def schedule_conv2d_NHWC(cfg, outs, interleave_A): + """Create schedule for tensors""" + s = te.create_schedule([x.op for x in outs]) + # Vectorize the output and then inline all the rest + out = outs[0] + n, h, w, c = out.op.axis + n_h_fused = s[out].fuse(n, h) + _, inner = s[out].split(c, 4) + s[out].vectorize(inner) + s[out].parallel(n_h_fused) + + def _callback(op): + """Traverse operators from computation graph""" + if op.name == "conv2d_gemm_output": + conv_out = op.output(0) + if interleave_A: + schedule_conv2d_gemm_interleaved(cfg, s, conv_out, out) + else: + schedule_conv2d_gemm_native(cfg, s, conv_out, out) + if out != conv_out: + s[conv_out].compute_at(s[out], inner) + else: + C = conv_out.op.input_tensors[0] + if interleave_A: + s[C].compute_at(s[out], inner) + + traverse_inline(s, outs[0].op, _callback) + return s + + +@autotvm.register_topi_compute("conv2d_NHWC_hybrid.arm_cpu") +def compute_conv2d_NHWC_hybrid(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Interface for hybrid compute_conv2d_NHWC_hybrid""" + return compute_conv2d_NHWC(cfg, data, kernel, strides, padding, dilation, out_dtype, False) + + +@autotvm.register_topi_compute("conv2d_NHWC_hybrid_without_transform.arm_cpu") +def compute_conv2d_NHWC_hybrid_without_transform( + cfg, data, kernel, strides, padding, dilation, out_dtype, kernel_size, output_channels +): + """Interface for hybrid compute_conv2d_NHWC_hybrid_without_transform""" + return compute_conv2d_NHWC_without_transform( + cfg, + data, + kernel, + strides, + padding, + dilation, + out_dtype, + kernel_size, + output_channels, + False, + ) + + +@autotvm.register_topi_schedule("conv2d_NHWC_hybrid.arm_cpu") +def schedule_conv2d_NHWC_hybrid(cfg, outs): + """Interface for hybrid schedule_conv2d_NHWC_hybrid""" + return schedule_conv2d_NHWC(cfg, outs, False) + + +@autotvm.register_topi_schedule("conv2d_NHWC_hybrid_without_transform.arm_cpu") +def schedule_conv2d_NHWC_hybrid_without_transform(cfg, outs): + """Interface for hybrid schedule_conv2d_NHWC_hybrid""" + return schedule_conv2d_NHWC(cfg, outs, False) diff --git a/python/tvm/topi/arm_cpu/conv2d_alter_op.py b/python/tvm/topi/arm_cpu/conv2d_alter_op.py index 1c30e1f3b650..fe4569ceb1ad 100644 --- a/python/tvm/topi/arm_cpu/conv2d_alter_op.py +++ b/python/tvm/topi/arm_cpu/conv2d_alter_op.py @@ -32,15 +32,15 @@ from ..x86.conv2d import _get_default_config as _get_x86_default_config from ..x86.conv2d_int8 import _get_default_config_int8 from .conv2d_int8 import is_int8_hw_support -from .arm_utils import get_tiling_B_interleaved_t, get_conv2d_weights_padding +from .arm_utils import get_tiling_B_transformed, get_conv2d_weights_padding from ..generic.conv2d import conv2d_alter_int8_common from .mprofile.dsp.micro_kernel.common import num_simd_lanes_per_word logger = logging.getLogger("topi") -def interleave_transpose_weights(inputs, data, kernel, interleave_A): - """Transform the weight matrix by reshaping, interleaving and transposing it +def transform_weights(inputs, data, kernel, interleave_A): + """Transform the weight matrix by tiling, interleaving (and transposing it) Parameters ---------- @@ -59,29 +59,28 @@ def interleave_transpose_weights(inputs, data, kernel, interleave_A): new_kernel_expr : tvm.relay.Expr The relay expression of the weights """ - assert ( - data.dtype == "int8" - and kernel.dtype == "int8" - or data.dtype == "uint8" - and kernel.dtype == "uint8" - ) KH, KW, IC, OC = get_const_tuple(kernel.shape) K = KH * KW * IC N = OC - # Get tiling information for the interleaved transposed version of B - tile_rows_B, tile_cols_B = get_tiling_B_interleaved_t(interleave_A) - pad_N, pad_K = get_conv2d_weights_padding(N, K, tile_rows_B, tile_cols_B) + # Get tiling information for the transformed version of B + tile_N, tile_K = get_tiling_B_transformed(interleave_A, data.dtype) + pad_N, pad_K = get_conv2d_weights_padding(N, K, tile_N, tile_K) N_padded = N + pad_N K_padded = K + pad_K - new_kernel_expr = relay.nn.contrib_conv2d_gemm_weight_transform( - inputs[1], tile_rows_B, tile_cols_B - ) - new_kernel = te.placeholder( - (N_padded // tile_rows_B, K_padded // tile_cols_B, tile_rows_B, tile_cols_B), kernel.dtype - ) + new_kernel_expr = relay.nn.contrib_conv2d_gemm_weight_transform(inputs[1], tile_N, tile_K) + if data.dtype in ["int8", "uint8"]: + new_kernel = te.placeholder( + (N_padded // tile_N, K_padded // tile_K, tile_N, tile_K), + kernel.dtype, + ) + else: + new_kernel = te.placeholder( + (N_padded // tile_N, K_padded // tile_K, tile_K, tile_N), + kernel.dtype, + ) return new_kernel, new_kernel_expr @@ -149,6 +148,20 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): inputs[0], relay.Constant(tvm.nd.array(reshaped_new_kernel)), **new_attrs ) + if topi_tmpl == "conv2d_NHWC_hybrid.arm_cpu": + assert data_layout == "NHWC" and kernel_layout == "HWIO" + KH, KW, _, OC = get_const_tuple(kernel.shape) + new_workload_name = "conv2d_NHWC_hybrid_without_transform.arm_cpu" + new_kernel, new_kernel_expr = transform_weights(inputs, data, kernel, interleave_A=False) + new_workload = autotvm.task.args_to_workload( + [data, new_kernel, strides, padding, dilation, out_dtype, (KH, KW), OC], + new_workload_name, + ) + dispatch_ctx.update(target, new_workload, cfg) + return relay.nn.contrib_conv2d_gemm_without_weight_transform( + inputs[0], new_kernel_expr, **new_attrs + ) + # Only microTVM does layout alteration for NHWC layout with real data types if data_layout == "NHWC" and data_dtype not in ["uint8", "int8"]: return None @@ -431,9 +444,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): assert data_layout == "NHWC" and kernel_layout == "HWIO" KH, KW, _, OC = get_const_tuple(kernel.shape) new_workload_name = "conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu" - new_kernel, new_kernel_expr = interleave_transpose_weights( - inputs, data, kernel, interleave_A=True - ) + new_kernel, new_kernel_expr = transform_weights(inputs, data, kernel, interleave_A=True) new_workload = autotvm.task.args_to_workload( [data, new_kernel, strides, padding, dilation, out_dtype, (KH, KW), OC], new_workload_name, @@ -447,9 +458,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): assert data_layout == "NHWC" and kernel_layout == "HWIO" KH, KW, _, OC = get_const_tuple(kernel.shape) new_workload_name = "conv2d_NHWC_quantized_native_without_transform.arm_cpu" - new_kernel, new_kernel_expr = interleave_transpose_weights( - inputs, data, kernel, interleave_A=False - ) + new_kernel, new_kernel_expr = transform_weights(inputs, data, kernel, interleave_A=False) new_workload = autotvm.task.args_to_workload( [data, new_kernel, strides, padding, dilation, out_dtype, (KH, KW), OC], new_workload_name, diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index 90e02c5ab043..e08775dcf3b5 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -70,6 +70,7 @@ def compute_conv2d_gemm_without_weight_transform( """Compute conv2d by transforming the input, executing GEMM and transforming the output back""" batches, IH, IW, IC = get_const_tuple(data.shape) + in_dtype = data.dtype KH, KW = get_const_tuple(kernel_size) OC = get_const_int(output_channels) @@ -90,7 +91,7 @@ def compute_conv2d_gemm_without_weight_transform( OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1 OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1 - if pad_top or pad_left: + if pad_top or pad_left or pad_down or pad_right: data_pad = nn.pad( data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0], name="data_pad" ) @@ -119,8 +120,12 @@ def compute_conv2d_gemm_without_weight_transform( # Pad if necessary N_transformed = B_interleaved_t.shape[0] - tile_rows_B = B_interleaved_t.shape[2] - tile_cols_B = B_interleaved_t.shape[3] + if in_dtype in ["int8", "uint8"]: + tile_N = B_interleaved_t.shape[2] + tile_K_B = B_interleaved_t.shape[3] + else: + tile_N = B_interleaved_t.shape[3] + tile_K_B = B_interleaved_t.shape[2] # Select the tiling strategy for A. # The tiling information is chosen to maximize register usage during @@ -134,34 +139,41 @@ def compute_conv2d_gemm_without_weight_transform( # In order to have more information # target = Target.current(allow_none=False) - if target.features.has_matmul_i8: - # If smmla/ummla is enabled, we are loading 8 rows from A. Each row - # will contain 8 elements - tile_rows_A = 8 - tile_cols_A = 8 - elif target.features.has_dotprod and interleave_A: - # If dot product has been enabled, and we are interleaving A - # tile size should be 8x4 - tile_rows_A = 8 - tile_cols_A = 4 + if in_dtype in ["int8", "uint8"]: + if target.features.has_matmul_i8: + # If smmla/ummla is enabled, we are loading 8 rows from A. Each row + # will contain 8 elements + tile_M = 8 + tile_K_A = 8 + elif target.features.has_dotprod and interleave_A: + # If dot product has been enabled, and we are interleaving A + # tile size should be 8x4 + tile_M = 8 + tile_K_A = 4 + else: + # If either there is no dot product or if we are using a native strategy + # tile size should be 4x16 + tile_M = 4 + tile_K_A = 16 else: - # If either there is no dot product or if we are using a native strategy - # tile size should be 4x16 - tile_rows_A = 4 - tile_cols_A = 16 + # In non-quantized cases, A is not interleaved. + # We are loading 4 rows from A. + # Each row will contain 4 elements, along the dimension of reduction + tile_M = 4 + tile_K_A = 4 pad_M = 0 pad_K = 0 - if M % tile_rows_A != 0: - pad_M = tile_rows_A - (M % tile_rows_A) + if M % tile_M != 0: + pad_M = tile_M - (M % tile_M) - if K % tile_cols_A != 0: - pad_K = tile_cols_A - (K % tile_cols_A) + if K % tile_K_A != 0: + pad_K = tile_K_A - (K % tile_K_A) M_padded = M + pad_M K_padded = K + pad_K - N_padded = N_transformed * tile_rows_B + N_padded = N_transformed * tile_N pad_before = (0, 0, 0) pad_after = (0, pad_M, pad_K) @@ -174,131 +186,160 @@ def compute_conv2d_gemm_without_weight_transform( idxm = tvm.tir.indexmod k = te.reduce_axis((0, K_padded), "k") - if interleave_A: - # Configuration space - configure_knobs(cfg, M_padded, K_padded, target) + if in_dtype in ["int8", "uint8"]: + if interleave_A: + # Configuration space + configure_knobs(cfg, M_padded, K_padded, target) - # Pack the input data - A_interleaved = te.compute( - (batches, M_padded // tile_rows_A, K_padded // tile_cols_A, tile_rows_A, tile_cols_A), - lambda b, x, y, z, w: A[b, z + tile_rows_A * x, w + tile_cols_A * y], - name="A_interleaved", - ) - target = Target.current(allow_none=False) - if target.features.has_matmul_i8: - # Execute GEMM. In the case of mmla, we need to enforce the tiling - # from the compute. This is because mmla is doing a tiled computation - # as well. So we have a big 8x12 tile, with small 2x2 sub-tiles - # generated by mmla. In theory we could make the tile 2x2 and - # fuse and split during scheduling, but this would not work - # because of possible padding - C_interleaved = te.compute( + # Pack the input data + A_interleaved = te.compute( ( batches, - M_padded // tile_rows_A, - N_transformed, - tile_rows_A // 2, - tile_rows_B // 2, - 2, - 2, - ), - lambda b, x, y, w, z, s, t: te.sum( - A_interleaved[b, x, k // tile_cols_A, 2 * w + s, idxm(k, tile_cols_A)].astype( - "int32" - ) - * B_interleaved_t[y, k // tile_cols_B, 2 * z + t, idxm(k, tile_cols_B)].astype( - "int32" - ), - axis=k, + M_padded // tile_M, + K_padded // tile_K_A, + tile_M, + tile_K_A, ), - name="C_interleaved", + lambda b, x, y, z, w: A[b, z + tile_M * x, w + tile_K_A * y], + name="A_interleaved", ) - # Ensure the padding needed for tensorize does not get removed during tir passes - # by adding a dummy reference to the specific padded area of the result - zero = ( - tvm.tir.const(1, C_interleaved.dtype) - * C_interleaved[ - batches - 1, - M // tile_rows_A, - N_transformed - 1, - idxm(M, tile_rows_A) // 2, - tile_rows_B // 2 - 1, - 1, - 1, - ] - - tvm.tir.const(1, C_interleaved.dtype) - * C_interleaved[ - batches - 1, - M // tile_rows_A, - N_transformed - 1, - idxm(M, tile_rows_A) // 2, - tile_rows_B // 2 - 1, - 1, - 1, - ] - ) - # Unpack the result - C = te.compute( - (batches, M, N), - lambda b, x, y: ( - C_interleaved[ - b, - x // tile_rows_A, - y // tile_rows_B, - idxm(x, tile_rows_A) // 2, - idxm(y, tile_rows_B) // 2, - idxm(idxm(x, tile_rows_A), 2), - idxm(idxm(y, tile_rows_B), 2), + target = Target.current(allow_none=False) + if target.features.has_matmul_i8: + # Execute GEMM. In the case of mmla, we need to enforce the tiling + # from the compute. This is because mmla is doing a tiled computation + # as well. So we have a big 8x12 tile, with small 2x2 sub-tiles + # generated by mmla. In theory we could make the tile 2x2 and + # fuse and split during scheduling, but this would not work + # because of possible padding + C_interleaved = te.compute( + ( + batches, + M_padded // tile_M, + N_transformed, + tile_M // 2, + tile_N // 2, + 2, + 2, + ), + lambda b, x, y, w, z, s, t: te.sum( + A_interleaved[b, x, k // tile_K_A, 2 * w + s, idxm(k, tile_K_A)].astype( + "int32" + ) + * B_interleaved_t[y, k // tile_K_B, 2 * z + t, idxm(k, tile_K_B)].astype( + "int32" + ), + axis=k, + ), + name="C_interleaved", + ) + # Ensure the padding needed for tensorize does not get removed during tir passes + # by adding a dummy reference to the specific padded area of the result + zero = ( + tvm.tir.const(1, C_interleaved.dtype) + * C_interleaved[ + batches - 1, + M // tile_M, + N_transformed - 1, + idxm(M, tile_M) // 2, + tile_N // 2 - 1, + 1, + 1, ] - + zero - ).astype(out_dtype), - name="C", - ) + - tvm.tir.const(1, C_interleaved.dtype) + * C_interleaved[ + batches - 1, + M // tile_M, + N_transformed - 1, + idxm(M, tile_M) // 2, + tile_N // 2 - 1, + 1, + 1, + ] + ) + # Unpack the result + C = te.compute( + (batches, M, N), + lambda b, x, y: ( + C_interleaved[ + b, + x // tile_M, + y // tile_N, + idxm(x, tile_M) // 2, + idxm(y, tile_N) // 2, + idxm(idxm(x, tile_M), 2), + idxm(idxm(y, tile_N), 2), + ] + + zero + ).astype(out_dtype), + name="C", + ) + else: + # Execute GEMM + C_interleaved = te.compute( + (batches, M_padded // tile_M, N_transformed, tile_M, tile_N), + lambda b, x, y, w, z: te.sum( + A_interleaved[b, x, k // tile_K_A, w, idxm(k, tile_K_A)].astype("int32") + * B_interleaved_t[y, k // tile_K_B, z, idxm(k, tile_K_B)].astype("int32"), + axis=k, + ), + name="C_interleaved", + ) + # Unpack the result + C = te.compute( + (batches, M, N), + lambda b, x, y: C_interleaved[ + b, + x // tile_M, + y // tile_N, + idxm(x, tile_M), + idxm(y, tile_N), + ].astype(out_dtype), + name="C", + ) + zero = tvm.tir.const(0) else: - # Execute GEMM - C_interleaved = te.compute( - (batches, M_padded // tile_rows_A, N_transformed, tile_rows_A, tile_rows_B), - lambda b, x, y, w, z: te.sum( - A_interleaved[b, x, k // tile_cols_A, w, idxm(k, tile_cols_A)].astype("int32") - * B_interleaved_t[y, k // tile_cols_B, z, idxm(k, tile_cols_B)].astype("int32"), + # No need to pack/unpack, execute GEMM directly + C = te.compute( + (batches, M_padded, N_padded), + lambda b, x, y: te.sum( + A[b, x, k].astype("int32") + * B_interleaved_t[ + y // tile_N, + k // tile_K_B, + idxm(y, tile_N), + idxm(k, tile_K_B), + ].astype("int32"), axis=k, ), - name="C_interleaved", - ) - # Unpack the result - C = te.compute( - (batches, M, N), - lambda b, x, y: C_interleaved[ - b, - x // tile_rows_A, - y // tile_rows_B, - idxm(x, tile_rows_A), - idxm(y, tile_rows_B), - ].astype(out_dtype), name="C", ) - zero = tvm.tir.const(0) + + # We need to ensure that infer bound pass does not remove the padding + # which is necessary for the tensorizations to work. So we need to + # add a dummy reference to the padding area of the result + zero = ( + tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1] + - tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1] + ) else: - # No need to pack/unpack, execute GEMM directly + # Configuration space + configure_knobs(cfg, M_padded, K_padded, target) + C = te.compute( (batches, M_padded, N_padded), lambda b, x, y: te.sum( - A[b, x, k].astype("int32") + A[b, x, k].astype(in_dtype) * B_interleaved_t[ - y // tile_rows_B, k // tile_cols_B, idxm(y, tile_rows_B), idxm(k, tile_cols_B) - ].astype("int32"), + y // tile_N, + k // tile_K_B, + idxm(k, tile_K_B), + idxm(y, tile_N), + ].astype(in_dtype), axis=k, ), name="C", ) - - # We need to ensure that infer bound pass does not remove the padding - # which is necessary for the tensorizations to work. So we need to - # add a dummy reference to the padding area of the result - zero = ( - tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1] - - tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1] - ) + zero = tvm.tir.const(0) # Reshape the result into a convolution output out_shape = (batches, OH, OW, OC) @@ -417,14 +458,35 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): # Computation b, x, y = C.op.axis (k,) = C.op.reduce_axis - k_outer, k_inner = s[C].split(k, 16) - y_tile_size = 16 - x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=4, y_factor=y_tile_size) - s[C].reorder(b, x_outer, y_outer, k_outer, x_inner, y_inner, k_inner) - gemm_acc = gemm_acc_nx16_int8_int8_int32(in_type, rows=1) - s[C].unroll(x_inner) - s[C].tensorize(y_inner, gemm_acc) - s[C].parallel(x_outer) + + if in_type in ["int8", "uint8"]: + k_outer, k_inner = s[C].split(k, 16) + y_tile_size = 16 + x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=4, y_factor=y_tile_size) + s[C].reorder(b, x_outer, y_outer, k_outer, x_inner, y_inner, k_inner) + gemm_acc = gemm_acc_nx16_int8_int8_int32(in_type, rows=1) + s[C].unroll(x_inner) + s[C].tensorize(y_inner, gemm_acc) + s[C].parallel(x_outer) + else: + k_outer, k_inner = s[C].split(k, 4) + y_tile_size = 16 + x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=4, y_factor=y_tile_size) + y_inner_outer, y_inner_inner = s[C].split(y_inner, 4) + b_x_outer_fused = s[C].fuse(b, x_outer) + s[C].parallel(b_x_outer_fused) + s[C].reorder( + b_x_outer_fused, + y_outer, + k_outer, + k_inner, + y_inner_outer, + x_inner, + y_inner_inner, + ) + s[C].unroll(y_inner_outer) + s[C].unroll(x_inner) + s[C].vectorize(y_inner_inner) # Input transform if A.op.name == "A_padded_K" or A.op.name == "A_padded_M": @@ -450,7 +512,11 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): split_factor = 16 n_size = data_im2col.shape[2] - if n_size % split_factor != 0: + if n_size % 16 == 0: + split_factor = 16 + elif n_size % 8 == 0: + split_factor = 8 + else: # Split by kernel area (KH * KW) to ensure proper vectorization ic = data_im2col.op.input_tensors[0].shape[3] split_factor = n_size // ic @@ -466,6 +532,13 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): else: s[data_im2col].compute_at(s[C], x_inner) + A_pad = data_im2col.op.input_tensors[0] + if A_pad.op.name == "data_pad": + n, h, w, c = A_pad.op.axis + n_h_fused = s[A_pad].fuse(n, h) + s[A_pad].parallel(n_h_fused) + s[A_pad].vectorize(c) + # Output transform if out != final_out: n, h, w, c = out.op.axis diff --git a/python/tvm/topi/arm_cpu/conv2d_int8.py b/python/tvm/topi/arm_cpu/conv2d_int8.py index 6b2c9527a400..721385c189e7 100644 --- a/python/tvm/topi/arm_cpu/conv2d_int8.py +++ b/python/tvm/topi/arm_cpu/conv2d_int8.py @@ -25,12 +25,7 @@ from ..x86.conv2d_int8 import _pack_data from ..nn.utils import get_pad_tuple from .tensor_intrin import dot_int8_int8_int32_neon_82, dot_int8_int8_int32_neon -from .conv2d_gemm import ( - compute_conv2d_gemm_without_weight_transform, - schedule_conv2d_gemm_interleaved, - schedule_conv2d_gemm_native, -) -from .arm_utils import get_tiling_B_interleaved_t +from .conv2d import compute_conv2d_NHWC, compute_conv2d_NHWC_without_transform, schedule_conv2d_NHWC def _get_default_config(cfg, data, kernel, strides, padding, dilation, out_dtype): @@ -208,75 +203,6 @@ def schedule_conv2d_nchw_int8(outs): return schedule_conv2d_NCHWc_int8(outs) -def _compute_conv2d_NHWC_quantized( - cfg, data, kernel, strides, padding, dilation, out_dtype, interleave_A -): - N, IH, IW, IC = get_const_tuple(data.shape) - KH, KW, _, OC = get_const_tuple(kernel.shape) - tile_rows_B, tile_cols_B = get_tiling_B_interleaved_t(interleave_A) - - kernel = nn.conv2d_gemm_weight_transform(kernel, tile_rows_B, tile_cols_B) - return compute_conv2d_gemm_without_weight_transform( - cfg, data, kernel, strides, padding, dilation, out_dtype, (KH, KW), OC, interleave_A - ) - - -def _compute_conv2d_NHWC_quantized_without_transform( - cfg, - data, - B, - strides, - padding, - dilation, - out_dtype, - kernel_size=None, - output_channels=None, - interleave_A=False, -): - return compute_conv2d_gemm_without_weight_transform( - cfg, - data, - B, - strides, - padding, - dilation, - out_dtype, - kernel_size, - output_channels, - interleave_A, - ) - - -def _schedule_conv2d_NHWC_quantized(cfg, outs, interleave_A): - """Create schedule for tensors""" - s = te.create_schedule([x.op for x in outs]) - # Vectorize the output and then inline all the rest - out = outs[0] - n, h, w, c = out.op.axis - n_h_fused = s[out].fuse(n, h) - outer, inner = s[out].split(c, 4) - s[out].vectorize(inner) - s[out].parallel(n_h_fused) - - def _callback(op): - """Traverse operators from computation graph""" - if op.name == "conv2d_gemm_output": - conv_out = op.output(0) - if interleave_A: - schedule_conv2d_gemm_interleaved(cfg, s, conv_out, out) - else: - schedule_conv2d_gemm_native(cfg, s, conv_out, out) - if out != conv_out: - s[conv_out].compute_at(s[out], inner) - else: - C = conv_out.op.input_tensors[0] - if interleave_A: - s[C].compute_at(s[out], inner) - - traverse_inline(s, outs[0].op, _callback) - return s - - # Interleaved schedules: those schedule will interleave the input data. The # weights are interleaved and transposed @autotvm.register_topi_compute("conv2d_NHWC_quantized_interleaved.arm_cpu") @@ -284,9 +210,7 @@ def compute_conv2d_NHWC_quantized_interleaved( cfg, data, kernel, strides, padding, dilation, out_dtype ): """Interface for interleaved compute_conv2d_NHWC_quantized_interleaved""" - return _compute_conv2d_NHWC_quantized( - cfg, data, kernel, strides, padding, dilation, out_dtype, True - ) + return compute_conv2d_NHWC(cfg, data, kernel, strides, padding, dilation, out_dtype, True) @autotvm.register_topi_compute("conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu") @@ -294,7 +218,7 @@ def compute_conv2d_NHWC_quantized_interleaved_without_transform( cfg, data, kernel, strides, padding, dilation, out_dtype, kernel_size, output_channels ): """Interface for interleaved compute_conv2d_NHWC_quantized_interleaved_without_transform""" - return _compute_conv2d_NHWC_quantized_without_transform( + return compute_conv2d_NHWC_without_transform( cfg, data, kernel, strides, padding, dilation, out_dtype, kernel_size, output_channels, True ) @@ -302,13 +226,13 @@ def compute_conv2d_NHWC_quantized_interleaved_without_transform( @autotvm.register_topi_schedule("conv2d_NHWC_quantized_interleaved.arm_cpu") def schedule_conv2d_NHWC_quantized_interleaved(cfg, outs): """Interface for interleaved schedule_conv2d_NHWC_quantized_interleaved""" - return _schedule_conv2d_NHWC_quantized(cfg, outs, True) + return schedule_conv2d_NHWC(cfg, outs, True) @autotvm.register_topi_schedule("conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu") def schedule_conv2d_NHWC_quantized_interleaved_without_transform(cfg, outs): """Interface for interleaved schedule_conv2d_NHWC_quantized_interleaved""" - return _schedule_conv2d_NHWC_quantized(cfg, outs, True) + return schedule_conv2d_NHWC(cfg, outs, True) # Native schedules: those schedule won't interleave A (which is left in its native form). @@ -316,9 +240,7 @@ def schedule_conv2d_NHWC_quantized_interleaved_without_transform(cfg, outs): @autotvm.register_topi_compute("conv2d_NHWC_quantized_native.arm_cpu") def compute_conv2d_NHWC_quantized_native(cfg, data, kernel, strides, padding, dilation, out_dtype): """Interface for native compute_conv2d_NHWC_quantized""" - return _compute_conv2d_NHWC_quantized( - cfg, data, kernel, strides, padding, dilation, out_dtype, False - ) + return compute_conv2d_NHWC(cfg, data, kernel, strides, padding, dilation, out_dtype, False) @autotvm.register_topi_compute("conv2d_NHWC_quantized_native_without_transform.arm_cpu") @@ -326,7 +248,7 @@ def compute_conv2d_NHWC_quantized_native_without_transform( cfg, data, kernel, strides, padding, dilation, out_dtype, kernel_size, output_channels ): """Interface for compute_conv2d_NHWC_quantized_native_without_transform""" - return _compute_conv2d_NHWC_quantized_without_transform( + return compute_conv2d_NHWC_without_transform( cfg, data, kernel, @@ -343,10 +265,10 @@ def compute_conv2d_NHWC_quantized_native_without_transform( @autotvm.register_topi_schedule("conv2d_NHWC_quantized_native.arm_cpu") def schedule_conv2d_NHWC_quantized_native(cfg, outs): """Interface for native schedule_conv2d_NHWC_quantized""" - return _schedule_conv2d_NHWC_quantized(cfg, outs, False) + return schedule_conv2d_NHWC(cfg, outs, False) @autotvm.register_topi_schedule("conv2d_NHWC_quantized_native_without_transform.arm_cpu") def schedule_conv2d_NHWC_quantized_native_without_transform(cfg, outs): """Interface for native schedule_conv2d_NHWC_quantized""" - return _schedule_conv2d_NHWC_quantized(cfg, outs, False) + return schedule_conv2d_NHWC(cfg, outs, False) diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 75f72ee93d4d..7516bff702f4 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -615,17 +615,17 @@ def conv2d_NCHWc_int8( ) -def conv2d_gemm_weight_transform(kernel, tile_rows, tile_cols): +def conv2d_gemm_weight_transform(kernel, tile_N, tile_K): """Weight transformation for winograd Parameters ---------- kernel: Tensor The raw kernel tensor with layout "NHWC". - tile_rows: int - Tile rows of the weight transformation for ConvGemm. - tile_cols: int - Tile columns of the weight transformation for ConvGemm. + tile_N: int + Tile size across N axis of the weight transformation for ConvGemm. (N = OC) + tile_K: int + Tile size across K axis of the weight transformation for ConvGemm. (K = KW * KH * IC) Returns ------- @@ -640,7 +640,7 @@ def conv2d_gemm_weight_transform(kernel, tile_rows, tile_cols): (K, N), lambda x, y: kernel[(x // IC) // KW, (x // IC) % KW, x % IC, y], "weight_flatten" ) - pad_N, pad_K = tvm.topi.arm_cpu.arm_utils.get_conv2d_weights_padding(N, K, tile_rows, tile_cols) + pad_N, pad_K = tvm.topi.arm_cpu.arm_utils.get_conv2d_weights_padding(N, K, tile_N, tile_K) N_padded = N + pad_N K_padded = K + pad_K @@ -650,11 +650,19 @@ def conv2d_gemm_weight_transform(kernel, tile_rows, tile_cols): kernel_flat, pad_before=(0, 0), pad_after=(pad_K, pad_N), name="weight_padding" ) - return te.compute( - (N_padded // tile_rows, K_padded // tile_cols, tile_rows, tile_cols), - lambda x, y, z, w: kernel_flat[w + tile_cols * y, z + tile_rows * x], - name="weight_block_reshape", - ) + if kernel.dtype in ["int8", "uint8"]: + B_inter_t = te.compute( + (N_padded // tile_N, K_padded // tile_K, tile_N, tile_K), + lambda x, y, z, w: kernel_flat[w + tile_K * y, z + tile_N * x], + name="weight_block_reshape", + ) + else: + B_inter_t = te.compute( + (N_padded // tile_N, K_padded // tile_K, tile_K, tile_N), + lambda x, y, z, w: kernel_flat[z + tile_K * y, w + tile_N * x], + name="weight_block_reshape", + ) + return B_inter_t def conv2d_winograd_weight_transform(kernel, tile_size): diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 13c7f74c7ecd..e895c98df2c0 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -43,10 +43,10 @@ Expr MakeConvWinogradWeightTransform(Expr weight, int tile_size, std::string op_ return Call(op, {weight}, Attrs(attrs), {}); } -Expr MakeConvGemmWeightTransform(Expr weight, int tile_rows, int tile_cols, std::string op_name) { +Expr MakeConvGemmWeightTransform(Expr weight, int tile_N, int tile_K, std::string op_name) { auto attrs = make_object(); - attrs->tile_rows = tile_rows; - attrs->tile_cols = tile_cols; + attrs->tile_N = tile_N; + attrs->tile_K = tile_K; const Op& op = Op::Get(op_name); return Call(op, {weight}, Attrs(attrs), {}); } @@ -1472,13 +1472,14 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_gemm_without_weight_transform") TVM_REGISTER_NODE_TYPE(ConvGemmWeightTransformAttrs); // Gemm convolution shape relations -// In order to run GEMM we need to block-transpose and interleave the K x N weights matrix W. -// The high level idea is to subdivide W in tiles of tile_cols x tile_rows, and transpose and -// interleave them. The final output is a [N//tile_rows, K//tile_cols, tile_rows, tile_cols] +// In order to run GEMM we need to transform the K x N weights matrix W. +// +// For integer datatypes, the high level idea is to subdivide W in tiles of tile_K x tile_N, and +// transpose and interleave them. The final output is a [N//tile_N, K//tile_K, tile_N, tile_K] // matrix that we call W_interleaved_t. // -// In the following picture, we show how the first [tile_cols,tile_rows] block of W is transformed -// for tile_rows = 4 and tile_cols = 16 +// In the following picture, we show how the first [tile_K,tile_N] block of W is transformed +// for tile_N = 4 and tile_K = 16 // // W[0,0,:,:] W_interleaved_t[0,0,:,:] // +-------------------------------+ +----------------------------------- + @@ -1490,9 +1491,31 @@ TVM_REGISTER_NODE_TYPE(ConvGemmWeightTransformAttrs); // |W[15,0] W[15,1] W[15,2] W[15,3]| // +-------------------------------+ // -// Tile columns is usually the direction of the reduction. So, if our target can reduce k elements -// at the time, we should set tile_cols = k. -// Tile rows is connected with the number of registers available for the given target. +// Alternatively, for floating point datatypes, we subdivide W in tiles of tile_K x tile_N size, +// then interleave these tiles, without transposing. The final output is a [N//tile_N, K//tile_K, +// tile_K, tile_N] matrix called W_interleaved. +// +// In the following illustration, we show how the tiles are interleaved. +// Note that the inside of each tile is kept unchanged during this tranformation. +// +// W[:,:,:,:] W_interleaved[:,:,:,:] +// +--------+--------+--------+ +--------+--------+ +// | | | | | | | +// | tile_1 | tile_2 | tile_3 | | tile_1 | tile_4 | +// | | | | --\ | | | +// +--------+--------+--------+ --/ +--------+--------+ +// | | | | | | | +// | tile_4 | tile_5 | tile_6 | | tile_2 | tile_5 | +// | | | | | | | +// +--------+--------+--------+ +--------+--------+ +// | | | +// | tile_3 | tile_6 | +// | | | +// +--------+--------+ +// +// Tile K is the direction of the reduction in both cases. So, if our target can reduce k elements +// at the time, we should set tile_K = k. +// Tile N is connected with the number of registers available for the given target. // bool Conv2DGemmWeightTransformRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { @@ -1502,8 +1525,8 @@ bool Conv2DGemmWeightTransformRel(const Array& types, int num_inputs, cons const ConvGemmWeightTransformAttrs* param = attrs.as(); ICHECK(param != nullptr); - int n = param->tile_rows; - int k = param->tile_cols; + int n = param->tile_N; + int k = param->tile_K; ICHECK_EQ(weight->shape.size(), 4) << "Only support HWIO kernel layout"; @@ -1519,12 +1542,21 @@ bool Conv2DGemmWeightTransformRel(const Array& types, int num_inputs, cons const auto N_padded = N + pad_N; const auto K_padded = K + pad_K; - Array oshape{ - indexdiv(N_padded, n), - indexdiv(K_padded, k), - n, - k, - }; + Array oshape; + if (weight->dtype.bits() == 8 && (weight->dtype.is_int() || weight->dtype.is_uint())) + oshape = { + indexdiv(N_padded, n), + indexdiv(K_padded, k), + n, + k, + }; + else + oshape = { + indexdiv(N_padded, n), + indexdiv(K_padded, k), + k, + n, + }; reporter->Assign(types[1], TensorType(oshape, weight->dtype)); return true; diff --git a/tests/python/integration/test_arm_aprofile.py b/tests/python/integration/test_arm_aprofile.py index c38217a1b1c0..006ad5f359f4 100644 --- a/tests/python/integration/test_arm_aprofile.py +++ b/tests/python/integration/test_arm_aprofile.py @@ -49,6 +49,7 @@ def test_conv2d(dtype): invar, weight, kernel_size=kernel_size, + channels=2, strides=(1, 1), padding=(0, 0), dilation=(1, 1), diff --git a/tests/python/relay/strategy/test_select_implementation.py b/tests/python/relay/strategy/test_select_implementation.py index d7dd0abbc4d7..f9b1a002a8b6 100644 --- a/tests/python/relay/strategy/test_select_implementation.py +++ b/tests/python/relay/strategy/test_select_implementation.py @@ -57,6 +57,39 @@ def test_concatenate(target, expected_implementation): assert impl.name == expected_implementation +def _get_conv2d_impl(dtype, target): + """Returns selected conv2d implementation for a given datatype and target""" + data_shape = (1, 1, 1, 4) + weight_shape = (1, 1, 4, 4) + data_layout = "NHWC" + kernel_layout = "HWIO" + channels = 4 + kernel_size = (1, 1) + + out = relay.nn.conv2d( + relay.var("data", shape=data_shape, dtype=dtype), + relay.var("weight", shape=weight_shape, dtype=dtype), + kernel_size=kernel_size, + channels=channels, + data_layout=data_layout, + kernel_layout=kernel_layout, + out_dtype=dtype, + ) + + with target: + out = run_opt_pass(out, relay.transform.AlterOpLayout()) + impl, _ = relay.backend.te_compiler.select_implementation( + out.op, + out.attrs, + [te.placeholder(data_shape, dtype), te.placeholder(weight_shape, dtype)], + out.checked_type, + target, + use_autotvm=False, + ) + + return impl.name + + @pytest.mark.parametrize( "target,expected_impl", [ @@ -93,37 +126,78 @@ def test_concatenate(target, expected_implementation): ) def test_int8_conv2d(target, expected_impl): target = tvm.target.Target(target) - dtype = "int8" - data_shape = (1, 1, 1, 4) - weight_shape = (1, 1, 4, 4) - data_layout = "NHWC" - kernel_layout = "HWIO" - channels = 4 - kernel_size = (1, 1) - out = relay.nn.conv2d( - relay.var("data", shape=data_shape, dtype=dtype), - relay.var("weight", shape=weight_shape, dtype=dtype), - kernel_size=kernel_size, - channels=channels, - data_layout=data_layout, - kernel_layout=kernel_layout, - out_dtype=dtype, - ) + selected_impl = _get_conv2d_impl(dtype, target) + assert selected_impl == expected_impl - with target: - out = run_opt_pass(out, relay.transform.AlterOpLayout()) - impl, _ = relay.backend.te_compiler.select_implementation( - out.op, - out.attrs, - [te.placeholder(data_shape, dtype), te.placeholder(weight_shape, dtype)], - out.checked_type, - target, - use_autotvm=False, - ) - assert impl.name == expected_impl +@pytest.mark.parametrize( + "target,expected_impl", + [ + ("llvm -device=arm_cpu", "conv2d_nhwc_spatial_pack.arm_cpu"), + ( + "llvm -device=arm_cpu -mtriple=armv8l-linux-gnu -mattr=+neon", + "conv2d_nhwc_spatial_pack.arm_cpu", + ), + ( + "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu", + "conv2d_NHWC_hybrid_without_transform.arm_cpu", + ), + ( + "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon", + "conv2d_NHWC_hybrid_without_transform.arm_cpu", + ), + ( + "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v8.2a,+neon", + "conv2d_NHWC_hybrid_without_transform.arm_cpu", + ), + ( + "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v9a", + "conv2d_NHWC_hybrid_without_transform.arm_cpu", + ), + ], +) +def test_fp32_conv2d(target, expected_impl): + target = tvm.target.Target(target) + dtype = "float32" + + selected_impl = _get_conv2d_impl(dtype, target) + assert selected_impl == expected_impl + + +@pytest.mark.parametrize( + "target,expected_impl", + [ + ("llvm -device=arm_cpu", "conv2d_nhwc_spatial_pack.arm_cpu"), + ( + "llvm -device=arm_cpu -mtriple=armv8l-linux-gnu -mattr=+neon", + "conv2d_nhwc_spatial_pack.arm_cpu", + ), + ( + "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu", + "conv2d_NHWC_hybrid_without_transform.arm_cpu", + ), + ( + "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon", + "conv2d_NHWC_hybrid_without_transform.arm_cpu", + ), + ( + "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v8.2a,+neon", + "conv2d_NHWC_hybrid_without_transform.arm_cpu", + ), + ( + "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v9a", + "conv2d_NHWC_hybrid_without_transform.arm_cpu", + ), + ], +) +def test_fp16_conv2d(target, expected_impl): + target = tvm.target.Target(target) + dtype = "float16" + + selected_impl = _get_conv2d_impl(dtype, target) + assert selected_impl == expected_impl @pytest.mark.parametrize( diff --git a/tests/python/topi/test_topi_conv2d_nhwc.py b/tests/python/topi/test_topi_conv2d_nhwc.py index e60cf12aa83e..05f9cb9c0570 100644 --- a/tests/python/topi/test_topi_conv2d_nhwc.py +++ b/tests/python/topi/test_topi_conv2d_nhwc.py @@ -16,6 +16,7 @@ # under the License. """Example code to do convolution.""" import os +import platform import numpy as np import tvm from tvm import te @@ -45,6 +46,19 @@ "hls": (topi.nn.conv2d_nhwc, topi.hls.schedule_conv2d_nhwc), } +device = tvm.testing.parameter( + ( + "llvm --device arm_cpu --mtriple aarch64-linux-gnu", + topi.arm_cpu.conv2d_nhwc_spatial_pack, + topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack, + ), + ( + "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a", + topi.arm_cpu.compute_conv2d_NHWC_hybrid, + topi.arm_cpu.schedule_conv2d_NHWC_hybrid, + ), +) + dtype = tvm.testing.parameter("float32") batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation = tvm.testing.parameters( @@ -77,6 +91,31 @@ def ref_data(dtype, batch, in_channel, in_size, num_filter, kernel, stride, padd return a_np, w_np, b_np +def test_conv2d_nhwc_gemm_fp32(device, ref_data, dtype, stride, padding, dilation): + a_np, w_np, b_np = ref_data + + A = te.placeholder(a_np.shape, name="A", dtype=dtype) + W = te.placeholder(w_np.shape, name="W", dtype=dtype) + + target, compute, schedule = device + dev = tvm.device(target, 0) + + with tvm.target.Target(target): + B = compute(A, W, stride, padding, dilation, dtype) + s = schedule([B]) + a = tvm.nd.array(a_np, dev) + w = tvm.nd.array(w_np, dev) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev) + func = tvm.build(s, [A, W, B], target) + + build_only = platform.machine() != "aarch64" + if build_only: + return + + func(a, w, b) + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5) + + def test_conv2d_nhwc_hwio(target, dev, ref_data, dtype, stride, padding, dilation): a_np, w_np, b_np = ref_data