diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 61c9b36e1ffd..1361fd361473 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -198,6 +198,16 @@ This level support backpropagation of broadcast operators. It is temporary. tvm.relay.contrib.adaptive_avg_pool2d +**Level 11: QNN Dialect Operators** + +This level supports quantized operators present in the QNN dialect. + +.. autosummary:: + :nosignatures: + + tvm.relay.qnn.op.requantize + + Level 1 Definitions ------------------- .. autofunction:: tvm.relay.log @@ -332,3 +342,8 @@ Level 10 Definitions .. autofunction:: tvm.relay.nn.batch_matmul .. autofunction:: tvm.relay.contrib.adaptive_max_pool2d .. autofunction:: tvm.relay.contrib.adaptive_avg_pool2d + + +Level 11 Definitions +-------------------- +.. autofunction:: tvm.relay.qnn.op.requantize diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h new file mode 100644 index 000000000000..d11daf42858f --- /dev/null +++ b/include/tvm/relay/qnn/attrs.h @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/qnn/attrs.h + * \brief Auxiliary attributes for qnn operators. + */ + +#ifndef TVM_RELAY_QNN_ATTRS_H_ +#define TVM_RELAY_QNN_ATTRS_H_ + +#include +#include +#include + +namespace tvm { +namespace relay { +namespace qnn { + +/*! \brief Attribute for requantize operator */ +struct RequantizeAttrs : public tvm::AttrsNode { + double input_scale; + int32_t input_zero_point; + double output_scale; + int32_t output_zero_point; + std::string rounding; + DataType out_dtype; + + TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") { + TVM_ATTR_FIELD(input_scale) + .describe("The scale of the input tensor."); + TVM_ATTR_FIELD(input_zero_point) + .describe("The zero point of the input tensor."); + TVM_ATTR_FIELD(output_scale) + .describe("The scale of the output tensor."); + TVM_ATTR_FIELD(output_zero_point) + .describe("The zero point of the output tensor."); + TVM_ATTR_FIELD(rounding).set_default("AWAY_FROM_ZERO") + .describe("Defines the rounding direction when the value is midway between" + "two representable values. There are two supported modes - UPWARD" + "or AWAY_FROM_ZERO. Both modes behave exactly same except at the" + "midpoints between the two representable values. At the midpoint," + "UPWARD rounds towards positive infinity (for example -1.5 will be" + "rounded to -1). AWAY_FROM_ZERO is the standard rounding where the" + "value is rounded away from zero at midpoints (for example, -1.5" + "rounds to -2). More context can be found at following gblic manual" + "https://www.gnu.org/software/libc/manual/html_node/Rounding.html." + "FE_UPWARD corresponds to UPWARD here and FE_TONEAREST corresponds" + "to AWAY_FROM_ZERO rounding mode."); + TVM_ATTR_FIELD(out_dtype) + .set_default(NullValue()) + .describe("Output data type, set to explicit type under mixed precision setting"); + } +}; + +/*! \brief Attributes for quantized dense operator */ +struct QDenseAttrs : public tvm::AttrsNode { + IndexExpr units; + DataType out_dtype; + // Quantization related attributes. + int32_t input_zero_point; + int32_t kernel_zero_point; + + TVM_DECLARE_ATTRS(QDenseAttrs, "relay.attrs.QDenseAttrs") { + TVM_ATTR_FIELD(units) + .describe("Number of hidden units of the dense transformation."); + + TVM_ATTR_FIELD(out_dtype) + .describe("Output data type, set to explicit type under mixed precision setting"); + + TVM_ATTR_FIELD(input_zero_point) + .describe("The zero point of the input tensor."); + TVM_ATTR_FIELD(kernel_zero_point) + .describe("The zero point of the kernel tensor."); + } +}; + +} // namespace qnn +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_ATTRS_QNN_H_ diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index e94ef411d29d..0c45961b2708 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -51,6 +51,9 @@ from . import backend from . import quantize +# Dialects +from . import qnn + from .scope_builder import ScopeBuilder # Span diff --git a/python/tvm/relay/qnn/__init__.py b/python/tvm/relay/qnn/__init__.py new file mode 100644 index 000000000000..fa888d7ce7dd --- /dev/null +++ b/python/tvm/relay/qnn/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import +"""QNN dialect operators and IR passes.""" +from __future__ import absolute_import as _abs +from . import op +from . import transform diff --git a/python/tvm/relay/qnn/_transform.py b/python/tvm/relay/qnn/_transform.py new file mode 100644 index 000000000000..e2ff6f9ed652 --- /dev/null +++ b/python/tvm/relay/qnn/_transform.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +#pylint: disable=unused-argument +"""Internal module for quantization.""" +from __future__ import absolute_import +from tvm._ffi.function import _init_api + +_init_api("relay.qnn._transform", __name__) diff --git a/python/tvm/relay/qnn/op/__init__.py b/python/tvm/relay/qnn/op/__init__.py new file mode 100644 index 000000000000..e9adfa783f93 --- /dev/null +++ b/python/tvm/relay/qnn/op/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import +"""Neural network related operators.""" +from __future__ import absolute_import as _abs +from .qnn import * diff --git a/python/tvm/relay/qnn/op/_make.py b/python/tvm/relay/qnn/op/_make.py new file mode 100644 index 000000000000..07b3dd154760 --- /dev/null +++ b/python/tvm/relay/qnn/op/_make.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Constructor APIs""" +from ...._ffi.function import _init_api + +_init_api("relay.qnn.op._make", __name__) diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py new file mode 100644 index 000000000000..9b500d71fa4e --- /dev/null +++ b/python/tvm/relay/qnn/op/qnn.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +#pylint: disable=invalid-name +"""QNN dialect operators.""" + +from __future__ import absolute_import as _abs +from . import _make + +def requantize(data, + input_scale, + input_zero_point, + output_scale, + output_zero_point, + rounding="AWAY_FROM_ZERO", + out_dtype="int8"): + r"""Requantized operator. + + The requantize operator converts one quantized tensor representation to + another quantized tensor representation. For the output tensor, we are + provided with output scale and zero point. The computation is as follows + + Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input) + + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + input_scale: float + The quantization scale for the input tensor. + + input_zero_point: int + The zero point of the input tensor. + + output_scale: float + The quantization scale for the output tensor. + + output_zero_point: int + The zero point of the output tensor. + + rounding : string, optional + Defines the rounding direction when the value is midway between two + representable values. + + out_dtype : str, optional + Specifies the output data type. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + + return _make.requantize(data, + input_scale, + input_zero_point, + output_scale, + output_zero_point, + rounding, + out_dtype) + +def quantized_dense(data, weight, input_zero_point, kernel_zero_point, units=None, out_dtype="int32"): + """Dense operator. + Applies a linear transformation + + .. math:: + + `Y = X * W` + + Parameters + ---------- + data : tvm.relay.Expr + The quantied input data to the operator. + + weight : tvm.relay.Expr + The quantized weight expressions. + + units : int, optional + Number of hidden units of the dense transformation. + + out_dtype : str, optional + Specifies the output data type for mixed precision dense can be int32 or int16. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + return _make.dense(data, weight, units, input_zero_point, kernel_zero_point, out_dtype) diff --git a/python/tvm/relay/qnn/transform.py b/python/tvm/relay/qnn/transform.py new file mode 100644 index 000000000000..6ca456b4fb81 --- /dev/null +++ b/python/tvm/relay/qnn/transform.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name + +"""QNN Dialect transformation passes.""" +from __future__ import absolute_import + +from . import _transform + +def QnnLower(): + """ + Rewrites the high-level quantized ops into low-level exisiting Relay ops. + + Returns + ------- + Pass : tvm.relay.transform.Pass + The optmized pas. + """ + return _transform.QnnLower() diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index c0f36bfa2915..8a35f699e65a 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -32,7 +32,7 @@ #include #include #include -#include "../type_relations.h" +#include "nn.h" #include "../../pass/alter_op_layout.h" #include "../op_common.h" @@ -102,45 +102,6 @@ RELAY_REGISTER_OP("nn.bias_add") // relay.nn.dense TVM_REGISTER_NODE_TYPE(DenseAttrs); - -bool DenseRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 3); - const auto* data = types[0].as(); - const auto* weight = types[1].as(); - if (data == nullptr) return false; - - const DenseAttrs* param = attrs.as(); - CHECK(param != nullptr); - - CHECK(static_cast(data->shape.size()) != 0); - - Array oshape = data->shape; - if (param->units.defined()) { - Array dshape = data->shape; - // validate the weight shape is proper if defined - // Assign weight type - Array wshape({param->units, dshape[dshape.size() - 1]}); - reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype)); - oshape.Set((oshape.size() - 1), param->units); - } else { - if (weight == nullptr) return false; - Array wshape = weight->shape; - oshape.Set((oshape.size() - 1), wshape[0]); - } - - DataType out_dtype = param->out_dtype; - if (out_dtype.bits() == 0) { - out_dtype = data->dtype; - } - // assign output type - reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); - return true; -} - - // Positional relay function to create dense operator used by frontend FFI. Expr MakeDense(Expr data, Expr weight, @@ -171,7 +132,7 @@ RELAY_REGISTER_OP("nn.dense") .add_argument("data", "nD Tensor", "Input data.") .add_argument("weight", "2D Tensor", "Weight matrix.") .set_support_level(1) -.add_type_rel("Dense", DenseRel); +.add_type_rel("Dense", DenseRel); // relay.leaky_relu TVM_REGISTER_NODE_TYPE(LeakyReluAttrs); diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h new file mode 100644 index 000000000000..41a702cd2959 --- /dev/null +++ b/src/relay/op/nn/nn.h @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file nn.h + * \brief Property def of nn operators that need to be shared by quantized and unquantized ops. + */ + +#ifndef TVM_NN_H +#define TVM_NN_H + +#include +#include +#include "../type_relations.h" +#include "../../qnn/util.h" +#include + +namespace tvm { +namespace relay { + +// relay.nn.dense +enum DenseType { + kUnquantizedDense, + kQuantizedDense +}; + +template +inline bool DenseRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* weight = types[1].as(); + if (data == nullptr) return false; + const auto* param = attrs.as(); + CHECK(param != nullptr); + + CHECK(static_cast(data->shape.size()) != 0); + if(mode == DenseType::kQuantizedDense) { + CHECK(IsValidOpInputType(qnn::QuantizeOpType::QuantizedDense, data->dtype)) + << "Expected quantized dense type(int8, uint8) for input but was " << data->dtype; + CHECK(IsValidOpInputType(qnn::QuantizeOpType::QuantizedDense, weight->dtype)) + << "Expected quantized dense type(int8, uint8) for weight but was " << weight->dtype; + CHECK(data->dtype == weight->dtype) << "Weight and kernel dtypes do not match"; + CHECK(IsValidOpOutputType(qnn::QuantizeOpType::QuantizedDense, param->out_dtype)) + << "Expected quantized dense type(int32, int16) for output but was " << param->out_dtype; + } + Array oshape = data->shape; + if (param->units.defined()) { + Array dshape = data->shape; + // validate the weight shape is proper if defined + // Assign weight type + Array wshape({param->units, dshape[dshape.size() - 1]}); + reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype)); + oshape.Set((oshape.size() - 1), param->units); + } else { + if (weight == nullptr) return false; + Array wshape = weight->shape; + oshape.Set((oshape.size() - 1), wshape[0]); + } + + DataType out_dtype = param->out_dtype; + if(mode == DenseType::kUnquantizedDense) { + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + } + // assign output type + reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); + return true; +} + +} // namespace relay +} // namespace tvm + +#endif //TVM_NN_H diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 5c303905968e..b99c213f0f80 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -33,6 +33,7 @@ #include #include #include +#include #include @@ -373,6 +374,37 @@ inline Expr Copy(Expr data) { } +static inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) { + static const Op& op = Op::Get("where"); + return CallNode::make(op, {condition, x, y}); +} + +static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) { + static const Op& op = Op::Get("greater_equal"); + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); +} + +static inline Expr Full(Expr fill_value, + Array shape, + DataType dtype) { + auto attrs = make_node(); + attrs->shape = std::move(shape); + attrs->dtype = std::move(dtype); + static const Op& op = Op::Get("full"); + return CallNode::make(op, {fill_value}, Attrs(attrs), {}); +} + +inline Expr Dense(Expr data, + Expr weight, + IndexExpr units, + DataType out_dtype) { + auto attrs = make_node(); + attrs->units = units; + attrs->out_dtype = out_dtype; + static const Op& op = Op::Get("nn.dense"); + return CallNode::make(op, {data, weight}, Attrs(attrs), {}); +} + Expr MakeConcatenate(Expr data, int axis); Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides); diff --git a/src/relay/qnn/op/nn/nn.cc b/src/relay/qnn/op/nn/nn.cc new file mode 100644 index 000000000000..a41ec8621ef6 --- /dev/null +++ b/src/relay/qnn/op/nn/nn.cc @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file nn.cc + * \brief Property def of qauntized nn operators. + */ + +#include +#include "../../../op/nn/nn.h" + +namespace tvm { +namespace relay { +namespace qnn { + +// relay.qnn.dense +TVM_REGISTER_NODE_TYPE(QDenseAttrs); + +// Positional relay function to create quantized dense operator used by frontend FFI. +Expr MakeQuantizedDense(Expr data, + Expr weight, + IndexExpr units, + int32_t input_zero_point, + int32_t kernel_zero_point, + DataType out_dtype) { + auto attrs = make_node(); + attrs->units = units; + attrs->out_dtype = out_dtype; + attrs->input_zero_point = input_zero_point; + attrs->kernel_zero_point = kernel_zero_point; + static const Op& op = Op::Get("qnn.dense"); + return CallNode::make(op, {data, weight}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.qnn.op._make.dense") +.set_body_typed(MakeQuantizedDense); + +RELAY_REGISTER_OP("qnn.dense") + .describe(R"code(Applies a linear transformation: :math:`Y = XW^T`. + +- **data**: quantized(int8, unit8) `(x1, x2, ..., xn, input_dim)` +- **weight**: quantized(int8, unit8) `(units, input_dim)` +- **out**: quantized(int32) `(x1, x2, ..., xn, units)`. + +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.QDenseAttrs") +.set_num_inputs(2) +.add_argument("data", "quantized nD Tensor", "Input data.") +.add_argument("weight", "quantized 2D Tensor", "Weight matrix.") +.set_support_level(10) +.add_type_rel("QDense", DenseRel); + +} // namespace qnn +} // namespace relay +} // namespace tvm diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc new file mode 100644 index 000000000000..843ae9ede92d --- /dev/null +++ b/src/relay/qnn/op/requantize.cc @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file requantize.cc + * \brief QNN requantize operator. + */ + +#include +#include +#include +#include "../util.h" + +namespace tvm { +namespace relay { +namespace qnn { + +TVM_REGISTER_NODE_TYPE(RequantizeAttrs); + +/* + * \brief Infer shape function of Requantize op. + * \param types The types of input args. + * \param num_inputs The number of inputs. + * \param attrs The op attributes. + * \param reporter The type reporter that sets the dtype and shapes. + * \return True if the infer shape succeeded. + */ +bool RequantizeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + const auto input_dtype = data->dtype; + CHECK(IsValidOpInputType(QuantizeOpType::Requantize, input_dtype)) + << "Input type should be an integer but was " << input_dtype; + + const Array oshape = data->shape; + // assign output type + const RequantizeAttrs* param = attrs.as(); + CHECK(IsValidOpOutputType(QuantizeOpType::Requantize, param->out_dtype)) + << "Output type should be an integer but was " << param->out_dtype; + reporter->Assign(types[1], TensorTypeNode::make(oshape, param->out_dtype)); + return true; +} + +// Positional relay function to create qnn requantize operator +// used by frontend FFI. +Expr MakeRequantize(Expr data, + double input_scale, + int32_t input_zero_point, + double output_scale, + int32_t output_zero_point, + std::string rounding, + DataType out_dtype) { + auto attrs = make_node(); + attrs->input_scale = std::move(input_scale); + attrs->input_zero_point = std::move(input_zero_point); + attrs->output_scale = std::move(output_scale); + attrs->output_zero_point = std::move(output_zero_point); + attrs->rounding = std::move(rounding); + attrs->out_dtype = std::move(out_dtype); + static const Op& op = Op::Get("qnn.requantize"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +RELAY_REGISTER_OP("qnn.requantize") +.describe(R"code(Requantize operator. + +The requantize operator converts one quantized tensor to another quantized +tensor. For the output tensor, we are provided with output scale and zero +point. The computation looks like this + +Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input) + +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.RequantizeAttrs") +.set_num_inputs(1) +.add_argument("data", "Tensor", "The quantized input tensor.") +.set_support_level(11) +.add_type_rel("Requantize", RequantizeRel); + +TVM_REGISTER_API("relay.qnn.op._make.requantize") +.set_body_typed(MakeRequantize); + +} // namespace qnn +} // namespace relay +} // namespace tvm diff --git a/src/relay/qnn/pass/qnn_lower.cc b/src/relay/qnn/pass/qnn_lower.cc new file mode 100644 index 000000000000..1ec8dbbf3ecd --- /dev/null +++ b/src/relay/qnn/pass/qnn_lower.cc @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file qnn_lower.cc + * \brief Lower qnn ops to a sequence of existing Relay ops. + */ + +#include +#include +#include +#include +#include "../util.h" +#include "../../pass/pattern_util.h" + +namespace tvm { +namespace relay { +namespace qnn { +/*! + * \brief namespace of qnn lower pass. + * + * Use namespace to reduce potential naming conflict. + */ +namespace qnn_lower { + +using runtime::TypedPackedFunc; + +// Lowering of qnn.requantize op + +/* + * \brief Convert FP32 representation into fixed point representation. + * \param double_multplier The input FP32 number. + * \param idtype The input datatype. + * \return The pair of multiplier and shift for fixed point representation. + * \note Converts a floating point number so that it can be represented by + * integers. The representation is + * float_number = (significand) * 2^(exponent) + * + * The significand is a number between 0.5 and 1. This is represented by + * an integer number. For example, if it is int32, then the decimal point + * exists between bit 31 and 30 from LSB (or between first and second bit + * from the left). + * + * Some examples are + * 0.25 = (0.5) * 2^(-1) + * 0.125 = (0.5) * 2^(-2) + * + * Credit to TFLite reference implementation. + */ +std::pair GetFixedPointMultiplierShift(double double_multiplier, + const DataType& idtype) { + int significand, exponent; + if (double_multiplier == 0.) { + significand = 0; + exponent = 0; + return std::pair(significand, exponent); + } + int idtype_bits = idtype.bits(); + + // Get the significand (significand) and exponent (exponent) + double significand_d = std::frexp(double_multiplier, &exponent); + + // Convert the double significand to int significand. + significand_d = std::round(significand_d * (1ll << (idtype_bits - 1))); + auto significand_int64 = static_cast(significand_d); + CHECK_LE(significand_int64, (1ll << (idtype_bits - 1))); + if (significand_int64 == (1ll << (idtype_bits - 1))) { + significand_int64 /= 2; + ++exponent; + } + CHECK_LE(significand_int64, std::numeric_limits::max()); + significand = static_cast(significand_int64); + return std::pair(significand, exponent); +} + +/* + + * \brief Lower requantize to a sequence of ops. + * \param input_tensor The input tensor to requantize op. + * \param param The requantize op attrs. + * \param idtype The dtype of the input tensor. + * \param out_shape The output shape of the requantize op. + * \return The sequence of existing Relay ops. + * \note Requantization using only integer computation. Here, the computation is + * converted to a fixed point computation by computing output multiplier + * and shift. This is useful, if the target device does not support/have + * very expensive floating point computations. + * + * Original compuation is scale_fp32 * quantized_tensor. To convert into + * integer computation, the multiplication with fp32 scalar can be + * replaced by multiplication with an int value and then right shifting + * the result. This approximates the floating point computation with a + * fixed point computation. + * + * The whole computation this can be broken down into following steps + * 1) Calculate the integer multiplier and integer shift. + * 2) Subtract the input integer point. + * 3) Multiply the integer fixed point multiplier with quantized tensor. + * 4) Round the result. + * 5) Right shift the result. + * 6) Add the output_zero_point. + * 7) Cast to the out_dtype. + */ +Expr RequantizeLower(const Expr& input_tensor, + const RequantizeAttrs* param, const DataType& idtype, + const Array& out_shape) { + + double double_multiplier = param->input_scale/param->output_scale; + + // The multiplication will be performed in higher precision. Find the dtype. + int idtype_bits = idtype.bits(); + DataType up_idtype = Int(2 * idtype_bits); + + // 1) Calculating the integer multiplier and integer shift + std::pair fixed_point_params = + GetFixedPointMultiplierShift(double_multiplier, idtype); + int fixed_point_multiplier = fixed_point_params.first; + int shift = fixed_point_params.second; + int left_shift = shift > 0 ? shift : 0; + int right_shift = shift > 0 ? 0 : -shift; + + // 2) Subtract the input_zero_point + auto tensor = Cast(input_tensor, up_idtype); + if (param->input_zero_point != 0) { + auto input_zp = MakeConstantScalar(up_idtype, param->input_zero_point); + tensor = Subtract(tensor, input_zp); + } + + // 3) Multiply the integer multiplier + if (left_shift != 0) { + tensor = Multiply(tensor, MakeConstantScalar(up_idtype, 1 << left_shift)); + } + // Perform the multiplication in higher precision. + // If idtype is Int(32), the scalar is a fixed point value of int32 where the + // decimal point is between bits 31 and 30. After multiplying with + // input_tensor, the result is in int64 where the decimal point is sitting + // between bits 31 and 30 (from the right, rightmost bit is bit 0). The + // computation is performed in higher precision to avoid overflow in + // multiplying two int32 values. + Expr scalar = MakeConstantScalar(up_idtype, fixed_point_multiplier); + auto multiplied_t = Multiply(tensor, scalar); + + // 4) Find the rounding scalar. This depends on where the final decimal point + // sits. As we will be right shifting the multiplied_t, we need to first + // calculate the total_right_shift. + int total_right_shift = right_shift + idtype_bits - 1; + + tensor = multiplied_t; + Expr round_scalar; + if (param->rounding == "UPWARD") { + auto pos_rounder = MakeConstantScalar(up_idtype, (1ll << (total_right_shift - 1))); + round_scalar = pos_rounder; + } else if (param->rounding == "AWAY_FROM_ZERO") { + auto pos_rounder = MakeConstantScalar(up_idtype, (1ll << (total_right_shift - 1))); + auto neg_rounder = MakeConstantScalar(up_idtype, (1ll << (total_right_shift - 1)) - 1); + auto pos_rounder_t = Full(pos_rounder, out_shape, up_idtype); + auto neg_rounder_t = Full(neg_rounder, out_shape, up_idtype); + + auto zero = MakeConstantScalar(up_idtype, 0); + auto zero_t = Full(zero, out_shape, up_idtype); + round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, + neg_rounder_t); + } + // Add the rounding scalar. + tensor = Add(tensor, round_scalar); + + // 5) Simply right shift the result to get the final output. + auto scaled_int64_t = RightShift(tensor, + MakeConstantScalar(up_idtype, total_right_shift)); + + // 6) Add the output zero point. + auto output_zp = MakeConstantScalar(up_idtype, param->output_zero_point); + auto shifted_int64_t = Add(output_zp, scaled_int64_t); + + // 7) Clip to the out_dtype min/max. + // Find the right clip min/maxes. While clipping, it is necessary that + // clip_min and clip_max are within the dtype range of the input tensor to the + // clip operator. For example, if the input to clip operator is int8, but the + // out_dtype is uint8, we will get incorrect results, if we set max as 255. + auto q_min = std::max(GetQmin(param->out_dtype), GetQmin(idtype)); + auto q_max = std::min(GetQmax(param->out_dtype), GetQmax(idtype)); + auto clipped_t = Clip(shifted_int64_t, q_min, q_max); + auto requantized_output = Cast(clipped_t, param->out_dtype); + return requantized_output; +} + +/* + * \brief Forward rewrite the requantize op. + * \param ref_call The original call that will be lowered. + * \param new_args The new mutated args to the call node. + * \param ctx The node context. + * \return The sequence of Relay ops for requantize op. + * \note Lowering of the requantize operation. The requantize operator converts + * one quantized tensor to another quantized tensor. For the output + * tensor, we are provided with output scale and zero point. The + * computation looks like this + * + * Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input) + */ +Expr RequantizeForwardRewrite(const Call& ref_call, + const Array& new_args, const NodeRef& ctx) { + CHECK_EQ(new_args.size(), 1); + Expr quantized_data = new_args[0]; + const auto* param = ref_call->attrs.as(); + CHECK(param != nullptr); + + // Find output shape. + Array out_shape; + auto ref_call_t = ref_call->checked_type(); + auto output_tt = ref_call_t.as(); + CHECK(output_tt != nullptr) << "Type information missing." + << " Please run infer_type pass."; + out_shape = output_tt->shape; + + // Find input dtype. + auto ref_input_t = ref_call->args[0]->checked_type(); + auto input_tt = ref_input_t.as(); + CHECK(input_tt != nullptr) << "Type information missing." + << " Please run infer_type pass."; + const auto input_dtype = input_tt->dtype; + + // Check rounding validity. + CHECK(param->rounding == "UPWARD" || param->rounding == "AWAY_FROM_ZERO") + << "QNN requantize supports two rounding modes - UPWARD and " + << "AWAY_FROM_ZERO"; + return RequantizeLower(quantized_data, param, input_dtype, out_shape); +} + +RELAY_REGISTER_OP("qnn.requantize") +.set_attr("FQnnForwardRewrite", RequantizeForwardRewrite); + +Expr QuantizedDenseForwardRewrite(const Call& ref_call, + const Array& new_args, const NodeRef& ctx) { + CHECK_EQ(new_args.size(), 2); + Expr quantized_data = new_args[0]; + Expr quantized_kernel = new_args[1]; + const auto* param = ref_call->attrs.as(); + + Array out_shape; + auto ref_call_t = ref_call->checked_type(); + auto output_tt = ref_call_t.as(); + CHECK(output_tt != nullptr) << "Type information missing." + << " Please run infer_type pass."; + //TODO: need to benchmark the performance of this lowering. + Expr quantized_data_int32 = Cast(quantized_data, Int(32)); + if(param->input_zero_point != 0) { + quantized_data_int32 = Add(quantized_data_int32, MakeConstantScalar(Int(32), + param->input_zero_point)); + } + Expr quantized_kernel_int32 = Cast(quantized_kernel, Int(32)); + if(param->kernel_zero_point != 0) { + quantized_kernel_int32 = Add(quantized_kernel_int32, MakeConstantScalar(Int(32), + param->kernel_zero_point)); + } + Expr int32_dense = Dense(quantized_data_int32, + quantized_kernel_int32, + param->units, + param->out_dtype); + return int32_dense; +} + +RELAY_REGISTER_OP("qnn.dense") +.set_attr("FQnnForwardRewrite", QuantizedDenseForwardRewrite); + + +TVM_REGISTER_API("relay._qnn.qnn_lower") +.set_body_typed([](const Expr& e) { + Expr ret = ForwardRewrite(e, "FQnnForwardRewrite", nullptr, nullptr); + return ret; +}); + +Expr QnnLower(const Expr& expr) { + return ForwardRewrite(expr, "FQnnForwardRewrite", nullptr, nullptr); +} +} // namespace qnn_lower + +namespace transform { +using namespace tvm::relay::transform; +Pass QnnLower() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast( + relay::qnn::qnn_lower::QnnLower(f)); + }; + return CreateFunctionPass(pass_func, 0, "QnnLower", + {ir::StringImm::make("InferType")}); +} + +TVM_REGISTER_API("relay.qnn._transform.QnnLower") +.set_body_typed(QnnLower); +} // namespace transform + +} // namespace qnn +} // namespace relay +} // namespace tvm diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h new file mode 100644 index 000000000000..65988fb096c1 --- /dev/null +++ b/src/relay/qnn/util.h @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/qnn/util.h + * \brief Utility methods needs for quantized ops that can be shared + */ + +#ifndef TVM_RELAY_QNN_UTIL_H_ +#define TVM_RELAY_QNN_UTIL_H_ + +#include +#include +#include + +namespace tvm { +namespace relay { +namespace qnn { + +inline bool IsInt8(const DataType& dtype) { + return dtype == Int(8); +} + +static inline bool IsQNNDataType(const DataType& dtype) { + return dtype == Int(8) || dtype == UInt(8) + || dtype == Int(16) || dtype == UInt(16); +} + +inline bool IsUint8(const DataType& dtype) { + return dtype == UInt(8); +} + +inline bool IsInt16(const DataType& dtype) { + return dtype == Int(16); +} + +inline bool IsUint16(const DataType& dtype) { + return dtype == UInt(16); +} + +inline bool IsInt32(const DataType& dtype) { + return dtype == Int(32); +} + +inline bool IsUint32(const DataType& dtype) { + return dtype == UInt(32); +} + +inline bool IsFloat32(const DataType& dtype) { + return dtype == Float(32); +} + +inline bool IsQuantizedType(const DataType& dtype) { + return IsInt8(dtype) || IsUint8(dtype) + || IsInt16(dtype) || IsUint16(dtype); +} + +enum class QuantizeOpType : uint8_t { + Quantize, + Dequantize, + Requantize, + QuantizedDense +}; + +static inline bool IsValidOpInputType(const QuantizeOpType& op_type, + const DataType& in_dtype) { + switch (op_type) { + case QuantizeOpType::Quantize: + return IsFloat32(in_dtype); + case QuantizeOpType::Dequantize: + case QuantizeOpType::QuantizedDense: + return IsQuantizedType(in_dtype); + case QuantizeOpType ::Requantize: + return in_dtype.is_int() || in_dtype.is_uint(); + default: + return false; + } +} + +static inline bool IsValidOpOutputType(const QuantizeOpType& op_type, + const DataType& out_dtype) { + switch (op_type) { + case QuantizeOpType::Quantize: + return IsQNNDataType(out_dtype); + case QuantizeOpType::Dequantize: + return out_dtype == Float(32); + case QuantizeOpType::Requantize: + return out_dtype.is_int() || out_dtype.is_uint(); + case QuantizeOpType::QuantizedDense: + return IsInt32(out_dtype) || IsInt16(out_dtype); + default: + return false; + } +} + +static inline const int32_t GetQmin(const DataType& dtype) { + CHECK_LE(dtype.bits(), 32) + << "QNN ops support uint32/int32 or lower precision"; + if (dtype.is_int()) { + auto* min_value = as_const_int(dtype.min()); + CHECK(min_value != nullptr); + return static_cast(min_value[0]); + } else if (dtype.is_uint()) { + auto* min_value = as_const_uint(dtype.min()); + CHECK(min_value != nullptr); + return static_cast(min_value[0]); + } + LOG(FATAL) << "Type not supported " << dtype; + return -1; +} + +static inline const int32_t GetQmax(const DataType& dtype) { + CHECK_LE(dtype.bits(), 32) + << "QNN ops support uint32/int32 or lower precision"; + if (dtype.is_int()) { + auto* max_value = as_const_int(dtype.max()); + CHECK(max_value != nullptr); + return static_cast(max_value[0]); + } else if (dtype.is_uint()) { + auto* max_value = as_const_uint(dtype.max()); + CHECK(max_value != nullptr); + return static_cast(max_value[0]); + } + LOG(FATAL) << "Type not supported " << dtype; + return -1; +} + +} // namespace qnn +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_QNN_UTIL_H_ diff --git a/tests/python/relay/test_qnn_quantized_fully_connected.py b/tests/python/relay/test_qnn_quantized_fully_connected.py new file mode 100644 index 000000000000..e06a9b13e1a4 --- /dev/null +++ b/tests/python/relay/test_qnn_quantized_fully_connected.py @@ -0,0 +1,190 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import numpy as np +from tvm import relay +from tvm.contrib import graph_runtime + +def test_quantized_dense(): + + def make_requantize_params(input_scale, output_scale, output_zero_point, out_dtype): + config = { + 'input_scale': input_scale, + 'output_scale': output_scale, + 'output_zero_point': output_zero_point, + 'out_dtype': out_dtype + } + return config + + def make_test_configuration(quantized_data, quantized_kernel, dtype, input_shape, kernel_shape, input_zero_point, + kernel_zero_point, units, output, out_dtype='int32', bias=None, requantize=None): + if requantize is not None: + assert bias is not None + config = { + 'quantized_data': quantized_data, + 'quantized_kernel': quantized_kernel, + 'dtype': dtype, + 'input_shape': input_shape, + 'kernel_shape': kernel_shape, + 'input_zero_point': input_zero_point, + 'kernel_zero_point': kernel_zero_point, + 'units': units, + 'output': output, + 'out_dtype': out_dtype, + 'bias': bias, + 'requantize': requantize + } + return config + + def make_uint_configuration(use_bias=False, requantize_output=False): + input_shape, kernel_shape, output_shape = (2, 10), (3,10), (2, 3) + input_zero_point, kernel_zero_point = -127, -127 + in_dtype = 'uint8' + out_dtype = 'int32' if not requantize_output else 'uint8' + units = 3 + quantized_data_np = np.array([129, 131, 133, 135, 137, 139, 141, 143, 109, 107, + 129, 131, 133, 135, 137, 139, 141, 111, 145, 107]) \ + .astype(in_dtype) \ + .reshape(input_shape) + quantized_kernel_np = np.array([129, 131, 133, 135, 137, 139, 141, 143, 145, 147, + 129, 131, 133, 135, 137, 139, 141, 143, 145, 147, + 129, 131, 133, 135, 137, 139, 141, 143, 145, 147]) \ + .astype(in_dtype) \ + .reshape(kernel_shape) + bias = np.array([4, 8, 12]).astype(out_dtype).reshape((units, )) if use_bias else None + requant_params = make_requantize_params(0.25, 1.0, 127, 'uint8') if requantize_output else None + + if requantize_output: + assert use_bias + output = np.array([151, 152, 153, 185, 186, 187]) + elif use_bias: + output = np.array([96, 100, 104, 232, 236, 240 ]) + else: + output = np.array([92, 92, 92, 228, 228, 228 ]) + output = output.astype(out_dtype).reshape(output_shape) + return make_test_configuration(quantized_data=quantized_data_np, + quantized_kernel=quantized_kernel_np, + dtype=in_dtype, + input_shape=input_shape, + kernel_shape=kernel_shape, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, + units=units, + output=output, + bias=bias, + requantize=requant_params) + + def make_int_configuration(use_bias=False, requantize_output=False): + input_shape, kernel_shape, output_shape = (2, 10), (3,10), (2, 3) + input_zero_point, kernel_zero_point = 1, 1 + in_dtype = 'int8' + out_dtype = 'int32' if not requantize_output else 'int8' + units = 3 + quantized_data_np = np.array([1, 3, 5, 7, 9, 11, 13, 15, -19, -21, + 1, 3, 5, 7, 9, 11, 13, -17, 17, -21]) \ + .astype(in_dtype) \ + .reshape(input_shape) + quantized_kernel_np = np.array([1, 3, 5, 7, 9, 11, 13, 15, 17, 19, + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19]) \ + .astype(in_dtype) \ + .reshape(kernel_shape) + bias = np.array([4, 8, 12]).astype(out_dtype).reshape((units, )) if use_bias else None + requant_params = make_requantize_params(0.25, 1.0, -1, 'int8') if requantize_output else None + + if requantize_output: + assert use_bias + output = np.array([23, 24, 25, 57, 58, 59]) + elif use_bias: + output = np.array([96, 100, 104, 232, 236, 240 ]) + else: + output = np.array([92, 92, 92, 228, 228, 228 ]) + output = output.astype(out_dtype).reshape(output_shape) + return make_test_configuration(quantized_data=quantized_data_np, + quantized_kernel=quantized_kernel_np, + dtype=in_dtype, + input_shape=input_shape, + kernel_shape=kernel_shape, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, + units=units, + output=output, + bias=bias, + requantize=requant_params) + + def test_quantized_convolution(test_configuration): + in_dtype = test_configuration['dtype'] + out_dtype = test_configuration['out_dtype'] + quantized_data_name = "quantized_data" + quantized_kernel_name = "quantized_kernel" + expected_out_dtype = test_configuration['out_dtype'] + bias_name = 'bias' + quantized_data = relay.var(quantized_data_name, shape=test_configuration['input_shape'], + dtype=in_dtype) + quantized_kernel = relay.var(quantized_kernel_name, shape=test_configuration['kernel_shape'], + dtype=in_dtype) + mod = relay.qnn.op.quantized_dense( + quantized_data, + quantized_kernel, + test_configuration['input_zero_point'], + test_configuration['kernel_zero_point'], + test_configuration['units']) + if test_configuration[bias_name] is not None: + bias = relay.var(bias_name, shape=test_configuration['bias'].shape, dtype=out_dtype) + mod = relay.nn.bias_add(mod, bias) + if test_configuration['requantize'] is not None: + requantize_config = test_configuration['requantize'] + mod = relay.qnn.op.requantize( + mod, + input_scale=requantize_config['input_scale'], + input_zero_point=0, + output_scale=requantize_config['output_scale'], + output_zero_point=requantize_config['output_zero_point'], + out_dtype=requantize_config['out_dtype']) + expected_out_dtype = requantize_config['out_dtype'] + mod = relay.Function(relay.analysis.free_vars(mod), mod) + mod = relay.Module.from_expr(mod) + mod = relay.qnn.transform.QnnLower()(mod) + with relay.build_config(opt_level=3): + graph, lib, params = relay.build(mod, "llvm", params=None) + mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) + mod.set_input(quantized_data_name,test_configuration[quantized_data_name]) + mod.set_input(quantized_kernel_name,test_configuration[quantized_kernel_name]) + if test_configuration[bias_name] is not None: + mod.set_input(bias_name, test_configuration[bias_name]) + mod.set_input(**params) + mod.run() + res = mod.get_output(0).asnumpy() + np.testing.assert_equal(res, test_configuration['output']) + assert res.dtype == expected_out_dtype + + def test_configurations(): + test_prams = [{'use_bias': False}, {'use_bias': True}, {'use_bias': True, 'requantize_output': True}, ] + tests = [test_quantized_convolution] + configurations = [] + for test_param in test_prams: + configurations.append(make_uint_configuration(**test_param)) + configurations.append(make_int_configuration(**test_param)) + for configuration in configurations: + for test in tests: + test(configuration) + + test_configurations() + +if __name__ == "__main__": + test_quantized_dense() diff --git a/tests/python/relay/test_qnn_requantize.py b/tests/python/relay/test_qnn_requantize.py new file mode 100644 index 000000000000..fa77f52e7139 --- /dev/null +++ b/tests/python/relay/test_qnn_requantize.py @@ -0,0 +1,426 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import numpy as np +from tvm import relay +from tvm.relay.testing import create_workload +from tvm.contrib import graph_runtime + +roundings = ["UPWARD", "AWAY_FROM_ZERO"] + +def run_infer_type(expr): + mod = relay.Module.from_expr(expr) + mod = relay.transform.InferType()(mod) + entry = mod["main"] + return entry if isinstance(expr, relay.Function) else entry.body + + +def test_requantize(): + def verify(mod, goldens): + with relay.build_config(opt_level=3): + graph, lib, params = relay.build(mod, "llvm", params=None) + golden_data, golden_output = goldens + rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) + rt_mod.set_input("quantized_data",golden_data) + rt_mod.set_input(**params) + rt_mod.run() + res = rt_mod.get_output(0).asnumpy() + np.testing.assert_equal(res, golden_output) + + def get_mod(data_shape, data_dtype, out_dtype, input_scale, output_scale, + input_zero_point=0, output_zero_point=0, rounding="AWAY_FROM_ZERO"): + quantized_data = relay.var("quantized_data", shape=data_shape, + dtype=data_dtype) + mod = relay.qnn.op.requantize( + quantized_data, + input_scale=input_scale, + input_zero_point=input_zero_point, + output_scale=output_scale, + output_zero_point=output_zero_point, + rounding=rounding, + out_dtype=out_dtype) + + mod = relay.Function(relay.analysis.free_vars(mod), mod) + mod = relay.Module.from_expr(mod) + mod = relay.qnn.transform.QnnLower()(mod) + return mod + + def same_scale_test(): + # Have same scales, everything within range + golden_data = np.arange(-100, 100, 1).astype('int32') + golden_output = golden_data + + for rounding in roundings: + mod = get_mod(data_shape=(200, ), + data_dtype='int32', + out_dtype="int8", + input_scale=0.5, + output_scale=0.5, + rounding=rounding) + verify(mod, (golden_data, golden_output)) + + def downscale_test(): + for rounding in roundings: + mod = get_mod(data_shape=(32, ), + data_dtype='int32', + out_dtype='int8', + input_scale=1, + output_scale=16, + rounding=rounding) + + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype('int32') + golden_output = np.repeat([0, 1, 2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) + + # Try negative values + # -8 corresponds to -0.5. For UPWARD, this is 0 + golden_data = np.arange(0, -32, -1).astype('int32') + if rounding == "UPWARD": + golden_output = np.repeat([0, -1, -2], [9, 16, 7]) + else: + golden_output = np.repeat([0, -1, -2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) + + # Try a different scale + mod = get_mod(data_shape=(32, ), + data_dtype='int32', + out_dtype="int8", + input_scale=1, + output_scale=4, + rounding=rounding) + + # Try positive values + # 2I corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype('int32') + golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8], + [2, 4, 4, 4, 4, 4, 4, 4, 2]) + verify(mod, (golden_data, golden_output)) + + # Try negative values + # -8 corresponds to -0.5. For UPWARD, this is 0 + golden_data = np.arange(0, -32, -1).astype('int32') + if rounding == "UPWARD": + golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8], + [3, 4, 4, 4, 4, 4, 4, 4, 1]) + else: + golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8], + [2, 4, 4, 4, 4, 4, 4, 4, 2]) + verify(mod, (golden_data, golden_output)) + + # Try uint8 out_dtype + mod = get_mod(data_shape=(32, ), + data_dtype='int32', + out_dtype='uint8', + input_scale=1, + output_scale=16, + rounding=rounding) + + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype('int32') + golden_output = np.repeat([0, 1, 2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) + + # Try uint8 in_dtyope and uint8 out_dtype + mod = get_mod(data_shape=(32, ), + data_dtype='uint8', + out_dtype='uint8', + input_scale=1, + output_scale=16, + rounding=rounding) + + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype('int32') + golden_output = np.repeat([0, 1, 2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) + + def upscale_test(): + for rounding in roundings: + mod = get_mod(data_shape=(32, ), + data_dtype='int32', + out_dtype="int8", + input_scale=2, + output_scale=1, + rounding=rounding) + + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype('int32') + golden_output = np.multiply(2, golden_data) + verify(mod, (golden_data, golden_output)) + + # Try negative values + # -8 corresponds to -0.5. For UPWARD, this is 0 + golden_data = np.arange(0, -32, -1).astype('int32') + golden_output = np.multiply(2, golden_data) + verify(mod, (golden_data, golden_output)) + + def saturation_test(): + for rounding in roundings: + mod = get_mod(data_shape=(16, ), + data_dtype='int32', + out_dtype="int8", + input_scale=0.5, + output_scale=0.5, + rounding=rounding) + golden_data = np.arange(0, 16, 1).astype('int32') + golden_data = np.add(120, golden_data) + output = np.array([120, 121, 122, 123, 124, 125, 126, 127, + 127, 127, 127, 127, 127, 127, 127, 127]) + golden_output = output + verify(mod, (golden_data, golden_output)) + + # Try negative numbers + golden_data = np.arange(0, -16, -1).astype('int32') + golden_data = np.add(-120, golden_data) + output = np.array([-120, -121, -122, -123, -124, -125, -126, -127, + -128, -128, -128, -128, -128, -128, -128, -128]) + golden_output = output + verify(mod, (golden_data, golden_output)) + + def zero_point_test(): + # Output zero point + for rounding in roundings: + mod = get_mod(data_shape=(32, ), + data_dtype='int32', + out_dtype='int8', + input_scale=1, + output_scale=16, + output_zero_point=1, + rounding=rounding) + + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype('int32') + golden_output = np.repeat([0, 1, 2], [8, 16, 8]) + golden_output = np.add(1, golden_output) + verify(mod, (golden_data, golden_output)) + + # Try negative values + # -8 corresponds to -0.5. For UPWARD, this is 0 + golden_data = np.arange(-32, -64, -1).astype('int32') + if rounding == "UPWARD": + golden_output = np.repeat([-2, -3, -4], [9, 16, 7]) + else: + golden_output = np.repeat([-2, -3, -4], [8, 16, 8]) + golden_output = np.add(1, golden_output) + verify(mod, (golden_data, golden_output)) + + # Input zero point + for rounding in roundings: + mod = get_mod(data_shape=(32, ), + data_dtype='int32', + out_dtype='int8', + input_scale=1, + output_scale=16, + input_zero_point=16, + rounding=rounding) + + # Try positive values + golden_data = np.arange(32, 64, 1).astype('int32') + golden_output = np.repeat([2, 3, 4], [8, 16, 8]) + golden_output = np.subtract(golden_output, 1) + verify(mod, (golden_data, golden_output)) + + # Try negative values + golden_data = np.arange(-32, -64, -1).astype('int32') + if rounding == "UPWARD": + golden_output = np.repeat([-2, -3, -4], [9, 16, 7]) + else: + golden_output = np.repeat([-2, -3, -4], [8, 16, 8]) + golden_output = np.subtract(golden_output, 1) + verify(mod, (golden_data, golden_output)) + + same_scale_test() + downscale_test() + upscale_test() + saturation_test() + zero_point_test() + +def test_quantized_dense(): + + def make_requantize_params(input_scale, output_scale, output_zero_point, out_dtype): + config = { + 'input_scale': input_scale, + 'output_scale': output_scale, + 'output_zero_point': output_zero_point, + 'out_dtype': out_dtype + } + return config + + def make_test_configuration(quantized_data, quantized_kernel, dtype, input_shape, kernel_shape, input_zero_point, + kernel_zero_point, units, output, out_dtype='int32', bias=None, requantize=None): + if requantize is not None: + assert bias is not None + config = { + 'quantized_data': quantized_data, + 'quantized_kernel': quantized_kernel, + 'dtype': dtype, + 'input_shape': input_shape, + 'kernel_shape': kernel_shape, + 'input_zero_point': input_zero_point, + 'kernel_zero_point': kernel_zero_point, + 'units': units, + 'output': output, + 'out_dtype': out_dtype, + 'bias': bias, + 'requantize': requantize + } + return config + + def make_uint_configuration(use_bias=False, requantize_output=False): + input_shape, kernel_shape, output_shape = (2, 10), (3,10), (2, 3) + input_zero_point, kernel_zero_point = -127, -127 + in_dtype = 'uint8' + out_dtype = 'int32' if not requantize_output else 'uint8' + units = 3 + quantized_data_np = np.array([129, 131, 133, 135, 137, 139, 141, 143, 109, 107, + 129, 131, 133, 135, 137, 139, 141, 111, 145, 107])\ + .astype(in_dtype)\ + .reshape(input_shape) + quantized_kernel_np = np.array([129, 131, 133, 135, 137, 139, 141, 143, 145, 147, + 129, 131, 133, 135, 137, 139, 141, 143, 145, 147, + 129, 131, 133, 135, 137, 139, 141, 143, 145, 147])\ + .astype(in_dtype)\ + .reshape(kernel_shape) + bias = np.array([4, 8, 12]).astype(out_dtype).reshape((units, )) if use_bias else None + requant_params = make_requantize_params(0.25, 1.0, 127, 'uint8') if requantize_output else None + + if requantize_output: + assert use_bias + output = np.array([151, 152, 153, 185, 186, 187]) + elif use_bias: + output = np.array([96, 100, 104, 232, 236, 240 ]) + else: + output = np.array([92, 92, 92, 228, 228, 228 ]) + output = output.astype(out_dtype).reshape(output_shape) + return make_test_configuration(quantized_data=quantized_data_np, + quantized_kernel=quantized_kernel_np, + dtype=in_dtype, + input_shape=input_shape, + kernel_shape=kernel_shape, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, + units=units, + output=output, + bias=bias, + requantize=requant_params) + + def make_int_configuration(use_bias=False, requantize_output=False): + input_shape, kernel_shape, output_shape = (2, 10), (3,10), (2, 3) + input_zero_point, kernel_zero_point = 1, 1 + in_dtype = 'int8' + out_dtype = 'int32' if not requantize_output else 'int8' + units = 3 + quantized_data_np = np.array([1, 3, 5, 7, 9, 11, 13, 15, -19, -21, + 1, 3, 5, 7, 9, 11, 13, -17, 17, -21]) \ + .astype(in_dtype) \ + .reshape(input_shape) + quantized_kernel_np = np.array([1, 3, 5, 7, 9, 11, 13, 15, 17, 19, + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19]) \ + .astype(in_dtype) \ + .reshape(kernel_shape) + bias = np.array([4, 8, 12]).astype(out_dtype).reshape((units, )) if use_bias else None + requant_params = make_requantize_params(0.25, 1.0, -1, 'int8') if requantize_output else None + + if requantize_output: + assert use_bias + output = np.array([23, 24, 25, 57, 58, 59]) + elif use_bias: + output = np.array([96, 100, 104, 232, 236, 240 ]) + else: + output = np.array([92, 92, 92, 228, 228, 228 ]) + output = output.astype(out_dtype).reshape(output_shape) + return make_test_configuration(quantized_data=quantized_data_np, + quantized_kernel=quantized_kernel_np, + dtype=in_dtype, + input_shape=input_shape, + kernel_shape=kernel_shape, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, + units=units, + output=output, + bias=bias, + requantize=requant_params) + + def test_quantized_convolution(test_configuration): + in_dtype = test_configuration['dtype'] + out_dtype = test_configuration['out_dtype'] + quantized_data_name = "quantized_data" + quantized_kernel_name = "quantized_kernel" + expected_out_dtype = test_configuration['out_dtype'] + bias_name = 'bias' + quantized_data = relay.var(quantized_data_name, shape=test_configuration['input_shape'], + dtype=in_dtype) + quantized_kernel = relay.var(quantized_kernel_name, shape=test_configuration['kernel_shape'], + dtype=in_dtype) + mod = relay.qnn.op.quantized_dense( + quantized_data, + quantized_kernel, + test_configuration['input_zero_point'], + test_configuration['kernel_zero_point'], + test_configuration['units']) + if test_configuration[bias_name] is not None: + bias = relay.var(bias_name, shape=test_configuration['bias'].shape, dtype=out_dtype) + mod = relay.nn.bias_add(mod, bias) + if test_configuration['requantize'] is not None: + requantize_config = test_configuration['requantize'] + mod = relay.qnn.op.requantize( + mod, + input_scale=requantize_config['input_scale'], + input_zero_point=0, + output_scale=requantize_config['output_scale'], + output_zero_point=requantize_config['output_zero_point'], + out_dtype=requantize_config['out_dtype']) + expected_out_dtype = requantize_config['out_dtype'] + mod = relay.Function(relay.analysis.free_vars(mod), mod) + mod = relay.Module.from_expr(mod) + mod = relay.qnn.transform.QnnLower()(mod) + with relay.build_config(opt_level=3): + graph, lib, params = relay.build(mod, "llvm", params=None) + mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) + mod.set_input(quantized_data_name,test_configuration[quantized_data_name]) + mod.set_input(quantized_kernel_name,test_configuration[quantized_kernel_name]) + if test_configuration[bias_name] is not None: + mod.set_input(bias_name, test_configuration[bias_name]) + mod.set_input(**params) + mod.run() + res = mod.get_output(0).asnumpy() + np.testing.assert_equal(res, test_configuration['output']) + assert res.dtype == expected_out_dtype + + def test_configurations(): + test_prams = [{'use_bias': False}, {'use_bias': True}, {'use_bias': True, 'requantize_output': True}, ] + tests = [test_quantized_convolution] + configurations = [] + for test_param in test_prams: + configurations.append(make_uint_configuration(**test_param)) + configurations.append(make_int_configuration(**test_param)) + for configuration in configurations: + for test in tests: + test(configuration) + + test_configurations() + +if __name__ == "__main__": + test_requantize() + test_quantized_dense()