diff --git a/python/tvm/relay/op/contrib/arm_compute_lib.py b/python/tvm/relay/op/contrib/arm_compute_lib.py index 1b9abb0948b5..6b8d000c6664 100644 --- a/python/tvm/relay/op/contrib/arm_compute_lib.py +++ b/python/tvm/relay/op/contrib/arm_compute_lib.py @@ -359,6 +359,10 @@ def qnn_conv2d(expr): kernel_typ = args[1].checked_type if len(kernel_typ.shape) != 4 or kernel_typ.dtype not in qnn_dtypes: return False + if is_per_channel_quantization( + zero_point=args[2], scale=args[4] + ) or is_per_channel_quantization(zero_point=args[3], scale=args[5]): + return False is_depthwise = is_depthwise_conv2d( data_typ.shape, attrs["data_layout"], @@ -422,6 +426,10 @@ def qnn_dense(expr): return False if attrs.out_dtype != "int32": return False + if is_per_channel_quantization( + zero_point=args[2], scale=args[4] + ) or is_per_channel_quantization(zero_point=args[3], scale=args[5]): + return False return True @@ -514,10 +522,24 @@ def qnn_add(expr): for typ in [args[0].checked_type, args[1].checked_type]: if typ.dtype not in ["int8", "uint8"]: return False - + if ( + is_per_channel_quantization(zero_point=args[3], scale=args[2]) + or is_per_channel_quantization(zero_point=args[5], scale=args[4]) + or is_per_channel_quantization(zero_point=args[7], scale=args[6]) + ): + return False return True +def is_per_channel_quantization(zero_point, scale): + """Check if the quantization is per-channel""" + for value in [zero_point, scale]: + shape = value.checked_type.shape + if len(shape) != 0 and shape[0] != 1: + return True + return False + + class OpAttrContext(object): """Temporarily changes the attr of an op.""" diff --git a/tests/python/contrib/test_arm_compute_lib/test_add.py b/tests/python/contrib/test_arm_compute_lib/test_add.py index ee6fcf603cb0..319105bb5fd9 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_add.py +++ b/tests/python/contrib/test_arm_compute_lib/test_add.py @@ -17,6 +17,7 @@ """Arm Compute Library integration reshape tests.""" import numpy as np +import pytest import tvm import tvm.testing @@ -134,6 +135,34 @@ def test_codegen_add(): verify_codegen(func, exp_codegen, 1) +@pytest.mark.parametrize( + "param, param_type", + [ + ("lhs_scale", "float32"), + ("lhs_zero_point", "int32"), + ("rhs_scale", "float32"), + ("rhs_zero_point", "int32"), + ], +) +def test_codegen_add_per_channel_quantization(param, param_type): + if skip_codegen_test(): + return + + qnn_params = _qnn_params + qnn_params[param] = relay.const([1, 2], param_type) + + dtype = "int8" + op_name = "qnn.add" + op = relay.qnn.op.add + inputs = {"a", "b"} + + for shape in [(1, 3, 3, 2)]: + func = _get_model(shape, dtype, iter(inputs), op, qnn_params) + exp_codegen = _get_expected_codegen(shape, dtype, op_name, qnn_params) + verify_codegen(func, exp_codegen, num_acl_modules=0, tvm_ops=1) + + if __name__ == "__main__": - test_codegen_add() test_runtime_add() + test_codegen_add() + test_codegen_add_per_channel_quantization() diff --git a/tests/python/contrib/test_arm_compute_lib/test_conv2d.py b/tests/python/contrib/test_arm_compute_lib/test_conv2d.py index df708020bf0f..b4fa49ffa288 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_conv2d.py +++ b/tests/python/contrib/test_arm_compute_lib/test_conv2d.py @@ -615,8 +615,58 @@ def test_codegen_qnn_conv2d(trial, dtype): verify_codegen(func, exp_codegen, 1) +@pytest.mark.parametrize( + "param", + ["kernel_sc", "kernel_zp"], +) +def test_codegen_qnn_conv2d_per_channel_quantization(param): + if skip_codegen_test(): + return + + dtype = "int8" + kernel_h = 2 + kernel_w = 2 + pad = (1, 1) + stride = (1, 1) + dilation = (1, 1) + out_channels = 4 + shape = (1, 10, 10, 14) + composite = (False, False, False) + groups = 1 + inputs = {"a"} + + qnn_params = { + "input_zp": 1, + "input_sc": 1, + "kernel_zp": 1, + "kernel_sc": 1, + "output_zp": 1, + "output_sc": 1, + } + qnn_params[param] = [1, 1, 1, 1] + + args = (shape, kernel_h, kernel_w, pad, stride, dilation, groups, dtype, out_channels) + + func, params = _get_qnn_model( + *args, + input_zp=qnn_params["input_zp"], + input_sc=qnn_params["input_sc"], + kernel_zp=qnn_params["kernel_zp"], + kernel_sc=qnn_params["kernel_sc"], + output_zp=qnn_params["output_zp"], + output_sc=qnn_params["output_sc"], + var_names=iter(inputs), + has_pad=composite[0], + has_bias=composite[1], + has_activation=composite[2], + ) + exp_codegen = _get_expected_codegen(*args, has_bias=composite[1], has_activation=composite[2]) + verify_codegen(func, exp_codegen, num_acl_modules=0, tvm_ops=2) + + if __name__ == "__main__": test_conv2d() test_qnn_conv2d() test_codegen_conv2d() test_codegen_qnn_conv2d() + test_codegen_qnn_conv2d_per_channel_quantization() diff --git a/tests/python/contrib/test_arm_compute_lib/test_dense.py b/tests/python/contrib/test_arm_compute_lib/test_dense.py index bbcfc4abe6a9..411f790f347d 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_dense.py +++ b/tests/python/contrib/test_arm_compute_lib/test_dense.py @@ -380,8 +380,51 @@ def test_codegen_qnn_dense(dtype): verify_codegen(func, exp_codegen) +@pytest.mark.parametrize( + "param", + ["kernel_sc", "kernel_zp"], +) +def test_codegen_qnn_dense_per_channel_quantization(param): + if skip_codegen_test(): + return + + np.random.seed(0) + dtype = "int8" + shape = (1, 2) + weight_shape = (2, 2) + units = 2 + composite = True + inputs = {"a"} + args = (shape, weight_shape, units, dtype) + + qnn_params = { + "input_zp": 1, + "input_sc": 1, + "kernel_zp": 1, + "kernel_sc": 1, + "output_zp": 1, + "output_sc": 1, + } + qnn_params[param] = [1, 1] + + func, _ = _get_qnn_model( + *args, + var_names=iter(inputs), + input_zp=qnn_params["input_zp"], + input_sc=qnn_params["input_sc"], + kernel_zp=qnn_params["kernel_zp"], + kernel_sc=qnn_params["kernel_sc"], + output_zp=qnn_params["output_zp"], + output_sc=qnn_params["output_sc"], + has_bias=composite, + ) + exp_codegen = _get_expected_codegen(*args, has_bias=composite) + verify_codegen(func, exp_codegen, num_acl_modules=0, tvm_ops=3) + + if __name__ == "__main__": test_dense() test_qnn_dense() test_codegen_dense() test_codegen_qnn_dense() + test_codegen_qnn_dense_per_channel_quantization()