Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,7 +1243,7 @@ def convert_square(self, op):

return out

def _convert_elemwise(self, relay_op, op, ignore_qnn_params=False):
def _convert_elemwise(self, relay_op, op, ignore_qnn_params=False, dequantize=False):
"""Generic method to Convert TFLite elemwise"""
try:
from tflite.AddOptions import AddOptions
Expand Down Expand Up @@ -1277,8 +1277,13 @@ def _convert_elemwise(self, relay_op, op, ignore_qnn_params=False):

# If quantized, extracts qnn params and call QNN add operator.
if not ignore_qnn_params and lhs_tensor.qnn_params:
assert rhs_tensor.qnn_params, "Both tensors should be quantized."
assert output_tensor.qnn_params, "Output tensor should be quantized."
if not dequantize:
assert rhs_tensor.qnn_params, "Both tensors should be quantized."
assert output_tensor.qnn_params, "Output tensor should be quantized."
else:
lhs_expr = self.dequantize(lhs_expr, lhs_tensor)
rhs_expr = self.dequantize(rhs_expr, rhs_tensor)

out = relay_op(
lhs=lhs_expr,
rhs=rhs_expr,
Expand All @@ -1292,6 +1297,9 @@ def _convert_elemwise(self, relay_op, op, ignore_qnn_params=False):
else:
out = relay_op(lhs_expr, rhs_expr)

if dequantize and output_tensor.qnn_params:
out = self.quantize(out, output_tensor)

# Options (fused_activation_function)
options = None
if op.BuiltinOptionsType() == BuiltinOptions.AddOptions:
Expand Down Expand Up @@ -1393,11 +1401,7 @@ def convert_greater(self, op):
def convert_squared_difference(self, op):
"""Convert TFLite SQUARED DIFFERENCE"""
# Check if the input tensor is quantized, call QNN op
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
"TFlite quantized squared difference operator is not supported yet."
)
difference = self._convert_elemwise(_op.subtract, op)
difference = self._convert_elemwise(_op.subtract, op, dequantize=True)
# _convert_elemwise has guaranteed only have one output tensor
exp_type = self.get_tensor_type_str(self.get_output_tensors(op)[0].tensor.Type())
out = _op.power(difference, relay.const(2, exp_type))
Expand Down
8 changes: 6 additions & 2 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2227,9 +2227,11 @@ def _test_not_equal(data):
# ------------------


def _test_squared_difference(data):
def _test_squared_difference(data, fused_activation_function=None, quantized=False, qnn_op=None):
"""One iteration of squared difference"""
return _test_elemwise(math_ops.squared_difference, data)
return _test_elemwise(
math_ops.squared_difference, data, fused_activation_function, quantized, qnn_op
)


#######################################################################
Expand Down Expand Up @@ -2293,6 +2295,7 @@ def _test_elemwise_qnn_out_range(qnn_op):
_test_mul: (-5e3, 5e3),
_test_maximum: (-112, 111),
_test_minimum: (-128, 127),
_test_squared_difference: (0, 225e2),
}

return qnn_out_range[qnn_op]
Expand Down Expand Up @@ -2323,6 +2326,7 @@ def test_all_elemwise():
_test_forward_elemwise_quantized(_test_minimum)
_test_forward_elemwise(_test_greater)
_test_forward_elemwise(_test_squared_difference)
_test_forward_elemwise_quantized(_test_squared_difference)
_test_forward_elemwise(_test_greater_equal)
_test_forward_elemwise(_test_less)
_test_forward_elemwise(_test_less_equal)
Expand Down