diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 757fdac32b81..6950ecceee05 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -202,6 +202,16 @@ This level support backpropagation of broadcast operators. It is temporary. tvm.relay.contrib.adaptive_avg_pool2d +**Level 11: Dialect Operators** + +This level supports dialect operators. + +.. autosummary:: + :nosignatures: + + tvm.relay.qnn.op.requantize + + Level 1 Definitions ------------------- .. autofunction:: tvm.relay.log @@ -340,3 +350,8 @@ Level 10 Definitions .. autofunction:: tvm.relay.nn.batch_matmul .. autofunction:: tvm.relay.contrib.adaptive_max_pool2d .. autofunction:: tvm.relay.contrib.adaptive_avg_pool2d + + +Level 11 Definitions +-------------------- +.. autofunction:: tvm.relay.qnn.op.requantize diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h new file mode 100644 index 000000000000..e99602813229 --- /dev/null +++ b/include/tvm/relay/qnn/attrs.h @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/qnn/attrs.h + * \brief Auxiliary attributes for qnn operators. + */ +#ifndef TVM_RELAY_QNN_ATTRS_H_ +#define TVM_RELAY_QNN_ATTRS_H_ + +#include +#include + +namespace tvm { +namespace relay { +namespace qnn { + +/*! \brief Attribute for requantize operator */ +struct RequantizeAttrs : public tvm::AttrsNode { + double input_scale; + int32_t input_zero_point; + double output_scale; + int32_t output_zero_point; + std::string rounding; + DataType out_dtype; + + TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") { + TVM_ATTR_FIELD(input_scale) + .describe("The scale of the input tensor."); + TVM_ATTR_FIELD(input_zero_point) + .describe("The zero point of the input tensor."); + TVM_ATTR_FIELD(output_scale) + .describe("The scale of the output tensor."); + TVM_ATTR_FIELD(output_zero_point) + .describe("The zero point of the output tensor."); + TVM_ATTR_FIELD(rounding).set_default("TONEAREST") + .describe("Defines the rounding direction when the value is midway between" + "two representable values. There are two supported modes - UPWARD" + "or TONEAREST. Both modes behave exactly same except at the" + "midpoints between the two representable values. At the midpoint," + "UPWARD rounds towards positive infinity (for example -1.5 will be" + "rounded to -1). TONEAREST is the standard rounding where the" + "value is rounded away from zero at midpoints (for example, -1.5" + "rounds to -2). More context can be found at following gblic manual" + "https://www.gnu.org/software/libc/manual/html_node/Rounding.html."); + TVM_ATTR_FIELD(out_dtype) + .set_default(NullValue()) + .describe("Output data type, set to explicit type under mixed precision setting"); + } +}; + +} // namespace qnn +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_QNN_ATTRS_H_ diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index da14c80b33b4..01baa00c9f7e 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -53,6 +53,9 @@ from . import backend from . import quantize +# Dialects +from . import qnn + from .scope_builder import ScopeBuilder # Span diff --git a/python/tvm/relay/qnn/__init__.py b/python/tvm/relay/qnn/__init__.py new file mode 100644 index 000000000000..a472109add39 --- /dev/null +++ b/python/tvm/relay/qnn/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import +"""QNN dialect operators and IR passes.""" +from __future__ import absolute_import as _abs +from . import op diff --git a/python/tvm/relay/qnn/op/__init__.py b/python/tvm/relay/qnn/op/__init__.py new file mode 100644 index 000000000000..e9adfa783f93 --- /dev/null +++ b/python/tvm/relay/qnn/op/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import +"""Neural network related operators.""" +from __future__ import absolute_import as _abs +from .qnn import * diff --git a/python/tvm/relay/qnn/op/_make.py b/python/tvm/relay/qnn/op/_make.py new file mode 100644 index 000000000000..07b3dd154760 --- /dev/null +++ b/python/tvm/relay/qnn/op/_make.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Constructor APIs""" +from ...._ffi.function import _init_api + +_init_api("relay.qnn.op._make", __name__) diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py new file mode 100644 index 000000000000..1717bc42fe94 --- /dev/null +++ b/python/tvm/relay/qnn/op/qnn.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +#pylint: disable=invalid-name +"""QNN dialect operators.""" + +from __future__ import absolute_import as _abs +from . import _make + +def requantize(data, + input_scale, + input_zero_point, + output_scale, + output_zero_point, + rounding="TONEAREST", + out_dtype="int8"): + r"""Requantized operator. + + The requantize operator converts one quantized tensor representation to + another quantized tensor representation. For the output tensor, we are + provided with output scale and zero point. The computation is as follows + + Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input) + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + input_scale: float + The quantization scale for the input tensor. + + input_zero_point: int + The zero point of the input tensor. + + output_scale: float + The quantization scale for the output tensor. + + output_zero_point: int + The zero point of the output tensor. + + rounding : string, optional + Defines the rounding direction when the value is midway between two + representable values. + + out_dtype : str, optional + Specifies the output data type. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + + return _make.requantize(data, + input_scale, + input_zero_point, + output_scale, + output_zero_point, + rounding, + out_dtype) diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 7dcfd5cb4b7f..3ccfff0c3463 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -394,6 +394,26 @@ inline Expr Variance(Expr data, Expr mean, Array axis, bool keepdims, b } +static inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) { + static const Op& op = Op::Get("where"); + return CallNode::make(op, {condition, x, y}); +} + +static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) { + static const Op& op = Op::Get("greater_equal"); + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); +} + +static inline Expr Full(Expr fill_value, + Array shape, + DataType dtype) { + auto attrs = make_node(); + attrs->shape = std::move(shape); + attrs->dtype = std::move(dtype); + static const Op& op = Op::Get("full"); + return CallNode::make(op, {fill_value}, Attrs(attrs), {}); +} + Expr MakeConcatenate(Expr data, int axis); Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides); diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc new file mode 100644 index 000000000000..04f7e80d5c64 --- /dev/null +++ b/src/relay/qnn/op/requantize.cc @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file requantize.cc + * \brief QNN requantize operator. + */ + +#include +#include +#include +#include "../../pass/pattern_util.h" +#include "../util.h" + +namespace tvm { +namespace relay { +namespace qnn { + +TVM_REGISTER_NODE_TYPE(RequantizeAttrs); + +// Lowering of qnn.requantize op + +/* + * \brief Convert FP32 representation into fixed point representation. + * \param double_multplier The input FP32 number. + * \return The pair of multiplier and shift for fixed point representation. + * \note Converts a floating point number so that it can be represented by + * integers. The representation is + * float_number = (significand) * 2^(exponent) + * + * The significand is a number between 0.5 and 1. This is represented by + * an integer number. For example, if it is int32, then the decimal point + * exists between bit 31 and 30 from LSB (or between first and second bit + * from the left). + * + * Some examples are + * 0.25 = (0.5) * 2^(-1) + * 0.125 = (0.5) * 2^(-2) + * + * Credit to TFLite reference implementation. + */ +std::pair GetFixedPointMultiplierShift(double double_multiplier) { + int32_t significand, exponent; + if (double_multiplier == 0.) { + significand = 0; + exponent = 0; + return std::make_pair(significand, exponent); + } + + // Get the significand and exponent. + double significand_d = std::frexp(double_multiplier, &exponent); + + // Convert the double significand to int significand, i.e., convert into a + // integer where the decimal point is between bit 31 and 30. This is done by + // multiplying the double value with 2^31 and then casting to int. + significand_d = std::round(significand_d * (1ll << 31)); + auto significand_int64 = static_cast(significand_d); + CHECK_LE(significand_int64, (1ll << 31)); + if (significand_int64 == (1ll << 31)) { + significand_int64 /= 2; + ++exponent; + } + CHECK_LE(significand_int64, std::numeric_limits::max()); + significand = static_cast(significand_int64); + return std::make_pair(significand, exponent); +} + +/* + * \brief Lower requantize to a sequence of ops. + * \param input_tensor The input tensor to requantize op. + * \param param The requantize op attrs. + * \param input_shape The input tensor shape of the requantize op. + * \return The sequence of existing Relay ops. + * \note Requantization using only integer computation. Here, the computation is + * converted to a fixed point computation by computing output multiplier + * and shift. This is useful, if the target device does not support/have + * very expensive floating point computations. + * + * Original compuation is scale_fp32 * quantized_tensor. To convert into + * integer computation, the multiplication with fp32 scalar can be + * replaced by multiplication with an int value and then right shifting + * the result. This approximates the floating point computation with a + * fixed point computation. + * + * The whole computation this can be broken down into following steps + * 1) Calculate the integer multiplier and integer shift. + * 2) Subtract the input integer zero point. + * 3) Multiply the fixed point multiplier with quantized tensor. + * 4) Round the result. + * 5) Right shift the result. + * 6) Add the output zero point. + * 7) Cast to the out_dtype. + */ +Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, + const Array& input_shape) { + double double_multiplier = param->input_scale / param->output_scale; + + // Choose high precision datatype to be int64. This is for avoiding overflow + // in multiplication of two int32 values. + DataType hp_dtype = Int(64); + + // 1) Calculating the integer multiplier and integer shift + int32_t fixed_point_multiplier, shift; + std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier); + int left_shift = shift > 0 ? shift : 0; + int right_shift = shift > 0 ? 0 : -shift; + + // 2) Subtract the input_zero_point + auto tensor = Cast(input_tensor, hp_dtype); + if (param->input_zero_point != 0) { + auto input_zp = MakeConstantScalar(hp_dtype, param->input_zero_point); + tensor = Subtract(tensor, input_zp); + } + + // 3) Multiply the integer multiplier + if (left_shift != 0) { + tensor = Multiply(tensor, MakeConstantScalar(hp_dtype, 1 << left_shift)); + } + // Perform the multiplication in higher precision. + // The scalar is a fixed point value of int32 where the decimal point is + // between bits 31 and 30. After multiplying with input_tensor, the result is + // in int64 where the decimal point is sitting between bits 31 and 30 (from + // the right, rightmost bit is bit 0). The computation is performed in higher + // precision to avoid overflow in multiplying two int32 values. + Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier); + auto multiplied_t = Multiply(tensor, scalar); + + // 4) Find the rounding scalar. This depends on where the final decimal point + // sits. As we will be right shifting the multiplied_t, we need to first + // calculate the total_right_shift. + int total_right_shift = right_shift + 31; + int64_t pos_rounding_value = (1ll << (total_right_shift - 1)); + + tensor = multiplied_t; + Expr round_scalar; + if (param->rounding == "UPWARD") { + round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value); + } else if (param->rounding == "TONEAREST") { + auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value); + auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1); + auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype); + auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype); + + auto zero = MakeConstantScalar(hp_dtype, 0); + auto zero_t = Full(zero, input_shape, hp_dtype); + round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t); + } + // Add the rounding scalar. + tensor = Add(tensor, round_scalar); + + // 5) Simply right shift the result to get the final output. + auto scaled_int64_t = RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift)); + + // 6) Add the output zero point. + auto output_zp = MakeConstantScalar(hp_dtype, param->output_zero_point); + auto shifted_int64_t = Add(output_zp, scaled_int64_t); + + // 7) Clip to the out_dtype min/max. + auto q_min = GetQmin(param->out_dtype); + auto q_max = GetQmax(param->out_dtype); + auto clipped_t = Clip(shifted_int64_t, q_min, q_max); + return Cast(clipped_t, param->out_dtype); +} + +/* + * \brief Forward rewrite the requantize op. + * \param ref_call The original call that will be lowered. + * \param new_args The new mutated args to the call node. + * \param ctx The node context. + * \return The sequence of Relay ops for requantize op. + * \note Lowering of the requantize operation. The requantize operator converts + * one quantized tensor to another quantized tensor. For the output + * tensor, we are provided with output scale and zero point. The + * computation looks like this + * + * Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input) + */ +Expr RequantizeLegalize(const Attrs& attrs, const Array& new_args, + const Array& arg_types) { + CHECK_EQ(new_args.size(), 1); + auto& quantized_data = new_args[0]; + const auto* param = attrs.as(); + CHECK(param != nullptr); + + // Find input shape. + CHECK_EQ(arg_types.size(), 1); + auto input_dtype = arg_types[0]; + auto input_tensor_type = input_dtype.as(); + CHECK(input_tensor_type != nullptr) << "Type information missing." + << " Please run infer_type pass."; + Array input_shape = input_tensor_type->shape; + + // Check rounding validity. + CHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST") + << "QNN requantize supports two rounding modes - UPWARD and " + << "TONEAREST"; + return RequantizeLower(quantized_data, param, input_shape); +} + +/* + * \brief Infer shape function of Requantize op. + * \param types The types of input args. + * \param num_inputs The number of inputs. + * \param attrs The op attributes. + * \param reporter The type reporter that sets the dtype and shapes. + * \return True if the infer shape succeeded. + */ +bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + const auto in_dtype = data->dtype; + CHECK(in_dtype == Int(8) || in_dtype == UInt(8) || in_dtype == Int(32)) + << "Input type should be an integer but was " << in_dtype; + + const Array oshape = data->shape; + // assign output type + const RequantizeAttrs* param = attrs.as(); + auto out_dtype = param->out_dtype; + CHECK(out_dtype == Int(8) || out_dtype == UInt(8) || out_dtype == Int(32)) + << "Output type should be an integer but was " << out_dtype; + reporter->Assign(types[1], TensorTypeNode::make(oshape, out_dtype)); + return true; +} + +// Positional relay function to create qnn requantize operator +// used by frontend FFI. +Expr MakeRequantize(Expr data, double input_scale, int32_t input_zero_point, double output_scale, + int32_t output_zero_point, std::string rounding, DataType out_dtype) { + auto attrs = make_node(); + attrs->input_scale = std::move(input_scale); + attrs->input_zero_point = std::move(input_zero_point); + attrs->output_scale = std::move(output_scale); + attrs->output_zero_point = std::move(output_zero_point); + attrs->rounding = std::move(rounding); + attrs->out_dtype = std::move(out_dtype); + static const Op& op = Op::Get("qnn.requantize"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +RELAY_REGISTER_OP("qnn.requantize") +.describe(R"code(Requantize operator. +The requantize operator converts one quantized tensor to another quantized +tensor. For the output tensor, we are provided with output scale and zero +point. The computation looks like this + +Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input) + +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.RequantizeAttrs") +.set_num_inputs(1) +.add_argument("data", "Tensor", "The quantized input tensor.") +.set_support_level(11) +.add_type_rel("Requantize", RequantizeRel) +.set_attr("FTVMLegalize", RequantizeLegalize); + +TVM_REGISTER_API("relay.qnn.op._make.requantize") +.set_body_typed(MakeRequantize); + +} // namespace qnn +} // namespace relay +} // namespace tvm diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h new file mode 100644 index 000000000000..1ada7ecd070e --- /dev/null +++ b/src/relay/qnn/util.h @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/qnn/util.h + * \brief Utility methods needs for quantized ops that can be shared + */ + +#ifndef TVM_RELAY_QNN_UTIL_H_ +#define TVM_RELAY_QNN_UTIL_H_ + +#include +#include +#include + +namespace tvm { +namespace relay { +namespace qnn { + +static inline const int32_t GetQmin(const DataType& dtype) { + CHECK_LE(dtype.bits(), 32) + << "QNN ops support int32 or lower precision"; + if (dtype.is_int()) { + auto* min_value = as_const_int(dtype.min()); + CHECK(min_value != nullptr); + return static_cast(min_value[0]); + } else if (dtype.is_uint()) { + auto* min_value = as_const_uint(dtype.min()); + CHECK(min_value != nullptr); + return static_cast(min_value[0]); + } else { + LOG(FATAL) << "Type not supported " << dtype; + return -1; // To hide the warning + } +} + +static inline const int32_t GetQmax(const DataType& dtype) { + CHECK_LE(dtype.bits(), 32) + << "QNN ops support int32 or lower precision"; + if (dtype.is_int()) { + auto* max_value = as_const_int(dtype.max()); + CHECK(max_value != nullptr); + return static_cast(max_value[0]); + } else if (dtype.is_uint()) { + auto* max_value = as_const_uint(dtype.max()); + CHECK(max_value != nullptr); + return static_cast(max_value[0]); + } else { + LOG(FATAL) << "Type not supported " << dtype; + return -1; // To hide the warning + } +} + +} // namespace qnn +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_QNN_UTIL_H_ diff --git a/tests/python/relay/test_qnn_requantize.py b/tests/python/relay/test_qnn_requantize.py new file mode 100644 index 000000000000..cd478fb5ba22 --- /dev/null +++ b/tests/python/relay/test_qnn_requantize.py @@ -0,0 +1,259 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import numpy as np +from tvm import relay +from tvm.relay.testing import create_workload +from tvm.contrib import graph_runtime + +roundings = ["UPWARD", "TONEAREST"] + +def run_infer_type(expr): + mod = relay.Module.from_expr(expr) + mod = relay.transform.InferType()(mod) + entry = mod["main"] + return entry if isinstance(expr, relay.Function) else entry.body + + +def test_requantize(): + def verify(mod, goldens): + with relay.build_config(opt_level=3): + graph, lib, params = relay.build(mod, "llvm", params=None) + golden_data, golden_output = goldens + rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) + rt_mod.set_input("quantized_data",golden_data) + rt_mod.set_input(**params) + rt_mod.run() + res = rt_mod.get_output(0).asnumpy() + np.testing.assert_equal(res, golden_output) + + def get_mod(data_shape, data_dtype, out_dtype, input_scale, output_scale, + input_zero_point=0, output_zero_point=0, rounding="TONEAREST"): + quantized_data = relay.var("quantized_data", shape=data_shape, + dtype=data_dtype) + mod = relay.qnn.op.requantize( + quantized_data, + input_scale=input_scale, + input_zero_point=input_zero_point, + output_scale=output_scale, + output_zero_point=output_zero_point, + rounding=rounding, + out_dtype=out_dtype) + + mod = relay.Function(relay.analysis.free_vars(mod), mod) + mod = relay.Module.from_expr(mod) + mod = relay.transform.Legalize()(mod) + return mod + + def same_scale_test(): + # Have same scales, everything within range + golden_data = np.arange(-100, 100, 1).astype('int32') + golden_output = golden_data + + for rounding in roundings: + mod = get_mod(data_shape=(200, ), + data_dtype='int32', + out_dtype="int8", + input_scale=0.5, + output_scale=0.5, + rounding=rounding) + verify(mod, (golden_data, golden_output)) + + def downscale_test(): + for rounding in roundings: + mod = get_mod(data_shape=(32, ), + data_dtype='int32', + out_dtype='int8', + input_scale=1, + output_scale=16, + rounding=rounding) + + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype('int32') + golden_output = np.repeat([0, 1, 2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) + + # Try negative values + # -8 corresponds to -0.5. For UPWARD, this is 0 + golden_data = np.arange(0, -32, -1).astype('int32') + if rounding == "UPWARD": + golden_output = np.repeat([0, -1, -2], [9, 16, 7]) + else: + golden_output = np.repeat([0, -1, -2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) + + # Try a different scale + mod = get_mod(data_shape=(32, ), + data_dtype='int32', + out_dtype="int8", + input_scale=1, + output_scale=4, + rounding=rounding) + + # Try positive values + # 2I corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype('int32') + golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8], + [2, 4, 4, 4, 4, 4, 4, 4, 2]) + verify(mod, (golden_data, golden_output)) + + # Try negative values + # -8 corresponds to -0.5. For UPWARD, this is 0 + golden_data = np.arange(0, -32, -1).astype('int32') + if rounding == "UPWARD": + golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8], + [3, 4, 4, 4, 4, 4, 4, 4, 1]) + else: + golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8], + [2, 4, 4, 4, 4, 4, 4, 4, 2]) + verify(mod, (golden_data, golden_output)) + + # Try uint8 out_dtype + mod = get_mod(data_shape=(32, ), + data_dtype='int32', + out_dtype='uint8', + input_scale=1, + output_scale=16, + rounding=rounding) + + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype('int32') + golden_output = np.repeat([0, 1, 2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) + + # Try uint8 in_dtyope and uint8 out_dtype + mod = get_mod(data_shape=(32, ), + data_dtype='uint8', + out_dtype='uint8', + input_scale=1, + output_scale=16, + rounding=rounding) + + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype('int32') + golden_output = np.repeat([0, 1, 2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) + + def upscale_test(): + for rounding in roundings: + mod = get_mod(data_shape=(32, ), + data_dtype='int32', + out_dtype="int8", + input_scale=2, + output_scale=1, + rounding=rounding) + + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype('int32') + golden_output = np.multiply(2, golden_data) + verify(mod, (golden_data, golden_output)) + + # Try negative values + # -8 corresponds to -0.5. For UPWARD, this is 0 + golden_data = np.arange(0, -32, -1).astype('int32') + golden_output = np.multiply(2, golden_data) + verify(mod, (golden_data, golden_output)) + + def saturation_test(): + for rounding in roundings: + mod = get_mod(data_shape=(16, ), + data_dtype='int32', + out_dtype="int8", + input_scale=0.5, + output_scale=0.5, + rounding=rounding) + golden_data = np.arange(0, 16, 1).astype('int32') + golden_data = np.add(120, golden_data) + output = np.array([120, 121, 122, 123, 124, 125, 126, 127, + 127, 127, 127, 127, 127, 127, 127, 127]) + golden_output = output + verify(mod, (golden_data, golden_output)) + + # Try negative numbers + golden_data = np.arange(0, -16, -1).astype('int32') + golden_data = np.add(-120, golden_data) + output = np.array([-120, -121, -122, -123, -124, -125, -126, -127, + -128, -128, -128, -128, -128, -128, -128, -128]) + golden_output = output + verify(mod, (golden_data, golden_output)) + + def zero_point_test(): + # Output zero point + for rounding in roundings: + mod = get_mod(data_shape=(32, ), + data_dtype='int32', + out_dtype='int8', + input_scale=1, + output_scale=16, + output_zero_point=1, + rounding=rounding) + + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype('int32') + golden_output = np.repeat([0, 1, 2], [8, 16, 8]) + golden_output = np.add(1, golden_output) + verify(mod, (golden_data, golden_output)) + + # Try negative values + # -8 corresponds to -0.5. For UPWARD, this is 0 + golden_data = np.arange(-32, -64, -1).astype('int32') + if rounding == "UPWARD": + golden_output = np.repeat([-2, -3, -4], [9, 16, 7]) + else: + golden_output = np.repeat([-2, -3, -4], [8, 16, 8]) + golden_output = np.add(1, golden_output) + verify(mod, (golden_data, golden_output)) + + # Input zero point + for rounding in roundings: + mod = get_mod(data_shape=(32, ), + data_dtype='int32', + out_dtype='int8', + input_scale=1, + output_scale=16, + input_zero_point=16, + rounding=rounding) + + # Try positive values + golden_data = np.arange(32, 64, 1).astype('int32') + golden_output = np.repeat([2, 3, 4], [8, 16, 8]) + golden_output = np.subtract(golden_output, 1) + verify(mod, (golden_data, golden_output)) + + # Try negative values + golden_data = np.arange(-32, -64, -1).astype('int32') + if rounding == "UPWARD": + golden_output = np.repeat([-2, -3, -4], [9, 16, 7]) + else: + golden_output = np.repeat([-2, -3, -4], [8, 16, 8]) + golden_output = np.subtract(golden_output, 1) + verify(mod, (golden_data, golden_output)) + + same_scale_test() + downscale_test() + upscale_test() + saturation_test() + zero_point_test() + +if __name__ == "__main__": + test_requantize()