diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h new file mode 100644 index 000000000000..9645b3cf587d --- /dev/null +++ b/include/tvm/relay/qnn/attrs.h @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/qnn/attrs.h + * \brief Auxiliary attributes for quantized nn operators. + */ +#ifndef TVM_RELAY_QNN_ATTRS_H_ +#define TVM_RELAY_QNN_ATTRS_H_ + +#include +#include + +namespace tvm { +namespace relay { + +struct QuantizeAttrs : public tvm::AttrsNode { + int32_t output_zero_point; + double output_scale; + DataType out_dtype; + + TVM_DECLARE_ATTRS(QuantizeAttrs, "relay.attrs.QuantizeAttrs") { + TVM_ATTR_FIELD(out_dtype) + .describe("Output data type, can be one of [int8 or uint8]."); + + TVM_ATTR_FIELD(output_zero_point) + .describe("The zero_point for the activation of this op."); + + TVM_ATTR_FIELD(output_scale) + .describe("The scale for the activation of this op."); + } +}; + +struct DequantizeAttrs : public tvm::AttrsNode { + int32_t input_zero_point; + double input_scale; + + TVM_DECLARE_ATTRS(QuantizeAttrs, "relay.attrs.QuantizeAttrs") { + TVM_ATTR_FIELD(input_zero_point) + .describe("The zero_point for the input tensor of this op."); + + TVM_ATTR_FIELD(input_scale) + .describe("The scale for the input tensor of this op."); + } +}; + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_QNN_ATTRS_H_ diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index dfac85bb1ed2..be78d8bdc353 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -49,6 +49,9 @@ from . import backend from . import quantize +# Dialects +from . import qnn + from .scope_builder import ScopeBuilder # Span diff --git a/python/tvm/relay/qnn/__init__.py b/python/tvm/relay/qnn/__init__.py new file mode 100644 index 000000000000..236b094a6988 --- /dev/null +++ b/python/tvm/relay/qnn/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Neural network related operators.""" +from __future__ import absolute_import as _abs +from . import op +from . import ir_pass diff --git a/python/tvm/relay/qnn/_qnn.py b/python/tvm/relay/qnn/_qnn.py new file mode 100644 index 000000000000..77ecc325ae18 --- /dev/null +++ b/python/tvm/relay/qnn/_qnn.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Internal module for quantization.""" + +from __future__ import absolute_import +from tvm._ffi.function import _init_api + +_init_api("relay._qnn", __name__) diff --git a/python/tvm/relay/qnn/ir_pass.py b/python/tvm/relay/qnn/ir_pass.py new file mode 100644 index 000000000000..ea272dec429a --- /dev/null +++ b/python/tvm/relay/qnn/ir_pass.py @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Automatic quantization toolkit.""" +from __future__ import absolute_import + +from . import _qnn + +def rewrite(expr): + """ + Rewrites the high-level quantized ops into low-level exisiting Relay ops. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression. + + Returns + ------- + expr : tvm.relay.Expr + The output expression. + """ + return _qnn.rewrite(expr) diff --git a/python/tvm/relay/qnn/op/__init__.py b/python/tvm/relay/qnn/op/__init__.py new file mode 100644 index 000000000000..f1c896489fd3 --- /dev/null +++ b/python/tvm/relay/qnn/op/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Neural network related operators.""" + +from __future__ import absolute_import as _abs +from .qnn import * diff --git a/python/tvm/relay/qnn/op/_make.py b/python/tvm/relay/qnn/op/_make.py new file mode 100644 index 000000000000..82d5e5a9cdc3 --- /dev/null +++ b/python/tvm/relay/qnn/op/_make.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Constructor APIs""" + +from ...._ffi.function import _init_api + +_init_api("relay.op.qnn._make", __name__) diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py new file mode 100644 index 000000000000..290fb4912a17 --- /dev/null +++ b/python/tvm/relay/qnn/op/qnn.py @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Neural network operations.""" + +from __future__ import absolute_import as _abs +from . import _make + +def quantize(input_data, output_zero_point, output_scale, out_dtype='int8'): + r""" Quantize op + This operator takes float32 as input and produces quantized int8 or unit8 as output. + The input tensor can be of any shape. The output shape is the same as input shape. + ..math:: + \mbox{out}[x] = + \mbox{clamp(round(input_tensor/output_scale) + output_zero_point); + out_dtype::min, out_dtype::max} + Parameters + ---------- + input_data : tvm.relay.Expr + The input tensor to be quantized. Can be of type float32. + output_zero_point : + The output zero_point. + output_scale: + The output scale. + input_dtype: + The data type of the input tensor. Can be [int8, uint8] + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + return _make.quantize(input_data, output_zero_point, output_scale, out_dtype) + +def dequantize(input_data, input_zero_point, input_scale): + r""" Dequantize op + This operator takes quantized int8 and unit8 as input and produces + dequantized float32 as output. The output shape is the same as input shape. The input + tensor can be of any shape. + Parameters + ---------- + input_data : tvm.relay.Expr + The input tensor to be dequantized. Can be of type [int8, uint8]. + input_zero_point : + The output zero_point. + input_scale: + The output scale. + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + return _make.dequantize(input_data, input_zero_point, input_scale) diff --git a/python/tvm/relay/quantize/__init__.py b/python/tvm/relay/quantize/__init__.py index 45bb62e66853..8da4e7953566 100644 --- a/python/tvm/relay/quantize/__init__.py +++ b/python/tvm/relay/quantize/__init__.py @@ -19,4 +19,5 @@ from __future__ import absolute_import as _abs from .quantize import * +from .rewrite import * from ._annotate import register_annotate_function diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 5c303905968e..906e3193729e 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -34,7 +34,7 @@ #include #include #include - +#include namespace tvm { namespace relay { @@ -373,6 +373,26 @@ inline Expr Copy(Expr data) { } +inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) { + static const Op& op = Op::Get("where"); + return CallNode::make(op, {condition, x, y}); +} + +inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) { + static const Op& op = Op::Get("greater_equal"); + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); +} + +inline Expr Full(Expr fill_value, + Array shape, + DataType dtype) { + auto attrs = make_node(); + attrs->shape = std::move(shape); + attrs->dtype = std::move(dtype); + static const Op& op = Op::Get("full"); + return CallNode::make(op, {fill_value}, Attrs(attrs), {}); +} + Expr MakeConcatenate(Expr data, int axis); Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides); diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc new file mode 100644 index 000000000000..cfaff3f23755 --- /dev/null +++ b/src/relay/qnn/op/dequantize.cc @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/qnn/op/dequantize.cc + * \brief Dequantize operator that converts from quantized domain to + * unquantized domain. + */ + +#include +#include +#include "../util.h" + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(DequantizeAttrs); + +bool DequantizeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + const auto input_dtype = data->dtype; + CHECK(IsValidOpInputType(QuantizeOpType::Dequantize, input_dtype)) + << "Input type should be one of the quantized types [unit8, int8] but was " << input_dtype; + const Array oshape = data->shape; + // assign output type + reporter->Assign(types[1], TensorTypeNode::make(oshape, Float(32))); + return true; +} + +Expr MakeDequantize(Expr data, + int32_t input_zero_point, + double input_scale) { + auto attrs = make_node(); + attrs->input_scale = input_scale; + attrs->input_zero_point = input_zero_point; + static const Op& op = Op::Get("qnn.dequantize"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +RELAY_REGISTER_OP("qnn.dequantize") + .describe(R"code(Quantizes the input and produces quantized output. + +The input is always quantized (int8, uint8) and will be converted to float32 given input scale and zero_point. +- **data**: Quantized tensor of any shape to dequantize. The input data can be of floating point +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.DequantizeAttrs") +.set_num_inputs(1) +.add_argument("data", "Tensor", "The tensor to dequantize.") +.set_support_level(10) +.add_type_rel("Dequantize", DequantizeRel); + +TVM_REGISTER_API("relay.op.qnn._make.dequantize") +.set_body_typed(MakeDequantize); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/qnn/op/quantize_op.cc b/src/relay/qnn/op/quantize_op.cc new file mode 100644 index 000000000000..b3be62742aae --- /dev/null +++ b/src/relay/qnn/op/quantize_op.cc @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/qnn/op/quantize_op.cc + * \brief Quantize operator which converts from unquantized domain representation to quantized domain + * representation. + */ + +#include +#include +#include "../util.h" + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(QuantizeAttrs); + +bool QuantizeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + const auto input_dtype = data->dtype; + CHECK(IsValidOpInputType(QuantizeOpType::Quantize, input_dtype)) + << "Input type should be one of float32 but was " << input_dtype; + const auto* param = attrs.as(); + const Array oshape = data->shape; + const DataType out_dtype = param->out_dtype; + CHECK(IsValidOpOutputType(QuantizeOpType::Quantize, out_dtype)) + << "Output type should be one of [int8, unit8 ] but was " << out_dtype; + // assign output type + reporter->Assign(types[1], TensorTypeNode::make(oshape, out_dtype)); + return true; +} + +Expr MakeQuantize(Expr data, + int32_t output_zero_point, + double output_scale, + DataType out_dtype) { + auto attrs = make_node(); + attrs->output_scale = output_scale; + attrs->output_zero_point = output_zero_point; + attrs->out_dtype = std::move(out_dtype); + static const Op& op = Op::Get("qnn.quantize"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +RELAY_REGISTER_OP("qnn.quantize") +.describe(R"code(Quantizes the input and produces quantized output. + +The input can be either float or quantized(int8, unit8). If the input is float, +this op takes scale and zero point and quantize the float value to +quantized output, in int8 or uint8 format. If the input is quantized value, +the op requantize the input (of a certain type, with a given scale and zero +point) to the output of the same or different type with a same or different +scale and zero point. + +- **data**: Tensor of any shape to quantize. The input data can be of floating point + or quantized. +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.QuantizeAttrs") +.set_num_inputs(1) +.add_argument("data", "Tensor", "The tensor to quantize.") +.set_support_level(10) +.add_type_rel("Quantize", QuantizeRel); + +TVM_REGISTER_API("relay.op.qnn._make.quantize") +.set_body_typed(MakeQuantize); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/qnn/pass/quantize_rewrite.cc b/src/relay/qnn/pass/quantize_rewrite.cc new file mode 100644 index 000000000000..1da4824772ea --- /dev/null +++ b/src/relay/qnn/pass/quantize_rewrite.cc @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/qnn/pass/quantize_rewrite.cc + * \brief Lower quantized ops to existing Relay ops. + */ + +#include +#include +#include +#include +#include "../util.h" +#include "../../pass/pattern_util.h" + +namespace tvm { +namespace relay { + +Expr QuantizeForwardRewrite(const Call& ref_call, const Array& new_args, const NodeRef& ctx) { + CHECK_EQ(new_args.size(), 1); + Expr data = new_args[0]; + const auto* attrs = ref_call->attrs.as(); + const auto out_dtype = attrs->out_dtype; + const auto* new_tensor = data.operator->()->checked_type().as(); + CHECK(new_tensor) << "Expected TensorTypeNode but was " << data.operator->()->checked_type(); + const auto output_zero_point = MakeConstantScalar(Int(32), attrs->output_zero_point); + const auto scale = MakeConstantScalar(Float(32), attrs->output_scale); + const int32_t min_val = GetQmin(out_dtype); + const int32_t max_val = GetQmax(out_dtype); + auto scale_data = Cast(Round(Divide(data, scale)), Int(32)); + // we are trying to do - std::min(std::max(unclamped, min_val), max_val); + auto add_zero_point = Add(scale_data, output_zero_point); + auto clamped_output = Clip(add_zero_point, min_val, max_val); + auto clamp_out_dtype = Cast(clamped_output, out_dtype); + return clamp_out_dtype; +} + +RELAY_REGISTER_OP("qnn.quantize") + .set_attr("FQuantizeForwardRewrite", QuantizeForwardRewrite); + +Expr DequantizeForwardRewrite(const Call& ref_call, const Array& new_args, + const NodeRef& ctx) { + CHECK_EQ(new_args.size(), 1); + Expr data = new_args[0]; + const auto* attrs = ref_call->attrs.as(); + const auto* new_tensor = data.operator->()->checked_type().as(); + CHECK(new_tensor) << "Expected TensorTypeNode but was " << data.operator->()->checked_type(); + const auto input_zero_point = MakeConstantScalar(Int(32), attrs->input_zero_point); + const auto input_scale = MakeConstantScalar(Float(32), attrs->input_scale); + auto shift = Subtract(Cast(data, Int(32)), input_zero_point); + auto scale = Multiply(Cast(shift, Float(32)), input_scale); + return scale; +} + +RELAY_REGISTER_OP("qnn.dequantize") + .set_attr("FQuantizeForwardRewrite", DequantizeForwardRewrite); + +TVM_REGISTER_API("relay._qnn.rewrite").set_body_typed([](const Expr& e) { + Expr ret = ForwardRewrite(e, "FQuantizeForwardRewrite", nullptr, nullptr); + return ret; +}); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h new file mode 100644 index 000000000000..c96227c3667a --- /dev/null +++ b/src/relay/qnn/util.h @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/qnn/util.h + * \brief Utility methods needs for quantized ops that can be shared + */ + +#ifndef TVM_RELAY_QNN_UTIL_H_ +#define TVM_RELAY_QNN_UTIL_H_ + +#include +#include +#include + +namespace tvm { +namespace relay { + +inline bool IsInt8(const DataType& dtype) { + return dtype == Int(8); +} + +inline bool IsUint8(const DataType& dtype) { + return dtype == UInt(8); +} + +inline bool IsInt16(const DataType& dtype) { + return dtype == Int(16); +} + +inline bool IsUint16(const DataType& dtype) { + return dtype == UInt(16); +} + +inline bool IsInt32(const DataType& dtype) { + return dtype == Int(32); +} + +inline bool IsUint32(const DataType& dtype) { + return dtype == UInt(32); +} + +inline bool IsFloat32(const DataType& dtype) { + return dtype == Float(32); +} + +inline bool IsQuantizedType(const DataType& dtype) { + return IsInt8(dtype) || IsUint8(dtype) + || IsInt16(dtype) || IsUint16(dtype); +} + +enum class QuantizeOpType : uint8_t { + Quantize, + Dequantize, + Requantize +}; + +inline bool IsValidOpInputType(const QuantizeOpType& op_type, + const DataType& in_dtype) { + switch (op_type) { + case QuantizeOpType::Quantize: + return IsFloat32(in_dtype); + case QuantizeOpType ::Dequantize: + return IsQuantizedType(in_dtype); + default: + return false; + } +} + +inline bool IsValidOpOutputType(const QuantizeOpType& op_type, + const DataType& in_dtype) { + switch (op_type) { + case QuantizeOpType::Quantize: + return IsQuantizedType(in_dtype); + case QuantizeOpType::Dequantize: + return IsFloat32(in_dtype); + default: + return false; + } +} + +inline const int32_t GetQmin(const DataType& dtype) { + if (IsInt8(dtype)) { + return std::numeric_limits::min(); + } else if (IsUint8(dtype)) { + return std::numeric_limits::min(); + } else if (IsInt16(dtype)) { + return std::numeric_limits::min(); + } else if (IsUint16(dtype)) { + return std::numeric_limits::min(); + } else if (IsInt32(dtype)) { + return std::numeric_limits::min(); + } else if (IsUint32(dtype)) { + return std::numeric_limits::min(); + } + LOG(FATAL) << "Type not supported\n"; + return -1; +} + +inline const int32_t GetQmax(const DataType& dtype) { + if (IsInt8(dtype)) { + return std::numeric_limits::max(); + } else if (IsUint8(dtype)) { + return std::numeric_limits::max(); + } else if (IsInt16(dtype)) { + return std::numeric_limits::max(); + } else if (IsUint16(dtype)) { + return std::numeric_limits::max(); + } else if (IsInt32(dtype)) { + return std::numeric_limits::max(); + } else if (IsUint32(dtype)) { + return std::numeric_limits::max(); + } + LOG(FATAL) << "Type not supported\n"; + return -1; +} + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_QNN_UTIL_H_ diff --git a/tests/python/unittest/test_quantized_ops.py b/tests/python/unittest/test_quantized_ops.py new file mode 100644 index 000000000000..c489ab2ba3aa --- /dev/null +++ b/tests/python/unittest/test_quantized_ops.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import numpy as np +from tvm import relay +from tvm.contrib import graph_runtime + +# TODOs for janimesh before submitting this patch. +# TODO - Add tests for int8 input/weight dtype +# TODO - opt_level=0 fails mostly due to fusion. +# TODO - opt_level=3 fails, likely culprit kernel layout for int8 +# compute. Work with Rankyung to see if this is the culprit. Handle +# it in a separate patch. + +def run_infer_type(expr): + mod = relay.Module.from_expr(expr) + mod = relay.transform.InferType()(mod) + entry = mod["main"] + return entry if isinstance(expr, relay.Function) else entry.body + +def test_quantize_op(): + + def quantize_test_driver(in_dtype, quant_args, out_dtype, in_data, verify_output_data): + shape = in_data.shape + input_data = relay.var("input_data", shape=shape, dtype=in_dtype) + output_zero_point = quant_args['out_zero_point'] + output_scale = quant_args['out_scale'] + quantized_output = relay.qnn.op.quantize(input_data, output_zero_point=output_zero_point, + output_scale=output_scale, out_dtype=out_dtype) + func = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output) + func = run_infer_type(func) + func = relay.qnn.ir_pass.rewrite(func) + func = run_infer_type(func) + graph, lib, params = relay.build(func, "llvm", params=None) + mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) + mod.set_input(input_data=in_data) + mod.run() + res = mod.get_output(0).asnumpy() + np.testing.assert_equal(res, verify_output_data) + assert res.dtype == out_dtype + + def test_float32_to_uint8(): + data = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \ + .astype('float32') \ + .reshape((2,5)) + output = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \ + .astype('uint8') \ + .reshape((2,5)) + quant_args = {"out_zero_point":127, "out_scale":0.5} + quantize_test_driver(in_dtype='float32', quant_args=quant_args, out_dtype='uint8', in_data=data, + verify_output_data=output) + + def test_float32_to_int8(): + data = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \ + .astype('float32') \ + .reshape((2,5)) + output = np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127]) \ + .astype('int8') \ + .reshape((2,5)) + quant_args = {"out_zero_point":-1, "out_scale":0.5} + quantize_test_driver(in_dtype='float32', quant_args=quant_args, out_dtype='int8', in_data=data, + verify_output_data=output) + + test_float32_to_uint8() + test_float32_to_int8() + +def test_dequantize_op(): + + def quantize_test_driver(in_dtype, quant_args, in_data, verify_output_data): + shape = in_data.shape + input_data = relay.var("input_data", shape=shape, dtype=in_dtype) + input_zero_point = quant_args['in_zero_point'] + input_scale = quant_args['in_scale'] + quantized_output = relay.qnn.op.dequantize(input_data, input_zero_point=input_zero_point, + input_scale=input_scale) + func = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output) + func = run_infer_type(func) + func = relay.qnn.ir_pass.rewrite(func) + func = run_infer_type(func) + graph, lib, params = relay.build(func, "llvm", params=None) + mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) + mod.set_input(input_data=in_data) + mod.run() + res = mod.get_output(0).asnumpy() + np.testing.assert_allclose(res, verify_output_data) + assert res.dtype == np.float32 + + def test_uint8_to_float32(): + data = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \ + .astype('uint8') \ + .reshape((2,5)) + output = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \ + .astype('float32') \ + .reshape((2,5)) + quant_args = {"in_zero_point":127, "in_scale":0.5} + quantize_test_driver(in_dtype='uint8', quant_args=quant_args, in_data=data, + verify_output_data=output) + + def test_int8_to_float32(): + data = np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127]) \ + .astype('int8') \ + .reshape((2,5)) + output = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \ + .astype('float32') \ + .reshape((2,5)) + quant_args = {"in_zero_point":-1, "in_scale":0.5} + quantize_test_driver(in_dtype='int8', quant_args=quant_args, in_data=data, + verify_output_data=output) + + test_uint8_to_float32() + test_int8_to_float32() + +if __name__ == "__main__": + test_quantize_op() + test_dequantize_op()