From 9ae8819fe40f2d5f3304daff97765d7063fc42d6 Mon Sep 17 00:00:00 2001 From: Nicola Lancellotti Date: Fri, 16 Apr 2021 14:14:54 +0100 Subject: [PATCH] Add support for the quantized TANH operator to relay TFLite frontend Change-Id: I70df765e1562fa586ed0ffd0e07b8858f7fbb831 --- python/tvm/relay/frontend/tflite.py | 11 +++++++-- tests/python/frontend/tflite/test_forward.py | 25 +++++++++++++++----- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 6b14a6f58e60..6a52cd21d921 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -769,11 +769,18 @@ def convert_tanh(self, op): """Convert TFLite TANH""" input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" - input_tensor = input_tensors[0] in_expr = self.get_expr(input_tensor.tensor_idx) - out = _op.tanh(in_expr) + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1, "output tensors length should be 1" + output_tensor = output_tensors[0] + + if input_tensor.qnn_params: + in_expr = self.dequantize(in_expr, input_tensor) + out = _op.tanh(in_expr) + if output_tensor.qnn_params: + out = self.quantize(out, output_tensor) return out def convert_range(self, op): diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 43ffc400491a..09d273d44dc9 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -3255,17 +3255,30 @@ def test_forward_log_softmax(): # ---- -def _test_tanh(data): +def _test_tanh(data, quantized=False): """ One iteration of TANH """ with tf.Graph().as_default(): - in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) - out = math_ops.tanh(in_data) - compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out]) + in_data = array_ops.placeholder(shape=data.shape, dtype="float32", name="in_0") + + if quantized: + inq_data = tf.quantization.fake_quant_with_min_max_args( + in_data, min=-3, max=3, name="inq_0" + ) + input_range = {"inq_0": (-3, 3)} + out = math_ops.tanh(inq_data) + out = tf.quantization.fake_quant_with_min_max_args(out, min=-1, max=1, name="out") + compare_tflite_with_tvm( + data, "inq_0:0", [inq_data], [out], quantized=True, input_range=input_range + ) + else: + out = math_ops.tanh(in_data) + compare_tflite_with_tvm(data, "in_0:0", [in_data], [out]) def test_forward_tanh(): - """ TANH """ - _test_tanh(np.arange(6.0, dtype=np.float32).reshape((1, 6))) + """TANH""" + _test_tanh(np.arange(6.0, dtype=np.float32).reshape((1, 6)), quantized=False) + _test_tanh(np.arange(0, 256, 30, dtype=np.uint8), quantized=True) #######################################################################