From bb54d0bb095a78e0d56e030040462c996c3550b3 Mon Sep 17 00:00:00 2001 From: shoubhik Date: Thu, 5 Sep 2019 15:59:45 -0700 Subject: [PATCH 1/6] Qnn Dense layer. --- include/tvm/relay/qnn/attrs.h | 22 +++ python/tvm/relay/qnn/op/qnn.py | 41 +++++- python/tvm/relay/qnn/transform.py | 4 +- src/relay/op/nn/convolution.h | 1 + src/relay/op/nn/nn.h | 1 + src/relay/pass/pattern_util.h | 11 ++ src/relay/qnn/op/convolution.cc | 1 - src/relay/qnn/op/dense.cc | 120 ++++++++++++++++ tests/python/relay/test_qnn_conv2d.py | 1 - tests/python/relay/test_qnn_dense.py | 199 ++++++++++++++++++++++++++ 10 files changed, 396 insertions(+), 5 deletions(-) create mode 100644 src/relay/qnn/op/dense.cc create mode 100644 tests/python/relay/test_qnn_dense.py diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h index 1cd4c191be51..c79dfe563a3a 100644 --- a/include/tvm/relay/qnn/attrs.h +++ b/include/tvm/relay/qnn/attrs.h @@ -217,6 +217,28 @@ struct QnnBinaryOpAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes for qnn dense operator */ +struct QnnDenseAttrs : public tvm::AttrsNode { + IndexExpr units; + DataType out_dtype; + // Quantization related attributes. + int32_t input_zero_point; + int32_t kernel_zero_point; + + TVM_DECLARE_ATTRS(QnnDenseAttrs, "relay.attrs.qnn.QnnDenseAttrs") { + TVM_ATTR_FIELD(units) + .describe("Number of hidden units of the dense transformation."); + + TVM_ATTR_FIELD(out_dtype) + .describe("Output data type, set to explicit type under mixed precision setting"); + + TVM_ATTR_FIELD(input_zero_point) + .describe("The zero point of the input tensor."); + TVM_ATTR_FIELD(kernel_zero_point) + .describe("The zero point of the kernel tensor."); + } +}; + } // namespace qnn } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index 570448f60738..c9b43ca28b2d 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -265,7 +265,13 @@ def conv2d(data, data_layout, kernel_layout, out_layout, out_dtype) -def add(lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, +def add(lhs, + rhs, + lhs_scale, + lhs_zero_point, + rhs_scale, + rhs_zero_point, + output_scale, output_zero_point): """Quantized addition with numpy-style broadcasting. @@ -305,3 +311,36 @@ def add(lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_s lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point) + + +def quantized_dense(data, + weight, + input_zero_point, + kernel_zero_point, + units=None, + out_dtype="int32"): + """Dense operator. + Applies a linear transformation + .. math:: + `Y = X * W` + Parameters + ---------- + data : tvm.relay.Expr + The quantied input data to the operator. + weight : tvm.relay.Expr + The quantized weight expressions. + units : int, optional + Number of hidden units of the dense transformation. + out_dtype : str, optional + Specifies the output data type for mixed precision dense can be int32 or int16. + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + return _make.dense(data, + weight, + units, + input_zero_point, + kernel_zero_point, + out_dtype) diff --git a/python/tvm/relay/qnn/transform.py b/python/tvm/relay/qnn/transform.py index 22e8f7f64fff..a76bdaf6310f 100644 --- a/python/tvm/relay/qnn/transform.py +++ b/python/tvm/relay/qnn/transform.py @@ -22,7 +22,7 @@ def CanonicalizeOps(): """Converts/Lowers an expression containing QNN ops to an expression containing only core - (non-Dialect) Relay ops. Each QNN op is lowered to a sequence of exisiting Relay ops. This is a + (non-Dialect) Relay ops. Each QNN op is lowered to a sequence of existing Relay ops. This is a target-independent pass. One can register the lowering/transformation function for this op using FTVMQnnCanonicalize attr_name for FTVMLegalize op attribute. An example of this transformation is below @@ -40,7 +40,7 @@ def CanonicalizeOps(): output_zero_point=0, out_dtype='int8') - # We want to utilize all the existing Relay infrastucture. So, instead of supporting this + # We want to utilize all the existing Relay infrastructure. So, instead of supporting this # QNN requantize op, we convert it into a sequence of existing Relay operators. mod = relay.Module.from_expr(qnn_expr) mod = relay.qnn.transform.CanonicalizeOps()(mod) diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index fb5844749117..c962abc6b756 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -27,6 +27,7 @@ #include #include +#include namespace tvm { namespace relay { diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index 2c65d2526437..54ba3bf939ef 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -26,6 +26,7 @@ #define TVM_RELAY_OP_NN_NN_H_ #include +#include namespace tvm { namespace relay { diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 682e2e337b17..54b141914294 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -434,6 +434,17 @@ static inline Expr Conv2D(Expr data, Expr weight, Array strides, return CallNode::make(op, {data, weight}, Attrs(attrs), {}); } +static inline Expr Dense(Expr data, + Expr weight, + IndexExpr units, + DataType out_dtype) { + auto attrs = make_node(); + attrs->units = units; + attrs->out_dtype = out_dtype; + static const Op& op = Op::Get("nn.dense"); + return CallNode::make(op, {data, weight}, Attrs(attrs), {}); +} + static inline Expr Sum(Expr data, Array axis, bool keepdims, bool exclude) { auto attrs = make_node(); attrs->axis = std::move(axis); diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 6e1d13ec9ed9..b837bef18687 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -23,7 +23,6 @@ * \brief Property def of qnn convolution operator. */ #include -#include #include #include #include diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc new file mode 100644 index 000000000000..09a7f79ba86a --- /dev/null +++ b/src/relay/qnn/op/dense.cc @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file src/relay/qnn/op/dense.cc + * \brief Property def of qnn dense operator. + */ + +#include +#include +#include +#include +#include +#include "../../op/nn/nn.h" +#include "../../pass/pattern_util.h" + +namespace tvm { +namespace relay { +namespace qnn { + +// relay.op.qnn.dense +TVM_REGISTER_NODE_TYPE(QnnDenseAttrs); + +bool QnnDenseRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* weight = types[1].as(); + const auto* param = attrs.as(); + CHECK(data->dtype == Int(8) || data->dtype == UInt(8)) + << "Expected quantized dense type(int8, uint8) for input but was " << data->dtype; + CHECK(weight->dtype == Int(8) || weight->dtype == UInt(8)) + << "Expected quantized dense type(int8, uint8) for weight but was " << weight->dtype; + CHECK(data->dtype == weight->dtype) << "Weight and kernel dtypes do not match"; + CHECK(param->out_dtype == Int(16) || param->out_dtype == Int(32)) + << "Expected quantized dense type(int32, int16) for output but was " << param->out_dtype; + CHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0."; + return DenseRel(types, num_inputs, attrs, reporter); +} + +// Positional relay function to create quantized dense operator used by frontend FFI. +Expr MakeQuantizedDense(Expr data, + Expr weight, + IndexExpr units, + int32_t input_zero_point, + int32_t kernel_zero_point, + DataType out_dtype) { + auto attrs = make_node(); + attrs->units = units; + attrs->out_dtype = out_dtype; + attrs->input_zero_point = input_zero_point; + attrs->kernel_zero_point = kernel_zero_point; + static const Op& op = Op::Get("qnn.dense"); + return CallNode::make(op, {data, weight}, Attrs(attrs), {}); +} + +Expr QnnDenseCanonicalize (const Attrs& attrs, + const Array& new_args, + const Array& arg_types) { + CHECK_EQ(new_args.size(), 2); + Expr quantized_data = new_args[0]; + Expr quantized_kernel = new_args[1]; + const auto* qnn_dense_attrs = attrs.as(); + //TODO: need to benchmark the performance of this lowering. + Expr quantized_data_int32 = Cast(quantized_data, Int(32)); + if(qnn_dense_attrs->input_zero_point != 0) { + quantized_data_int32 = Subtract(quantized_data_int32, + MakeConstantScalar(Int(32), + qnn_dense_attrs->input_zero_point)); + } + Expr quantized_kernel_int32 = Cast(quantized_kernel, Int(32)); + if(qnn_dense_attrs->kernel_zero_point != 0) { + quantized_kernel_int32 = Subtract(quantized_kernel_int32, + MakeConstantScalar(Int(32), + qnn_dense_attrs->kernel_zero_point)); + } + Expr int32_dense = Dense(quantized_data_int32, + quantized_kernel_int32, + qnn_dense_attrs->units, + qnn_dense_attrs->out_dtype); + return int32_dense; +} + +RELAY_REGISTER_OP("qnn.dense") +.describe(R"code(Applies a linear transformation: :math:`Y = XW^T`. +- **data**: quantized(int8, unit8) `(x1, x2, ..., xn, input_dim)` +- **weight**: quantized(int8, unit8) `(units, input_dim)` +- **out**: quantized(int32) `(x1, x2, ..., xn, units)`. +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.qnn.QnnDenseAttrs") +.set_num_inputs(2) +.add_argument("data", "quantized nD Tensor", "Input data.") +.add_argument("weight", "quantized 2D Tensor", "Weight matrix.") +.set_support_level(11) +.add_type_rel("QDense", DenseRel) +.set_attr("FTVMQnnCanonicalize", QnnDenseCanonicalize); + +TVM_REGISTER_API("relay.qnn.op._make.dense") +.set_body_typed(MakeQuantizedDense); + +} // namespace qnn +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_qnn_conv2d.py b/tests/python/relay/test_qnn_conv2d.py index 99cb482e2610..dd4ad8d491fa 100644 --- a/tests/python/relay/test_qnn_conv2d.py +++ b/tests/python/relay/test_qnn_conv2d.py @@ -19,7 +19,6 @@ import numpy as np from tvm import relay from tvm.relay import transform -from tvm.relay.testing import create_workload from tvm.relay.testing import run_infer_type from tvm.contrib import graph_runtime diff --git a/tests/python/relay/test_qnn_dense.py b/tests/python/relay/test_qnn_dense.py new file mode 100644 index 000000000000..7d447758afda --- /dev/null +++ b/tests/python/relay/test_qnn_dense.py @@ -0,0 +1,199 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import numpy as np +from tvm import relay +from tvm.relay import transform +from tvm.relay.testing import run_infer_type +from tvm.contrib import graph_runtime + +import tvm +import numpy as np +from tvm import relay +from tvm.contrib import graph_runtime + + +def test_quantized_dense(): + + def make_requantize_params(input_scale, output_scale, output_zero_point, out_dtype): + config = { + 'input_scale': input_scale, + 'output_scale': output_scale, + 'output_zero_point': output_zero_point, + 'out_dtype': out_dtype + } + return config + + def make_test_configuration(quantized_data, quantized_kernel, dtype, input_shape, kernel_shape, input_zero_point, + kernel_zero_point, units, output, out_dtype='int32', bias=None, requantize=None): + if requantize is not None: + assert bias is not None + config = { + 'quantized_data': quantized_data, + 'quantized_kernel': quantized_kernel, + 'dtype': dtype, + 'input_shape': input_shape, + 'kernel_shape': kernel_shape, + 'input_zero_point': input_zero_point, + 'kernel_zero_point': kernel_zero_point, + 'units': units, + 'output': output, + 'out_dtype': out_dtype, + 'bias': bias, + 'requantize': requantize + } + return config + + def make_uint_configuration(use_bias=False, requantize_output=False): + input_shape, kernel_shape, output_shape = (2, 10), (3,10), (2, 3) + input_zero_point, kernel_zero_point = 127, 127 + in_dtype = 'uint8' + out_dtype = 'int32' if not requantize_output else 'uint8' + units = 3 + quantized_data_np = np.array([129, 131, 133, 135, 137, 139, 141, 143, 109, 107, + 129, 131, 133, 135, 137, 139, 141, 111, 145, 107]) \ + .astype(in_dtype) \ + .reshape(input_shape) + quantized_kernel_np = np.array([129, 131, 133, 135, 137, 139, 141, 143, 145, 147, + 129, 131, 133, 135, 137, 139, 141, 143, 145, 147, + 129, 131, 133, 135, 137, 139, 141, 143, 145, 147]) \ + .astype(in_dtype) \ + .reshape(kernel_shape) + bias = np.array([4, 8, 12]).astype(out_dtype).reshape((units, )) if use_bias else None + requant_params = make_requantize_params(0.25, 1.0, 127, 'uint8') if requantize_output else None + + if requantize_output: + assert use_bias + output = np.array([151, 152, 153, 185, 186, 187]) + elif use_bias: + output = np.array([96, 100, 104, 232, 236, 240 ]) + else: + output = np.array([92, 92, 92, 228, 228, 228 ]) + output = output.astype(out_dtype).reshape(output_shape) + return make_test_configuration(quantized_data=quantized_data_np, + quantized_kernel=quantized_kernel_np, + dtype=in_dtype, + input_shape=input_shape, + kernel_shape=kernel_shape, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, + units=units, + output=output, + bias=bias, + requantize=requant_params) + + def make_int_configuration(use_bias=False, requantize_output=False): + input_shape, kernel_shape, output_shape = (2, 10), (3,10), (2, 3) + input_zero_point, kernel_zero_point = -1, -1 + in_dtype = 'int8' + out_dtype = 'int32' if not requantize_output else 'int8' + units = 3 + quantized_data_np = np.array([1, 3, 5, 7, 9, 11, 13, 15, -19, -21, + 1, 3, 5, 7, 9, 11, 13, -17, 17, -21]) \ + .astype(in_dtype) \ + .reshape(input_shape) + quantized_kernel_np = np.array([1, 3, 5, 7, 9, 11, 13, 15, 17, 19, + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19]) \ + .astype(in_dtype) \ + .reshape(kernel_shape) + bias = np.array([4, 8, 12]).astype(out_dtype).reshape((units, )) if use_bias else None + requant_params = make_requantize_params(0.25, 1.0, -1, 'int8') if requantize_output else None + + if requantize_output: + assert use_bias + output = np.array([23, 24, 25, 57, 58, 59]) + elif use_bias: + output = np.array([96, 100, 104, 232, 236, 240 ]) + else: + output = np.array([92, 92, 92, 228, 228, 228 ]) + output = output.astype(out_dtype).reshape(output_shape) + return make_test_configuration(quantized_data=quantized_data_np, + quantized_kernel=quantized_kernel_np, + dtype=in_dtype, + input_shape=input_shape, + kernel_shape=kernel_shape, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, + units=units, + output=output, + bias=bias, + requantize=requant_params) + + def test_quantized_dense(test_configuration): + in_dtype = test_configuration['dtype'] + out_dtype = test_configuration['out_dtype'] + quantized_data_name = "quantized_data" + quantized_kernel_name = "quantized_kernel" + expected_out_dtype = test_configuration['out_dtype'] + bias_name = 'bias' + quantized_data = relay.var(quantized_data_name, shape=test_configuration['input_shape'], + dtype=in_dtype) + quantized_kernel = relay.var(quantized_kernel_name, shape=test_configuration['kernel_shape'], + dtype=in_dtype) + mod = relay.qnn.op.quantized_dense( + quantized_data, + quantized_kernel, + test_configuration['input_zero_point'], + test_configuration['kernel_zero_point'], + test_configuration['units']) + if test_configuration[bias_name] is not None: + bias = relay.var(bias_name, shape=test_configuration['bias'].shape, dtype=out_dtype) + mod = relay.nn.bias_add(mod, bias) + if test_configuration['requantize'] is not None: + requantize_config = test_configuration['requantize'] + mod = relay.qnn.op.requantize( + mod, + input_scale=requantize_config['input_scale'], + input_zero_point=0, + output_scale=requantize_config['output_scale'], + output_zero_point=requantize_config['output_zero_point'], + out_dtype=requantize_config['out_dtype']) + expected_out_dtype = requantize_config['out_dtype'] + + mod = relay.Function(relay.analysis.free_vars(mod), mod) + mod = relay.Module.from_expr(mod) + mod = relay.qnn.transform.CanonicalizeOps()(mod) + with relay.build_config(opt_level=2): + graph, lib, params = relay.build(mod, "llvm", params=None) + mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) + mod.set_input(quantized_data_name,test_configuration[quantized_data_name]) + mod.set_input(quantized_kernel_name,test_configuration[quantized_kernel_name]) + if test_configuration[bias_name] is not None: + mod.set_input(bias_name, test_configuration[bias_name]) + mod.set_input(**params) + mod.run() + res = mod.get_output(0).asnumpy() + np.testing.assert_equal(res, test_configuration['output']) + assert res.dtype == expected_out_dtype + + def test_configurations(): + test_prams = [{'use_bias': False}, {'use_bias': True}, {'use_bias': True, 'requantize_output': True}, ] + tests = [test_quantized_dense] + configurations = [] + for test_param in test_prams: + configurations.append(make_uint_configuration(**test_param)) + configurations.append(make_int_configuration(**test_param)) + for configuration in configurations: + for test in tests: + test(configuration) + + test_configurations() + +if __name__ == "__main__": + test_quantized_dense() From cf4383b7b7e66722e471883772463438f0cd41a1 Mon Sep 17 00:00:00 2001 From: shoubhik Date: Fri, 6 Sep 2019 12:59:25 -0700 Subject: [PATCH 2/6] Reformatting code. --- include/tvm/relay/qnn/attrs.h | 14 -------------- python/tvm/relay/qnn/op/qnn.py | 10 +++++----- src/relay/op/nn/nn.h | 1 - src/relay/qnn/op/dense.cc | 11 ++++++----- 4 files changed, 11 insertions(+), 25 deletions(-) diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h index c79dfe563a3a..83b55b04222a 100644 --- a/include/tvm/relay/qnn/attrs.h +++ b/include/tvm/relay/qnn/attrs.h @@ -74,10 +74,8 @@ struct QuantizeAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(QuantizeAttrs, "relay.attrs.QuantizeAttrs") { TVM_ATTR_FIELD(out_dtype) .describe("Output data type, can be one of [int8 or uint8]."); - TVM_ATTR_FIELD(output_zero_point) .describe("The zero_point for the activation of this op."); - TVM_ATTR_FIELD(output_scale) .describe("The scale for the activation of this op."); } @@ -91,7 +89,6 @@ struct DequantizeAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(DequantizeAttrs, "relay.attrs.DequantizeAttrs") { TVM_ATTR_FIELD(input_zero_point) .describe("The zero_point for the input tensor of this op."); - TVM_ATTR_FIELD(input_scale) .describe("The scale for the input tensor of this op."); } @@ -108,16 +105,12 @@ struct QnnConcatenateAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(QnnConcatenateAttrs, "relay.attrs.QnnConcatenateAttrs") { TVM_ATTR_FIELD(input_scales) .describe("The list of scales of input quantized tensors."); - TVM_ATTR_FIELD(input_zero_points) .describe("The list of zero points of input quantized tensors."); - TVM_ATTR_FIELD(output_zero_point) .describe("The zero_point for the output tensor."); - TVM_ATTR_FIELD(output_scale) .describe("The scale for the output tensor."); - TVM_ATTR_FIELD(axis) .describe("The axis at which the input arrays are concatenated." "Should lie in range `[-ndim, ndim)`.") @@ -199,19 +192,14 @@ struct QnnBinaryOpAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(QnnBinaryOpAttrs, "relay.attrs.QnnBinaryOpAttrs") { TVM_ATTR_FIELD(lhs_zero_point) .describe("The zero_point for the lhs input tensor of this op."); - TVM_ATTR_FIELD(lhs_scale) .describe("The scale for the lhs input tensor of this op."); - TVM_ATTR_FIELD(rhs_zero_point) .describe("The zero_point for the rhs input tensor of this op."); - TVM_ATTR_FIELD(rhs_scale) .describe("The scale for the rhs input tensor of this op."); - TVM_ATTR_FIELD(output_zero_point) .describe("The zero_point for the activation of this op."); - TVM_ATTR_FIELD(output_scale) .describe("The scale for the activation of this op."); } @@ -228,10 +216,8 @@ struct QnnDenseAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(QnnDenseAttrs, "relay.attrs.qnn.QnnDenseAttrs") { TVM_ATTR_FIELD(units) .describe("Number of hidden units of the dense transformation."); - TVM_ATTR_FIELD(out_dtype) .describe("Output data type, set to explicit type under mixed precision setting"); - TVM_ATTR_FIELD(input_zero_point) .describe("The zero point of the input tensor."); TVM_ATTR_FIELD(kernel_zero_point) diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index c9b43ca28b2d..8719918a2789 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -96,7 +96,7 @@ def quantize(data, The output zero_point. output_scale : float The output scale. - input_dtype : str, optional + out_dtype : str, optional The data type of the input tensor. Can be [int8, uint8] Returns ------- @@ -319,14 +319,14 @@ def quantized_dense(data, kernel_zero_point, units=None, out_dtype="int32"): - """Dense operator. - Applies a linear transformation + """Qnn Dense operator. + Applies a quantized linear transformation .. math:: `Y = X * W` - Parameters + Parameters ---------- data : tvm.relay.Expr - The quantied input data to the operator. + The quantized input data to the operator. weight : tvm.relay.Expr The quantized weight expressions. units : int, optional diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index 54ba3bf939ef..2c65d2526437 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -26,7 +26,6 @@ #define TVM_RELAY_OP_NN_NN_H_ #include -#include namespace tvm { namespace relay { diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index 09a7f79ba86a..3dcba1efee00 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -27,7 +27,6 @@ #include #include #include -#include #include "../../op/nn/nn.h" #include "../../pass/pattern_util.h" @@ -38,7 +37,9 @@ namespace qnn { // relay.op.qnn.dense TVM_REGISTER_NODE_TYPE(QnnDenseAttrs); -bool QnnDenseRel(const Array& types, int num_inputs, const Attrs& attrs, +bool QnnDenseRel(const Array& types, + int num_inputs, + const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -71,9 +72,9 @@ Expr MakeQuantizedDense(Expr data, return CallNode::make(op, {data, weight}, Attrs(attrs), {}); } -Expr QnnDenseCanonicalize (const Attrs& attrs, - const Array& new_args, - const Array& arg_types) { +Expr QnnDenseCanonicalize(const Attrs& attrs, + const Array& new_args, + const Array& arg_types) { CHECK_EQ(new_args.size(), 2); Expr quantized_data = new_args[0]; Expr quantized_kernel = new_args[1]; From 10ac31302b819cf95e40d2b16288b85dcaed77a3 Mon Sep 17 00:00:00 2001 From: shoubhik Date: Fri, 6 Sep 2019 14:34:36 -0700 Subject: [PATCH 3/6] Reformatting code and making the test case more readable. --- python/tvm/relay/qnn/op/qnn.py | 1 + src/relay/pass/pattern_util.h | 6 +- src/relay/qnn/op/convolution.cc | 14 +- src/relay/qnn/op/dense.cc | 24 +- tests/python/relay/test_qnn_dense.py | 350 ++++++++++++++------------- 5 files changed, 206 insertions(+), 189 deletions(-) diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index 8719918a2789..8ff29d2f6154 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -323,6 +323,7 @@ def quantized_dense(data, Applies a quantized linear transformation .. math:: `Y = X * W` + Parameters ---------- data : tvm.relay.Expr diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 54b141914294..e7ee9f1ff83f 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -435,9 +435,9 @@ static inline Expr Conv2D(Expr data, Expr weight, Array strides, } static inline Expr Dense(Expr data, - Expr weight, - IndexExpr units, - DataType out_dtype) { + Expr weight, + IndexExpr units, + DataType out_dtype) { auto attrs = make_node(); attrs->units = units; attrs->out_dtype = out_dtype; diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index b837bef18687..82538a28b565 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -157,7 +157,7 @@ Expr Conv2DPadInput(const Expr& data, const QnnConv2DAttrs* param) { * \param data The input expr. * \param weight The weight expr. * \param param The qnn conv2d attributes. - * \return The sequence of Relay operatos for term1. + * \return The sequence of Relay operators for term1. * \note The term1 is * Sigma(c,r,s) QW(k, c, r, s) * QA(n, c, h + r, w + s) * This is just conv2d on int tensors. @@ -177,12 +177,12 @@ Expr Conv2DFirstTerm(const Expr& padded_data, const Expr& weight, const QnnConv2 * \param param The qnn conv2d attributes. * \param kernel_h The height of kernel. * \param kernel_w The width of kernel. - * \return The sequence of Relay operatos for term2. + * \return The sequence of Relay operators for term2. * \note The term2 looks like this * * Sigma(c,r,s) zp_w * QA(n, c, h + r, w + s) * - * Second term is not directly represetable by one Relay operator. + * Second term is not directly representable by one Relay operator. * However, deeper analysis shows that we can reduce r,s using avg_pool2d, * followed by a reduce on the C axis. Using avg_pool2d also gives an * opportunity to reuse alter_op_layout infrastructure. @@ -292,7 +292,7 @@ Expr Conv2DThirdTerm(const Expr& weight, const Expr& zp_data, const QnnConv2DAtt * \param in_channels The number of input channels. * \param kernel_h The height of kernel. * \param kernel_w The width of kernel. - * \return The sequence of Relay operatos for term4. + * \return The sequence of Relay operators for term4. * \note The term4 looks like this * * Sigma(c,r,s) zp_a * zp_w @@ -352,7 +352,7 @@ Expr Conv2DCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3, * where QA is quantized tensor, scale_a and zp_A are quantizations * params. * - * Quantized convlution convolves two quantized tensors and returns a + * Quantized convolution will convolve two quantized tensors and returns a * quantized tensor of default dtype of int32, with scale equaling to the * product of scales of input tensors, and a zero point of zero. * @@ -378,7 +378,7 @@ Expr Conv2DCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3, * zero point. This might leave some performance opportunity at the * table. Can be avoided by modifying conv2d API to accept the * pad_const_value. - * 2) Second term is not directly represetable by one Relay operator. + * 2) Second term is not directly representable by one Relay operator. * However, deeper analysis shows that we can reduce r,s using * avg_pool2d, followed by a reduce on the C axis. Using avg_pool2d also * gives an opportunity to reuse alter_op_layout infrastructure. @@ -387,7 +387,7 @@ Expr Conv2DCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3, * the conv is dilated. We fallback also in case of depthwise conv. * * The whole process can be broken down into following steps - * * Assertion checks for exisiting support, fallback if necessary + * * Assertion checks for existing support, fallback if necessary * * Pad the input. * * Get Term1. * * Get Term2. diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index 3dcba1efee00..bfb7fabf0d5b 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -44,14 +44,15 @@ bool QnnDenseRel(const Array& types, CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); const auto* weight = types[1].as(); + if(data == nullptr || weight == nullptr) return false; const auto* param = attrs.as(); + CHECK(param != nullptr) << "QnnConv2DAttrs cannot be nullptr."; CHECK(data->dtype == Int(8) || data->dtype == UInt(8)) << "Expected quantized dense type(int8, uint8) for input but was " << data->dtype; CHECK(weight->dtype == Int(8) || weight->dtype == UInt(8)) << "Expected quantized dense type(int8, uint8) for weight but was " << weight->dtype; - CHECK(data->dtype == weight->dtype) << "Weight and kernel dtypes do not match"; - CHECK(param->out_dtype == Int(16) || param->out_dtype == Int(32)) - << "Expected quantized dense type(int32, int16) for output but was " << param->out_dtype; + CHECK(param->out_dtype == Int(32)) + << "Expected quantized dense type(int32) for output but was " << param->out_dtype; CHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0."; return DenseRel(types, num_inputs, attrs, reporter); } @@ -64,7 +65,7 @@ Expr MakeQuantizedDense(Expr data, int32_t kernel_zero_point, DataType out_dtype) { auto attrs = make_node(); - attrs->units = units; + attrs->units = std::move(units); attrs->out_dtype = out_dtype; attrs->input_zero_point = input_zero_point; attrs->kernel_zero_point = kernel_zero_point; @@ -72,6 +73,16 @@ Expr MakeQuantizedDense(Expr data, return CallNode::make(op, {data, weight}, Attrs(attrs), {}); } +/** + * \brief Lowers Qnn convolution in terms of core operators in relay. + * Mathematically it is equals to - + * Dense((quantized_input - input_zero_point;int32), (quantized_kernel - kernel_zero_point; int32)) + * + * \param attrs QnnDenseAttrs for Qnn Dense layer. + * \param new_args The new mutated args to the call node. + * \param arg_types The data types of input and output. + * \reutrn The sequence of Relay ops for qnn cov2d op. + */ Expr QnnDenseCanonicalize(const Attrs& attrs, const Array& new_args, const Array& arg_types) { @@ -79,18 +90,17 @@ Expr QnnDenseCanonicalize(const Attrs& attrs, Expr quantized_data = new_args[0]; Expr quantized_kernel = new_args[1]; const auto* qnn_dense_attrs = attrs.as(); - //TODO: need to benchmark the performance of this lowering. Expr quantized_data_int32 = Cast(quantized_data, Int(32)); if(qnn_dense_attrs->input_zero_point != 0) { quantized_data_int32 = Subtract(quantized_data_int32, MakeConstantScalar(Int(32), - qnn_dense_attrs->input_zero_point)); + qnn_dense_attrs->input_zero_point)); } Expr quantized_kernel_int32 = Cast(quantized_kernel, Int(32)); if(qnn_dense_attrs->kernel_zero_point != 0) { quantized_kernel_int32 = Subtract(quantized_kernel_int32, MakeConstantScalar(Int(32), - qnn_dense_attrs->kernel_zero_point)); + qnn_dense_attrs->kernel_zero_point)); } Expr int32_dense = Dense(quantized_data_int32, quantized_kernel_int32, diff --git a/tests/python/relay/test_qnn_dense.py b/tests/python/relay/test_qnn_dense.py index 7d447758afda..233bf46587b1 100644 --- a/tests/python/relay/test_qnn_dense.py +++ b/tests/python/relay/test_qnn_dense.py @@ -18,182 +18,188 @@ import tvm import numpy as np from tvm import relay -from tvm.relay import transform -from tvm.relay.testing import run_infer_type from tvm.contrib import graph_runtime -import tvm -import numpy as np -from tvm import relay -from tvm.contrib import graph_runtime +def make_requantize_params(input_scale, output_scale, output_zero_point, out_dtype): + config = { + 'input_scale': input_scale, + 'output_scale': output_scale, + 'output_zero_point': output_zero_point, + 'out_dtype': out_dtype + } + return config + + +def make_test_configuration(quantized_data, quantized_kernel, dtype, input_shape, kernel_shape, input_zero_point, + kernel_zero_point, units, output, out_dtype='int32', bias=None, requantize=None): + if requantize is not None: + assert bias is not None + config = { + 'quantized_data': quantized_data, + 'quantized_kernel': quantized_kernel, + 'dtype': dtype, + 'input_shape': input_shape, + 'kernel_shape': kernel_shape, + 'input_zero_point': input_zero_point, + 'kernel_zero_point': kernel_zero_point, + 'units': units, + 'output': output, + 'out_dtype': out_dtype, + 'bias': bias, + 'requantize': requantize + } + return config + + +def make_uint_configuration(use_bias=False, requantize_output=False): + input_shape, kernel_shape, output_shape = (2, 10), (3,10), (2, 3) + input_zero_point, kernel_zero_point = 127, 127 + in_dtype = 'uint8' + out_dtype = 'int32' if not requantize_output else 'uint8' + units = 3 + quantized_data_np = np.array([129, 131, 133, 135, 137, 139, 141, 143, 109, 107, + 129, 131, 133, 135, 137, 139, 141, 111, 145, 107]) \ + .astype(in_dtype) \ + .reshape(input_shape) + quantized_kernel_np = np.array([129, 131, 133, 135, 137, 139, 141, 143, 145, 147, + 129, 131, 133, 135, 137, 139, 141, 143, 145, 147, + 129, 131, 133, 135, 137, 139, 141, 143, 145, 147]) \ + .astype(in_dtype) \ + .reshape(kernel_shape) + bias = np.array([4, 8, 12]).astype(out_dtype).reshape((units, )) if use_bias else None + requant_params = make_requantize_params(0.25, 1.0, 127, 'uint8') if requantize_output else None + + if requantize_output: + assert use_bias + output = np.array([151, 152, 153, 185, 186, 187]) + elif use_bias: + output = np.array([96, 100, 104, 232, 236, 240 ]) + else: + output = np.array([92, 92, 92, 228, 228, 228 ]) + output = output.astype(out_dtype).reshape(output_shape) + return make_test_configuration(quantized_data=quantized_data_np, + quantized_kernel=quantized_kernel_np, + dtype=in_dtype, + input_shape=input_shape, + kernel_shape=kernel_shape, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, + units=units, + output=output, + bias=bias, + requantize=requant_params) + + +def make_int_configuration(use_bias=False, requantize_output=False): + input_shape, kernel_shape, output_shape = (2, 10), (3,10), (2, 3) + input_zero_point, kernel_zero_point = -1, -1 + in_dtype = 'int8' + out_dtype = 'int32' if not requantize_output else 'int8' + units = 3 + quantized_data_np = np.array([1, 3, 5, 7, 9, 11, 13, 15, -19, -21, + 1, 3, 5, 7, 9, 11, 13, -17, 17, -21]) \ + .astype(in_dtype) \ + .reshape(input_shape) + quantized_kernel_np = np.array([1, 3, 5, 7, 9, 11, 13, 15, 17, 19, + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19]) \ + .astype(in_dtype) \ + .reshape(kernel_shape) + bias = np.array([4, 8, 12]).astype(out_dtype).reshape((units, )) if use_bias else None + requant_params = make_requantize_params(0.25, 1.0, -1, 'int8') if requantize_output else None -def test_quantized_dense(): - - def make_requantize_params(input_scale, output_scale, output_zero_point, out_dtype): - config = { - 'input_scale': input_scale, - 'output_scale': output_scale, - 'output_zero_point': output_zero_point, - 'out_dtype': out_dtype - } - return config - - def make_test_configuration(quantized_data, quantized_kernel, dtype, input_shape, kernel_shape, input_zero_point, - kernel_zero_point, units, output, out_dtype='int32', bias=None, requantize=None): - if requantize is not None: - assert bias is not None - config = { - 'quantized_data': quantized_data, - 'quantized_kernel': quantized_kernel, - 'dtype': dtype, - 'input_shape': input_shape, - 'kernel_shape': kernel_shape, - 'input_zero_point': input_zero_point, - 'kernel_zero_point': kernel_zero_point, - 'units': units, - 'output': output, - 'out_dtype': out_dtype, - 'bias': bias, - 'requantize': requantize - } - return config - - def make_uint_configuration(use_bias=False, requantize_output=False): - input_shape, kernel_shape, output_shape = (2, 10), (3,10), (2, 3) - input_zero_point, kernel_zero_point = 127, 127 - in_dtype = 'uint8' - out_dtype = 'int32' if not requantize_output else 'uint8' - units = 3 - quantized_data_np = np.array([129, 131, 133, 135, 137, 139, 141, 143, 109, 107, - 129, 131, 133, 135, 137, 139, 141, 111, 145, 107]) \ - .astype(in_dtype) \ - .reshape(input_shape) - quantized_kernel_np = np.array([129, 131, 133, 135, 137, 139, 141, 143, 145, 147, - 129, 131, 133, 135, 137, 139, 141, 143, 145, 147, - 129, 131, 133, 135, 137, 139, 141, 143, 145, 147]) \ - .astype(in_dtype) \ - .reshape(kernel_shape) - bias = np.array([4, 8, 12]).astype(out_dtype).reshape((units, )) if use_bias else None - requant_params = make_requantize_params(0.25, 1.0, 127, 'uint8') if requantize_output else None - - if requantize_output: - assert use_bias - output = np.array([151, 152, 153, 185, 186, 187]) - elif use_bias: - output = np.array([96, 100, 104, 232, 236, 240 ]) - else: - output = np.array([92, 92, 92, 228, 228, 228 ]) - output = output.astype(out_dtype).reshape(output_shape) - return make_test_configuration(quantized_data=quantized_data_np, - quantized_kernel=quantized_kernel_np, - dtype=in_dtype, - input_shape=input_shape, - kernel_shape=kernel_shape, - input_zero_point=input_zero_point, - kernel_zero_point=kernel_zero_point, - units=units, - output=output, - bias=bias, - requantize=requant_params) - - def make_int_configuration(use_bias=False, requantize_output=False): - input_shape, kernel_shape, output_shape = (2, 10), (3,10), (2, 3) - input_zero_point, kernel_zero_point = -1, -1 - in_dtype = 'int8' - out_dtype = 'int32' if not requantize_output else 'int8' - units = 3 - quantized_data_np = np.array([1, 3, 5, 7, 9, 11, 13, 15, -19, -21, - 1, 3, 5, 7, 9, 11, 13, -17, 17, -21]) \ - .astype(in_dtype) \ - .reshape(input_shape) - quantized_kernel_np = np.array([1, 3, 5, 7, 9, 11, 13, 15, 17, 19, - 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, - 1, 3, 5, 7, 9, 11, 13, 15, 17, 19]) \ - .astype(in_dtype) \ - .reshape(kernel_shape) - bias = np.array([4, 8, 12]).astype(out_dtype).reshape((units, )) if use_bias else None - requant_params = make_requantize_params(0.25, 1.0, -1, 'int8') if requantize_output else None - - if requantize_output: - assert use_bias - output = np.array([23, 24, 25, 57, 58, 59]) - elif use_bias: - output = np.array([96, 100, 104, 232, 236, 240 ]) - else: - output = np.array([92, 92, 92, 228, 228, 228 ]) - output = output.astype(out_dtype).reshape(output_shape) - return make_test_configuration(quantized_data=quantized_data_np, - quantized_kernel=quantized_kernel_np, - dtype=in_dtype, - input_shape=input_shape, - kernel_shape=kernel_shape, - input_zero_point=input_zero_point, - kernel_zero_point=kernel_zero_point, - units=units, - output=output, - bias=bias, - requantize=requant_params) - - def test_quantized_dense(test_configuration): - in_dtype = test_configuration['dtype'] - out_dtype = test_configuration['out_dtype'] - quantized_data_name = "quantized_data" - quantized_kernel_name = "quantized_kernel" - expected_out_dtype = test_configuration['out_dtype'] - bias_name = 'bias' - quantized_data = relay.var(quantized_data_name, shape=test_configuration['input_shape'], - dtype=in_dtype) - quantized_kernel = relay.var(quantized_kernel_name, shape=test_configuration['kernel_shape'], - dtype=in_dtype) - mod = relay.qnn.op.quantized_dense( - quantized_data, - quantized_kernel, - test_configuration['input_zero_point'], - test_configuration['kernel_zero_point'], - test_configuration['units']) + if requantize_output: + assert use_bias + output = np.array([23, 24, 25, 57, 58, 59]) + elif use_bias: + output = np.array([96, 100, 104, 232, 236, 240 ]) + else: + output = np.array([92, 92, 92, 228, 228, 228 ]) + output = output.astype(out_dtype).reshape(output_shape) + return make_test_configuration(quantized_data=quantized_data_np, + quantized_kernel=quantized_kernel_np, + dtype=in_dtype, + input_shape=input_shape, + kernel_shape=kernel_shape, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, + units=units, + output=output, + bias=bias, + requantize=requant_params) + + +def test_quantized_dense(test_configuration): + in_dtype = test_configuration['dtype'] + out_dtype = test_configuration['out_dtype'] + quantized_data_name = "quantized_data" + quantized_kernel_name = "quantized_kernel" + expected_out_dtype = test_configuration['out_dtype'] + bias_name = 'bias' + quantized_data = relay.var(quantized_data_name, shape=test_configuration['input_shape'], + dtype=in_dtype) + quantized_kernel = relay.var(quantized_kernel_name, shape=test_configuration['kernel_shape'], + dtype=in_dtype) + mod = relay.qnn.op.quantized_dense( + quantized_data, + quantized_kernel, + test_configuration['input_zero_point'], + test_configuration['kernel_zero_point'], + test_configuration['units']) + if test_configuration[bias_name] is not None: + bias = relay.var(bias_name, shape=test_configuration['bias'].shape, dtype=out_dtype) + mod = relay.nn.bias_add(mod, bias) + if test_configuration['requantize'] is not None: + requantize_config = test_configuration['requantize'] + mod = relay.qnn.op.requantize( + mod, + input_scale=requantize_config['input_scale'], + input_zero_point=0, + output_scale=requantize_config['output_scale'], + output_zero_point=requantize_config['output_zero_point'], + out_dtype=requantize_config['out_dtype']) + expected_out_dtype = requantize_config['out_dtype'] + + mod = relay.Function(relay.analysis.free_vars(mod), mod) + mod = relay.Module.from_expr(mod) + mod = relay.qnn.transform.CanonicalizeOps()(mod) + with relay.build_config(opt_level=2): + graph, lib, params = relay.build(mod, "llvm", params=None) + mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) + mod.set_input(quantized_data_name,test_configuration[quantized_data_name]) + mod.set_input(quantized_kernel_name,test_configuration[quantized_kernel_name]) if test_configuration[bias_name] is not None: - bias = relay.var(bias_name, shape=test_configuration['bias'].shape, dtype=out_dtype) - mod = relay.nn.bias_add(mod, bias) - if test_configuration['requantize'] is not None: - requantize_config = test_configuration['requantize'] - mod = relay.qnn.op.requantize( - mod, - input_scale=requantize_config['input_scale'], - input_zero_point=0, - output_scale=requantize_config['output_scale'], - output_zero_point=requantize_config['output_zero_point'], - out_dtype=requantize_config['out_dtype']) - expected_out_dtype = requantize_config['out_dtype'] - - mod = relay.Function(relay.analysis.free_vars(mod), mod) - mod = relay.Module.from_expr(mod) - mod = relay.qnn.transform.CanonicalizeOps()(mod) - with relay.build_config(opt_level=2): - graph, lib, params = relay.build(mod, "llvm", params=None) - mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) - mod.set_input(quantized_data_name,test_configuration[quantized_data_name]) - mod.set_input(quantized_kernel_name,test_configuration[quantized_kernel_name]) - if test_configuration[bias_name] is not None: - mod.set_input(bias_name, test_configuration[bias_name]) - mod.set_input(**params) - mod.run() - res = mod.get_output(0).asnumpy() - np.testing.assert_equal(res, test_configuration['output']) - assert res.dtype == expected_out_dtype - - def test_configurations(): - test_prams = [{'use_bias': False}, {'use_bias': True}, {'use_bias': True, 'requantize_output': True}, ] - tests = [test_quantized_dense] - configurations = [] - for test_param in test_prams: - configurations.append(make_uint_configuration(**test_param)) - configurations.append(make_int_configuration(**test_param)) - for configuration in configurations: - for test in tests: - test(configuration) - - test_configurations() + mod.set_input(bias_name, test_configuration[bias_name]) + mod.set_input(**params) + mod.run() + res = mod.get_output(0).asnumpy() + np.testing.assert_equal(res, test_configuration['output']) + assert res.dtype == expected_out_dtype + + +def test_qnn_dense_without_bias(): + uint32_output_without_bias_paramas = make_uint_configuration(use_bias=False) + int32_output_without_bias_params = make_int_configuration(use_bias=False) + test_quantized_dense(uint32_output_without_bias_paramas) + test_quantized_dense(int32_output_without_bias_params) + + +def test_qnn_dense_with_bias(): + uint32_output_with_bias_params = make_uint_configuration(use_bias=True) + int32_output_with_bias_params = make_int_configuration(use_bias=True) + test_quantized_dense(uint32_output_with_bias_params) + test_quantized_dense(int32_output_with_bias_params) + + +def test_qnn_dense_with_requantized_output(): + uint8_requantized_output_with_bias_params = make_uint_configuration(use_bias=True, requantize_output=True) + int8_requantized_output_with_bias_params = make_int_configuration(use_bias=True, requantize_output=True) + test_quantized_dense(uint8_requantized_output_with_bias_params) + test_quantized_dense(int8_requantized_output_with_bias_params) + if __name__ == "__main__": - test_quantized_dense() + test_qnn_dense_without_bias() + test_qnn_dense_with_bias() + test_qnn_dense_with_requantized_output() From 829a1515949fa8c731968e0e31fa36c5f49b2155 Mon Sep 17 00:00:00 2001 From: shoubhik Date: Mon, 9 Sep 2019 10:39:29 -0700 Subject: [PATCH 4/6] Fixing lint issues. --- src/relay/op/nn/convolution.h | 2 +- src/relay/qnn/op/dense.cc | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index c962abc6b756..803eae3d0cfe 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -25,9 +25,9 @@ #ifndef TVM_RELAY_OP_NN_CONVOLUTION_H_ #define TVM_RELAY_OP_NN_CONVOLUTION_H_ +#include #include #include -#include namespace tvm { namespace relay { diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index bfb7fabf0d5b..5473d139e215 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -44,7 +44,7 @@ bool QnnDenseRel(const Array& types, CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); const auto* weight = types[1].as(); - if(data == nullptr || weight == nullptr) return false; + if (data == nullptr || weight == nullptr) return false; const auto* param = attrs.as(); CHECK(param != nullptr) << "QnnConv2DAttrs cannot be nullptr."; CHECK(data->dtype == Int(8) || data->dtype == UInt(8)) @@ -91,13 +91,13 @@ Expr QnnDenseCanonicalize(const Attrs& attrs, Expr quantized_kernel = new_args[1]; const auto* qnn_dense_attrs = attrs.as(); Expr quantized_data_int32 = Cast(quantized_data, Int(32)); - if(qnn_dense_attrs->input_zero_point != 0) { + if (qnn_dense_attrs->input_zero_point != 0) { quantized_data_int32 = Subtract(quantized_data_int32, MakeConstantScalar(Int(32), qnn_dense_attrs->input_zero_point)); } Expr quantized_kernel_int32 = Cast(quantized_kernel, Int(32)); - if(qnn_dense_attrs->kernel_zero_point != 0) { + if (qnn_dense_attrs->kernel_zero_point != 0) { quantized_kernel_int32 = Subtract(quantized_kernel_int32, MakeConstantScalar(Int(32), qnn_dense_attrs->kernel_zero_point)); From 8931c5fd57390300414cc6b3b5a85e71d45d126e Mon Sep 17 00:00:00 2001 From: shoubhik Date: Thu, 12 Sep 2019 09:38:51 -0700 Subject: [PATCH 5/6] Fixing test method names to pass the nose related configurations. --- tests/python/relay/test_qnn_dense.py | 104 ++++++++++++++++----------- 1 file changed, 62 insertions(+), 42 deletions(-) diff --git a/tests/python/relay/test_qnn_dense.py b/tests/python/relay/test_qnn_dense.py index 233bf46587b1..2d14593b331f 100644 --- a/tests/python/relay/test_qnn_dense.py +++ b/tests/python/relay/test_qnn_dense.py @@ -31,8 +31,18 @@ def make_requantize_params(input_scale, output_scale, output_zero_point, out_dty return config -def make_test_configuration(quantized_data, quantized_kernel, dtype, input_shape, kernel_shape, input_zero_point, - kernel_zero_point, units, output, out_dtype='int32', bias=None, requantize=None): +def make_configuration(quantized_data, + quantized_kernel, + dtype, + input_shape, + kernel_shape, + input_zero_point, + kernel_zero_point, + units, + output, + out_dtype='int32', + bias=None, + requantize=None): if requantize is not None: assert bias is not None config = { @@ -78,17 +88,17 @@ def make_uint_configuration(use_bias=False, requantize_output=False): else: output = np.array([92, 92, 92, 228, 228, 228 ]) output = output.astype(out_dtype).reshape(output_shape) - return make_test_configuration(quantized_data=quantized_data_np, - quantized_kernel=quantized_kernel_np, - dtype=in_dtype, - input_shape=input_shape, - kernel_shape=kernel_shape, - input_zero_point=input_zero_point, - kernel_zero_point=kernel_zero_point, - units=units, - output=output, - bias=bias, - requantize=requant_params) + return make_configuration(quantized_data=quantized_data_np, + quantized_kernel=quantized_kernel_np, + dtype=in_dtype, + input_shape=input_shape, + kernel_shape=kernel_shape, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, + units=units, + output=output, + bias=bias, + requantize=requant_params) def make_int_configuration(use_bias=False, requantize_output=False): @@ -117,29 +127,31 @@ def make_int_configuration(use_bias=False, requantize_output=False): else: output = np.array([92, 92, 92, 228, 228, 228 ]) output = output.astype(out_dtype).reshape(output_shape) - return make_test_configuration(quantized_data=quantized_data_np, - quantized_kernel=quantized_kernel_np, - dtype=in_dtype, - input_shape=input_shape, - kernel_shape=kernel_shape, - input_zero_point=input_zero_point, - kernel_zero_point=kernel_zero_point, - units=units, - output=output, - bias=bias, - requantize=requant_params) - - -def test_quantized_dense(test_configuration): + return make_configuration(quantized_data=quantized_data_np, + quantized_kernel=quantized_kernel_np, + dtype=in_dtype, + input_shape=input_shape, + kernel_shape=kernel_shape, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, + units=units, + output=output, + bias=bias, + requantize=requant_params) + + +def qnn_dense_driver(test_configuration): in_dtype = test_configuration['dtype'] out_dtype = test_configuration['out_dtype'] quantized_data_name = "quantized_data" quantized_kernel_name = "quantized_kernel" expected_out_dtype = test_configuration['out_dtype'] bias_name = 'bias' - quantized_data = relay.var(quantized_data_name, shape=test_configuration['input_shape'], + quantized_data = relay.var(quantized_data_name, + shape=test_configuration['input_shape'], dtype=in_dtype) - quantized_kernel = relay.var(quantized_kernel_name, shape=test_configuration['kernel_shape'], + quantized_kernel = relay.var(quantized_kernel_name, + shape=test_configuration['kernel_shape'], dtype=in_dtype) mod = relay.qnn.op.quantized_dense( quantized_data, @@ -148,7 +160,9 @@ def test_quantized_dense(test_configuration): test_configuration['kernel_zero_point'], test_configuration['units']) if test_configuration[bias_name] is not None: - bias = relay.var(bias_name, shape=test_configuration['bias'].shape, dtype=out_dtype) + bias = relay.var(bias_name, + shape=test_configuration['bias'].shape, + dtype=out_dtype) mod = relay.nn.bias_add(mod, bias) if test_configuration['requantize'] is not None: requantize_config = test_configuration['requantize'] @@ -179,24 +193,30 @@ def test_quantized_dense(test_configuration): def test_qnn_dense_without_bias(): - uint32_output_without_bias_paramas = make_uint_configuration(use_bias=False) - int32_output_without_bias_params = make_int_configuration(use_bias=False) - test_quantized_dense(uint32_output_without_bias_paramas) - test_quantized_dense(int32_output_without_bias_params) + uint32_output_without_bias_paramas = \ + make_uint_configuration(use_bias=False) + int32_output_without_bias_params = \ + make_int_configuration(use_bias=False) + qnn_dense_driver(uint32_output_without_bias_paramas) + qnn_dense_driver(int32_output_without_bias_params) def test_qnn_dense_with_bias(): - uint32_output_with_bias_params = make_uint_configuration(use_bias=True) - int32_output_with_bias_params = make_int_configuration(use_bias=True) - test_quantized_dense(uint32_output_with_bias_params) - test_quantized_dense(int32_output_with_bias_params) + uint32_output_with_bias_params = \ + make_uint_configuration(use_bias=True) + int32_output_with_bias_params = \ + make_int_configuration(use_bias=True) + qnn_dense_driver(uint32_output_with_bias_params) + qnn_dense_driver(int32_output_with_bias_params) def test_qnn_dense_with_requantized_output(): - uint8_requantized_output_with_bias_params = make_uint_configuration(use_bias=True, requantize_output=True) - int8_requantized_output_with_bias_params = make_int_configuration(use_bias=True, requantize_output=True) - test_quantized_dense(uint8_requantized_output_with_bias_params) - test_quantized_dense(int8_requantized_output_with_bias_params) + uint8_requantized_output_with_bias_params = \ + make_uint_configuration(use_bias=True, requantize_output=True) + int8_requantized_output_with_bias_params = \ + make_int_configuration(use_bias=True, requantize_output=True) + qnn_dense_driver(uint8_requantized_output_with_bias_params) + qnn_dense_driver(int8_requantized_output_with_bias_params) if __name__ == "__main__": From cadda6c0ce331e0beba729a9e990ecd46000ed8b Mon Sep 17 00:00:00 2001 From: shoubhik Date: Fri, 13 Sep 2019 10:55:03 -0700 Subject: [PATCH 6/6] Aligning the code for code style. --- python/tvm/relay/qnn/op/qnn.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index 8ff29d2f6154..878a3a72b01e 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -321,24 +321,28 @@ def quantized_dense(data, out_dtype="int32"): """Qnn Dense operator. Applies a quantized linear transformation + .. math:: + `Y = X * W` Parameters ---------- data : tvm.relay.Expr The quantized input data to the operator. - weight : tvm.relay.Expr + weight : tvm.relay.Expr The quantized weight expressions. - units : int, optional + units : int, optional Number of hidden units of the dense transformation. - out_dtype : str, optional + out_dtype : str, optional Specifies the output data type for mixed precision dense can be int32 or int16. - Returns + + Returns ------- result : tvm.relay.Expr The computed result. """ + return _make.dense(data, weight, units,