Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,44 @@ def dense(data,
input_zero_point,
kernel_zero_point,
out_dtype)


def mul(lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale,
output_zero_point):
"""Quantized multiplication with numpy-style broadcasting.

Parameters
----------
lhs : relay.Expr
The left hand side quantized input data.

rhs : relay.Expr
The right hand side quantized input data.

lhs_scale: float
The scale of the lhs quantized expr.
lhs_zero_point: int
The zero point of lhs quantized expr.

rhs_scale: float
The scale of the rhs quantized expr.

rhs_zero_point: int
The zero point of rhs quantized expr.

output_scale: float
The scale of the output quantized expr.

output_zero_point: int
The zero point of output quantized expr.

Returns
-------
result : relay.Expr
The computed result.

"""
return _make.mul(lhs, rhs,
lhs_scale, lhs_zero_point,
rhs_scale, rhs_zero_point,
output_scale, output_zero_point)
105 changes: 105 additions & 0 deletions src/relay/qnn/op/mul.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/qnn/op/mul.cc
* \brief QNN mul operator.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h>
#include "../../pass/pattern_util.h"
#include "../util.h"
#include "op_common.h"

namespace tvm {
namespace relay {
namespace qnn {

/*
* \brief Canonicalizes the QNN mul op.
* \param attrs The QNN concatenate attrs.
* \param new_args The new mutated args to the call node.
* \param arg_types The types of input and output.
* \return The sequence of Relay ops for mul op.
*/
Expr QnnMulCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that there is a mistake in the lowering. Can you please provide description of how the lowering instructions are decided?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that there is a mistake in the lowering. Can you please provide description of how the lowering instructions are decided?

Same as the addition operator - only difference is that the requantized tensors are multiplied to provide the output. Let me know which part of the instructions you see an issue with and I'll take a closer look - Thanks

const Array<tvm::relay::Type>& arg_types) {
// Get the attrs.
CHECK_EQ(new_args.size(), 2);
auto& lhs = new_args[0];
auto& rhs = new_args[1];
const auto* binary_op_attrs = attrs.as<QnnBinaryOpAttrs>();
CHECK(binary_op_attrs != nullptr);
auto lhs_scale = binary_op_attrs->lhs_scale;
auto lhs_zero_point = binary_op_attrs->lhs_zero_point;
auto rhs_scale = binary_op_attrs->rhs_scale;
auto rhs_zero_point = binary_op_attrs->rhs_zero_point;
auto output_scale = binary_op_attrs->output_scale;
auto output_zero_point = binary_op_attrs->output_zero_point;

// Get the input dtype and shape.
CHECK_EQ(arg_types.size(), 3);
auto tensor_type = arg_types[0].as<TensorTypeNode>();
auto input_dtype = tensor_type->dtype;
auto input_shape = tensor_type->shape;

// Requantize LHS if necessary.
auto requantized_lhs = lhs;
if (lhs_scale != output_scale || lhs_zero_point != output_zero_point) {
requantized_lhs = Requantize(lhs, input_shape, lhs_scale, lhs_zero_point, output_scale,
output_zero_point, Int(32));
} else {
requantized_lhs = Cast(requantized_lhs, Int(32));
}

// Requantize RHS if necessary.
auto requantized_rhs = rhs;
if (rhs_scale != output_scale || rhs_zero_point != output_zero_point) {
requantized_rhs = Requantize(rhs, input_shape, rhs_scale, rhs_zero_point, output_scale,
output_zero_point, Int(32));
} else {
requantized_rhs = Cast(requantized_rhs, Int(32));
}

auto output = Multiply(requantized_lhs, requantized_rhs);

// Subtract zero point.
if (output_zero_point != 0) {
auto output_zp = MakeConstantScalar(Int(32), output_zero_point);
output = Subtract(output, output_zp);
}

// Go back to lower precision.
auto q_min = GetQmin(input_dtype);
auto q_max = GetQmax(input_dtype);
output = Clip(output, q_min, q_max);
return Cast(output, input_dtype);
}

// QNN Multiplication operator.
QNN_REGISTER_BINARY_OP("mul")
.describe("Elementwise mul with with broadcasting for quantized tensors.")
.set_support_level(11)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnMulCanonicalize);

} // namespace qnn
} // namespace relay
} // namespace tvm
196 changes: 196 additions & 0 deletions tests/python/relay/test_qnn_mul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tvm
import numpy as np
from tvm import relay
from tvm.contrib import graph_runtime
import topi.testing

def test_tflite_same_io_qnn_params():
data_dtype = 'uint8'

x = relay.var("x", shape=(1, 4), dtype=data_dtype)
y = relay.var("y", shape=(1, 4), dtype=data_dtype)
z = relay.qnn.op.mul(lhs=x, rhs=y,
lhs_scale=0.00784314,
lhs_zero_point=127,
rhs_scale=0.00784314,
rhs_zero_point=127,
output_scale=0.00784314,
output_zero_point=127)

func = relay.Function([x, y], z)
mod = relay.Module.from_expr(func)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"]

