From 98568b226127a2a2989b7269bf67774e829bfec6 Mon Sep 17 00:00:00 2001 From: Leandro Nunes Date: Fri, 8 Jul 2022 16:59:11 +0100 Subject: [PATCH] [TFLite] Enable int64 biases for int16 quantized operators This enables int64 biases for quantized fully connected, requantize and transpose convolution in TFLite networks. It goes on top of existing int16 support for TFLite frontend. Add a test case using DS_CNN int16 quantized. Change-Id: I3006ee76f5037fb6f915818358c9aada2faf40bf --- python/tvm/relay/frontend/tflite.py | 6 +- src/relay/qnn/op/convolution_transpose.cc | 10 +- src/relay/qnn/op/dense.cc | 10 +- src/relay/qnn/op/requantize.cc | 5 +- .../test_ethosn/test_convert_equivalents.py | 4 +- tests/python/frontend/tflite/test_forward.py | 23 + tests/python/relay/test_op_qnn_requantize.py | 495 ++++++++++-------- 7 files changed, 329 insertions(+), 224 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 1915eb9322ff..3d2f4a2f25e6 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -1966,7 +1966,7 @@ def convert_fully_connected(self, op): input_scale=input_tensor.qnn_params["scale"], kernel_scale=weight_tensor.qnn_params["scale"], units=weight_shape[0], - out_dtype="int32", + out_dtype="int64" if output_tensor_type_str == "int16" else "int32", ) else: out = _op.nn.dense(in_expr, weight_expr, units=weight_shape[0]) @@ -1977,7 +1977,7 @@ def convert_fully_connected(self, op): if bias_tensor.tensor_idx != -1: bias_tensor_type = bias_tensor.tensor.Type() # bias tensor type should be INT32 (quantization) or FLOAT32 - assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32) + assert bias_tensor_type in (TensorType.INT32, TensorType.INT64, TensorType.FLOAT32) bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type) if self.has_expr(bias_tensor.tensor_idx): bias_expr = self.get_expr(bias_tensor.tensor_idx) @@ -3175,7 +3175,7 @@ def convert_transpose_conv(self, op): bias_tensor = input_tensors[3] bias_tensor_type = bias_tensor.tensor.Type() # bias tensor type should be INT32 (quantization) or FLOAT32 - assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32) + assert bias_tensor_type in (TensorType.INT32, TensorType.INT64, TensorType.FLOAT32) bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type) if self.has_expr(bias_tensor.tensor_idx): bias_expr = self.get_expr(bias_tensor.tensor_idx) diff --git a/src/relay/qnn/op/convolution_transpose.cc b/src/relay/qnn/op/convolution_transpose.cc index 6163e1c20429..951c1bdfb051 100644 --- a/src/relay/qnn/op/convolution_transpose.cc +++ b/src/relay/qnn/op/convolution_transpose.cc @@ -93,12 +93,14 @@ bool QnnConv2DTransposeRel(const Array& types, int num_inputs, const Attrs if (data == nullptr || weight == nullptr) return false; const auto* param = attrs.as(); ICHECK(param != nullptr) << "Conv2DTransposeAttrs cannot be nullptr."; - ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8)) - << "Expected qnn conv2d type(int8, uint8) for input but was " << data->dtype; + ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8) || + data->dtype == DataType::Int(16) || data->dtype == DataType::UInt(16)) + << "Expected qnn conv2d type(int8, uint8, int16) for input but was " << data->dtype; ICHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8)) << "Expected qnn conv2d type(int8, uint8) for weight but was " << weight->dtype; - ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == DataType::Int(32)) - << "Expected qnn conv2d type(int32, int16) for output but was " << param->out_dtype; + ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == DataType::Int(32) || + data->dtype == DataType::Int(64)) + << "Expected qnn conv2d type(int16, int32, int64) for output but was " << param->out_dtype; ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0."; // Check the types of scale and zero points. diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index adaf509e7daf..09d51e3c9ce7 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -47,12 +47,14 @@ bool QnnDenseRel(const Array& types, int num_inputs, const Attrs& attrs, if (data == nullptr || weight == nullptr) return false; const auto* param = attrs.as(); ICHECK(param != nullptr) << "DenseAttrs cannot be nullptr."; - ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8)) - << "Expected quantized dense type(int8, uint8) for input but was " << data->dtype; + ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8) || + data->dtype == DataType::Int(16) || data->dtype == DataType::UInt(16)) + << "Expected quantized dense type(int8, uint8, int16, uint16) for input but was " + << data->dtype; ICHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8)) << "Expected quantized dense type(int8, uint8) for weight but was " << weight->dtype; - ICHECK(param->out_dtype == DataType::Int(32)) - << "Expected quantized dense type(int32) for output but was " << param->out_dtype; + ICHECK(param->out_dtype == DataType::Int(32) || param->out_dtype == DataType::Int(64)) + << "Expected quantized dense type(int32, int64) for output but was " << param->out_dtype; // Check the types of scale and zero points. for (size_t i = 2; i < 5; ++i) { diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 1614652719c6..e199ea27f1e4 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -480,8 +480,9 @@ bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, } const auto in_dtype = data->dtype; ICHECK(in_dtype == DataType::Int(8) || in_dtype == DataType::UInt(8) || - in_dtype == DataType::Int(32) || in_dtype == DataType::Int(64)) - << "Input type should be one of [int8, uint8, int32, int64] but was " << in_dtype; + in_dtype == DataType::Int(16) || in_dtype == DataType::Int(32) || + in_dtype == DataType::Int(64)) + << "Input type should be one of [int8, uint8, int16, int32, int64] but was " << in_dtype; const RequantizeAttrs* requantize_attrs = attrs.as(); int axis = requantize_attrs->axis; diff --git a/tests/python/contrib/test_ethosn/test_convert_equivalents.py b/tests/python/contrib/test_ethosn/test_convert_equivalents.py index 77777293729c..a3e48f4424ad 100644 --- a/tests/python/contrib/test_ethosn/test_convert_equivalents.py +++ b/tests/python/contrib/test_ethosn/test_convert_equivalents.py @@ -227,7 +227,7 @@ def expected(): @requires_ethosn @pytest.mark.parametrize( "dtype,shape,constant_shape", - [("int16", (1, 16, 12, 4), None)], + [("float32", (1, 16, 12, 4), None)], ) def test_unsupported_multiply_to_reinterpret_quantize(dtype, shape, constant_shape): """ @@ -445,7 +445,7 @@ def expected(): @pytest.mark.parametrize( "dtype,shape,constant_shape", [ - ("int16", (1, 16, 12, 4), None), + ("float32", (1, 16, 12, 4), None), ], ) def test_unsupported_add_to_reinterpret_quantize(dtype, shape, constant_shape): diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 7b2bd60d8a20..877406ae2a64 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -4878,6 +4878,28 @@ def representative_dataset(): tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) +def test_forward_ds_cnn_int16(): + """Test DS_CNN int16 quantized model""" + tflite_model_file = download_testdata( + "https://github.com/ARM-software/ML-zoo/blob/48f458af1e9065d9aad2ad94d24b58d6e7c00817/" + "models/keyword_spotting/ds_cnn_small/tflite_int16/ds_cnn_quantized.tflite?raw=true", + "ds_cnn_quantized_int16.tflite", + ) + + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + + data = np.random.uniform(size=(1, 490)).astype("int16") + + tflite_output = run_tflite_graph(tflite_model_buf, data) + tflite_predictions = np.squeeze(tflite_output) + tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] + tvm_output = run_tvm_graph(tflite_model_buf, data, "serving_default_input:0") + tvm_predictions = np.squeeze(tvm_output) + tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] + tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) + + ####################################################################### # Unidirectional Sequence LSTM # --------------------- @@ -5250,3 +5272,4 @@ def test_forward_nms_v5(): test_forward_tflite_float16() test_forward_tflite_int16() + test_forward_ds_cnn_int16() diff --git a/tests/python/relay/test_op_qnn_requantize.py b/tests/python/relay/test_op_qnn_requantize.py index 64306476dfe9..1dee1f5b619c 100644 --- a/tests/python/relay/test_op_qnn_requantize.py +++ b/tests/python/relay/test_op_qnn_requantize.py @@ -23,6 +23,7 @@ roundings = ["UPWARD", "TONEAREST"] compute_dtypes = ["float32", "float64", "int64"] +out_dtypes = ["int8", "int16"] def verify(mod, goldens, target="llvm"): @@ -83,17 +84,18 @@ def test_same_scale(): golden_output = golden_data for compute_dtype in compute_dtypes: for rounding in roundings: - mod = get_mod( - data_shape=(200,), - data_dtype="int32", - out_dtype="int8", - input_scale=0.5, - output_scale=0.5, - rounding=rounding, - compute_dtype=compute_dtype, - ) - assert "right_shift" not in mod.astext() - verify(mod, (golden_data, golden_output)) + for qnn_out_dtype in out_dtypes: + mod = get_mod( + data_shape=(200,), + data_dtype="int32", + out_dtype=qnn_out_dtype, + input_scale=0.5, + output_scale=0.5, + rounding=rounding, + compute_dtype=compute_dtype, + ) + assert "right_shift" not in mod.astext() + verify(mod, (golden_data, golden_output)) def test_scalar_same_scale(): @@ -102,75 +104,77 @@ def test_scalar_same_scale(): golden_output = golden_data for compute_dtype in compute_dtypes: for rounding in roundings: - mod = get_mod( - data_shape=(), - data_dtype="int32", - out_dtype="int8", - input_scale=0.5, - output_scale=0.5, - rounding=rounding, - compute_dtype=compute_dtype, - ) - assert "right_shift" not in mod.astext() - verify(mod, (golden_data, golden_output)) + for qnn_out_dtype in out_dtypes: + mod = get_mod( + data_shape=(), + data_dtype="int32", + out_dtype=qnn_out_dtype, + input_scale=0.5, + output_scale=0.5, + rounding=rounding, + compute_dtype=compute_dtype, + ) + assert "right_shift" not in mod.astext() + verify(mod, (golden_data, golden_output)) def test_downscale(): for compute_dtype in compute_dtypes: for rounding in roundings: - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=1, - output_scale=16, - rounding=rounding, - compute_dtype=compute_dtype, - ) + for qnn_out_dtype in out_dtypes: + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype=qnn_out_dtype, + input_scale=1, + output_scale=16, + rounding=rounding, + compute_dtype=compute_dtype, + ) - # Try positive values - # 8 corresponds to 0.5, resulting in 1 - golden_data = np.arange(0, 32, 1).astype("int32") - golden_output = np.repeat([0, 1, 2], [8, 16, 8]) - verify(mod, (golden_data, golden_output)) + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype("int32") + golden_output = np.repeat([0, 1, 2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) - # Try negative values - # -8 corresponds to -0.5. For UPWARD, this is 0 - golden_data = np.arange(0, -32, -1).astype("int32") - if rounding == "UPWARD": - golden_output = np.repeat([0, -1, -2], [9, 16, 7]) - else: - golden_output = np.repeat([0, -1, -2], [8, 16, 8]) - verify(mod, (golden_data, golden_output)) + # Try negative values + # -8 corresponds to -0.5. For UPWARD, this is 0 + golden_data = np.arange(0, -32, -1).astype("int32") + if rounding == "UPWARD": + golden_output = np.repeat([0, -1, -2], [9, 16, 7]) + else: + golden_output = np.repeat([0, -1, -2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) - # Try a different scale - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=1, - output_scale=4, - rounding=rounding, - ) + # Try a different scale + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype=qnn_out_dtype, + input_scale=1, + output_scale=4, + rounding=rounding, + ) - # Try positive values - # 2I corresponds to 0.5, resulting in 1 - golden_data = np.arange(0, 32, 1).astype("int32") - golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8], [2, 4, 4, 4, 4, 4, 4, 4, 2]) - verify(mod, (golden_data, golden_output)) + # Try positive values + # 2I corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype("int32") + golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8], [2, 4, 4, 4, 4, 4, 4, 4, 2]) + verify(mod, (golden_data, golden_output)) - # Try negative values - # -8 corresponds to -0.5. For UPWARD, this is 0 - golden_data = np.arange(0, -32, -1).astype("int32") - if rounding == "UPWARD": - golden_output = np.repeat( - [0, -1, -2, -3, -4, -5, -6, -7, -8], [3, 4, 4, 4, 4, 4, 4, 4, 1] - ) - else: - golden_output = np.repeat( - [0, -1, -2, -3, -4, -5, -6, -7, -8], [2, 4, 4, 4, 4, 4, 4, 4, 2] - ) - verify(mod, (golden_data, golden_output)) + # Try negative values + # -8 corresponds to -0.5. For UPWARD, this is 0 + golden_data = np.arange(0, -32, -1).astype("int32") + if rounding == "UPWARD": + golden_output = np.repeat( + [0, -1, -2, -3, -4, -5, -6, -7, -8], [3, 4, 4, 4, 4, 4, 4, 4, 1] + ) + else: + golden_output = np.repeat( + [0, -1, -2, -3, -4, -5, -6, -7, -8], [2, 4, 4, 4, 4, 4, 4, 4, 2] + ) + verify(mod, (golden_data, golden_output)) # Try uint8 out_dtype mod = get_mod( @@ -208,74 +212,76 @@ def test_downscale(): def test_upscale(): for compute_dtype in compute_dtypes: for rounding in roundings: - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=2, - output_scale=1, - rounding=rounding, - compute_dtype=compute_dtype, - ) + for qnn_out_dtype in out_dtypes: + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype=qnn_out_dtype, + input_scale=2, + output_scale=1, + rounding=rounding, + compute_dtype=compute_dtype, + ) - # Try positive values - # 8 corresponds to 0.5, resulting in 1 - golden_data = np.arange(0, 32, 1).astype("int32") - golden_output = np.multiply(2, golden_data) - verify(mod, (golden_data, golden_output)) + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype("int32") + golden_output = np.multiply(2, golden_data) + verify(mod, (golden_data, golden_output)) - # Try negative values - # -8 corresponds to -0.5. For UPWARD, this is 0 - golden_data = np.arange(0, -32, -1).astype("int32") - golden_output = np.multiply(2, golden_data) - verify(mod, (golden_data, golden_output)) + # Try negative values + # -8 corresponds to -0.5. For UPWARD, this is 0 + golden_data = np.arange(0, -32, -1).astype("int32") + golden_output = np.multiply(2, golden_data) + verify(mod, (golden_data, golden_output)) def test_non_power_of_two(): for compute_dtype in compute_dtypes: for rounding in roundings: - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=1, - output_scale=3, - rounding=rounding, - compute_dtype=compute_dtype, - ) + for qnn_out_dtype in out_dtypes: + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype=qnn_out_dtype, + input_scale=1, + output_scale=3, + rounding=rounding, + compute_dtype=compute_dtype, + ) - # Try positive values - golden_data = np.multiply(np.arange(0, 32, 1).astype("int32"), 3) - golden_output = np.arange(0, 32, 1) - verify(mod, (golden_data, golden_output)) + # Try positive values + golden_data = np.multiply(np.arange(0, 32, 1).astype("int32"), 3) + golden_output = np.arange(0, 32, 1) + verify(mod, (golden_data, golden_output)) - # Try negative values - golden_data = np.multiply(np.arange(0, -32, -1).astype("int32"), 3) - golden_output = np.arange(0, -32, -1) - verify(mod, (golden_data, golden_output)) + # Try negative values + golden_data = np.multiply(np.arange(0, -32, -1).astype("int32"), 3) + golden_output = np.arange(0, -32, -1) + verify(mod, (golden_data, golden_output)) - # Try a different scale - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=3, - output_scale=1, - rounding=rounding, - ) + # Try a different scale + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype=qnn_out_dtype, + input_scale=3, + output_scale=1, + rounding=rounding, + ) - # Try positive values - golden_data = np.arange(0, 32, 1).astype("int32") - golden_output = np.multiply(golden_data, 3) - verify(mod, (golden_data, golden_output)) + # Try positive values + golden_data = np.arange(0, 32, 1).astype("int32") + golden_output = np.multiply(golden_data, 3) + verify(mod, (golden_data, golden_output)) - # Try negative values - golden_data = np.arange(0, -32, -1).astype("int32") - golden_output = np.multiply(golden_data, 3) - verify(mod, (golden_data, golden_output)) + # Try negative values + golden_data = np.arange(0, -32, -1).astype("int32") + golden_output = np.multiply(golden_data, 3) + verify(mod, (golden_data, golden_output)) -def test_saturation(): +def test_saturation_int8(): for compute_dtype in compute_dtypes: for rounding in roundings: mod = get_mod( @@ -322,6 +328,70 @@ def test_saturation(): verify(mod, (golden_data, golden_output)) +def test_saturation_int16(): + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(16,), + data_dtype="int32", + out_dtype="int16", + input_scale=0.5, + output_scale=0.5, + rounding=rounding, + compute_dtype=compute_dtype, + ) + golden_data = np.arange(0, 16, 1).astype("int32") + golden_data = np.add(32760, golden_data) + output = np.array( + [ + 32760, + 32761, + 32762, + 32763, + 32764, + 32765, + 32766, + 32767, + 32767, + 32767, + 32767, + 32767, + 32767, + 32767, + 32767, + 32767, + ] + ) + golden_output = output + verify(mod, (golden_data, golden_output)) + + # Try negative numbers + golden_data = np.arange(0, -16, -1).astype("int32") + golden_data = np.add(-32760, golden_data) + output = np.array( + [ + -32760, + -32761, + -32762, + -32763, + -32764, + -32765, + -32766, + -32767, + -32768, + -32768, + -32768, + -32768, + -32768, + -32768, + -32768, + -32768, + ] + ) + golden_output = output + verify(mod, (golden_data, golden_output)) + + def test_zero_point(): # Output zero point for compute_dtype in compute_dtypes: @@ -357,31 +427,32 @@ def test_zero_point(): # Input zero point for compute_dtype in compute_dtypes: for rounding in roundings: - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=1, - output_scale=16, - input_zero_point=16, - rounding=rounding, - compute_dtype=compute_dtype, - ) + for qnn_out_dtype in out_dtypes: + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype=qnn_out_dtype, + input_scale=1, + output_scale=16, + input_zero_point=16, + rounding=rounding, + compute_dtype=compute_dtype, + ) - # Try positive values - golden_data = np.arange(32, 64, 1).astype("int32") - golden_output = np.repeat([2, 3, 4], [8, 16, 8]) - golden_output = np.subtract(golden_output, 1) - verify(mod, (golden_data, golden_output)) + # Try positive values + golden_data = np.arange(32, 64, 1).astype("int32") + golden_output = np.repeat([2, 3, 4], [8, 16, 8]) + golden_output = np.subtract(golden_output, 1) + verify(mod, (golden_data, golden_output)) - # Try negative values - golden_data = np.arange(-32, -64, -1).astype("int32") - if rounding == "UPWARD": - golden_output = np.repeat([-2, -3, -4], [9, 16, 7]) - else: - golden_output = np.repeat([-2, -3, -4], [8, 16, 8]) - golden_output = np.subtract(golden_output, 1) - verify(mod, (golden_data, golden_output)) + # Try negative values + golden_data = np.arange(-32, -64, -1).astype("int32") + if rounding == "UPWARD": + golden_output = np.repeat([-2, -3, -4], [9, 16, 7]) + else: + golden_output = np.repeat([-2, -3, -4], [8, 16, 8]) + golden_output = np.subtract(golden_output, 1) + verify(mod, (golden_data, golden_output)) def test_per_channel_same_scale(): @@ -390,17 +461,18 @@ def test_per_channel_same_scale(): golden_output = golden_data for compute_dtype in compute_dtypes: for rounding in roundings: - mod = get_mod( - data_shape=(5, 2), - data_dtype="int32", - out_dtype="int8", - input_scale=[0.5, 0.5], - output_scale=0.5, - axis=1, - rounding=rounding, - compute_dtype=compute_dtype, - ) - verify(mod, (golden_data, golden_output)) + for qnn_out_dtype in out_dtypes: + mod = get_mod( + data_shape=(5, 2), + data_dtype="int32", + out_dtype=qnn_out_dtype, + input_scale=[0.5, 0.5], + output_scale=0.5, + axis=1, + rounding=rounding, + compute_dtype=compute_dtype, + ) + verify(mod, (golden_data, golden_output)) # Change axis golden_data = np.arange(-10, 10, 1).astype("int32").reshape((2, 2, 5)) @@ -480,88 +552,93 @@ def test_per_channel_different_scale(): def test_default_cfg_and_no_args(): - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=1, - output_scale=16, - ) - golden_data = np.arange(0, -32, -1).astype("int32") - golden_output = np.repeat([0, -1, -2], [9, 16, 7]) - verify(mod, (golden_data, golden_output)) + for qnn_out_dtype in out_dtypes: + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype=qnn_out_dtype, + input_scale=1, + output_scale=16, + ) + golden_data = np.arange(0, -32, -1).astype("int32") + golden_output = np.repeat([0, -1, -2], [9, 16, 7]) + verify(mod, (golden_data, golden_output)) def test_non_default_cfg_and_no_args(): for rounding_cfg in roundings: - with relay.qnn.op.requantize_config(rounding=rounding_cfg): - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=1, - output_scale=16, - ) + for qnn_out_dtype in out_dtypes: + with relay.qnn.op.requantize_config(rounding=rounding_cfg): + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype=qnn_out_dtype, + input_scale=1, + output_scale=16, + ) - golden_data = np.arange(0, -32, -1).astype("int32") + golden_data = np.arange(0, -32, -1).astype("int32") - if rounding_cfg == "UPWARD": - golden_output = np.repeat([0, -1, -2], [9, 16, 7]) - else: - golden_output = np.repeat([0, -1, -2], [8, 16, 8]) - verify(mod, (golden_data, golden_output)) + if rounding_cfg == "UPWARD": + golden_output = np.repeat([0, -1, -2], [9, 16, 7]) + else: + golden_output = np.repeat([0, -1, -2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) def test_default_cfg_and_args(): for rounding in roundings: - with relay.qnn.op.requantize_config(rounding="UPWARD"): - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=1, - output_scale=16, - rounding=rounding, - ) - - golden_data = np.arange(0, -32, -1).astype("int32") - - if rounding == "UPWARD": - golden_output = np.repeat([0, -1, -2], [9, 16, 7]) - else: - golden_output = np.repeat([0, -1, -2], [8, 16, 8]) - verify(mod, (golden_data, golden_output)) - - -def test_non_default_cfg_and_args(): - for rounding_arg in roundings: - for rounding_cfg in roundings: - with relay.qnn.op.requantize_config(rounding=rounding_cfg): + for qnn_out_dtype in out_dtypes: + with relay.qnn.op.requantize_config(rounding="UPWARD"): mod = get_mod( data_shape=(32,), data_dtype="int32", - out_dtype="int8", + out_dtype=qnn_out_dtype, input_scale=1, output_scale=16, - rounding=rounding_arg, + rounding=rounding, ) golden_data = np.arange(0, -32, -1).astype("int32") - if rounding_arg == "UPWARD": + if rounding == "UPWARD": golden_output = np.repeat([0, -1, -2], [9, 16, 7]) else: golden_output = np.repeat([0, -1, -2], [8, 16, 8]) verify(mod, (golden_data, golden_output)) +def test_non_default_cfg_and_args(): + for rounding_arg in roundings: + for rounding_cfg in roundings: + for qnn_out_dtype in out_dtypes: + with relay.qnn.op.requantize_config(rounding=rounding_cfg): + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype=qnn_out_dtype, + input_scale=1, + output_scale=16, + rounding=rounding_arg, + ) + + golden_data = np.arange(0, -32, -1).astype("int32") + + if rounding_arg == "UPWARD": + golden_output = np.repeat([0, -1, -2], [9, 16, 7]) + else: + golden_output = np.repeat([0, -1, -2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) + + if __name__ == "__main__": test_same_scale() test_scalar_same_scale() test_downscale() test_upscale() test_non_power_of_two() - test_saturation() + test_saturation_int8() + test_saturation_int16() test_zero_point() test_per_channel_same_scale() test_per_channel_different_scale()