From 6ec8aa402689feb967a4030509981dee4f569267 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 15 Oct 2019 21:41:13 +0000 Subject: [PATCH] [QNN] Change default rouning to UPWARD. --- include/tvm/relay/qnn/attrs.h | 2 +- python/tvm/relay/qnn/op/qnn.py | 2 +- src/relay/qnn/util.h | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h index 83b55b04222a..e5f4ba94e12e 100644 --- a/include/tvm/relay/qnn/attrs.h +++ b/include/tvm/relay/qnn/attrs.h @@ -49,7 +49,7 @@ struct RequantizeAttrs : public tvm::AttrsNode { .describe("The scale of the output tensor."); TVM_ATTR_FIELD(output_zero_point) .describe("The zero point of the output tensor."); - TVM_ATTR_FIELD(rounding).set_default("TONEAREST") + TVM_ATTR_FIELD(rounding).set_default("UPWARD") .describe("Defines the rounding direction when the value is midway between" "two representable values. There are two supported modes - UPWARD" "or TONEAREST. Both modes behave exactly same except at the" diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index ed443abb5293..c8ebfc00a21b 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -27,7 +27,7 @@ def requantize(data, input_zero_point, output_scale, output_zero_point, - rounding="TONEAREST", + rounding="UPWARD", out_dtype="int8"): r"""Requantized operator. diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index c26183705b89..f94860d28cf9 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -76,7 +76,7 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, static inline Expr Requantize(const Expr& data, const Array& input_shape, double input_scale, int32_t input_zero_point, double output_scale, int32_t output_zero_point, const DataType& out_dtype, - const std::string& rounding = "TONEAREST") { + const std::string& rounding = "UPWARD") { auto attrs = make_node(); attrs->input_scale = std::move(input_scale); attrs->input_zero_point = std::move(input_zero_point);