From 6a2b82cf45a5dca50321d720989796582df10972 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 31 Mar 2022 07:15:13 +0900 Subject: [PATCH 1/8] [ARM] pass correct n_elems in NCHWc compute definition --- python/tvm/topi/arm_cpu/conv2d_int8.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/arm_cpu/conv2d_int8.py b/python/tvm/topi/arm_cpu/conv2d_int8.py index 91e3e79cf8c7..2e4b11c37562 100644 --- a/python/tvm/topi/arm_cpu/conv2d_int8.py +++ b/python/tvm/topi/arm_cpu/conv2d_int8.py @@ -63,6 +63,7 @@ def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out # data is nchw, implicitly treat it as nchw1c n, in_channel, ih, iw = get_const_tuple(data.shape) num_filter, _, kh, kw = get_const_tuple(kernel.shape) + n_elems = 1 # Define autotvm tuning space is_kernel_1x1 = kh == 1 and kw == 1 @@ -104,7 +105,7 @@ def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out data, kernel = _pack_data(cfg, data, kernel) return nn.conv2d_NCHWc_int8( - data, kernel, strides, padding, dilation, layout, out_layout, out_dtype + data, kernel, strides, padding, dilation, layout, out_layout, out_dtype, n_elems=n_elems ) @@ -135,6 +136,7 @@ def schedule_conv2d_NCHWc_int8(cfg, outs): def _callback(op): if "conv2d_NCHWc_int8" in op.tag: + return conv_out = op.output(0) kernel_vec = conv_out.op.input_tensors[1] data_vec = conv_out.op.input_tensors[0] From 329a7b42e454fc863d3fe22e72cfc3c31dbad056 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 31 Mar 2022 08:54:15 +0900 Subject: [PATCH 2/8] schedule with dot product intrin computes correct output --- python/tvm/topi/arm_cpu/conv2d_alter_op.py | 8 ++++++-- python/tvm/topi/arm_cpu/conv2d_int8.py | 1 - 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/python/tvm/topi/arm_cpu/conv2d_alter_op.py b/python/tvm/topi/arm_cpu/conv2d_alter_op.py index eb719dd66777..9a3e3b3c3410 100644 --- a/python/tvm/topi/arm_cpu/conv2d_alter_op.py +++ b/python/tvm/topi/arm_cpu/conv2d_alter_op.py @@ -29,7 +29,7 @@ from ..x86.conv2d import _get_default_config as _get_x86_default_config from ..x86.conv2d_int8 import _get_default_config_int8 from .conv2d_int8 import is_int8_hw_support -from .arm_utils import get_tiling_B_interleaved_t +from .arm_utils import get_tiling_B_interleaved_t, is_dotprod_available, is_neon_available from ..generic.conv2d import conv2d_alter_int8_common logger = logging.getLogger("topi") @@ -347,7 +347,11 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape) out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape) - n_elems = 8 + if is_dotprod_available(): + n_elems = 4 + else: + assert is_neon_available(), "Neon required for ARM int8 NCHWc conv2d" + n_elems = 8 if cfg.is_fallback: _get_default_config_int8( diff --git a/python/tvm/topi/arm_cpu/conv2d_int8.py b/python/tvm/topi/arm_cpu/conv2d_int8.py index 2e4b11c37562..4ecf5688b90f 100644 --- a/python/tvm/topi/arm_cpu/conv2d_int8.py +++ b/python/tvm/topi/arm_cpu/conv2d_int8.py @@ -136,7 +136,6 @@ def schedule_conv2d_NCHWc_int8(cfg, outs): def _callback(op): if "conv2d_NCHWc_int8" in op.tag: - return conv_out = op.output(0) kernel_vec = conv_out.op.input_tensors[1] data_vec = conv_out.op.input_tensors[0] From b9396208031921b787e65ff5a686cee2f2eb9172 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 31 Mar 2022 12:35:20 +0900 Subject: [PATCH 3/8] [ARM] Fix int8 NCHWc compute and tensor intrin for non dot product path --- python/tvm/topi/arm_cpu/conv2d_alter_op.py | 8 ++------ python/tvm/topi/arm_cpu/conv2d_int8.py | 3 ++- python/tvm/topi/arm_cpu/tensor_intrin.py | 21 +++++++++++---------- 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/python/tvm/topi/arm_cpu/conv2d_alter_op.py b/python/tvm/topi/arm_cpu/conv2d_alter_op.py index 9a3e3b3c3410..728e0db102fe 100644 --- a/python/tvm/topi/arm_cpu/conv2d_alter_op.py +++ b/python/tvm/topi/arm_cpu/conv2d_alter_op.py @@ -29,7 +29,7 @@ from ..x86.conv2d import _get_default_config as _get_x86_default_config from ..x86.conv2d_int8 import _get_default_config_int8 from .conv2d_int8 import is_int8_hw_support -from .arm_utils import get_tiling_B_interleaved_t, is_dotprod_available, is_neon_available +from .arm_utils import get_tiling_B_interleaved_t from ..generic.conv2d import conv2d_alter_int8_common logger = logging.getLogger("topi") @@ -347,11 +347,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape) out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape) - if is_dotprod_available(): - n_elems = 4 - else: - assert is_neon_available(), "Neon required for ARM int8 NCHWc conv2d" - n_elems = 8 + n_elems = 4 if cfg.is_fallback: _get_default_config_int8( diff --git a/python/tvm/topi/arm_cpu/conv2d_int8.py b/python/tvm/topi/arm_cpu/conv2d_int8.py index 4ecf5688b90f..5d98e562b443 100644 --- a/python/tvm/topi/arm_cpu/conv2d_int8.py +++ b/python/tvm/topi/arm_cpu/conv2d_int8.py @@ -150,7 +150,8 @@ def _callback(op): args = [s, cfg, data_vec, kernel_vec, conv_out, outs[0]] # int8 conv kernel is 7-dim - _, _, kh, kw, _, _, _ = get_const_tuple(kernel_vec.shape) + _, _, kh, kw, _, _, n_elems = get_const_tuple(kernel_vec.shape) + assert n_elems == 4 dtype = "uint" if data.dtype == "uint8" else "int" if is_dotprod_available(): intrin = dot_int8_int8_int32_neon_82(int32_lanes=4, dtype=dtype) diff --git a/python/tvm/topi/arm_cpu/tensor_intrin.py b/python/tvm/topi/arm_cpu/tensor_intrin.py index d6b6f225890a..e27d00f17617 100644 --- a/python/tvm/topi/arm_cpu/tensor_intrin.py +++ b/python/tvm/topi/arm_cpu/tensor_intrin.py @@ -614,21 +614,22 @@ def _instr(index): ib.emit(outs[0].vstore(0, tvm.tir.const(0, int_32xl))) return ib.get() - def pairwise_add_mul(idx): - # this broadcasts data to the vector size - a_int8 = ins[0].vload([0], "int8x4") - re_int32 = tvm.tir.call_intrin("int32", "tir.reinterpret", a_int8) - vec_ai32 = re_int32.astype("int32x2") - vec_a = tvm.tir.call_intrin(int_8xl, "tir.reinterpret", vec_ai32) + # this broadcasts data to the vector size + a_int8 = ins[0].vload([0], "int8x4") + re_int32 = tvm.tir.call_intrin("int32", "tir.reinterpret", a_int8) + vec_ai32 = re_int32.astype("int32x2") + vec_a = tvm.tir.call_intrin(int_8xl, "tir.reinterpret", vec_ai32) - vec_b = ins[1].vload([idx * 2, 0], int_8xl) # we take two inputs at a time + vec_b = ins[1].vload([0, 0], "int8x16") + def pairwise_add_mul(extract_half): + vec_b_half = tvm.tir.call_intrin("int8x8", extract_half, vec_b) multiply = tvm.tir.call_llvm_pure_intrin( "int16x8", "llvm.aarch64.neon.smull.v8i16", # saturating pairwise multiplication tvm.tir.const(2, "uint32"), vec_a, - vec_b, + vec_b_half, ) pairwise_reduction = tvm.tir.call_llvm_pure_intrin( "int32x4", @@ -638,8 +639,8 @@ def pairwise_add_mul(idx): ) return pairwise_reduction - pair_1 = pairwise_add_mul(0) - pair_2 = pairwise_add_mul(1) + pair_1 = pairwise_add_mul("tir.vectorlow") + pair_2 = pairwise_add_mul("tir.vectorhigh") quad_reduction = tvm.tir.call_llvm_pure_intrin( "int32x4", "llvm.aarch64.neon.addp.v4i32", From bb5f1dea447bc318a0bcd57fd519f2aad71140a1 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 31 Mar 2022 12:44:32 +0900 Subject: [PATCH 4/8] enable neon test on aarch64 CI --- tests/python/topi/python/test_topi_conv2d_int8.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/python/topi/python/test_topi_conv2d_int8.py b/tests/python/topi/python/test_topi_conv2d_int8.py index 96457d9b08e6..411d7878ad1e 100644 --- a/tests/python/topi/python/test_topi_conv2d_int8.py +++ b/tests/python/topi/python/test_topi_conv2d_int8.py @@ -21,7 +21,6 @@ import tvm from tvm import te from tvm import autotvm -from tvm.autotvm.task.space import FallbackConfigEntity from tvm import topi import tvm.topi.testing from tvm.contrib.pickle_memoize import memoize @@ -34,6 +33,7 @@ from common import Int8Fallback import tvm.testing import pytest +import platform def compile_conv2d_NHWC_gemm_int8_arm( @@ -299,7 +299,6 @@ def get_ref_data(): a_np, w_np, b_np, c_np = get_ref_data() - print("Running on target: %s" % target) with tvm.target.Target(target): C = compute( A, @@ -342,6 +341,8 @@ def get_ref_data(): if build_only: return + print("Running on target: %s" % target) + func(*run_args) tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5) @@ -364,7 +365,7 @@ def get_ref_data(): # ), ] - # TODO(tvm-team): Properly run ARM code on CI aarch64 environment + # TODO(tvm-team): Figure out ARM dot product availability on CI aarch64 environment targets.append( ( "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod", @@ -376,13 +377,14 @@ def get_ref_data(): ) if in_dtype == "int8": + build_only_aarch64 = platform.machine() != "aarch64" targets.append( ( "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon", topi.arm_cpu.conv2d_NCHWc_int8, topi.arm_cpu.schedule_conv2d_NCHWc_int8, 8, - True, + build_only_aarch64 ) ) From 1cc8acaaa2d1cd556e9db81e6af2b541376da762 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 31 Mar 2022 13:06:18 +0900 Subject: [PATCH 5/8] lint --- tests/python/topi/python/test_topi_conv2d_int8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/topi/python/test_topi_conv2d_int8.py b/tests/python/topi/python/test_topi_conv2d_int8.py index 411d7878ad1e..0f9e38024970 100644 --- a/tests/python/topi/python/test_topi_conv2d_int8.py +++ b/tests/python/topi/python/test_topi_conv2d_int8.py @@ -384,7 +384,7 @@ def get_ref_data(): topi.arm_cpu.conv2d_NCHWc_int8, topi.arm_cpu.schedule_conv2d_NCHWc_int8, 8, - build_only_aarch64 + build_only_aarch64, ) ) From 05d50b305b2b545a3a6db31dc5fb4320ebcd8ab3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 31 Mar 2022 16:03:14 +0900 Subject: [PATCH 6/8] Correctly account for n_elems when input is NCHW --- python/tvm/topi/arm_cpu/conv2d_int8.py | 5 +++-- tests/python/topi/python/test_topi_conv2d_int8.py | 2 -- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/tvm/topi/arm_cpu/conv2d_int8.py b/python/tvm/topi/arm_cpu/conv2d_int8.py index 5d98e562b443..b6ab89de8b0a 100644 --- a/python/tvm/topi/arm_cpu/conv2d_int8.py +++ b/python/tvm/topi/arm_cpu/conv2d_int8.py @@ -57,13 +57,12 @@ def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) in_channel = ic_chunk * ic_bn - oc_chunk, ic_chunk, kh, kw, ic_bn, oc_bn, n_elems = get_const_tuple(kernel.shape) + oc_chunk, ic_chunk, kh, kw, ic_bn, oc_bn, _ = get_const_tuple(kernel.shape) num_filter = oc_chunk * oc_bn else: # data is nchw, implicitly treat it as nchw1c n, in_channel, ih, iw = get_const_tuple(data.shape) num_filter, _, kh, kw = get_const_tuple(kernel.shape) - n_elems = 1 # Define autotvm tuning space is_kernel_1x1 = kh == 1 and kw == 1 @@ -104,6 +103,8 @@ def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out if len(data.shape) == 4: data, kernel = _pack_data(cfg, data, kernel) + n_elems = int(kernel.shape[-1]) + return nn.conv2d_NCHWc_int8( data, kernel, strides, padding, dilation, layout, out_layout, out_dtype, n_elems=n_elems ) diff --git a/tests/python/topi/python/test_topi_conv2d_int8.py b/tests/python/topi/python/test_topi_conv2d_int8.py index 0f9e38024970..8cbbabdb84de 100644 --- a/tests/python/topi/python/test_topi_conv2d_int8.py +++ b/tests/python/topi/python/test_topi_conv2d_int8.py @@ -310,8 +310,6 @@ def get_ref_data(): "NCHW", out_dtype, ) - print(C.shape) - print(bias.shape) if add_bias: C = topi.add(C, bias) if add_relu: From b6729a582104cf18bf33062bc08d380a91362657 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 31 Mar 2022 18:53:48 +0900 Subject: [PATCH 7/8] fixed pack_data --- python/tvm/topi/nn/conv2d.py | 1 - python/tvm/topi/x86/conv2d_int8.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 68eb4eb6f01b..c27ea81144ac 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -486,7 +486,6 @@ def conv2d_NCHWc_int8( oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple( kernel.shape ) - num_filter = oc_chunk * oc_bn groups = ic_chunk // ic_chunk_group dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 diff --git a/python/tvm/topi/x86/conv2d_int8.py b/python/tvm/topi/x86/conv2d_int8.py index b0edb02b0804..048d9468051b 100644 --- a/python/tvm/topi/x86/conv2d_int8.py +++ b/python/tvm/topi/x86/conv2d_int8.py @@ -120,7 +120,7 @@ def _pack_data(cfg, data, kernel): kernel = te.compute( (oc_chunk, ic_chunk, kh, kw, ic_bn // n_elems, oc_bn, n_elems), lambda occ, icc, k_h, k_w, icbc, ocb, icbb: kernel[ - occ * oc_bn + ocb, icc * ic_bn + icbc * ic_bn // n_elems + icbb, k_h, k_w + occ * oc_bn + ocb, icc * ic_bn + icbc * n_elems + icbb, k_h, k_w ], name="kernel_vec", ) From 53ff53eea8e0445d37635987e1f30d9150cfb92c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 1 Apr 2022 04:17:27 +0900 Subject: [PATCH 8/8] try run dot product schedule on CI --- tests/python/topi/python/test_topi_conv2d_int8.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/topi/python/test_topi_conv2d_int8.py b/tests/python/topi/python/test_topi_conv2d_int8.py index 8cbbabdb84de..860118531e51 100644 --- a/tests/python/topi/python/test_topi_conv2d_int8.py +++ b/tests/python/topi/python/test_topi_conv2d_int8.py @@ -363,19 +363,19 @@ def get_ref_data(): # ), ] - # TODO(tvm-team): Figure out ARM dot product availability on CI aarch64 environment + build_only_aarch64 = platform.machine() != "aarch64" + targets.append( ( "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod", topi.arm_cpu.conv2d_NCHWc_int8, topi.arm_cpu.schedule_conv2d_NCHWc_int8, 8, - True, + build_only_aarch64, ) ) if in_dtype == "int8": - build_only_aarch64 = platform.machine() != "aarch64" targets.append( ( "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",