diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index bf9621bb404a..e3bc3498f82a 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -336,6 +336,14 @@ inline Expr ZerosLike(Expr e) { return CallNode::make(op, {e}); } +inline Expr Zeros(Array shape, DataType dtype) { + auto attrs = make_node(); + attrs->shape = std::move(shape); + attrs->dtype = std::move(dtype); + static const Op& op = Op::Get("zeros"); + return CallNode::make(op, {}, Attrs(attrs), {}); +} + inline Expr OnesLike(Expr e) { static const Op& op = Op::Get("ones_like"); return CallNode::make(op, {e}); diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index cf5f316a1c97..85d8dc3609f8 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -37,50 +37,7 @@ TVM_REGISTER_NODE_TYPE(RequantizeAttrs); // Lowering of qnn.requantize op -/* - * \brief Convert FP32 representation into fixed point representation. - * \param double_multplier The input FP32 number. - * \return The pair of multiplier and shift for fixed point representation. - * \note Converts a floating point number so that it can be represented by - * integers. The representation is - * float_number = (significand) * 2^(exponent) - * - * The significand is a number between 0.5 and 1. This is represented by - * an integer number. For example, if it is int32, then the decimal point - * exists between bit 31 and 30 from LSB (or between first and second bit - * from the left). - * - * Some examples are - * 0.25 = (0.5) * 2^(-1) - * 0.125 = (0.5) * 2^(-2) - * - * Credit to TFLite reference implementation. - */ -std::pair GetFixedPointMultiplierShift(double double_multiplier) { - int32_t significand, exponent; - if (double_multiplier == 0.) { - significand = 0; - exponent = 0; - return std::make_pair(significand, exponent); - } - // Get the significand and exponent. - double significand_d = std::frexp(double_multiplier, &exponent); - - // Convert the double significand to int significand, i.e., convert into a - // integer where the decimal point is between bit 31 and 30. This is done by - // multiplying the double value with 2^31 and then casting to int. - significand_d = std::round(significand_d * (1ll << 31)); - auto significand_int64 = static_cast(significand_d); - CHECK_LE(significand_int64, (1ll << 31)); - if (significand_int64 == (1ll << 31)) { - significand_int64 /= 2; - ++exponent; - } - CHECK_LE(significand_int64, std::numeric_limits::max()); - significand = static_cast(significand_int64); - return std::make_pair(significand, exponent); -} /* * \brief Lower requantize to a sequence of ops. @@ -93,93 +50,41 @@ std::pair GetFixedPointMultiplierShift(double double_multiplie * and shift. This is useful, if the target device does not support/have * very expensive floating point computations. * - * Original compuation is scale_fp32 * quantized_tensor. To convert into - * integer computation, the multiplication with fp32 scalar can be - * replaced by multiplication with an int value and then right shifting - * the result. This approximates the floating point computation with a - * fixed point computation. - * * The whole computation this can be broken down into following steps * 1) Calculate the integer multiplier and integer shift. * 2) Subtract the input integer zero point. - * 3) Multiply the fixed point multiplier with quantized tensor. - * 4) Round the result. - * 5) Right shift the result. - * 6) Add the output zero point. - * 7) Cast to the out_dtype. + * 3) Perform fixed point multiplication. + * 4) Add the output zero point. + * 5) Cast to the out_dtype. */ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, const Array& input_shape, const DataType& out_dtype) { double double_multiplier = param->input_scale / param->output_scale; - // Choose high precision datatype to be int64. This is for avoiding overflow - // in multiplication of two int32 values. DataType hp_dtype = Int(64); - // 1) Calculating the integer multiplier and integer shift - int32_t fixed_point_multiplier, shift; - std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier); - int left_shift = shift > 0 ? shift : 0; - int right_shift = shift > 0 ? 0 : -shift; - - // 2) Subtract the input_zero_point auto tensor = Cast(input_tensor, hp_dtype); + // 1) Subtract the input_zero_point if (param->input_zero_point != 0) { auto input_zp = MakeConstantScalar(hp_dtype, param->input_zero_point); tensor = Subtract(tensor, input_zp); } - // If the input and output scales are same, we can skip the fixed point multiplication. + // 2) If the input and output scales are same, we can skip the fixed point multiplication. auto scaled_int64_t = tensor; if (param->input_scale != param->output_scale) { - // 3) Multiply the integer multiplier - if (left_shift != 0) { - tensor = Multiply(tensor, MakeConstantScalar(hp_dtype, 1 << left_shift)); - } - // Perform the multiplication in higher precision. - // The scalar is a fixed point value of int32 where the decimal point is - // between bits 31 and 30. After multiplying with input_tensor, the result is - // in int64 where the decimal point is sitting between bits 31 and 30 (from - // the right, rightmost bit is bit 0). The computation is performed in higher - // precision to avoid overflow in multiplying two int32 values. - Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier); - auto multiplied_t = Multiply(tensor, scalar); - - // 4) Find the rounding scalar. This depends on where the final decimal point - // sits. As we will be right shifting the multiplied_t, we need to first - // calculate the total_right_shift. - int total_right_shift = right_shift + 31; - int64_t pos_rounding_value = (1ll << (total_right_shift - 1)); - - tensor = multiplied_t; - Expr round_scalar; - if (param->rounding == "UPWARD") { - round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value); - } else if (param->rounding == "TONEAREST") { - auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value); - auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1); - auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype); - auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype); - - auto zero = MakeConstantScalar(hp_dtype, 0); - auto zero_t = Full(zero, input_shape, hp_dtype); - round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t); - } - // Add the rounding scalar. - tensor = Add(tensor, round_scalar); - - // 5) Simply right shift the result to get the final output. - scaled_int64_t = RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift)); + scaled_int64_t = FixedPointMuliply(scaled_int64_t, double_multiplier, input_shape, + param->rounding); } - // 6) Add the output zero point. + // 3) Add the output zero point. auto shifted_int64_t = scaled_int64_t; if (param->output_zero_point != 0) { auto output_zp = MakeConstantScalar(hp_dtype, param->output_zero_point); shifted_int64_t = Add(output_zp, scaled_int64_t); } - // 7) Clip to the out_dtype min/max. + // 4) Clip to the out_dtype min/max. auto q_min = GetQmin(out_dtype); auto q_max = GetQmax(out_dtype); auto clipped_t = Clip(shifted_int64_t, q_min, q_max); diff --git a/src/relay/qnn/util.cc b/src/relay/qnn/util.cc new file mode 100644 index 000000000000..d9e4506043c7 --- /dev/null +++ b/src/relay/qnn/util.cc @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file src/relay/qnn/util.cc + * \brief Utility functions for QNN. + */ + +#include "util.h" +#include "../pass/pattern_util.h" + +namespace tvm { +namespace relay { +namespace qnn { + +/* + * \brief Convert FP32 representation into fixed point representation. + * \param double_multplier The input FP32 number. + * \return The pair of multiplier and shift for fixed point representation. + * \note Converts a floating point number so that it can be represented by + * integers. The representation is + * float_number = (significand) * 2^(exponent) + * + * The significand is a number between 0.5 and 1. This is represented by + * an integer number. For example, if it is int32, then the decimal point + * exists between bit 31 and 30 from LSB (or between first and second bit + * from the left). + * + * Some examples are + * 0.25 = (0.5) * 2^(-1) + * 0.125 = (0.5) * 2^(-2) + * + * Credit to TFLite reference implementation. + */ +std::pair GetFixedPointMultiplierShift( + double double_multiplier) { + int32_t significand, exponent; + if (double_multiplier == 0.) { + significand = 0; + exponent = 0; + return std::make_pair(significand, exponent); + } + + // Get the significand and exponent. + double significand_d = std::frexp(double_multiplier, &exponent); + + // Convert the double significand to int significand, i.e., convert into a + // integer where the decimal point is between bit 31 and 30. This is done by + // multiplying the double value with 2^31 and then casting to int. + significand_d = std::round(significand_d * (1ll << 31)); + auto significand_int64 = static_cast(significand_d); + CHECK_LE(significand_int64, (1ll << 31)); + if (significand_int64 == (1ll << 31)) { + significand_int64 /= 2; + ++exponent; + } + CHECK_LE(significand_int64, std::numeric_limits::max()); + significand = static_cast(significand_int64); + return std::make_pair(significand, exponent); +} + +Expr FixedPointMuliply(Expr tensor, double multiplier, + const Array& input_shape, const std::string& rounding) { + // Choose high precision datatype to be int64. This is for avoiding overflow + // in multiplication of two int32 values. + DataType hp_dtype = Int(64); + + // 1) Calculating the integer multiplier and integer shift + int32_t fixed_point_multiplier, shift; + std::tie(fixed_point_multiplier, shift) = + GetFixedPointMultiplierShift(multiplier); + int left_shift = shift > 0 ? shift : 0; + int right_shift = shift > 0 ? 0 : -shift; + + // 2) Multiply the integer multiplier + if (left_shift != 0) { + tensor = LeftShift(tensor, MakeConstantScalar(hp_dtype, left_shift)); + } + + // 3) Perform the multiplication in higher precision. + // The scalar is a fixed point value of int32 where the decimal point is + // between bits 31 and 30. After multiplying with input_tensor, the result + // is in int64 where the decimal point is sitting between bits 31 and 30 + // (from the right, rightmost bit is bit 0). The computation is performed in + // higher precision to avoid overflow in multiplying two int32 values. + Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier); + tensor = Multiply(tensor, scalar); + + // 4) Find the rounding scalar. This depends on where the final decimal + // point sits. As we will be right shifting the multiplied_t, we need to + // first calculate the total_right_shift. + int total_right_shift = right_shift + 31; + int64_t pos_rounding_value = (1ll << (total_right_shift - 1)); + + Expr round_scalar; + if (rounding == "UPWARD") { + round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value); + } else if (rounding == "TONEAREST") { + auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value); + auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1); + auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype); + auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype); + + auto zero_t = Zeros(input_shape, hp_dtype); + round_scalar = + Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t); + } + // Add the rounding scalar. + tensor = Add(tensor, round_scalar); + + // 5) Simply right shift the result to get the final output. + tensor = + RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift)); + + return tensor; +} + +} // namespace qnn +} // namespace relay +} // namespace tvm diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index 94823317af46..c26183705b89 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -27,6 +27,7 @@ #include #include +#include #include #include #include @@ -92,6 +93,32 @@ static inline int64_t get_const_int(const tvm::Expr& x) { return value_ptr[0]; } +/* + * \brief Fixed point multiplication between integer tensor with floating point + scalar. + * \param tensor The quantized input tensor of dtype int64. + * \param multiplier The scalar multiplier. + * \param input_shape Shape of the input tensor. + * \param rounding "UPWARD" or "TONEAREST". The rounding direction when the value + is midway between" "two representable values. + * \return The sequence of Relay ops for fixed point multiplication. + + * \note Original compuation is scale_fp32 * quantized_tensor. To convert into + * integer computation, the multiplication with fp32 scalar can be + * replaced by multiplication with an int value and then right shifting + * the result. This approximates the floating point computation with a + * fixed point computation. + * + * Computation of fixed point multiplication is consist of following + steps: + * 1) Multiply the fixed point multiplier with quantized tensor. + * 2) Round the result. + * 3) Right shift the result + */ +Expr FixedPointMuliply(Expr tensor, double multiplier, + const Array& input_shape, + const std::string& rounding); + } // namespace qnn } // namespace relay } // namespace tvm