diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 6acc8554b4dd..42848783967c 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -933,9 +933,24 @@ def representative_data_gen(): is_float_output=True, int_quant_dtype=int_quant_dtype, ) + # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 + try: + import tflite.Model + + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_quant, 0) + except AttributeError: + import tflite + + tflite_model = tflite.Model.GetRootAsModel(tflite_model_quant, 0) + except ImportError: + raise ImportError("The tflite package must be installed") + + subgraph = tflite_model.Subgraphs(0) + model_input = subgraph.InputsAsNumpy() + input_node = subgraph.Tensors(model_input).Name().decode("utf-8") tflite_output = run_tflite_graph(tflite_model_quant, data) - tvm_output = run_tvm_graph(tflite_model_quant, data, data_in.name.replace(":0", "")) + tvm_output = run_tvm_graph(tflite_model_quant, data, input_node) tvm.testing.assert_allclose( np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-2, atol=1e-2 ) @@ -960,8 +975,28 @@ def test_forward_quantized_convolution(): ) +def test_forward_quantized_depthwise_convolution(): + for int_quant_dtype in [tf.int8, tf.int16]: + _test_tflite2_quantized_depthwise_convolution( + [1, 8, 8, 128], [1, 1, 128, 1], [1, 1], [1, 1], "SAME", "NHWC", 1, int_quant_dtype + ) + _test_tflite2_quantized_depthwise_convolution( + [1, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], "VALID", "NHWC", 1, int_quant_dtype + ) + _test_tflite2_quantized_depthwise_convolution( + [1, 24, 24, 3], [7, 7, 3, 8], [1, 1], [2, 2], "SAME", "NHWC", 8, int_quant_dtype + ) + + def _test_tflite2_quantized_depthwise_convolution( - input_shape, kernel_shape, dilations, strides, padding, data_format, depth_multiplier + input_shape, + kernel_shape, + dilations, + strides, + padding, + data_format, + depth_multiplier, + int_quant_dtype=tf.int8, ): """One iteration of TFLite2 quantized depthwise convolution with given shapes and attributes""" @@ -987,10 +1022,32 @@ def representative_data_gen(): for i in range(1): yield [data] - tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen) + tflite_model_quant = _quantize_keras_model( + keras_model, + representative_data_gen, + is_float_input=True, + is_float_output=True, + int_quant_dtype=int_quant_dtype, + ) + + # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 + try: + import tflite.Model + + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_quant, 0) + except AttributeError: + import tflite + + tflite_model = tflite.Model.GetRootAsModel(tflite_model_quant, 0) + except ImportError: + raise ImportError("The tflite package must be installed") + + subgraph = tflite_model.Subgraphs(0) + model_input = subgraph.InputsAsNumpy() + input_node = subgraph.Tensors(model_input).Name().decode("utf-8") tflite_output = run_tflite_graph(tflite_model_quant, data) - tvm_output = run_tvm_graph(tflite_model_quant, data, data_in.name.replace(":0", "")) + tvm_output = run_tvm_graph(tflite_model_quant, data, input_node) tvm.testing.assert_allclose( np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-2, atol=1e-2 ) @@ -1231,15 +1288,6 @@ def test_forward_convolution(): [1, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], "SAME", "NHWC", quantized=True ) - # Disable as tests are flaky - https://github.com/apache/tvm/issues/6064 - # depthwise convolution - # _test_tflite2_quantized_depthwise_convolution([1, 8, 8, 128], [1, 1, 128, 1], [1, 1], [1, 1], - # 'SAME', 'NHWC', 1) - # _test_tflite2_quantized_depthwise_convolution([1, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], - # 'VALID', 'NHWC', 1) - # _test_tflite2_quantized_depthwise_convolution([1, 24, 24, 3], [7, 7, 3, 8], [1, 1], [2, 2], - # 'SAME', 'NHWC', 8) - ####################################################################### # Transpose Convolution @@ -5106,6 +5154,8 @@ def test_forward_nms_v5(): test_forward_qnn_coco_ssd_mobilenet_v1() # TFLite 2.1.0 quantized tests + test_forward_quantized_convolution() + test_forward_quantized_depthwise_convolution() test_forward_tflite2_qnn_resnet50() test_forward_tflite2_qnn_inception_v1() test_forward_tflite2_qnn_mobilenet_v2()