diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 8d18cc2962ae..6275b163fca5 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -1220,7 +1220,7 @@ def convert_square(self, op): return out - def _convert_elemwise(self, relay_op, op, ignore_qnn_params=False): + def _convert_elemwise(self, relay_op, op, ignore_qnn_params=False, dequantize=False): """Generic method to Convert TFLite elemwise""" try: from tflite.AddOptions import AddOptions @@ -1254,8 +1254,13 @@ def _convert_elemwise(self, relay_op, op, ignore_qnn_params=False): # If quantized, extracts qnn params and call QNN add operator. if not ignore_qnn_params and lhs_tensor.qnn_params: - assert rhs_tensor.qnn_params, "Both tensors should be quantized." - assert output_tensor.qnn_params, "Output tensor should be quantized." + if not dequantize: + assert rhs_tensor.qnn_params, "Both tensors should be quantized." + assert output_tensor.qnn_params, "Output tensor should be quantized." + else: + lhs_expr = self.dequantize(lhs_expr, lhs_tensor) + rhs_expr = self.dequantize(rhs_expr, rhs_tensor) + out = relay_op( lhs=lhs_expr, rhs=rhs_expr, @@ -1269,6 +1274,9 @@ def _convert_elemwise(self, relay_op, op, ignore_qnn_params=False): else: out = relay_op(lhs_expr, rhs_expr) + if dequantize and output_tensor.qnn_params: + out = self.quantize(out, output_tensor) + # Options (fused_activation_function) options = None if op.BuiltinOptionsType() == BuiltinOptions.AddOptions: @@ -1370,11 +1378,7 @@ def convert_greater(self, op): def convert_squared_difference(self, op): """Convert TFLite SQUARED DIFFERENCE""" # Check if the input tensor is quantized, call QNN op - if self.is_quantized(op): - raise tvm.error.OpNotImplemented( - "TFlite quantized squared difference operator is not supported yet." - ) - difference = self._convert_elemwise(_op.subtract, op) + difference = self._convert_elemwise(_op.subtract, op, dequantize=True) # _convert_elemwise has guaranteed only have one output tensor exp_type = self.get_tensor_type_str(self.get_output_tensors(op)[0].tensor.Type()) out = _op.power(difference, relay.const(2, exp_type)) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 80cdcf327f4b..be454ec44ef7 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -270,6 +270,7 @@ def compare_tflite_with_tvm( input_range=None, mode="graph_executor", experimental_new_converter=False, + experimental_new_quantizer=False, fp16_quantized=False, ): """Generic function to generate and compare TFLite and TVM output""" @@ -286,6 +287,7 @@ def compare_tflite_with_tvm( # convert to tflite model converter = tf.lite.TFLiteConverter.from_session(sess, input_tensors, output_tensors) converter.experimental_new_converter = experimental_new_converter + converter.experimental_new_quantizer = experimental_new_quantizer if quantized: converter.inference_type = tf.lite.constants.QUANTIZED_UINT8 input_arrays = converter.get_input_arrays() @@ -2076,6 +2078,7 @@ def _test_elemwise( quantized=False, qnn_op=None, same_qnn_params=False, + experimental_new_quantizer=True, ): """One iteration of elemwise""" @@ -2135,6 +2138,7 @@ def __test_elemwise(in_data): quantized=True, input_range=input_range, experimental_new_converter=same_qnn_params, + experimental_new_quantizer=experimental_new_quantizer, ) else: out = math_op( @@ -2312,9 +2316,17 @@ def _test_not_equal(data): # ------------------ -def _test_squared_difference(data): +def _test_squared_difference(data, fused_activation_function=None, quantized=False, qnn_op=None): """One iteration of squared difference""" - return _test_elemwise(math_ops.squared_difference, data) + return _test_elemwise( + math_ops.squared_difference, + data, + fused_activation_function, + quantized, + qnn_op, + same_qnn_params=True, + experimental_new_quantizer=False, + ) ####################################################################### @@ -2378,6 +2390,7 @@ def _test_elemwise_qnn_out_range(qnn_op): _test_mul: (-5e3, 5e3), _test_maximum: (-112, 111), _test_minimum: (-128, 127), + _test_squared_difference: (0, 225e2), } return qnn_out_range[qnn_op] @@ -2408,6 +2421,7 @@ def test_all_elemwise(): _test_forward_elemwise_quantized(_test_minimum) _test_forward_elemwise(_test_greater) _test_forward_elemwise(_test_squared_difference) + _test_forward_elemwise_quantized(_test_squared_difference) _test_forward_elemwise(_test_greater_equal) _test_forward_elemwise(_test_less) _test_forward_elemwise(_test_less_equal)