From 17c049df50b69bcb5b0aa0d364d0fe629f6a278e Mon Sep 17 00:00:00 2001 From: Andrei Hutu Date: Mon, 29 Apr 2024 09:35:59 +0000 Subject: [PATCH] [TOPI] Revert unification of conv2d NHWC hybrid scheduling for `arm_cpu` targets This patch partly reverts the unification of scalable and non-scalable scheduling of conv2d NHWC for `arm_cpu` targets introduced in #16899. The non-scalable schedule for float32 splits the N axis (corresponding to number of output channels) by 16 in both the unified and the nonunified schedule versions, and then additionally splits the inner partitions by 4 in only the nonunified version to which this patch is reverting (first added in #16106). The two versions' behaviour would be equivalent if none of the padding on the N axis was removed during lowering, however we allow for that to happen as it proved to increase performance for very small convolutions. As it stands, there seems to be a regression in cases where the datatype is float32 and the number of output channels is greater than 16, a multiple of 4, and not a multiple of 16, because even with the removed padding the nonunified schedule is able to vectorise over 4 elements, while the unified version cannot vectorise over 16 elements anymore. Since all of the conv2d NHWC hybrid topi test cases used numbers of output channels either less than 16 or divisible by 16, this patch also adds a new case which falls in the aforementioned regression area. --- python/tvm/topi/arm_cpu/conv2d_gemm.py | 21 ++++++++++++++++++++- tests/python/topi/test_topi_conv2d_nhwc.py | 1 + 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index 26a65f0f224d..5ff2ccb2c137 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -456,7 +456,7 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): s[C].unroll(x_inner) s[C].tensorize(y_inner, gemm_acc) s[C].parallel(x_outer) - else: + elif use_scalable_vectors: k_outer, k_inner = s[C].split(k, factor=tile_K) x_outer, x_inner = s[C].split(x, factor=tile_M) y_outer, y_inner = s[C].split(y, factor=tile_N, disable_predication=use_scalable_vectors) @@ -472,6 +472,25 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): ) s[C].unroll(x_inner) s[C].vectorize(y_inner) + else: + k_outer, k_inner = s[C].split(k, factor=tile_K) + x_outer, x_inner = s[C].split(x, factor=tile_M) + y_outer, y_inner = s[C].split(y, factor=tile_N) + y_inner_outer, y_inner_inner = s[C].split(y_inner, nparts=4) + b_x_outer_fused = s[C].fuse(b, x_outer) + s[C].parallel(b_x_outer_fused) + s[C].reorder( + b_x_outer_fused, + y_outer, + k_outer, + k_inner, + y_inner_outer, + x_inner, + y_inner_inner, + ) + s[C].unroll(y_inner_outer) + s[C].unroll(x_inner) + s[C].vectorize(y_inner_inner) # Input transform if A.op.name == "A_padded_K" or A.op.name == "A_padded_M": diff --git a/tests/python/topi/test_topi_conv2d_nhwc.py b/tests/python/topi/test_topi_conv2d_nhwc.py index e9e532ef4c6d..6ff844de088f 100644 --- a/tests/python/topi/test_topi_conv2d_nhwc.py +++ b/tests/python/topi/test_topi_conv2d_nhwc.py @@ -81,6 +81,7 @@ (1, 7, 4, 16, 3, 1, "SAME", 1), # Pad N (1, 2, 4, 15, 4, 1, "SAME", 1), + (1, 2, 4, 20, 1, 1, "SAME", 1), # Large workloads (1, 256, 32, 256, 3, 1, "SAME", 1), (4, 128, 16, 128, 5, 2, "SAME", 1),