diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index 26a65f0f224d..5ff2ccb2c137 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -456,7 +456,7 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): s[C].unroll(x_inner) s[C].tensorize(y_inner, gemm_acc) s[C].parallel(x_outer) - else: + elif use_scalable_vectors: k_outer, k_inner = s[C].split(k, factor=tile_K) x_outer, x_inner = s[C].split(x, factor=tile_M) y_outer, y_inner = s[C].split(y, factor=tile_N, disable_predication=use_scalable_vectors) @@ -472,6 +472,25 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): ) s[C].unroll(x_inner) s[C].vectorize(y_inner) + else: + k_outer, k_inner = s[C].split(k, factor=tile_K) + x_outer, x_inner = s[C].split(x, factor=tile_M) + y_outer, y_inner = s[C].split(y, factor=tile_N) + y_inner_outer, y_inner_inner = s[C].split(y_inner, nparts=4) + b_x_outer_fused = s[C].fuse(b, x_outer) + s[C].parallel(b_x_outer_fused) + s[C].reorder( + b_x_outer_fused, + y_outer, + k_outer, + k_inner, + y_inner_outer, + x_inner, + y_inner_inner, + ) + s[C].unroll(y_inner_outer) + s[C].unroll(x_inner) + s[C].vectorize(y_inner_inner) # Input transform if A.op.name == "A_padded_K" or A.op.name == "A_padded_M": diff --git a/tests/python/topi/test_topi_conv2d_nhwc.py b/tests/python/topi/test_topi_conv2d_nhwc.py index e9e532ef4c6d..6ff844de088f 100644 --- a/tests/python/topi/test_topi_conv2d_nhwc.py +++ b/tests/python/topi/test_topi_conv2d_nhwc.py @@ -81,6 +81,7 @@ (1, 7, 4, 16, 3, 1, "SAME", 1), # Pad N (1, 2, 4, 15, 4, 1, "SAME", 1), + (1, 2, 4, 20, 1, 1, "SAME", 1), # Large workloads (1, 256, 32, 256, 3, 1, "SAME", 1), (4, 128, 16, 128, 5, 2, "SAME", 1),