diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 532318b804da..043da8704221 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -2432,11 +2432,6 @@ def convert_reverse_sequence(self, op): except ImportError: raise ImportError("The tflite package must be installed") - if self.is_quantized(op): - raise tvm.error.OpNotImplemented( - "TFLite does not support quantized REVERSE_SEQUENCE operator yet." - ) - input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 2, "input tensors length should be 2" diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index f60166702454..8780ced60efa 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -4272,28 +4272,43 @@ def test_forward_spacetodepth(): ####################################################################### # ReverseSequence # --------------- - - -def _test_reverse_sequence(shape, dtype, seq_lengths, batch_axis, seq_axis): +def _test_reverse_sequence(shape, dtype, seq_lengths, batch_axis, seq_axis, quantized=False): """One iteration of reverse_sequence operation with given data and attributes""" - data = np.random.uniform(0, 100, size=shape).astype(dtype) with tf.Graph().as_default(): - in_data = array_ops.placeholder(dtype=dtype, name="input", shape=shape) - out = tf.reverse_sequence( - in_data, seq_lengths=seq_lengths, batch_axis=batch_axis, seq_axis=seq_axis - ) - - compare_tflite_with_tvm(data, "input", [in_data], [out]) + in_data = array_ops.placeholder(dtype="float32", name="in_0", shape=shape) + if quantized: + inq_data = tf.quantization.fake_quant_with_min_max_args( + in_data, min=0, max=10, name="inq_0" + ) + input_range = {"inq_0": (-10, 10)} + out = tf.reverse_sequence( + inq_data, seq_lengths=seq_lengths, batch_axis=batch_axis, seq_axis=seq_axis + ) + out = tf.quantization.fake_quant_with_min_max_args(out, min=0, max=6, name="out") + compare_tflite_with_tvm( + data, "inq_0:0", [inq_data], [out], quantized=True, input_range=input_range + ) + else: + out = tf.reverse_sequence( + in_data, seq_lengths=seq_lengths, batch_axis=batch_axis, seq_axis=seq_axis + ) + compare_tflite_with_tvm(data, "in_0:0", [in_data], [out]) def test_forward_reverse_sequence(): + """Tests the reverse_sequence function with different input shapes, data types.""" if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"): _test_reverse_sequence([4, 3], "float32", [3, 2, 1], 1, 0) _test_reverse_sequence([4, 3], "float32", [3, 2, 1, 3], 0, 1) _test_reverse_sequence([2, 3, 3, 3], "float32", [2, 3, 2], 2, 1) _test_reverse_sequence([2, 4, 6, 4, 5], "float32", [5, 3], 0, 2) _test_reverse_sequence([2, 4, 6, 4, 5], "float32", [5, 3, 1, 4], 3, 2) + _test_reverse_sequence([4, 3], "uint8", [3, 2, 1], 1, 0, quantized=True) + _test_reverse_sequence([4, 3], "uint8", [3, 2, 1, 3], 0, 1, quantized=True) + _test_reverse_sequence([2, 3, 3, 3], "uint8", [2, 3, 2], 2, 1, quantized=True) + _test_reverse_sequence([2, 4, 6, 4, 5], "uint8", [5, 3], 0, 2, quantized=True) + _test_reverse_sequence([2, 4, 6, 4, 5], "uint8", [5, 3, 1, 4], 3, 2, quantized=True) #######################################################################