diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 525fb41407d3..316815c980e3 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -353,7 +353,7 @@ def get_tensor_value(self, tensor_wrapper, is_sparse=False): data = tensor_wrapper.buffer.DataAsNumpy() if tensor_wrapper.tensor.ShapeLength() != 0: - shape = to_int_list(tensor_wrapper.tensor.ShapeAsNumpy()) + shape = to_int_list(self.get_tensor_shape(tensor_wrapper)) else: shape = [] @@ -1417,7 +1417,7 @@ def convert_gather(self, op): axis = gather_options.Axis() # Check the indices are with in bounds. - data_shape = to_int_list(input_tensors[0].tensor.ShapeAsNumpy()) + data_shape = to_int_list(self.get_tensor_shape(input_tensors[0])) data_dim = len(data_shape) axis = data_dim + axis if axis < 0 else axis @@ -1535,7 +1535,7 @@ def convert_strided_slice(self, op): new_axis_mask = options.NewAxisMask() shrink_axis_mask = options.ShrinkAxisMask() - data_shape = to_int_list(input_tensors[0].tensor.ShapeAsNumpy()) + data_shape = to_int_list(self.get_tensor_shape(input_tensors[0])) data_dim = len(data_shape) stride_dim = len(stride) @@ -1792,7 +1792,7 @@ def convert_fully_connected(self, op): output_tensor_type = output_tensor.tensor.Type() output_tensor_type_str = self.get_tensor_type_str(output_tensor_type) - weight_tensor_shape = to_int_list(weight_tensor.tensor.ShapeAsNumpy()) + weight_tensor_shape = to_int_list(self.get_tensor_shape(weight_tensor)) # Weight should have only 2 dimensions(TFLite convention) assert len(weight_tensor_shape) == 2, "Weight should be only 2-dim" @@ -1987,16 +1987,16 @@ def convert_conv(self, op, conv_type): padding = conv_options.Padding() fused_activation_fn = conv_options.FusedActivationFunction() - _, input_h, input_w, input_c = to_int_list(input_tensor.tensor.ShapeAsNumpy()) + _, input_h, input_w, input_c = to_int_list(self.get_tensor_shape(input_tensor)) if is_depthwise_conv: # TFLite depthwise convolution kernel layout is: # 1 KH KW C(input_c * depth_multiplier) - _, kernel_h, kernel_w, in_channels = to_int_list(weight_tensor.tensor.ShapeAsNumpy()) + _, kernel_h, kernel_w, in_channels = to_int_list(self.get_tensor_shape(weight_tensor)) assert in_channels == input_c * depth_multiplier else: output_channels, kernel_h, kernel_w, _ = to_int_list( - weight_tensor.tensor.ShapeAsNumpy() + self.get_tensor_shape(weight_tensor) ) dilated_kernel_h = dilation_h * (kernel_h - 1) + 1 @@ -2219,7 +2219,7 @@ def convert_slice(self, op): size = list(self.get_tensor_value(input_tensors[2])) # strided_slice(Relay) needs the slice's end indices, not the size end = size - input_tensor_shape = to_int_list(input_tensor.tensor.ShapeAsNumpy()) + input_tensor_shape = to_int_list(self.get_tensor_shape(input_tensor)) input_tensor_rank = len(input_tensor_shape) for i in range(input_tensor_rank): if size[i] == -1: @@ -2381,7 +2381,8 @@ def convert_pool2d(self, op, pool_type): in_expr = self.get_expr(input_tensor_idx) - _, input_h, input_w, _ = to_int_list(input_tensor.tensor.ShapeAsNumpy()) + _, input_h, input_w, _ = to_int_list(self.get_tensor_shape(input_tensor)) + if padding == Padding.VALID: pass elif padding == Padding.SAME: @@ -2771,12 +2772,13 @@ def convert_transpose_conv(self, op): # Input (data) Tensor. NHWC layout input_tensor = input_tensors[2] - _, input_h, input_w, input_c = to_int_list(input_tensor.tensor.ShapeAsNumpy()) + _, input_h, input_w, input_c = to_int_list(self.get_tensor_shape(input_tensor)) # Weights tensor. TFLite uses OHWI layout weights_tensor = input_tensors[1] out_channels, kernel_h, kernel_w, in_channels = to_int_list( - weights_tensor.tensor.ShapeAsNumpy() + self.get_tensor_shape(weights_tensor) ) + assert ( input_c == in_channels ), "Input channel in the filter should match to channel in the input" @@ -3204,7 +3206,7 @@ def convert_matrix_diag(self, op): ), "TFLite MATRIX_DIAG requires diagonal and output tensors' \ scale and zero points to be equal" - shape = to_int_list(diagonal.tensor.ShapeAsNumpy()) + shape = to_int_list(self.get_tensor_shape(diagonal)) shape = np.append(shape, shape[-1]) dtype = self.get_tensor_type_str(diagonal.tensor.Type()) @@ -3265,6 +3267,15 @@ def get_tensor_expr(self, tensor, is_sparse=False): expr = self.exp_tab.new_const(self.get_tensor_value(tensor, is_sparse), dtype=type_str) return expr + def get_tensor_shape(self, tensor_wrapper): + """ Returns tensor shape. Infers shape if the shape is empty. """ + assert isinstance(tensor_wrapper, TensorWrapper), "Expecting TensorWrapper here" + return ( + tensor_wrapper.tensor.ShapeAsNumpy() + if tensor_wrapper.tensor.ShapeLength() > 0 + else _infer_shape(self.get_tensor_expr(tensor_wrapper)) + ) + # pylint: disable=no-else-return def prepare_dense_matrix_from_sparse(sparse_tensor, sparse_tensor_value, sparse_tensor_type):