From 0541a0cf2901e2edafb6d03d42c73a55fefd70d7 Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi Date: Fri, 10 Dec 2021 13:01:58 +0000 Subject: [PATCH 1/2] [CMSIS-NN] Fixed return data type from pattern callback function Change-Id: If87e853e094b72e6c34e030cd1ae3b690982332f --- python/tvm/relay/op/contrib/cmsisnn.py | 4 ++-- tests/python/contrib/test_cmsisnn/test_conv2d.py | 12 ++++++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/op/contrib/cmsisnn.py b/python/tvm/relay/op/contrib/cmsisnn.py index b80da2c8ccd0..f2099321c930 100644 --- a/python/tvm/relay/op/contrib/cmsisnn.py +++ b/python/tvm/relay/op/contrib/cmsisnn.py @@ -135,8 +135,8 @@ def check_qnn_conv2d(pattern): return ( conv2d.attrs.out_dtype == "int32" - and conv2d.attrs.padding[2] == 0 - and conv2d.attrs.padding[3] == 0 + and int(conv2d.attrs.padding[2]) == 0 + and int(conv2d.attrs.padding[3]) == 0 and conv2d_input.checked_type.dtype == "int8" and conv2d_weight.checked_type.dtype == "int8" and pattern.checked_type.dtype == "int8" diff --git a/tests/python/contrib/test_cmsisnn/test_conv2d.py b/tests/python/contrib/test_cmsisnn/test_conv2d.py index 8d62763aec52..7eefcc8d481b 100644 --- a/tests/python/contrib/test_cmsisnn/test_conv2d.py +++ b/tests/python/contrib/test_cmsisnn/test_conv2d.py @@ -69,6 +69,8 @@ def make_model( kernel_w = kernel_shape[w_index] invar = relay.var("input", shape=shape, dtype=dtype) p = (0, 0, 0, 0) + if padding == "INVALID": + p = [1, 2, 1, 2] if padding == "SAME": p = get_same_padding((shape[1], shape[2]), (kernel_h, kernel_w), dilation, strides) invar = relay.nn.pad( @@ -351,15 +353,16 @@ def parameterize_for_invalid_model(test): in_dtype = ["uint8", "int8"] kernel_dtype = ["uint8", "int8"] kernel_zero_point = [-33, 10, 0] - all_combinations = itertools.product(in_dtype, kernel_dtype, kernel_zero_point) + padding = ["SAME", "INVALID"] + all_combinations = itertools.product(in_dtype, kernel_dtype, kernel_zero_point, padding) all_combinations = filter( lambda parameters: not ( - parameters[0] == "int8" and parameters[1] == "int8" and parameters[2] == 0 + parameters[0] == "int8" and parameters[1] == "int8" and parameters[2] == 0 and parameters[3] == "SAME" ), all_combinations, ) return pytest.mark.parametrize( - ["in_dtype", "kernel_dtype", "kernel_zero_point"], + ["in_dtype", "kernel_dtype", "kernel_zero_point", "padding"], all_combinations, )(test) @@ -370,6 +373,7 @@ def test_invalid_parameters( in_dtype, kernel_dtype, kernel_zero_point, + padding, ): ifm_shape = (1, 28, 28, 12) out_channels = 2 @@ -400,7 +404,7 @@ def test_invalid_parameters( kernel_scale=kernel_scale, output_zero_point=output_zero_point, output_scale=output_scale, - padding="SAME", + padding=padding, strides=(1, 1), dilation=(1, 1), groups=1, From 8ed26fcf48879cedc2899016b5dfab4ff2cbaf24 Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi Date: Fri, 10 Dec 2021 16:17:10 +0000 Subject: [PATCH 2/2] Fixed lint error Change-Id: Ib014ab6829d67b9d05f06fa6ed58f5e775daa5dd --- tests/python/contrib/test_cmsisnn/test_conv2d.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/python/contrib/test_cmsisnn/test_conv2d.py b/tests/python/contrib/test_cmsisnn/test_conv2d.py index 7eefcc8d481b..ed389a661699 100644 --- a/tests/python/contrib/test_cmsisnn/test_conv2d.py +++ b/tests/python/contrib/test_cmsisnn/test_conv2d.py @@ -357,7 +357,10 @@ def parameterize_for_invalid_model(test): all_combinations = itertools.product(in_dtype, kernel_dtype, kernel_zero_point, padding) all_combinations = filter( lambda parameters: not ( - parameters[0] == "int8" and parameters[1] == "int8" and parameters[2] == 0 and parameters[3] == "SAME" + parameters[0] == "int8" + and parameters[1] == "int8" + and parameters[2] == 0 + and parameters[3] == "SAME" ), all_combinations, )