Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1466,10 +1466,19 @@ def convert_greater(self, op):
def convert_squared_difference(self, op):
"""Convert TFLite SQUARED DIFFERENCE"""
# Check if the input tensor is quantized, call QNN op
# (https://github.com/tensorflow/tflite-micro/blob/bc35c3ed9c7ab93b3a13b46fce936f854bcfce2c/tensorflow/lite/micro/kernels/squared_difference.cc#L157) # pylint: disable=line-too-long
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
"TFlite quantized squared difference operator is not supported yet."
)
input_tensors = self.get_input_tensors(op)
output_tensors = self.get_output_tensors(op)
lhs_expr = self.get_tensor_expr(input_tensors[0])
rhs_expr = self.get_tensor_expr(input_tensors[1])
assert len(input_tensors) == 2, "input tensors length should be 2"
assert len(output_tensors) == 1, "output tensors length should be 1"
lhs_expr_f32 = self.dequantize(lhs_expr, input_tensors[0])
rhs_expr_f32 = self.dequantize(rhs_expr, input_tensors[1])
out_f32 = _op.subtract(lhs_expr_f32, rhs_expr_f32)
return self.quantize(out_f32 * out_f32, output_tensors[0])

difference = self._convert_elemwise(_op.subtract, op)
# _convert_elemwise has guaranteed only have one output tensor
exp_type = self.get_tensor_type_str(self.get_output_tensors(op)[0].tensor.Type())
Expand Down
37 changes: 31 additions & 6 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def compare_tflite_with_tvm(
mode="graph_executor",
experimental_new_converter=False,
fp16_quantized=False,
int_quant_dtype=tf.int8,
int_quant_dtype=tf.uint8,
):
"""Generic function to generate and compare TFLite and TVM output"""
in_data = convert_to_list(in_data)
Expand Down Expand Up @@ -334,6 +334,8 @@ def compare_tflite_with_tvm(
converter.target_spec.supported_ops = [
tf.lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
]
elif int_quant_dtype == tf.int8:
converter.inference_type = tf.lite.constants.INT8
else:
# default to int8 quantization
converter.inference_type = tf.lite.constants.QUANTIZED_UINT8
Expand Down Expand Up @@ -2327,6 +2329,16 @@ def _test_elemwise(
def __test_elemwise(in_data):
assert len(in_data) == 2
if quantized:
int_quant_dtype = None
if data[0].dtype == "int8":
int_quant_dtype = tf.int8
elif data[0].dtype == "uint8":
int_quant_dtype = tf.uint8
elif data[0].dtype == "int16":
int_quant_dtype = tf.int16
else:
assert False, "Unsupported conversion from numpy to tflite dtype!"

# set the fp32 output range with respect to the operation
out_min, out_max = _test_elemwise_qnn_out_range(qnn_op)
inq0_min, inq0_max = (-100, 100)
Expand Down Expand Up @@ -2375,6 +2387,7 @@ def __test_elemwise(in_data):
quantized=True,
input_range=input_range,
experimental_new_converter=same_qnn_params,
int_quant_dtype=int_quant_dtype,
)
else:
out = math_op(inq_data[0], inq_data[1])
Expand All @@ -2392,6 +2405,7 @@ def __test_elemwise(in_data):
quantized=True,
input_range=input_range,
experimental_new_converter=same_qnn_params,
int_quant_dtype=int_quant_dtype,
)
else:
out = math_op(
Expand Down Expand Up @@ -2585,9 +2599,16 @@ def _test_not_equal(data):
# ------------------


def _test_squared_difference(data):
def _test_squared_difference(data, fused_activation_function=None, quantized=False, qnn_op=None):
"""One iteration of squared difference"""
return _test_elemwise(math_ops.squared_difference, data)
return _test_elemwise(
math_ops.squared_difference,
data,
fused_activation_function,
quantized,
qnn_op,
same_qnn_params=True,
)


#######################################################################
Expand Down Expand Up @@ -2632,11 +2653,13 @@ def _test_forward_elemwise(testop):
)


def _test_forward_elemwise_quantized(testop):
def _test_forward_elemwise_quantized(testop, dtype=np.uint8):
type_info = np.iinfo(dtype)
_min, _max = type_info.min, type_info.max
testop(
[
np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8),
np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8),
np.array(np.random.uniform(_min, _max, (3, 6)), dtype=dtype),
np.array(np.random.uniform(_min, _max, (3, 6)), dtype=dtype),
],
quantized=True,
qnn_op=testop,
Expand All @@ -2653,6 +2676,7 @@ def _test_elemwise_qnn_out_range(qnn_op):
_test_minimum: (-128, 127),
_test_equal: (-150, 150),
_test_greater: (-150, 150),
_test_squared_difference: (0, 65025),
}

return qnn_out_range[qnn_op]
Expand Down Expand Up @@ -2685,6 +2709,7 @@ def test_all_elemwise():
_test_forward_elemwise(_test_greater)
_test_forward_elemwise_quantized(_test_greater)
_test_forward_elemwise(_test_squared_difference)
_test_forward_elemwise_quantized(_test_squared_difference, np.int8)
_test_forward_elemwise(_test_greater_equal)
_test_forward_elemwise(_test_less)
_test_forward_elemwise(_test_less_equal)
Expand Down