Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion python/tvm/relay/op/contrib/arm_compute_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,10 @@ def qnn_conv2d(expr):
kernel_typ = args[1].checked_type
if len(kernel_typ.shape) != 4 or kernel_typ.dtype not in qnn_dtypes:
return False
if is_per_channel_quantization(
zero_point=args[2], scale=args[4]
) or is_per_channel_quantization(zero_point=args[3], scale=args[5]):
return False
is_depthwise = is_depthwise_conv2d(
data_typ.shape,
attrs["data_layout"],
Expand Down Expand Up @@ -422,6 +426,10 @@ def qnn_dense(expr):
return False
if attrs.out_dtype != "int32":
return False
if is_per_channel_quantization(
zero_point=args[2], scale=args[4]
) or is_per_channel_quantization(zero_point=args[3], scale=args[5]):
return False
return True


Expand Down Expand Up @@ -514,10 +522,24 @@ def qnn_add(expr):
for typ in [args[0].checked_type, args[1].checked_type]:
if typ.dtype not in ["int8", "uint8"]:
return False

if (
is_per_channel_quantization(zero_point=args[3], scale=args[2])
or is_per_channel_quantization(zero_point=args[5], scale=args[4])
or is_per_channel_quantization(zero_point=args[7], scale=args[6])
):
return False
return True


def is_per_channel_quantization(zero_point, scale):
"""Check if the quantization is per-channel"""
for value in [zero_point, scale]:
shape = value.checked_type.shape
if len(shape) != 0 and shape[0] != 1:
return True
return False


class OpAttrContext(object):
"""Temporarily changes the attr of an op."""

Expand Down
31 changes: 30 additions & 1 deletion tests/python/contrib/test_arm_compute_lib/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""Arm Compute Library integration reshape tests."""

import numpy as np
import pytest

import tvm
import tvm.testing
Expand Down Expand Up @@ -134,6 +135,34 @@ def test_codegen_add():
verify_codegen(func, exp_codegen, 1)


@pytest.mark.parametrize(
"param, param_type",
[
("lhs_scale", "float32"),
("lhs_zero_point", "int32"),
("rhs_scale", "float32"),
("rhs_zero_point", "int32"),
],
)
def test_codegen_add_per_channel_quantization(param, param_type):
if skip_codegen_test():
return

qnn_params = _qnn_params
qnn_params[param] = relay.const([1, 2], param_type)

dtype = "int8"
op_name = "qnn.add"
op = relay.qnn.op.add
inputs = {"a", "b"}

for shape in [(1, 3, 3, 2)]:
func = _get_model(shape, dtype, iter(inputs), op, qnn_params)
exp_codegen = _get_expected_codegen(shape, dtype, op_name, qnn_params)
verify_codegen(func, exp_codegen, num_acl_modules=0, tvm_ops=1)


if __name__ == "__main__":
test_codegen_add()
test_runtime_add()
test_codegen_add()
test_codegen_add_per_channel_quantization()
50 changes: 50 additions & 0 deletions tests/python/contrib/test_arm_compute_lib/test_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,8 +615,58 @@ def test_codegen_qnn_conv2d(trial, dtype):
verify_codegen(func, exp_codegen, 1)


@pytest.mark.parametrize(
"param",
["kernel_sc", "kernel_zp"],
)
def test_codegen_qnn_conv2d_per_channel_quantization(param):
if skip_codegen_test():
return

dtype = "int8"
kernel_h = 2
kernel_w = 2
pad = (1, 1)
stride = (1, 1)
dilation = (1, 1)
out_channels = 4
shape = (1, 10, 10, 14)
composite = (False, False, False)
groups = 1
inputs = {"a"}

qnn_params = {
"input_zp": 1,
"input_sc": 1,
"kernel_zp": 1,
"kernel_sc": 1,
"output_zp": 1,
"output_sc": 1,
}
qnn_params[param] = [1, 1, 1, 1]

args = (shape, kernel_h, kernel_w, pad, stride, dilation, groups, dtype, out_channels)

func, params = _get_qnn_model(
*args,
input_zp=qnn_params["input_zp"],
input_sc=qnn_params["input_sc"],
kernel_zp=qnn_params["kernel_zp"],
kernel_sc=qnn_params["kernel_sc"],
output_zp=qnn_params["output_zp"],
output_sc=qnn_params["output_sc"],
var_names=iter(inputs),
has_pad=composite[0],
has_bias=composite[1],
has_activation=composite[2],
)
exp_codegen = _get_expected_codegen(*args, has_bias=composite[1], has_activation=composite[2])
verify_codegen(func, exp_codegen, num_acl_modules=0, tvm_ops=2)


if __name__ == "__main__":
test_conv2d()
test_qnn_conv2d()
test_codegen_conv2d()
test_codegen_qnn_conv2d()
test_codegen_qnn_conv2d_per_channel_quantization()
43 changes: 43 additions & 0 deletions tests/python/contrib/test_arm_compute_lib/test_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,51 @@ def test_codegen_qnn_dense(dtype):
verify_codegen(func, exp_codegen)


@pytest.mark.parametrize(
"param",
["kernel_sc", "kernel_zp"],
)
def test_codegen_qnn_dense_per_channel_quantization(param):
if skip_codegen_test():
return

np.random.seed(0)
dtype = "int8"
shape = (1, 2)
weight_shape = (2, 2)
units = 2
composite = True
inputs = {"a"}
args = (shape, weight_shape, units, dtype)

qnn_params = {
"input_zp": 1,
"input_sc": 1,
"kernel_zp": 1,
"kernel_sc": 1,
"output_zp": 1,
"output_sc": 1,
}
qnn_params[param] = [1, 1]

func, _ = _get_qnn_model(
*args,
var_names=iter(inputs),
input_zp=qnn_params["input_zp"],
input_sc=qnn_params["input_sc"],
kernel_zp=qnn_params["kernel_zp"],
kernel_sc=qnn_params["kernel_sc"],
output_zp=qnn_params["output_zp"],
output_sc=qnn_params["output_sc"],
has_bias=composite,
)
exp_codegen = _get_expected_codegen(*args, has_bias=composite)
verify_codegen(func, exp_codegen, num_acl_modules=0, tvm_ops=3)


if __name__ == "__main__":
test_dense()
test_qnn_dense()
test_codegen_dense()
test_codegen_qnn_dense()
test_codegen_qnn_dense_per_channel_quantization()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to add now, we should probably convert to the:

if __name__ == "__main__":
    tvm.testing.main()

format when we get a another chance