From 4b1c203630f19995dd6b505fa66f516624864806 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 12 Oct 2020 08:03:39 +0000 Subject: [PATCH 1/5] [QNN] Optimize requantize for power of 2 and bug in dequantize --- src/relay/qnn/op/dequantize.cc | 4 +- src/relay/qnn/op/requantize.cc | 19 +++++---- src/relay/qnn/utils.cc | 15 +++++++ src/relay/qnn/utils.h | 8 ++++ tests/python/relay/test_op_qnn_dequantize.py | 18 ++++++++ tests/python/relay/test_op_qnn_requantize.py | 43 ++++++++++++++++++++ 6 files changed, 97 insertions(+), 10 deletions(-) diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index 2e7a28624e26..2fe075c7e64b 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -96,8 +96,8 @@ Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale, expanded_input_zero_point = ExpandBiasToMatchAxis(input_zero_point, n_dim, {axis}); } - auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), input_zero_point); - auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), input_scale); + auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), expanded_input_zero_point); + auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), expanded_input_scale); return scaled_output; } diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 8e9b31e6fc39..3c231561ab21 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -155,17 +155,20 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, if (!IsEqualScalar(input_scale, output_scale)) { int32_t fixed_point_multiplier, shift; std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier); - const bool is_upward_rounding = (param->rounding == "UPWARD"); - // When using upward rounding (i.e., x.5 rounded to x+1), leverage - // the FixedPointMultiply operator - scaled_int32_t = - (is_upward_rounding - ? FixedPointMultiply(scaled_int32_t, fixed_point_multiplier, shift) - : FixedPointMultiplyToNearest(scaled_int32_t, double_multiplier, input_shape)); + if (is_upward_rounding && fixed_point_multiplier == (1 << 30)) { + // Power of 2 + scaled_int32_t = PowerOfTwoMultiply(scaled_int32_t, shift - 1); + } else { + // When using upward rounding (i.e., x.5 rounded to x+1), leverage + // the FixedPointMultiply operator + scaled_int32_t = + (is_upward_rounding + ? FixedPointMultiply(scaled_int32_t, fixed_point_multiplier, shift) + : FixedPointMultiplyToNearest(scaled_int32_t, double_multiplier, input_shape)); + } } - } else { // This is per-channel (per=axis) quantization. std::vector double_multipliers; diff --git a/src/relay/qnn/utils.cc b/src/relay/qnn/utils.cc index 982efa0a61c1..55d9a290b438 100644 --- a/src/relay/qnn/utils.cc +++ b/src/relay/qnn/utils.cc @@ -56,6 +56,21 @@ std::pair GetFixedPointMultiplierShift(double double_multiplie return std::make_pair(significand, exponent); } +Expr PowerOfTwoMultiply(Expr tensor, int32_t exp) { + Expr out; + if (exp > 0) { + // power of 2 is greater than 0, apply left shift. + out = LeftShift(tensor, MakeConstantScalar(DataType::Int(32), exp)); + } else { + // power of 2 is less than 0, round and then apply right shift. + exp = -exp; + auto rounding_factor = 1 << (exp - 1); + auto rounded_t = Add(tensor, MakeConstantScalar(DataType::Int(32), rounding_factor)); + out = RightShift(rounded_t, MakeConstantScalar(DataType::Int(32), exp)); + } + return out; +} + Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier, const Array& input_shape) { // Choose high precision datatype to be int64. This is for avoiding overflow diff --git a/src/relay/qnn/utils.h b/src/relay/qnn/utils.h index ab5c9a4fbbe2..5c5027298e47 100644 --- a/src/relay/qnn/utils.h +++ b/src/relay/qnn/utils.h @@ -136,6 +136,14 @@ static inline int64_t get_const_int(const tvm::PrimExpr& x) { */ Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier, const Array& input_shape); +/* + * \brief Mutiply an integer datatype tensor by a power of two. + * \param tensor The quantized input tensor of dtype int32. + * \param exp The exp or the power of 2 representing the number to be multiplied. + * \return The sequence of Relay ops for power of two multiplication. + + */ +Expr PowerOfTwoMultiply(Expr tensor, int32_t exp); /* * \brief Fixed point multiplication between integer tensor with floating point diff --git a/tests/python/relay/test_op_qnn_dequantize.py b/tests/python/relay/test_op_qnn_dequantize.py index 6598e2bb2062..e1416622c236 100644 --- a/tests/python/relay/test_op_qnn_dequantize.py +++ b/tests/python/relay/test_op_qnn_dequantize.py @@ -101,8 +101,26 @@ def test_channelwise_axis_1(): ) +def test_channelwise_axis_0(): + data = np.array([0, 1, 2, 3, 4, 243, 247, 249, 250, 251]).astype("uint8").reshape((2, 5)) + output = ( + np.array([-63.5, -63, -62.5, -62, -61.5, 30, 31, 31.5, 31.75, 32]) + .astype("float32") + .reshape((2, 5)) + ) + quant_args = { + "in_zero_point": np.array([127, 123]).astype("int32"), + "in_scale": np.array([0.5, 0.25]).astype("float32"), + } + + dequantize_test_driver( + in_dtype="uint8", quant_args=quant_args, in_data=data, verify_output_data=output, axis=0 + ) + + if __name__ == "__main__": test_uint8_to_float32() test_int8_to_float32() test_int32_to_float32() test_channelwise_axis_1() + test_channelwise_axis_0() diff --git a/tests/python/relay/test_op_qnn_requantize.py b/tests/python/relay/test_op_qnn_requantize.py index f152a4ebf840..f40a08711451 100644 --- a/tests/python/relay/test_op_qnn_requantize.py +++ b/tests/python/relay/test_op_qnn_requantize.py @@ -204,6 +204,48 @@ def test_upscale(): verify(mod, (golden_data, golden_output)) +def test_non_power_of_two(): + for rounding in roundings: + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="int8", + input_scale=1, + output_scale=3, + rounding=rounding, + ) + + # Try positive values + golden_data = np.multiply(np.arange(0, 32, 1).astype("int32"), 3) + golden_output = np.arange(0, 32, 1) + verify(mod, (golden_data, golden_output)) + + # Try negative values + golden_data = np.multiply(np.arange(0, -32, -1).astype("int32"), 3) + golden_output = np.arange(0, -32, -1) + verify(mod, (golden_data, golden_output)) + + # Try a different scale + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="int8", + input_scale=3, + output_scale=1, + rounding=rounding, + ) + + # Try positive values + golden_data = np.arange(0, 32, 1).astype("int32") + golden_output = np.multiply(golden_data, 3) + verify(mod, (golden_data, golden_output)) + + # Try negative values + golden_data = np.arange(0, -32, -1).astype("int32") + golden_output = np.multiply(golden_data, 3) + verify(mod, (golden_data, golden_output)) + + def test_saturation(): for rounding in roundings: mod = get_mod( @@ -397,6 +439,7 @@ def test_per_channel_different_scale(): test_same_scale() test_downscale() test_upscale() + test_non_power_of_two() test_saturation() test_zero_point() test_per_channel_same_scale() From fbf52869f054708538b6def0779a62c09d52b6f2 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 20 Oct 2020 23:58:32 +0000 Subject: [PATCH 2/5] Comments --- src/relay/qnn/op/requantize.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 3c231561ab21..9fc671832838 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -158,7 +158,9 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, const bool is_upward_rounding = (param->rounding == "UPWARD"); if (is_upward_rounding && fixed_point_multiplier == (1 << 30)) { - // Power of 2 + // Power of 2 is determined by the fixed_point_multiplier == 1 << 30. In case of power of 2, + // fixed point multiplier will represent a float value of 0.5. In fixed point, this is + // represented by 1 << 30. scaled_int32_t = PowerOfTwoMultiply(scaled_int32_t, shift - 1); } else { // When using upward rounding (i.e., x.5 rounded to x+1), leverage From 4c86a86e2451603e5d4521637cf88f1b162d7894 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 21 Oct 2020 17:16:47 +0000 Subject: [PATCH 3/5] Docs --- src/relay/qnn/utils.h | 1 - 1 file changed, 1 deletion(-) diff --git a/src/relay/qnn/utils.h b/src/relay/qnn/utils.h index 5c5027298e47..917864079021 100644 --- a/src/relay/qnn/utils.h +++ b/src/relay/qnn/utils.h @@ -141,7 +141,6 @@ Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier, * \param tensor The quantized input tensor of dtype int32. * \param exp The exp or the power of 2 representing the number to be multiplied. * \return The sequence of Relay ops for power of two multiplication. - */ Expr PowerOfTwoMultiply(Expr tensor, int32_t exp); From 2a9fce4439e49c3cb211864696a286c3633f5d9e Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 28 Oct 2020 00:44:03 +0000 Subject: [PATCH 4/5] Comments --- src/relay/qnn/op/requantize.cc | 21 +++----- src/relay/qnn/utils.cc | 15 ------ src/relay/qnn/utils.h | 7 --- src/target/intrin_rule.cc | 90 ++++++++++++++++++++++------------ 4 files changed, 67 insertions(+), 66 deletions(-) diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 9fc671832838..8e9b31e6fc39 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -155,22 +155,17 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, if (!IsEqualScalar(input_scale, output_scale)) { int32_t fixed_point_multiplier, shift; std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier); + const bool is_upward_rounding = (param->rounding == "UPWARD"); - if (is_upward_rounding && fixed_point_multiplier == (1 << 30)) { - // Power of 2 is determined by the fixed_point_multiplier == 1 << 30. In case of power of 2, - // fixed point multiplier will represent a float value of 0.5. In fixed point, this is - // represented by 1 << 30. - scaled_int32_t = PowerOfTwoMultiply(scaled_int32_t, shift - 1); - } else { - // When using upward rounding (i.e., x.5 rounded to x+1), leverage - // the FixedPointMultiply operator - scaled_int32_t = - (is_upward_rounding - ? FixedPointMultiply(scaled_int32_t, fixed_point_multiplier, shift) - : FixedPointMultiplyToNearest(scaled_int32_t, double_multiplier, input_shape)); - } + // When using upward rounding (i.e., x.5 rounded to x+1), leverage + // the FixedPointMultiply operator + scaled_int32_t = + (is_upward_rounding + ? FixedPointMultiply(scaled_int32_t, fixed_point_multiplier, shift) + : FixedPointMultiplyToNearest(scaled_int32_t, double_multiplier, input_shape)); } + } else { // This is per-channel (per=axis) quantization. std::vector double_multipliers; diff --git a/src/relay/qnn/utils.cc b/src/relay/qnn/utils.cc index 55d9a290b438..982efa0a61c1 100644 --- a/src/relay/qnn/utils.cc +++ b/src/relay/qnn/utils.cc @@ -56,21 +56,6 @@ std::pair GetFixedPointMultiplierShift(double double_multiplie return std::make_pair(significand, exponent); } -Expr PowerOfTwoMultiply(Expr tensor, int32_t exp) { - Expr out; - if (exp > 0) { - // power of 2 is greater than 0, apply left shift. - out = LeftShift(tensor, MakeConstantScalar(DataType::Int(32), exp)); - } else { - // power of 2 is less than 0, round and then apply right shift. - exp = -exp; - auto rounding_factor = 1 << (exp - 1); - auto rounded_t = Add(tensor, MakeConstantScalar(DataType::Int(32), rounding_factor)); - out = RightShift(rounded_t, MakeConstantScalar(DataType::Int(32), exp)); - } - return out; -} - Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier, const Array& input_shape) { // Choose high precision datatype to be int64. This is for avoiding overflow diff --git a/src/relay/qnn/utils.h b/src/relay/qnn/utils.h index 917864079021..ab5c9a4fbbe2 100644 --- a/src/relay/qnn/utils.h +++ b/src/relay/qnn/utils.h @@ -136,13 +136,6 @@ static inline int64_t get_const_int(const tvm::PrimExpr& x) { */ Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier, const Array& input_shape); -/* - * \brief Mutiply an integer datatype tensor by a power of two. - * \param tensor The quantized input tensor of dtype int32. - * \param exp The exp or the power of 2 representing the number to be multiplied. - * \return The sequence of Relay ops for power of two multiplication. - */ -Expr PowerOfTwoMultiply(Expr tensor, int32_t exp); /* * \brief Fixed point multiplication between integer tensor with floating point diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 0808d237fc28..eeffc10fe604 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -128,37 +128,65 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.q_multiply_shift") PrimExpr q = call->args[2]; PrimExpr s = call->args[3]; - // Only int32 types are supported (any number of lanes is allowed) - ICHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32); - ICHECK(s.dtype().code() == DLDataTypeCode::kDLInt && s.dtype().bits() == 32); - - DataType hp_dtype = DataType::Int(64, x.dtype().lanes()); - DataType lp_dtype = DataType::Int(32, x.dtype().lanes()); - - // 1) Calculating the integer multiplier and integer shift - PrimExpr zero = make_const(s.dtype(), 0); - PrimExpr left_shift = tir::Select(s > zero, s, zero); - PrimExpr right_shift = tir::Select(s > zero, zero, -s); - - // 2) Cast and Multiply the integer multiplier - PrimExpr one = make_const(hp_dtype, 1); - x = cast(hp_dtype, x); - y = cast(hp_dtype, y); - x = tir::Select(left_shift != zero, x << left_shift, x); - - // 3) Perform the multiplication in higher precision. - x = x * y; - - // 4) Find the rounding scalar - PrimExpr total_right_shift = right_shift + q; - PrimExpr pos_rounding_value = (one << (total_right_shift - 1)); - x = x + pos_rounding_value; - - // 5) Simply right shift the result to get the final output. - x = x >> total_right_shift; - - // 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32. - *rv = cast(lp_dtype, x); + // Lambda function to extract the int value from PrimExpr + auto get_int_value = [](const PrimExpr node) { + auto broadcast_node = node.as(); + CHECK(broadcast_node != nullptr); + auto int_node = broadcast_node->value.as(); + CHECK(int_node != nullptr); + return int_node->value; + }; + // Power of 2 is determined by the fixed_point_multiplier == 1 << 30. In case of power of 2, + // fixed point multiplier will represent a float value of 0.5. In fixed point, this is + // represented by 1 << 30. + if (get_int_value(y) == (1 << 30)) { + PrimExpr exp = s - 1; + int exp_val = get_int_value(s) - 1; + if (exp_val > 0) { + // power of 2 is greater than 0, apply left shift. + *rv = x << exp; + } else { + // power of 2 is less than 0, round and then apply right shift. + DataType lp_dtype = DataType::Int(32, x.dtype().lanes()); + PrimExpr one = make_const(lp_dtype, 1); + exp = -exp; + PrimExpr rounding_factor = one << (exp - 1); + PrimExpr rounded_t = x + rounding_factor; + *rv = rounded_t >> exp; + } + } else { + // Only int32 types are supported (any number of lanes is allowed) + ICHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32); + ICHECK(s.dtype().code() == DLDataTypeCode::kDLInt && s.dtype().bits() == 32); + + DataType hp_dtype = DataType::Int(64, x.dtype().lanes()); + DataType lp_dtype = DataType::Int(32, x.dtype().lanes()); + + // 1) Calculating the integer multiplier and integer shift + PrimExpr zero = make_const(s.dtype(), 0); + PrimExpr left_shift = tir::Select(s > zero, s, zero); + PrimExpr right_shift = tir::Select(s > zero, zero, -s); + + // 2) Cast and Multiply the integer multiplier + PrimExpr one = make_const(hp_dtype, 1); + x = cast(hp_dtype, x); + y = cast(hp_dtype, y); + x = tir::Select(left_shift != zero, x << left_shift, x); + + // 3) Perform the multiplication in higher precision. + x = x * y; + + // 4) Find the rounding scalar + PrimExpr total_right_shift = right_shift + q; + PrimExpr pos_rounding_value = (one << (total_right_shift - 1)); + x = x + pos_rounding_value; + + // 5) Simply right shift the result to get the final output. + x = x >> total_right_shift; + + // 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32. + *rv = cast(lp_dtype, x); + } }); } // namespace intrin From 2402661408958b78d103d199e62f81954655c7fd Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 28 Oct 2020 16:29:50 +0000 Subject: [PATCH 5/5] Ethos --- src/target/intrin_rule.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index eeffc10fe604..f8f4d0ef5414 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -130,6 +130,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.q_multiply_shift") // Lambda function to extract the int value from PrimExpr auto get_int_value = [](const PrimExpr node) { + if (auto int_node = node.as()) { + return int_node->value; + } auto broadcast_node = node.as(); CHECK(broadcast_node != nullptr); auto int_node = broadcast_node->value.as();