From db32f0e7daa8d6564689cb98f47b85380f85b2e6 Mon Sep 17 00:00:00 2001 From: Dmitriy Smirnov Date: Mon, 11 Jan 2021 19:50:56 +0000 Subject: [PATCH] [TFLite] Added ability to infer shapes for arguments Added an ability to infer argument shapes if shapes are not present in TFLite files. The set of networks on which the patch was tested is internal to Arm. Any help with creating unit tests would be appreciated. --- python/tvm/relay/frontend/tflite.py | 35 +++++++++++++++++++---------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 525fb41407d3..316815c980e3 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -353,7 +353,7 @@ def get_tensor_value(self, tensor_wrapper, is_sparse=False): data = tensor_wrapper.buffer.DataAsNumpy() if tensor_wrapper.tensor.ShapeLength() != 0: - shape = to_int_list(tensor_wrapper.tensor.ShapeAsNumpy()) + shape = to_int_list(self.get_tensor_shape(tensor_wrapper)) else: shape = [] @@ -1417,7 +1417,7 @@ def convert_gather(self, op): axis = gather_options.Axis() # Check the indices are with in bounds. - data_shape = to_int_list(input_tensors[0].tensor.ShapeAsNumpy()) + data_shape = to_int_list(self.get_tensor_shape(input_tensors[0])) data_dim = len(data_shape) axis = data_dim + axis if axis < 0 else axis @@ -1535,7 +1535,7 @@ def convert_strided_slice(self, op): new_axis_mask = options.NewAxisMask() shrink_axis_mask = options.ShrinkAxisMask() - data_shape = to_int_list(input_tensors[0].tensor.ShapeAsNumpy()) + data_shape = to_int_list(self.get_tensor_shape(input_tensors[0])) data_dim = len(data_shape) stride_dim = len(stride) @@ -1792,7 +1792,7 @@ def convert_fully_connected(self, op): output_tensor_type = output_tensor.tensor.Type() output_tensor_type_str = self.get_tensor_type_str(output_tensor_type) - weight_tensor_shape = to_int_list(weight_tensor.tensor.ShapeAsNumpy()) + weight_tensor_shape = to_int_list(self.get_tensor_shape(weight_tensor)) # Weight should have only 2 dimensions(TFLite convention) assert len(weight_tensor_shape) == 2, "Weight should be only 2-dim" @@ -1987,16 +1987,16 @@ def convert_conv(self, op, conv_type): padding = conv_options.Padding() fused_activation_fn = conv_options.FusedActivationFunction() - _, input_h, input_w, input_c = to_int_list(input_tensor.tensor.ShapeAsNumpy()) + _, input_h, input_w, input_c = to_int_list(self.get_tensor_shape(input_tensor)) if is_depthwise_conv: # TFLite depthwise convolution kernel layout is: # 1 KH KW C(input_c * depth_multiplier) - _, kernel_h, kernel_w, in_channels = to_int_list(weight_tensor.tensor.ShapeAsNumpy()) + _, kernel_h, kernel_w, in_channels = to_int_list(self.get_tensor_shape(weight_tensor)) assert in_channels == input_c * depth_multiplier else: output_channels, kernel_h, kernel_w, _ = to_int_list( - weight_tensor.tensor.ShapeAsNumpy() + self.get_tensor_shape(weight_tensor) ) dilated_kernel_h = dilation_h * (kernel_h - 1) + 1 @@ -2219,7 +2219,7 @@ def convert_slice(self, op): size = list(self.get_tensor_value(input_tensors[2])) # strided_slice(Relay) needs the slice's end indices, not the size end = size - input_tensor_shape = to_int_list(input_tensor.tensor.ShapeAsNumpy()) + input_tensor_shape = to_int_list(self.get_tensor_shape(input_tensor)) input_tensor_rank = len(input_tensor_shape) for i in range(input_tensor_rank): if size[i] == -1: @@ -2381,7 +2381,8 @@ def convert_pool2d(self, op, pool_type): in_expr = self.get_expr(input_tensor_idx) - _, input_h, input_w, _ = to_int_list(input_tensor.tensor.ShapeAsNumpy()) + _, input_h, input_w, _ = to_int_list(self.get_tensor_shape(input_tensor)) + if padding == Padding.VALID: pass elif padding == Padding.SAME: @@ -2771,12 +2772,13 @@ def convert_transpose_conv(self, op): # Input (data) Tensor. NHWC layout input_tensor = input_tensors[2] - _, input_h, input_w, input_c = to_int_list(input_tensor.tensor.ShapeAsNumpy()) + _, input_h, input_w, input_c = to_int_list(self.get_tensor_shape(input_tensor)) # Weights tensor. TFLite uses OHWI layout weights_tensor = input_tensors[1] out_channels, kernel_h, kernel_w, in_channels = to_int_list( - weights_tensor.tensor.ShapeAsNumpy() + self.get_tensor_shape(weights_tensor) ) + assert ( input_c == in_channels ), "Input channel in the filter should match to channel in the input" @@ -3204,7 +3206,7 @@ def convert_matrix_diag(self, op): ), "TFLite MATRIX_DIAG requires diagonal and output tensors' \ scale and zero points to be equal" - shape = to_int_list(diagonal.tensor.ShapeAsNumpy()) + shape = to_int_list(self.get_tensor_shape(diagonal)) shape = np.append(shape, shape[-1]) dtype = self.get_tensor_type_str(diagonal.tensor.Type()) @@ -3265,6 +3267,15 @@ def get_tensor_expr(self, tensor, is_sparse=False): expr = self.exp_tab.new_const(self.get_tensor_value(tensor, is_sparse), dtype=type_str) return expr + def get_tensor_shape(self, tensor_wrapper): + """ Returns tensor shape. Infers shape if the shape is empty. """ + assert isinstance(tensor_wrapper, TensorWrapper), "Expecting TensorWrapper here" + return ( + tensor_wrapper.tensor.ShapeAsNumpy() + if tensor_wrapper.tensor.ShapeLength() > 0 + else _infer_shape(self.get_tensor_expr(tensor_wrapper)) + ) + # pylint: disable=no-else-return def prepare_dense_matrix_from_sparse(sparse_tensor, sparse_tensor_value, sparse_tensor_type):