Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2146,7 +2146,7 @@ def convert_conv(self, op, conv_type):
_, kernel_h, kernel_w, in_channels = to_int_list(self.get_tensor_shape(weight_tensor))
assert in_channels == input_c * depth_multiplier
else:
output_channels, kernel_h, kernel_w, _ = to_int_list(
output_channels, kernel_h, kernel_w, in_channels = to_int_list(
self.get_tensor_shape(weight_tensor)
)

Expand All @@ -2170,6 +2170,11 @@ def convert_conv(self, op, conv_type):
else:
params["channels"] = int(output_channels)
params["kernel_layout"] = "HWIO"
if input_c != in_channels:
assert (
input_c % in_channels == 0
), "Input channels is not divisible of kernel in_channels."
params["groups"] = int(input_c / in_channels)

# weight tensor type should be INT8/UINT8 (quantization) or FLOAT32
weight_tensor_type = weight_tensor.tensor.Type()
Expand Down
175 changes: 34 additions & 141 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import os
import tempfile
import typing
from packaging import version as package_version
import pytest
import numpy as np
Expand Down Expand Up @@ -292,11 +293,11 @@ def run_tflite_graph(tflite_model_buf, input_data):


def compare_tflite_with_tvm(
in_data,
in_name,
input_tensors,
output_tensors,
init_global_variables=False,
in_data: typing.List[np.ndarray],
in_name: typing.List[str],
input_tensors: typing.List,
output_tensors: typing.List,
init_global_variables: bool = False,
out_names=None,
quantized=False,
input_range=None,
Expand Down Expand Up @@ -5301,140 +5302,32 @@ def _golden():
_test_reshape_span()


#######################################################################
# Main
# ----
class TestConv2d:
"""Import Conv2d operator from TFLite, build with Relay and test."""

input_shape, kernel_shape, padding = tvm.testing.parameters(
((1, 128, 256, 6), (5, 5, 6, 10), "SAME"),
((1, 128, 256, 6), (5, 5, 6, 10), "VALID"),
# conv2d_group cases
((1, 30, 40, 6), (5, 5, 1, 6), "SAME"),
((1, 30, 40, 6), (5, 5, 1, 6), "VALID"),
)

def test_conv2d(self, input_shape: tuple, kernel_shape: tuple, padding: str):
dtype = tf.float32
kernel_in = np.ones(kernel_shape)
with tf.Graph().as_default():
x = array_ops.placeholder(shape=input_shape, dtype=dtype.name, name="input")
kernel = tf.constant(kernel_in, dtype=dtype, name="filter_weight")
out = tf.nn.conv2d(x, kernel, strides=[1, 1, 1, 1], padding=padding, name="conv2d")
input_data = np.random.randn(*input_shape).astype(dtype.name)
compare_tflite_with_tvm(
[input_data],
["input"],
[x],
[out],
)


if __name__ == "__main__":
# BatchToSpaceND
test_forward_batch_to_space_nd()

# SpaceToBatchND
test_forward_space_to_batch_nd()

# Split
test_forward_split()

# Transpose
test_forward_transpose()

# Cast
test_forward_cast()

# BatchMatMul
test_forward_batch_matmul()

# Tile
test_forward_tile()

# Query
test_forward_shape()

# Transforms
test_forward_concatenation()
test_forward_pad()
test_forward_pack()
test_forward_unpack()
test_forward_reshape()
test_all_resize()
test_forward_range()
test_forward_squeeze()
test_forward_slice()
test_forward_topk()
test_forward_gather()
test_forward_gather_nd()
test_forward_stridedslice()
test_forward_depthtospace()
test_forward_spacetodepth()
test_forward_reverse_sequence()
test_forward_sparse_to_dense()
test_forward_select()
test_forward_quantize_dequantize()
test_forward_arg_min_max()
test_forward_expand_dims()
test_forward_reverse_v2()
test_forward_matrix_set_diag()
test_forward_matrix_diag()

# NN
test_forward_convolution()
test_forward_transpose_conv()
test_forward_logistic()
test_forward_pooling()
test_forward_l2_pool2d()
test_forward_softmax()
test_forward_tanh()
test_forward_relu()
test_forward_relu6()
test_forward_leaky_relu()
test_forward_relu_n1_to_1()
test_forward_log_softmax()
test_forward_fully_connected()
test_forward_l2_normalization()
test_forward_local_response_normalization()
test_forward_prelu()
test_forward_unidirectional_sequence_lstm()

# Elemwise
test_all_elemwise()
test_forward_add_n()

# Unary elemwise
test_all_unary_elemwise()
# Zeros Like
test_forward_zeros_like()

# Fill
test_forward_fill()

# Reduce
test_all_reduce()

# Logical
test_all_logical()

# Detection_PostProcess
test_detection_postprocess()

# NonMaxSuppressionV5
test_forward_nms_v5()

# Overwrite Converter
test_custom_op_converter()

# test structural_equal and span information
test_structure_and_span()

# End to End
test_forward_mobilenet_v1()
test_forward_mobilenet_v2()
test_forward_mobilenet_v3()
test_forward_inception_v3_net()
test_forward_inception_v4_net()
test_forward_inception_v4_net_batched()
test_forward_coco_ssd_mobilenet_v1()
test_forward_mediapipe_hand_landmark()

# End to End Sparse models
test_forward_sparse_mobilenet_v1()
test_forward_sparse_mobilenet_v2()

# End to End quantized
test_forward_qnn_inception_v1_net()
test_forward_qnn_mobilenet_v1_net()
test_forward_qnn_mobilenet_v2_net()
# This also fails with a segmentation fault in my run
# with Tflite 1.15.2
test_forward_qnn_mobilenet_v3_net()
test_forward_qnn_coco_ssd_mobilenet_v1()

# TFLite 2.1.0 quantized tests
test_forward_quantized_convolution()
test_forward_quantized_depthwise_convolution()
test_forward_tflite2_qnn_resnet50()
test_forward_tflite2_qnn_inception_v1()
test_forward_tflite2_qnn_mobilenet_v2()

test_forward_tflite_float16()

test_forward_tflite_int16()
test_forward_ds_cnn_int16()
tvm.testing.main()