x_datas = [np.array((1, 153, 165, 178)).reshape((1,4)),
np.array((25, 1, 178, 216)).reshape((1,4)),
np.array((25, 153, 1, 165)).reshape((1,4))]
y_datas = [np.array((204, 178, 1, 140)).reshape((1,4)),
np.array((204, 178, 191, 1)).reshape((1,4)),
np.array((204, 178, 1, 191)).reshape((1,4))]
golden_outputs = [np.array((77, 255,38, 255)).reshape((1, 4)),
np.array((255, 51, 255, 89)).reshape((1,4)),
np.array((255, 255, 0, 255)).reshape((1,4))]

for i in range(0, 3):
x_data = x_datas[i]
y_data = y_datas[i]
golden_output = golden_outputs[i]

intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_equal(op_res.asnumpy(), golden_output)


def test_tflite_different_io_qnn_params():
data_dtype = 'uint8'

x = relay.var("x", shape=(1, 4), dtype=data_dtype)
y = relay.var("y", shape=(1, 4), dtype=data_dtype)
z = relay.qnn.op.mul(lhs=x, rhs=y,
lhs_scale=0.0156863,
lhs_zero_point=127,
rhs_scale=0.0117647,
rhs_zero_point=85,
output_scale=0.0235294,
output_zero_point=128)

func = relay.Function([x, y], z)
mod = relay.Module.from_expr(func)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"]

x_datas = [np.array((76, 140, 153, 172)).reshape((1,4)),
np.array((133, 140, 146, 153)).reshape((1,4)),
np.array((76, 140, 172, 146)).reshape((1,4))]
y_datas = [np.array((136, 119, 128, 17)).reshape((1,4)),
np.array((136, 119, 111, 94)).reshape((1,4)),
np.array((136, 119, 17, 128)).reshape((1,4))]
golden_outputs = [np.array((255, 255, 255, 255)).reshape((1, 4)),
np.array((255, 255, 255, 255)).reshape((1,4)),
np.array((255, 255, 255, 255)).reshape((1,4))]

for i in range(0, 3):
x_data = x_datas[i]
y_data = y_datas[i]
golden_output = golden_outputs[i]

intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_equal(op_res.asnumpy(), golden_output)


def test_saturation():
# Same params
data_dtype = 'uint8'
x = relay.var("x", shape=(1, 4), dtype=data_dtype)
y = relay.var("y", shape=(1, 4), dtype=data_dtype)
z = relay.qnn.op.mul(lhs=x, rhs=y,
lhs_scale=0.125,
lhs_zero_point=0,
rhs_scale=0.125,
rhs_zero_point=0,
output_scale=0.125,
output_zero_point=0)

func = relay.Function([x, y], z)
mod = relay.Module.from_expr(func)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"]

x_data = np.array((255, 1, 1, 0)).reshape((1,4))
y_data = np.array((255, 255, 128, 0)).reshape((1,4))
golden_output = np.array((255, 255, 128, 0)).reshape((1, 4))

intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_equal(op_res.asnumpy(), golden_output)

# Same params, different scale
z = relay.qnn.op.mul(lhs=x, rhs=y,
lhs_scale=0.125,
lhs_zero_point=0,
rhs_scale=0.125,
rhs_zero_point=0,
output_scale=0.25,
output_zero_point=0)

func = relay.Function([x, y], z)
mod = relay.Module.from_expr(func)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"]

x_data = np.array((255, 1, 1, 0)).reshape((1,4))
y_data = np.array((255, 255, 127, 0)).reshape((1,4))
golden_output = np.array((255, 128, 64, 0)).reshape((1, 4))

intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_equal(op_res.asnumpy(), golden_output)

# Same io params, different output scale
z = relay.qnn.op.mul(lhs=x, rhs=y,
lhs_scale=0.125,
lhs_zero_point=0,
rhs_scale=0.125,
rhs_zero_point=0,
output_scale=0.25,
output_zero_point=0)

func = relay.Function([x, y], z)
mod = relay.Module.from_expr(func)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"]

x_data = np.array((255, 1, 1, 0)).reshape((1,4))
y_data = np.array((255, 255, 127, 0)).reshape((1,4))
golden_output = np.array((255, 128, 64, 0)).reshape((1, 4))

intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_equal(op_res.asnumpy(), golden_output)

# All params different
z = relay.qnn.op.mul(lhs=x, rhs=y,
lhs_scale=0.5,
lhs_zero_point=0,
rhs_scale=0.25,
rhs_zero_point=0,
output_scale=0.125,
output_zero_point=0)

func = relay.Function([x, y], z)
mod = relay.Module.from_expr(func)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"]

x_data = np.array((255, 0, 1, 0)).reshape((1,4))
y_data = np.array((0, 128, 64, 0)).reshape((1,4))
golden_output = np.array((0, 0, 255, 0)).reshape((1, 4))

intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_equal(op_res.asnumpy(), golden_output)


if __name__ == '__main__':
test_tflite_same_io_qnn_params()
test_tflite_different_io_qnn_params()
test_saturation()