From bf6a26b45f511a82c0e66c0ac50affc03fcfe2ac Mon Sep 17 00:00:00 2001 From: Andrei Hutu Date: Wed, 6 Sep 2023 15:58:45 +0000 Subject: [PATCH] [Bugfix][Strategy] Fix `arm_cpu` int8 conv2d strategy for dotprod and i8mm targets Whenever both dotprod and i8mm were available together on a target (e.g. `"llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod,+i8mm"`), the native int8 conv2d implementation corresponding to the `+dotprod` attribute would be selected, but the compute definition of the conv2d operation would be constructed for the `+i8mm` attribute and its related interleaved schedule instead. The reason for this was a different order of conditional statements being used in 2 separate files: - `arm_cpu.py`: When selecting the conv2d implementation, the program first checked for `dotprod` support. If present, it chose the native schedule - `conv2d_gemm.py`: when constructing the compute definition, `i8mm` support is checked first, then `dotprod` To fix this, I modified the int8 conv2d strategy to prioritize `i8mm` over `dotprod` when both are available too. --- python/tvm/relay/op/strategy/arm_cpu.py | 49 ++++++++++++++----- .../strategy/test_select_implementation.py | 8 +++ 2 files changed, 45 insertions(+), 12 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index b64c541863f7..a23ccf8f6932 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -213,19 +213,35 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): is_aarch64 = target.features.is_aarch64 has_asimd = target.features.has_asimd has_dot_prod = target.features.has_dotprod + has_matmul_i8 = target.features.has_matmul_i8 - if has_dot_prod and data.dtype in ["int8", "uint8"]: - strategy.add_implementation( - wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_native), - wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_native), - name="conv2d_NHWC_quantized_native.arm_cpu", - ) - if is_aarch64 and has_asimd and data.dtype in ["int8", "uint8"]: - strategy.add_implementation( - wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved), - wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved), - name="conv2d_NHWC_quantized_interleaved.arm_cpu", - ) + if data.dtype in ["int8", "uint8"]: + if has_matmul_i8: + strategy.add_implementation( + wrap_compute_conv2d( + topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved + ), + wrap_topi_schedule( + topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved + ), + name="conv2d_NHWC_quantized_interleaved.arm_cpu", + ) + if has_dot_prod: + strategy.add_implementation( + wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_native), + wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_native), + name="conv2d_NHWC_quantized_native.arm_cpu", + ) + if is_aarch64 and has_asimd: + strategy.add_implementation( + wrap_compute_conv2d( + topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved + ), + wrap_topi_schedule( + topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved + ), + name="conv2d_NHWC_quantized_interleaved.arm_cpu", + ) if (not is_aarch64) or (data.dtype not in ["int8", "uint8"]): # TODO(@giuseros) # This strategy errors out for quantized data types when tuning. @@ -471,10 +487,19 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ is_aarch64 = target.features.is_aarch64 has_asimd = target.features.has_asimd has_dot_prod = target.features.has_dotprod + has_matmul_i8 = target.features.has_matmul_i8 interleaved_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved_without_transform native_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_native_without_transform if layout == "NHWC" and data.dtype in ["int8", "uint8"]: + if has_matmul_i8: + strategy.add_implementation( + wrap_compute_conv2d_gemm(interleaved_compute), + wrap_topi_schedule( + topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform + ), + name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu", + ) if has_dot_prod: strategy.add_implementation( wrap_compute_conv2d_gemm(native_compute), diff --git a/tests/python/relay/strategy/test_select_implementation.py b/tests/python/relay/strategy/test_select_implementation.py index 906ef2d161b0..d7dd0abbc4d7 100644 --- a/tests/python/relay/strategy/test_select_implementation.py +++ b/tests/python/relay/strategy/test_select_implementation.py @@ -81,6 +81,14 @@ def test_concatenate(target, expected_implementation): "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v8.2a,+i8mm", "conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu", ), + ( + "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod,+i8mm", + "conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu", + ), + ( + "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v9a", + "conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu", + ), ], ) def test_int8_conv2d(target, expected_impl):