From 5213fa55e0e4135399cd1a83eb5d373418063498 Mon Sep 17 00:00:00 2001 From: Elinx Hsi Date: Tue, 12 Oct 2021 22:24:43 +0800 Subject: [PATCH] Support quantised SQRT operator in TFLite --- python/tvm/relay/frontend/tflite.py | 2 - tests/python/frontend/tflite/test_forward.py | 53 +++++++++++++++----- 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 97382ffe6fcd..5da6fd2bf386 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -1191,8 +1191,6 @@ def convert_cos(self, op): def convert_sqrt(self, op): """Convert TFLite SQRT""" - if self.is_quantized(op): - raise tvm.error.OpNotImplemented("TFlite quantized SQRT operator is not supported yet.") return self._convert_unary_elemwise(_op.sqrt, op) def convert_rsqrt(self, op): diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index f8a603c87800..8bedb23d9155 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1850,16 +1850,6 @@ def _test_tan(data): return _test_unary_elemwise(math_ops.tan, data) -####################################################################### -# Sqrt -# ---- - - -def _test_sqrt(data): - """One iteration of sqrt""" - return _test_unary_elemwise(math_ops.sqrt, data) - - ####################################################################### # Square # ------ @@ -1882,7 +1872,7 @@ def _test_elu(data): def _test_forward_unary_elemwise(test_op): # functions that need positive input - if test_op.__name__ in {"_test_log", "_test_sqrt"}: + if test_op.__name__ in {"_test_log"}: test_op(np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3))) else: test_op(np.random.uniform(-10, 10, (3, 2)).astype(np.float32)) @@ -1893,7 +1883,6 @@ def test_all_unary_elemwise(): _test_forward_unary_elemwise(_test_exp) _test_forward_unary_elemwise(_test_log) _test_forward_unary_elemwise(_test_sin) - _test_forward_unary_elemwise(_test_sqrt) _test_forward_unary_elemwise(_test_square) # ceil and cos come with TFLite 1.14.0.post1 fbs schema if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"): @@ -3360,6 +3349,45 @@ def test_forward_rsqrt(): _test_rsqrt(np.arange(1, 240, 40, dtype=np.uint8).reshape((2, 1, 3)), quantized=True) +####################################################################### +# SQRT +# ---- + + +def _test_sqrt(data, quantized=False): + """One iteration of SQRT""" + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype="float32", name="in_0") + + if quantized: + inq_data = tf.quantization.fake_quant_with_min_max_args( + in_data, min=1, max=6, name="inq_0" + ) + input_range = {"inq_0": (1, 6)} + out = math_ops.sqrt(inq_data) + out = tf.quantization.fake_quant_with_min_max_args(out, min=1, max=6, name="out") + compare_tflite_with_tvm( + data, + "inq_0:0", + [inq_data], + [out], + quantized=True, + input_range=input_range, + experimental_new_converter=True, + ) + else: + out = math_ops.sqrt(in_data) + compare_tflite_with_tvm(data, "in_0:0", [in_data], [out]) + + +def test_forward_sqrt(): + """SQRT""" + _test_sqrt(np.arange(1.0, 7.0, dtype=np.float32), quantized=False) + _test_sqrt(np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3)), quantized=False) + _test_sqrt(np.arange(1, 240, 40, dtype=np.uint8), quantized=True) + _test_sqrt(np.arange(1, 240, 40, dtype=np.uint8).reshape((2, 1, 3)), quantized=True) + + ####################################################################### # NEG # ---- @@ -4742,6 +4770,7 @@ def test_prevent_tensorflow_dynamic_range(): test_forward_rsqrt() test_forward_neg() test_forward_abs() + test_forward_sqrt() test_forward_relu() test_forward_relu6() test_forward_leaky_relu()