From 8b08d01e9c6af6b181765c611f467b3200aa7683 Mon Sep 17 00:00:00 2001 From: Andrei Hutu Date: Mon, 6 Nov 2023 15:53:50 +0000 Subject: [PATCH] [TOPI] Reduce code redundancy in conv2d weights transformation Refactored out a piece of common functionality from the `conv2d_gemm_weight_transform` and `interleave_transpose_weights` functions, which has previously led to bugs stemming from changes made to only one but not the other, like in #15584. Determining the necessary padding for the interleaved and transposed weights matrix has now been separated into a new utility function, allowing future changes to be reflected in both callers. --- python/tvm/topi/arm_cpu/arm_utils.py | 38 ++++++++++++++++++++++ python/tvm/topi/arm_cpu/conv2d_alter_op.py | 18 ++-------- python/tvm/topi/nn/conv2d.py | 15 +-------- 3 files changed, 41 insertions(+), 30 deletions(-) diff --git a/python/tvm/topi/arm_cpu/arm_utils.py b/python/tvm/topi/arm_cpu/arm_utils.py index 1b2efc61ea56..9c519cbb936c 100644 --- a/python/tvm/topi/arm_cpu/arm_utils.py +++ b/python/tvm/topi/arm_cpu/arm_utils.py @@ -73,3 +73,41 @@ def get_tiling_B_interleaved_t(interleave_A): tile_cols_B = 16 return tile_rows_B, tile_cols_B + + +def get_conv2d_weights_padding(N, K, tile_rows, tile_cols): + """Compute the necessary padding for matrix B', where B' + is the transposed and interleaved version of matrix B in C=A*B. + + Parameters + ---------- + N : int + Number of rows in B' = OC + K : int + Number of columns in B' = KW * KH * IC + tile_rows : int + tile rows of B' + tile_cols : int + tile columns of B' + + Returns + ---------- + pad_N : padding for N axis + pad_K : padding for K axis + """ + pad_N = 0 + pad_K = 0 + + if N % tile_rows != 0: + pad_N = tile_rows - (N % tile_rows) + + # Tensorize will later make use of 4 tiles at once across the columns so make sure we pad such + # that the columns is multiple of 4 + column_multiplier = 4 + tile_cols_multiplied = tile_cols * column_multiplier + K_misalignment = K % tile_cols_multiplied + + if K_misalignment != 0: + pad_K = tile_cols_multiplied - K_misalignment + + return pad_N, pad_K diff --git a/python/tvm/topi/arm_cpu/conv2d_alter_op.py b/python/tvm/topi/arm_cpu/conv2d_alter_op.py index 6c3dbb48c9f2..1c30e1f3b650 100644 --- a/python/tvm/topi/arm_cpu/conv2d_alter_op.py +++ b/python/tvm/topi/arm_cpu/conv2d_alter_op.py @@ -32,7 +32,7 @@ from ..x86.conv2d import _get_default_config as _get_x86_default_config from ..x86.conv2d_int8 import _get_default_config_int8 from .conv2d_int8 import is_int8_hw_support -from .arm_utils import get_tiling_B_interleaved_t +from .arm_utils import get_tiling_B_interleaved_t, get_conv2d_weights_padding from ..generic.conv2d import conv2d_alter_int8_common from .mprofile.dsp.micro_kernel.common import num_simd_lanes_per_word @@ -72,21 +72,7 @@ def interleave_transpose_weights(inputs, data, kernel, interleave_A): # Get tiling information for the interleaved transposed version of B tile_rows_B, tile_cols_B = get_tiling_B_interleaved_t(interleave_A) - - pad_K = 0 - pad_N = 0 - - if N % tile_rows_B != 0: - pad_N = tile_rows_B - (N % tile_rows_B) - - # Tensorize will later make use of 4 tiles at once across the columns so make sure we pad such - # that the columns is multiple of 4 - column_multiplier = 4 - tile_cols_multiplied = tile_cols_B * column_multiplier - K_misalignment = K % tile_cols_multiplied - - if K_misalignment != 0: - pad_K = tile_cols_multiplied - K_misalignment + pad_N, pad_K = get_conv2d_weights_padding(N, K, tile_rows_B, tile_cols_B) N_padded = N + pad_N K_padded = K + pad_K diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index f70d749e0f3c..8fdcb0dc1a55 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -617,20 +617,7 @@ def conv2d_gemm_weight_transform(kernel, tile_rows, tile_cols): (K, N), lambda x, y: kernel[(x // IC) // KW, (x // IC) % KW, x % IC, y], "weight_flatten" ) - pad_K = 0 - pad_N = 0 - - if N % tile_rows != 0: - pad_N = tile_rows - (N % tile_rows) - - # Tensorize will later make use of 4 tiles at once across the columns so make sure we pad such - # that the columns is multiple of 4 - column_multiplier = 4 - tile_cols_multiplied = tile_cols * column_multiplier - K_misalignment = K % tile_cols_multiplied - - if K_misalignment != 0: - pad_K = tile_cols_multiplied - K_misalignment + pad_N, pad_K = tvm.topi.arm_cpu.arm_utils.get_conv2d_weights_padding(N, K, tile_rows, tile_cols) N_padded = N + pad_N K_padded = K + pad_K