diff --git a/python/tvm/topi/arm_cpu/conv2d_alter_op.py b/python/tvm/topi/arm_cpu/conv2d_alter_op.py index eb719dd66777..728e0db102fe 100644 --- a/python/tvm/topi/arm_cpu/conv2d_alter_op.py +++ b/python/tvm/topi/arm_cpu/conv2d_alter_op.py @@ -347,7 +347,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape) out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape) - n_elems = 8 + n_elems = 4 if cfg.is_fallback: _get_default_config_int8( diff --git a/python/tvm/topi/arm_cpu/conv2d_int8.py b/python/tvm/topi/arm_cpu/conv2d_int8.py index 91e3e79cf8c7..b6ab89de8b0a 100644 --- a/python/tvm/topi/arm_cpu/conv2d_int8.py +++ b/python/tvm/topi/arm_cpu/conv2d_int8.py @@ -57,7 +57,7 @@ def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) in_channel = ic_chunk * ic_bn - oc_chunk, ic_chunk, kh, kw, ic_bn, oc_bn, n_elems = get_const_tuple(kernel.shape) + oc_chunk, ic_chunk, kh, kw, ic_bn, oc_bn, _ = get_const_tuple(kernel.shape) num_filter = oc_chunk * oc_bn else: # data is nchw, implicitly treat it as nchw1c @@ -103,8 +103,10 @@ def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out if len(data.shape) == 4: data, kernel = _pack_data(cfg, data, kernel) + n_elems = int(kernel.shape[-1]) + return nn.conv2d_NCHWc_int8( - data, kernel, strides, padding, dilation, layout, out_layout, out_dtype + data, kernel, strides, padding, dilation, layout, out_layout, out_dtype, n_elems=n_elems ) @@ -149,7 +151,8 @@ def _callback(op): args = [s, cfg, data_vec, kernel_vec, conv_out, outs[0]] # int8 conv kernel is 7-dim - _, _, kh, kw, _, _, _ = get_const_tuple(kernel_vec.shape) + _, _, kh, kw, _, _, n_elems = get_const_tuple(kernel_vec.shape) + assert n_elems == 4 dtype = "uint" if data.dtype == "uint8" else "int" if is_dotprod_available(): intrin = dot_int8_int8_int32_neon_82(int32_lanes=4, dtype=dtype) diff --git a/python/tvm/topi/arm_cpu/tensor_intrin.py b/python/tvm/topi/arm_cpu/tensor_intrin.py index d6b6f225890a..e27d00f17617 100644 --- a/python/tvm/topi/arm_cpu/tensor_intrin.py +++ b/python/tvm/topi/arm_cpu/tensor_intrin.py @@ -614,21 +614,22 @@ def _instr(index): ib.emit(outs[0].vstore(0, tvm.tir.const(0, int_32xl))) return ib.get() - def pairwise_add_mul(idx): - # this broadcasts data to the vector size - a_int8 = ins[0].vload([0], "int8x4") - re_int32 = tvm.tir.call_intrin("int32", "tir.reinterpret", a_int8) - vec_ai32 = re_int32.astype("int32x2") - vec_a = tvm.tir.call_intrin(int_8xl, "tir.reinterpret", vec_ai32) + # this broadcasts data to the vector size + a_int8 = ins[0].vload([0], "int8x4") + re_int32 = tvm.tir.call_intrin("int32", "tir.reinterpret", a_int8) + vec_ai32 = re_int32.astype("int32x2") + vec_a = tvm.tir.call_intrin(int_8xl, "tir.reinterpret", vec_ai32) - vec_b = ins[1].vload([idx * 2, 0], int_8xl) # we take two inputs at a time + vec_b = ins[1].vload([0, 0], "int8x16") + def pairwise_add_mul(extract_half): + vec_b_half = tvm.tir.call_intrin("int8x8", extract_half, vec_b) multiply = tvm.tir.call_llvm_pure_intrin( "int16x8", "llvm.aarch64.neon.smull.v8i16", # saturating pairwise multiplication tvm.tir.const(2, "uint32"), vec_a, - vec_b, + vec_b_half, ) pairwise_reduction = tvm.tir.call_llvm_pure_intrin( "int32x4", @@ -638,8 +639,8 @@ def pairwise_add_mul(idx): ) return pairwise_reduction - pair_1 = pairwise_add_mul(0) - pair_2 = pairwise_add_mul(1) + pair_1 = pairwise_add_mul("tir.vectorlow") + pair_2 = pairwise_add_mul("tir.vectorhigh") quad_reduction = tvm.tir.call_llvm_pure_intrin( "int32x4", "llvm.aarch64.neon.addp.v4i32", diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 68eb4eb6f01b..c27ea81144ac 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -486,7 +486,6 @@ def conv2d_NCHWc_int8( oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple( kernel.shape ) - num_filter = oc_chunk * oc_bn groups = ic_chunk // ic_chunk_group dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 diff --git a/python/tvm/topi/x86/conv2d_int8.py b/python/tvm/topi/x86/conv2d_int8.py index b0edb02b0804..048d9468051b 100644 --- a/python/tvm/topi/x86/conv2d_int8.py +++ b/python/tvm/topi/x86/conv2d_int8.py @@ -120,7 +120,7 @@ def _pack_data(cfg, data, kernel): kernel = te.compute( (oc_chunk, ic_chunk, kh, kw, ic_bn // n_elems, oc_bn, n_elems), lambda occ, icc, k_h, k_w, icbc, ocb, icbb: kernel[ - occ * oc_bn + ocb, icc * ic_bn + icbc * ic_bn // n_elems + icbb, k_h, k_w + occ * oc_bn + ocb, icc * ic_bn + icbc * n_elems + icbb, k_h, k_w ], name="kernel_vec", ) diff --git a/tests/python/topi/python/test_topi_conv2d_int8.py b/tests/python/topi/python/test_topi_conv2d_int8.py index 96457d9b08e6..860118531e51 100644 --- a/tests/python/topi/python/test_topi_conv2d_int8.py +++ b/tests/python/topi/python/test_topi_conv2d_int8.py @@ -21,7 +21,6 @@ import tvm from tvm import te from tvm import autotvm -from tvm.autotvm.task.space import FallbackConfigEntity from tvm import topi import tvm.topi.testing from tvm.contrib.pickle_memoize import memoize @@ -34,6 +33,7 @@ from common import Int8Fallback import tvm.testing import pytest +import platform def compile_conv2d_NHWC_gemm_int8_arm( @@ -299,7 +299,6 @@ def get_ref_data(): a_np, w_np, b_np, c_np = get_ref_data() - print("Running on target: %s" % target) with tvm.target.Target(target): C = compute( A, @@ -311,8 +310,6 @@ def get_ref_data(): "NCHW", out_dtype, ) - print(C.shape) - print(bias.shape) if add_bias: C = topi.add(C, bias) if add_relu: @@ -342,6 +339,8 @@ def get_ref_data(): if build_only: return + print("Running on target: %s" % target) + func(*run_args) tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5) @@ -364,14 +363,15 @@ def get_ref_data(): # ), ] - # TODO(tvm-team): Properly run ARM code on CI aarch64 environment + build_only_aarch64 = platform.machine() != "aarch64" + targets.append( ( "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod", topi.arm_cpu.conv2d_NCHWc_int8, topi.arm_cpu.schedule_conv2d_NCHWc_int8, 8, - True, + build_only_aarch64, ) ) @@ -382,7 +382,7 @@ def get_ref_data(): topi.arm_cpu.conv2d_NCHWc_int8, topi.arm_cpu.schedule_conv2d_NCHWc_int8, 8, - True, + build_only_aarch64, ) )