From 4af8536566009432526f695d120f5f2c7f4231e1 Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Fri, 24 Feb 2023 11:29:01 -0800 Subject: [PATCH 1/2] Fix TFLite frontend bug and add test --- python/tvm/relay/frontend/tflite.py | 7 +- tests/python/frontend/tflite/test_forward.py | 173 ++++--------------- 2 files changed, 38 insertions(+), 142 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 95bdb0ce513c..db21fa6668d1 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -2146,7 +2146,7 @@ def convert_conv(self, op, conv_type): _, kernel_h, kernel_w, in_channels = to_int_list(self.get_tensor_shape(weight_tensor)) assert in_channels == input_c * depth_multiplier else: - output_channels, kernel_h, kernel_w, _ = to_int_list( + output_channels, kernel_h, kernel_w, in_channels = to_int_list( self.get_tensor_shape(weight_tensor) ) @@ -2170,6 +2170,11 @@ def convert_conv(self, op, conv_type): else: params["channels"] = int(output_channels) params["kernel_layout"] = "HWIO" + if input_c != in_channels: + assert ( + input_c % in_channels == 0 + ), "Input channels is not divisible of kernel in_channels." + params["groups"] = int(input_c / in_channels) # weight tensor type should be INT8/UINT8 (quantization) or FLOAT32 weight_tensor_type = weight_tensor.tensor.Type() diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 1d743ceb6938..4f3c811af91a 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -29,6 +29,7 @@ from packaging import version as package_version import pytest import numpy as np +import typing from PIL import Image @@ -292,11 +293,11 @@ def run_tflite_graph(tflite_model_buf, input_data): def compare_tflite_with_tvm( - in_data, - in_name, - input_tensors, - output_tensors, - init_global_variables=False, + in_data: typing.List[np.ndarray], + in_name: typing.List[str], + input_tensors: typing.List, + output_tensors: typing.List, + init_global_variables: bool = False, out_names=None, quantized=False, input_range=None, @@ -5301,140 +5302,30 @@ def _golden(): _test_reshape_span() -####################################################################### -# Main -# ---- +class TestConv2d: + input_shape, kernel_shape, padding = tvm.testing.parameters( + ((1, 128, 256, 6), (5, 5, 6, 10), "SAME"), + ((1, 128, 256, 6), (5, 5, 6, 10), "VALID"), + # conv2d_group cases + ((1, 30, 40, 6), (5, 5, 1, 6), "SAME"), + ((1, 30, 40, 6), (5, 5, 1, 6), "VALID"), + ) + + def test_conv2d(self, input_shape: tuple, kernel_shape: tuple, padding: str): + dtype = tf.float32 + kernel_in = np.ones(kernel_shape) + with tf.Graph().as_default(): + x = array_ops.placeholder(shape=input_shape, dtype=dtype.name, name="input") + kernel = tf.constant(kernel_in, dtype=dtype, name="filter_weight") + out = tf.nn.conv2d(x, kernel, strides=[1, 1, 1, 1], padding=padding, name="conv2d") + input_data = np.random.randn(*input_shape).astype(dtype.name) + compare_tflite_with_tvm( + [input_data], + ["input"], + [x], + [out], + ) + + if __name__ == "__main__": - # BatchToSpaceND - test_forward_batch_to_space_nd() - - # SpaceToBatchND - test_forward_space_to_batch_nd() - - # Split - test_forward_split() - - # Transpose - test_forward_transpose() - - # Cast - test_forward_cast() - - # BatchMatMul - test_forward_batch_matmul() - - # Tile - test_forward_tile() - - # Query - test_forward_shape() - - # Transforms - test_forward_concatenation() - test_forward_pad() - test_forward_pack() - test_forward_unpack() - test_forward_reshape() - test_all_resize() - test_forward_range() - test_forward_squeeze() - test_forward_slice() - test_forward_topk() - test_forward_gather() - test_forward_gather_nd() - test_forward_stridedslice() - test_forward_depthtospace() - test_forward_spacetodepth() - test_forward_reverse_sequence() - test_forward_sparse_to_dense() - test_forward_select() - test_forward_quantize_dequantize() - test_forward_arg_min_max() - test_forward_expand_dims() - test_forward_reverse_v2() - test_forward_matrix_set_diag() - test_forward_matrix_diag() - - # NN - test_forward_convolution() - test_forward_transpose_conv() - test_forward_logistic() - test_forward_pooling() - test_forward_l2_pool2d() - test_forward_softmax() - test_forward_tanh() - test_forward_relu() - test_forward_relu6() - test_forward_leaky_relu() - test_forward_relu_n1_to_1() - test_forward_log_softmax() - test_forward_fully_connected() - test_forward_l2_normalization() - test_forward_local_response_normalization() - test_forward_prelu() - test_forward_unidirectional_sequence_lstm() - - # Elemwise - test_all_elemwise() - test_forward_add_n() - - # Unary elemwise - test_all_unary_elemwise() - # Zeros Like - test_forward_zeros_like() - - # Fill - test_forward_fill() - - # Reduce - test_all_reduce() - - # Logical - test_all_logical() - - # Detection_PostProcess - test_detection_postprocess() - - # NonMaxSuppressionV5 - test_forward_nms_v5() - - # Overwrite Converter - test_custom_op_converter() - - # test structural_equal and span information - test_structure_and_span() - - # End to End - test_forward_mobilenet_v1() - test_forward_mobilenet_v2() - test_forward_mobilenet_v3() - test_forward_inception_v3_net() - test_forward_inception_v4_net() - test_forward_inception_v4_net_batched() - test_forward_coco_ssd_mobilenet_v1() - test_forward_mediapipe_hand_landmark() - - # End to End Sparse models - test_forward_sparse_mobilenet_v1() - test_forward_sparse_mobilenet_v2() - - # End to End quantized - test_forward_qnn_inception_v1_net() - test_forward_qnn_mobilenet_v1_net() - test_forward_qnn_mobilenet_v2_net() - # This also fails with a segmentation fault in my run - # with Tflite 1.15.2 - test_forward_qnn_mobilenet_v3_net() - test_forward_qnn_coco_ssd_mobilenet_v1() - - # TFLite 2.1.0 quantized tests - test_forward_quantized_convolution() - test_forward_quantized_depthwise_convolution() - test_forward_tflite2_qnn_resnet50() - test_forward_tflite2_qnn_inception_v1() - test_forward_tflite2_qnn_mobilenet_v2() - - test_forward_tflite_float16() - - test_forward_tflite_int16() - test_forward_ds_cnn_int16() + tvm.testing.main() From 5dd9816ad70bbc528d228b4f1809de140359d1b1 Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Mon, 27 Feb 2023 09:27:58 -0800 Subject: [PATCH 2/2] lint --- tests/python/frontend/tflite/test_forward.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 4f3c811af91a..42a27bbd2671 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -26,10 +26,10 @@ import os import tempfile +import typing from packaging import version as package_version import pytest import numpy as np -import typing from PIL import Image @@ -5303,6 +5303,8 @@ def _golden(): class TestConv2d: + """Import Conv2d operator from TFLite, build with Relay and test.""" + input_shape, kernel_shape, padding = tvm.testing.parameters( ((1, 128, 256, 6), (5, 5, 6, 10), "SAME"), ((1, 128, 256, 6), (5, 5, 6, 10), "VALID"),