diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 3d2f4a2f25e6..09e6523534cf 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -1549,7 +1549,7 @@ def convert_gather(self, op): assert axis < data_dim, "Axis out of bounds" if self.has_expr(indices.tensor_idx): - indices_expr = self.get_expr(indices.tensor_idx) + indices_expr = _op.cast(self.get_expr(indices.tensor_idx), "int32") else: indices_val = self.get_tensor_value(indices) indices_expr = self.exp_tab.new_const( diff --git a/tests/python/relay/opencl_texture/conftest.py b/tests/python/relay/opencl_texture/conftest.py new file mode 100644 index 000000000000..6b9c91ec1067 --- /dev/null +++ b/tests/python/relay/opencl_texture/conftest.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import tvm +from tvm import rpc +import pytest + + +@pytest.fixture(scope="session") +def remote(): + if ( + "TVM_TRACKER_HOST" in os.environ + and "TVM_TRACKER_PORT" in os.environ + and "RPC_DEVICE_KEY" in os.environ + ): + + rpc_tracker_host = os.environ["TVM_TRACKER_HOST"] + rpc_tracker_port = int(os.environ["TVM_TRACKER_PORT"]) + rpc_device_key = os.environ["RPC_DEVICE_KEY"] + tracker = rpc.connect_tracker(rpc_tracker_host, rpc_tracker_port) + remote = tracker.request(rpc_device_key, priority=0, session_timeout=600) + return remote + else: + return None diff --git a/tests/python/relay/opencl_texture/test_conv2d_nchw_texture.py b/tests/python/relay/opencl_texture/test_conv2d_nchw_texture.py index 5198cbdf6bc6..a0ca8423478e 100644 --- a/tests/python/relay/opencl_texture/test_conv2d_nchw_texture.py +++ b/tests/python/relay/opencl_texture/test_conv2d_nchw_texture.py @@ -30,7 +30,7 @@ @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad(target, dtype): +def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad(remote, target, dtype): input_shape = (1, 32, 42, 42) filter_shape = (96, 32, 3, 3) bias_shape = (1, 96, 1, 1) @@ -65,12 +65,14 @@ def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad(target, dtype): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target, [], gpu_preprocess) + build_run_compare( + remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, [], gpu_preprocess + ) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad_pass(target, dtype): +def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad_pass(remote, target, dtype): input_shape = (1, 32, 40, 40) filter_shape = (96, 32, 2, 2) bias_shape = (1, 96, 1, 1) @@ -105,12 +107,14 @@ def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad_pass(target, dtype): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target, [], gpu_preprocess) + build_run_compare( + remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, [], gpu_preprocess + ) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_inceptionv3_35_35_strides(target, dtype): +def test_conv2d_inceptionv3_35_35_strides(remote, target, dtype): input_shape = (1, 48, 35, 35) filter_shape = (64, 48, 5, 5) bias_shape = (1, 64, 1, 1) @@ -145,12 +149,14 @@ def test_conv2d_inceptionv3_35_35_strides(target, dtype): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target, [], gpu_preprocess) + build_run_compare( + remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, [], gpu_preprocess + ) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_resnet50_v2_nchw_3c(target, dtype): +def test_conv2d_resnet50_v2_nchw_3c(remote, target, dtype): input_shape = (1, 3, 224, 224) filter_shape = (64, 3, 7, 7) bias_shape = (1, 64, 1, 1) @@ -186,12 +192,12 @@ def test_conv2d_resnet50_v2_nchw_3c(target, dtype): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target) + build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_inceptionv3_nchw_3c(target, dtype): +def test_conv2d_inceptionv3_nchw_3c(remote, target, dtype): input_shape = (1, 3, 299, 299) filter_shape = (64, 3, 3, 3) bias_shape = (1, 64, 1, 1) @@ -226,12 +232,12 @@ def test_conv2d_inceptionv3_nchw_3c(target, dtype): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target) + build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_1x1_16c16spatial(target, dtype): +def test_conv2d_1x1_16c16spatial(remote, target, dtype): input_shape = (1, 16, 256, 256) filter_shape = (32, 16, 4, 4) bias_shape = (1, 32, 1, 1) @@ -266,12 +272,12 @@ def test_conv2d_1x1_16c16spatial(target, dtype): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target) + build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_4x4_16c16pad(target, dtype): +def test_conv2d_4x4_16c16pad(remote, target, dtype): input_shape = (1, 32, 256, 256) filter_shape = (32, 32, 4, 4) bias_shape = (1, 32, 1, 1) @@ -306,12 +312,12 @@ def test_conv2d_4x4_16c16pad(target, dtype): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target) + build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_4x4x4_16c16pad(target, dtype): +def test_conv2d_4x4x4_16c16pad(remote, target, dtype): input_shape = (1, 32, 256, 256) filter_shape = (4, 32, 4, 4) bias_shape = (1, 4, 1, 1) @@ -346,12 +352,12 @@ def test_conv2d_4x4x4_16c16pad(target, dtype): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target) + build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_yolov3_v2_nchw_3c(target, dtype): +def test_conv2d_yolov3_v2_nchw_3c(remote, target, dtype): input_shape = (1, 1024, 13, 13) filter_shape = (255, 1024, 1, 1) A = relay.var("data", shape=input_shape, dtype=dtype) @@ -379,12 +385,12 @@ def test_conv2d_yolov3_v2_nchw_3c(target, dtype): "weight": tvm.nd.array(filter_data), } - build_run_compare(mod, params, {"data": input_shape}, dtype, target) + build_run_compare(remote, mod, params, {"data": input_shape}, {"data": dtype}, target) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_vgg16_winograd_4d(target, dtype): +def test_conv2d_vgg16_winograd_4d(remote, target, dtype): input_shape = (1, 512, 28, 28) filter_shape = (512, 512, 3, 3) bias_shape = (1, 512, 1, 1) @@ -424,7 +430,7 @@ def test_conv2d_vgg16_winograd_4d(target, dtype): f'{{"input": ["opencl -keys=adreno,opencl,gpu -device=adreno -max_num_threads=256", "conv2d_nchw_winograd.image2d", [["TENSOR", [1, 512, 28, 28], "{dtype}"], ["TENSOR", [512, 512, 3, 3], "{dtype}"], [1, 1], [1, 1, 1, 1], [1, 1], "{dtype}"], {{}}], "config": {{"index": 1591, "code_hash": null, "entity": [["auto_unroll_max_step", "ot", 4], ["tile_y", "sp", [-1, 1, 32]], ["tile_x", "sp", [-1, 4, 2]], ["tile_rc", "sp", [-1, 8]]]}}, "result": [[0.0037244], 0, 7.06374192237854, 1653898629.7427933], "version": 0.2, "tvm_version": "0.8.dev0"}}\n' ) graph = build_run_compare( - mod, params1, {"data": input_shape}, dtype, target, stat_file=stat_file + remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, stat_file=stat_file ) matches = re.findall("winograd", graph) assert len(matches) > 0 @@ -432,7 +438,7 @@ def test_conv2d_vgg16_winograd_4d(target, dtype): @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_winograd_conv(target, dtype): +def test_conv2d_winograd_conv(remote, target, dtype): input_shape = (1, 4, 3, 3) A = relay.var("data", shape=input_shape, dtype=dtype) filter_shape3 = (8, 4, 3, 3) @@ -471,7 +477,7 @@ def test_conv2d_winograd_conv(target, dtype): f'{{"input": ["opencl -keys=adreno,opencl,gpu -device=adreno -max_num_threads=256", "conv2d_nchw_winograd.image2d", [["TENSOR", [1, 4, 3, 3], "{dtype}"], ["TENSOR", [8, 4, 3, 3], "{dtype}"], [1, 1], [1, 1, 1, 1], [1, 1], "{dtype}"], {{}}], "config": {{"index": 1591, "code_hash": null, "entity": [["auto_unroll_max_step", "ot", 4], ["tile_y", "sp", [-1, 1, 32]], ["tile_x", "sp", [-1, 4, 2]], ["tile_rc", "sp", [-1, 8]]]}}, "result": [[0.0037244], 0, 7.06374192237854, 1653898629.7427933], "version": 0.2, "tvm_version": "0.8.dev0"}}\n' ) graph = build_run_compare( - mod, params1, {"data": input_shape}, dtype, target, stat_file=stat_file + remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, stat_file=stat_file ) matches = re.findall("winograd", graph) assert len(matches) > 0 @@ -479,7 +485,7 @@ def test_conv2d_winograd_conv(target, dtype): @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_residual_block(target, dtype): +def test_residual_block(remote, target, dtype): """ - some kind of residual block followed by convolution to have texture after residual block - scalar data type verification which should be mapped to global memory scope @@ -596,12 +602,14 @@ def test_residual_block(target, dtype): "", ] - build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) + build_run_compare( + remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, static_memory_scope + ) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_concat(target, dtype): +def test_concat(remote, target, dtype): """ layout_transform (NCHW->NCHW4c) | <- buffer @@ -708,12 +716,14 @@ def test_concat(target, dtype): static_memory_scope = [] - build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) + build_run_compare( + remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, static_memory_scope + ) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_pooling_branching_texture_params(target, dtype): +def test_pooling_branching_texture_params(remote, target, dtype): """ Verification of the pooling and many branches having textures layout_transform (NCHW->NCHW4c) @@ -834,12 +844,14 @@ def test_pooling_branching_texture_params(target, dtype): "", ] - build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) + build_run_compare( + remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, static_memory_scope + ) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_branching_texture_params(target, dtype): +def test_branching_texture_params(remote, target, dtype): """ Verification of passing texture to several consumers markup of relay variables in primary functions + on_device @@ -958,13 +970,15 @@ def test_branching_texture_params(target, dtype): "", ] - build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) + build_run_compare( + remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, static_memory_scope + ) # function repeat, params scope are different in reused functions @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_different_lowering_same_op(target, dtype): +def test_conv2d_different_lowering_same_op(remote, target, dtype): """ Use case for verification of caching compiled functions Three convolutions following by each other in this case should be @@ -1040,12 +1054,14 @@ def test_conv2d_different_lowering_same_op(target, dtype): "", ] - build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) + build_run_compare( + remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, static_memory_scope + ) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_winograd_non_rect(target, dtype): +def test_conv2d_winograd_non_rect(remote, target, dtype): input_shape = (1, 771, 36, 64) A = relay.var("data", shape=input_shape, dtype=dtype) filter_shape = (128, 771, 3, 3) @@ -1070,7 +1086,7 @@ def test_conv2d_winograd_non_rect(target, dtype): f'{{"input": ["opencl -keys=adreno,opencl,gpu -device=adreno -max_num_threads=256 -texture_spatial_limit=16384 -thread_warp_size=1", "conv2d_nchw_winograd.image2d", [["TENSOR", [1, 771, 36, 64], "{dtype}"], ["TENSOR", [128, 771, 3, 3], "{dtype}"], [1, 1], [1, 1, 1, 1], [1, 1], "{dtype}"], {{}}], "config": {{"index": 5399, "code_hash": null, "entity": [["auto_unroll_max_step", "ot", 16], ["tile_y", "sp", [-1, 1, 32]], ["tile_x", "sp", [-1, 4, 8]], ["tile_rc", "sp", [-1, 193]]]}}, "result": [[0.0037244], 0, 7.06374192237854, 1653898629.7427933], "version": 0.2, "tvm_version": "0.8.dev0"}}\n' ) graph = build_run_compare( - mod, params1, {"data": input_shape}, dtype, target, stat_file=stat_file + remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, stat_file=stat_file ) matches = re.findall("winograd", graph) assert len(matches) > 0 @@ -1079,7 +1095,7 @@ def test_conv2d_winograd_non_rect(target, dtype): # function repeat, params scope are different in reused functions @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_injective_nwo_inputs1(target, dtype): +def test_injective_nwo_inputs1(remote, target, dtype): """ Use case for verification of stability of annotation primary functions having several ops accepting data outside of Primary function @@ -1170,13 +1186,15 @@ def test_injective_nwo_inputs1(target, dtype): "global", "global", ] - build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) + build_run_compare( + remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, static_memory_scope + ) # function repeat, params scope are different in reused functions @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_injective_nwo_inputs2(target, dtype): +def test_injective_nwo_inputs2(remote, target, dtype): """ Use case for verification of stability of annotation primary functions having several ops accepting data outside of Primary function @@ -1266,4 +1284,10 @@ def test_injective_nwo_inputs2(target, dtype): "global.texture", "global", ] - build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) + build_run_compare( + remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, static_memory_scope + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relay/opencl_texture/test_conv2d_nhwc_texture.py b/tests/python/relay/opencl_texture/test_conv2d_nhwc_texture.py index 0b89e3dc9c7f..43979cc79a68 100644 --- a/tests/python/relay/opencl_texture/test_conv2d_nhwc_texture.py +++ b/tests/python/relay/opencl_texture/test_conv2d_nhwc_texture.py @@ -31,7 +31,7 @@ @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_deeplabv3_1_257_257_32x1_1_32_16(target, dtype): +def test_conv2d_deeplabv3_1_257_257_32x1_1_32_16(remote, target, dtype): input_shape = (1, 257, 257, 32) filter_shape = (1, 1, 32, 16) bias_shape = (filter_shape[-1],) @@ -63,12 +63,12 @@ def test_conv2d_deeplabv3_1_257_257_32x1_1_32_16(target, dtype): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target) + build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_deeplabv3_1_257_257_32x1_1_32_16_with_padding(target, dtype): +def test_conv2d_deeplabv3_1_257_257_32x1_1_32_16_with_padding(remote, target, dtype): input_shape = (1, 257, 257, 32) filter_shape = (1, 1, 32, 16) bias_shape = (filter_shape[-1],) @@ -103,12 +103,12 @@ def test_conv2d_deeplabv3_1_257_257_32x1_1_32_16_with_padding(target, dtype): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target) + build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_4_35_35_32x3_3_144_16(target, dtype): +def test_conv2d_4_35_35_32x3_3_144_16(remote, target, dtype): input_shape = (4, 35, 35, 32) filter_shape = (3, 3, 32, 16) bias_shape = (filter_shape[-1],) @@ -141,12 +141,12 @@ def test_conv2d_4_35_35_32x3_3_144_16(target, dtype): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target) + build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_deeplabv3_1_513_513_3x3_3_3_32(target, dtype): +def test_conv2d_deeplabv3_1_513_513_3x3_3_3_32(remote, target, dtype): input_shape = (1, 513, 513, 3) filter_shape = (3, 3, 3, 32) bias_shape = (filter_shape[-1],) @@ -179,12 +179,12 @@ def test_conv2d_deeplabv3_1_513_513_3x3_3_3_32(target, dtype): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target) + build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad(target, dtype): +def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad(remote, target, dtype): input_shape = (1, 42, 42, 32) filter_shape = (3, 3, 32, 96) bias_shape = (1, 1, 1, 96) @@ -219,12 +219,14 @@ def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad(target, dtype): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target, [], gpu_preprocess) + build_run_compare( + remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, [], gpu_preprocess + ) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad_pass(target, dtype): +def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad_pass(remote, target, dtype): input_shape = (1, 40, 40, 32) filter_shape = (2, 2, 32, 96) bias_shape = (1, 1, 1, 96) @@ -259,12 +261,14 @@ def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad_pass(target, dtype): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target, [], gpu_preprocess) + build_run_compare( + remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, [], gpu_preprocess + ) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_inceptionv3_35_35_strides(target, dtype): +def test_conv2d_inceptionv3_35_35_strides(remote, target, dtype): input_shape = (1, 35, 35, 48) filter_shape = (5, 5, 48, 64) bias_shape = (1, 1, 1, 64) @@ -299,12 +303,14 @@ def test_conv2d_inceptionv3_35_35_strides(target, dtype): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target, [], gpu_preprocess) + build_run_compare( + remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, [], gpu_preprocess + ) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_resnet50_v2_nhwc_3c(target, dtype): +def test_conv2d_resnet50_v2_nhwc_3c(remote, target, dtype): input_shape = (1, 224, 224, 3) filter_shape = (7, 7, 3, 64) bias_shape = (1, 1, 1, 64) @@ -340,12 +346,12 @@ def test_conv2d_resnet50_v2_nhwc_3c(target, dtype): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target) + build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_inceptionv3_nhwc_3c(target, dtype): +def test_conv2d_inceptionv3_nhwc_3c(remote, target, dtype): input_shape = (1, 299, 299, 3) filter_shape = (3, 3, 3, 64) bias_shape = (1, 1, 1, 64) @@ -380,12 +386,12 @@ def test_conv2d_inceptionv3_nhwc_3c(target, dtype): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target) + build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_1x1_16c16spatial(target, dtype): +def test_conv2d_1x1_16c16spatial(remote, target, dtype): input_shape = (1, 128, 128, 16) filter_shape = (4, 4, 16, 32) bias_shape = (1, 1, 1, 32) @@ -420,12 +426,12 @@ def test_conv2d_1x1_16c16spatial(target, dtype): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target) + build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_4x4_16c16pad(target, dtype): +def test_conv2d_4x4_16c16pad(remote, target, dtype): input_shape = (1, 256, 256, 32) filter_shape = (4, 4, 32, 32) bias_shape = (1, 1, 1, 32) @@ -460,12 +466,12 @@ def test_conv2d_4x4_16c16pad(target, dtype): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target) + build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_4x4x4_16c16pad(target, dtype): +def test_conv2d_4x4x4_16c16pad(remote, target, dtype): input_shape = (1, 256, 256, 32) filter_shape = (4, 4, 32, 4) bias_shape = (1, 1, 1, 4) @@ -499,12 +505,12 @@ def test_conv2d_4x4x4_16c16pad(target, dtype): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target) + build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_yolov3_v2_nhwc_3c(target, dtype): +def test_conv2d_yolov3_v2_nhwc_3c(remote, target, dtype): input_shape = (1, 13, 13, 1024) filter_shape = (1, 1, 1024, 255) A = relay.var("data", shape=input_shape, dtype=dtype) @@ -532,12 +538,12 @@ def test_conv2d_yolov3_v2_nhwc_3c(target, dtype): "weight": tvm.nd.array(filter_data), } - build_run_compare(mod, params, {"data": input_shape}, dtype, target) + build_run_compare(remote, mod, params, {"data": input_shape}, {"data": dtype}, target) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_vgg16_winograd_4d(target, dtype): +def test_conv2d_vgg16_winograd_4d(remote, target, dtype): input_shape = (1, 28, 28, 512) filter_shape = (3, 3, 512, 512) bias_shape = (1, 1, 1, 512) @@ -577,7 +583,7 @@ def test_conv2d_vgg16_winograd_4d(target, dtype): f'{{"input": ["opencl -keys=adreno,opencl,gpu -device=adreno -max_num_threads=256", "conv2d_nhwc_winograd.image2d", [["TENSOR", [1, 28, 28, 512], "{dtype}"], ["TENSOR", [3, 3, 512, 512], "{dtype}"], [1, 1], [1, 1, 1, 1], [1, 1], "{dtype}"], {{}}], "config": {{"index": 1591, "code_hash": null, "entity": [["auto_unroll_max_step", "ot", 4], ["tile_y", "sp", [-1, 1, 32]], ["tile_x", "sp", [-1, 4, 2]], ["tile_rc", "sp", [-1, 8]]]}}, "result": [[0.0037244], 0, 7.06374192237854, 1653898629.7427933], "version": 0.2, "tvm_version": "0.8.dev0"}}\n' ) graph = build_run_compare( - mod, params1, {"data": input_shape}, dtype, target, stat_file=stat_file + remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, stat_file=stat_file ) matches = re.findall("winograd", graph) assert len(matches) > 0 @@ -585,7 +591,7 @@ def test_conv2d_vgg16_winograd_4d(target, dtype): @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_winograd_conv(target, dtype): +def test_conv2d_winograd_conv(remote, target, dtype): input_shape = (1, 3, 3, 4) A = relay.var("data", shape=input_shape, dtype=dtype) filter_shape3 = (3, 3, 4, 8) @@ -638,7 +644,7 @@ def test_conv2d_winograd_conv(target, dtype): f'{{"input": ["opencl -keys=adreno,opencl,gpu -device=adreno -max_num_threads=256", "conv2d_nhwc_winograd.image2d", [["TENSOR", [1, 3, 3, 4], "{dtype}"], ["TENSOR", [3, 3, 4, 8], "{dtype}"], [1, 1], [1, 1, 1, 1], [1, 1], "{dtype}"], {{}}], "config": {{"index": 1591, "code_hash": null, "entity": [["auto_unroll_max_step", "ot", 4], ["tile_y", "sp", [-1, 1, 32]], ["tile_x", "sp", [-1, 4, 2]], ["tile_rc", "sp", [-1, 8]]]}}, "result": [[0.0037244], 0, 7.06374192237854, 1653898629.7427933], "version": 0.2, "tvm_version": "0.8.dev0"}}\n' ) graph = build_run_compare( - mod, params1, {"data": input_shape}, dtype, target, stat_file=stat_file + remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, stat_file=stat_file ) matches = re.findall("winograd", graph) assert len(matches) > 0 @@ -646,7 +652,7 @@ def test_conv2d_winograd_conv(target, dtype): @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_conv2d_winograd_non_rect(target, dtype): +def test_conv2d_winograd_non_rect(remote, target, dtype): input_shape = (1, 36, 64, 771) A = relay.var("data", shape=input_shape, dtype=dtype) filter_shape = (3, 3, 771, 128) @@ -678,7 +684,11 @@ def test_conv2d_winograd_non_rect(target, dtype): f'{{"input": ["opencl -keys=adreno,opencl,gpu -device=adreno -max_num_threads=256 -texture_spatial_limit=16384 -thread_warp_size=1", "conv2d_nhwc_winograd.image2d", [["TENSOR", [1, 36, 64, 771], "{dtype}"], ["TENSOR", [3, 3, 771, 128], "{dtype}"], [1, 1], [1, 1, 1, 1], [1, 1], "{dtype}"], {{}}], "config": {{"index": 5399, "code_hash": null, "entity": [["auto_unroll_max_step", "ot", 16], ["tile_y", "sp", [-1, 1, 32]], ["tile_x", "sp", [-1, 4, 8]], ["tile_rc", "sp", [-1, 193]]]}}, "result": [[0.0037244], 0, 7.06374192237854, 1653898629.7427933], "version": 0.2, "tvm_version": "0.8.dev0"}}\n' ) graph = build_run_compare( - mod, params1, {"data": input_shape}, dtype, target, stat_file=stat_file + remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, stat_file=stat_file ) matches = re.findall("winograd", graph) assert len(matches) > 0 + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relay/opencl_texture/test_depthwise_conv2d_nchw_texture.py b/tests/python/relay/opencl_texture/test_depthwise_conv2d_nchw_texture.py index 0ac92d03b6f9..00e2c5a8c069 100644 --- a/tests/python/relay/opencl_texture/test_depthwise_conv2d_nchw_texture.py +++ b/tests/python/relay/opencl_texture/test_depthwise_conv2d_nchw_texture.py @@ -27,7 +27,7 @@ @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_depthwise_conv2d_bias_nchwc(target, dtype): +def test_depthwise_conv2d_bias_nchwc(remote, target, dtype): input_shape = (1, 64, 112, 112) filter_shape = (64, 1, 3, 3) bias_shape = (1, 64, 1, 1) @@ -64,12 +64,14 @@ def test_depthwise_conv2d_bias_nchwc(target, dtype): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target, [], gpu_preprocess) + build_run_compare( + remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, [], gpu_preprocess + ) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_depthwise_conv2d_nchwc(target, dtype): +def test_depthwise_conv2d_nchwc(remote, target, dtype): input_shape = (1, 64, 112, 112) filter_shape = (64, 1, 3, 3) bias_shape = (1, 64, 1, 1) @@ -101,12 +103,14 @@ def test_depthwise_conv2d_nchwc(target, dtype): "weight": tvm.nd.array(filter_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target, [], gpu_preprocess) + build_run_compare( + remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, [], gpu_preprocess + ) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_depthwise_conv2d_bias_nchw(target, dtype): +def test_depthwise_conv2d_bias_nchw(remote, target, dtype): input_shape = (1, 64, 112, 112) filter_shape = (64, 1, 3, 3) bias_shape = (1, 64, 1, 1) @@ -143,12 +147,12 @@ def test_depthwise_conv2d_bias_nchw(target, dtype): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target) + build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_depthwise_conv2d_repack_bias_nchw(target, dtype): +def test_depthwise_conv2d_repack_bias_nchw(remote, target, dtype): input_shape = (1, 63, 112, 112) filter_shape = (63, 1, 3, 3) bias_shape = (1, 63, 1, 1) @@ -185,4 +189,8 @@ def test_depthwise_conv2d_repack_bias_nchw(target, dtype): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target) + build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relay/opencl_texture/test_depthwise_conv2d_nhwc_texture.py b/tests/python/relay/opencl_texture/test_depthwise_conv2d_nhwc_texture.py index 3af7db3a4e1f..7d7f640294ce 100644 --- a/tests/python/relay/opencl_texture/test_depthwise_conv2d_nhwc_texture.py +++ b/tests/python/relay/opencl_texture/test_depthwise_conv2d_nhwc_texture.py @@ -27,7 +27,7 @@ @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_depthwise_conv2d_deeplabv3_1_129_129_144x3_3_144_1(target, dtype): +def test_depthwise_conv2d_deeplabv3_1_129_129_144x3_3_144_1(remote, target, dtype): input_shape = (1, 129, 129, 144) filter_shape = (3, 3, 144, 1) kernel_size = (filter_shape[0], filter_shape[1]) @@ -62,12 +62,12 @@ def test_depthwise_conv2d_deeplabv3_1_129_129_144x3_3_144_1(target, dtype): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target) + build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_depthwise_conv2d_deeplabv3_4_35_35_576x3_3_576_1(target, dtype): +def test_depthwise_conv2d_deeplabv3_4_35_35_576x3_3_576_1(remote, target, dtype): input_shape = (4, 35, 35, 576) filter_shape = (3, 3, 576, 1) kernel_size = (filter_shape[0], filter_shape[1]) @@ -102,12 +102,12 @@ def test_depthwise_conv2d_deeplabv3_4_35_35_576x3_3_576_1(target, dtype): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target) + build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_depthwise_conv2d_deeplabv3_1_129_129_144x3_3_144_1_with_padding(target, dtype): +def test_depthwise_conv2d_deeplabv3_1_129_129_144x3_3_144_1_with_padding(remote, target, dtype): input_shape = (1, 129, 129, 144) filter_shape = (3, 3, 144, 1) kernel_size = (filter_shape[0], filter_shape[1]) @@ -144,12 +144,12 @@ def test_depthwise_conv2d_deeplabv3_1_129_129_144x3_3_144_1_with_padding(target, "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target) + build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_depthwise_conv2d_1_513_513_7x3_3_7_1(target, dtype): +def test_depthwise_conv2d_1_513_513_7x3_3_7_1(remote, target, dtype): input_shape = (1, 513, 513, 7) filter_shape = (3, 3, 7, 1) bias_shape = (filter_shape[2],) @@ -183,12 +183,12 @@ def test_depthwise_conv2d_1_513_513_7x3_3_7_1(target, dtype): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target) + build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_depthwise_conv2d_1_513_513_3x3_3_3_1(target, dtype): +def test_depthwise_conv2d_1_513_513_3x3_3_3_1(remote, target, dtype): input_shape = (1, 513, 513, 3) filter_shape = (3, 3, 3, 1) bias_shape = (filter_shape[2],) @@ -222,4 +222,8 @@ def test_depthwise_conv2d_1_513_513_3x3_3_3_1(target, dtype): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target) + build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relay/opencl_texture/test_network.py b/tests/python/relay/opencl_texture/test_network.py new file mode 100644 index 000000000000..638be477d06c --- /dev/null +++ b/tests/python/relay/opencl_texture/test_network.py @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import re +import tvm +import numpy as np +from tvm import relay +from tvm.relay import testing +from tvm.contrib import utils +from utils.adreno_utils import gpu_preprocess, build_run_compare, get_model +import pytest +from tvm.relay.op import register_mixed_precision_conversion + + +def convert_to_fp16(mod, dtype): + from tvm.ir import IRModule + + mod = IRModule.from_expr(mod) + seq = tvm.transform.Sequential( + [relay.transform.InferType(), relay.transform.ToMixedPrecision()] + ) + with tvm.transform.PassContext(opt_level=3): + mod = seq(mod) + return mod + + +def _test_mobilenet_v1(remote, target, dtype): + mod, params, inputs, dtypes = get_model( + "https://github.com/mlcommons/mobile_models/raw/main/v0_7/tflite/mobilenet_edgetpu_224_1.0_float.tflite", + "mobilenet_edgetpu_224_1.0_float.tflite", + "tflite", + ) + if dtype == "float16": + mod = convert_to_fp16(mod["main"], dtype) + build_run_compare(remote, mod, params, inputs, dtypes, target, []) + + +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +@pytest.mark.skipif(tvm.testing.utils.IS_IN_CI, reason="CI doesn't support fp16(half datatypes)") +def test_mobilenet_v1_fp16(remote, target): + _test_mobilenet_v1(remote, target, "float16") + + +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_mobilenet_v1_fp32(remote, target): + _test_mobilenet_v1(remote, target, "float32") + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relay/opencl_texture/test_reduction_texture.py b/tests/python/relay/opencl_texture/test_reduction_texture.py index b14aefd2f9ab..9dc8a8992d27 100644 --- a/tests/python/relay/opencl_texture/test_reduction_texture.py +++ b/tests/python/relay/opencl_texture/test_reduction_texture.py @@ -29,23 +29,27 @@ @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_mean(target, dtype): +def test_mean(remote, target, dtype): # NCHW input_shape = (1, 3, 720, 1280) A = relay.var("data", shape=input_shape, dtype=dtype) mean = relay.mean(A, axis=1, keepdims=True) mod = relay.Function([A], mean) - build_run_compare(mod, {}, {"data": input_shape}, dtype, target) + build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target) @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") -def test_argmax(target, dtype): +def test_argmax(remote, target, dtype): # NCHW input_shape = (1, 3, 720, 1280) A = relay.var("data", shape=input_shape, dtype=dtype) argmax = relay.op.argmax(A, axis=[1]) mod = relay.Function([A], argmax) - build_run_compare(mod, {}, {"data": input_shape}, dtype, target) + build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relay/opencl_texture/utils/adreno_utils.py b/tests/python/relay/opencl_texture/utils/adreno_utils.py index 27768c3d0cec..e2a271d9f68d 100644 --- a/tests/python/relay/opencl_texture/utils/adreno_utils.py +++ b/tests/python/relay/opencl_texture/utils/adreno_utils.py @@ -21,6 +21,8 @@ import numpy as np from tvm import relay from tvm import autotvm +from tvm import rpc +from tvm.contrib import utils, ndk from tvm.relay import testing from tvm.relay.transform import recast from tvm.contrib import graph_runtime @@ -47,25 +49,20 @@ def get_cpu_reference(mod, params1, input_shape, inputs): # build module run with opencl and cpu, compare results def build_run_compare( + remote, tvm_mod, params1, input_shape, - dtype="float32", + dtypes, target="llvm", static_mem_scopes=[], gpu_preprocess=None, stat_file=None, ): - - if "TVM_TRACKER_HOST" in os.environ and "TVM_TRACKER_PORT" in os.environ: - rpc_tracker_host = os.environ["TVM_TRACKER_HOST"] - rpc_tracker_port = os.environ["TVM_TRACKER_PORT"] - run_on_host = 0 - target_host = "llvm -mtriple=arm64-linux-android" - rpc_tracker_port = int(rpc_tracker_port) - else: - run_on_host = 1 + if remote is None: target_host = "llvm" + else: + target_host = "llvm -mtriple=arm64-linux-android" if gpu_preprocess: tvm_mod_nchwc = gpu_preprocess(tvm_mod) @@ -97,16 +94,10 @@ def build_run_compare( for i in range(0, len(static_mem_scopes)): assert static_mem_scopes[i] == graph_json["attrs"]["storage_scope"][1][i] - if run_on_host: + if remote is None: ctx = tvm.opencl() m = graph_runtime.create(graph, lib, ctx) else: - from tvm import rpc - from tvm.contrib import utils, ndk - - rpc_key = "android" - tracker = rpc.connect_tracker(rpc_tracker_host, rpc_tracker_port) - remote = tracker.request(rpc_key, priority=0, session_timeout=600) temp = utils.tempdir() dso_binary = "dev_lib_cl.so" dso_binary_path = temp.relpath(dso_binary) @@ -117,22 +108,15 @@ def build_run_compare( m = graph_runtime.create(graph, rlib, ctx) m.set_input(**params) inputs = [] - if isinstance(input_shape, dict): - for key in input_shape: - inputs.append(np.random.normal(size=input_shape[key]).astype(dtype)) - m.set_input(key, inputs[-1]) - else: - inputs.append(np.random.normal(size=input_shape).astype(dtype)) - m.set_input("data", inputs[-1]) + for key in input_shape: + inputs.append(np.random.normal(size=input_shape[key]).astype(dtypes[key])) + m.set_input(key, inputs[-1]) m.run() ref_outputs = get_cpu_reference(tvm_mod, params1, input_shape, inputs) for i, ref_output in enumerate(ref_outputs): tvm_output = m.get_output(i) output = tvm_output.asnumpy() - # for index, x in np.ndenumerate(ref_output): - # if abs(output[index] - x) > 0.01: - # print(index, output[index], x) np.testing.assert_allclose(output, ref_output, rtol=1e-1, atol=1e-1) return graph @@ -147,3 +131,95 @@ def gpu_preprocess(tvm_mod): mod = tvm.IRModule.from_expr(tvm_mod) tvm_mod_nchwc = seq(mod) return tvm_mod_nchwc + + +def get_model(url, local_file, module): + def get_tensor_type_str(tensor_type): + """Get tensor type string representation when given TFLite tensor type""" + try: + from tflite.TensorType import TensorType + except ImportError: + raise ImportError("The tflite package must be installed") + + if tensor_type == TensorType.INT8: + return "int8" + if tensor_type == TensorType.INT16: + return "int16" + if tensor_type == TensorType.UINT8: + return "uint8" + if tensor_type == TensorType.FLOAT16: + return "float16" + if tensor_type == TensorType.FLOAT32: + return "float32" + if tensor_type == TensorType.INT32: + return "int32" + if tensor_type == TensorType.INT64: + return "int64" + if tensor_type == TensorType.BOOL: + return "bool" + raise NotImplementedError( + "Tensor type {} is currently not supported".format(str(tensor_type)) + ) + + if url is None: + model_path = local_file + else: + model_path = tvm.contrib.download.download_testdata(url, local_file, module=module) + + with open(model_path, "rb") as f: + tflite_model_buf = f.read() + + try: + import tflite.Model + + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) + except AttributeError: + import tflite + + tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) + except ImportError: + raise ImportError("The tflite package must be installed") + + # keep the same as tflite + assert tflite_model.SubgraphsLength() == 1, "only support one subgraph (main subgraph)" + subgraph = tflite_model.Subgraphs(0) + + # model inputs + model_inputs = subgraph.InputsAsNumpy() + shape_dict = {} + dtype_dict = {} + for model_input in model_inputs: + model_input_name = subgraph.Tensors(model_input).Name().decode("utf-8") + model_shape_length = subgraph.Tensors(model_input).ShapeLength() + model_input_shape = [ + subgraph.Tensors(model_input).Shape(i) for i in range(model_shape_length) + ] + shape_dict[model_input_name] = model_input_shape + dtype_dict[model_input_name] = get_tensor_type_str(subgraph.Tensors(model_input).Type()) + + # model Outputs + model_outputs = subgraph.OutputsAsNumpy() + shape_dict_out = {} + dtype_dict_out = {} + for model_output in model_outputs: + model_output_name = subgraph.Tensors(model_output).Name().decode("utf-8") + model_shape_length = subgraph.Tensors(model_output).ShapeLength() + model_output_shape = [ + subgraph.Tensors(model_output).Shape(i) for i in range(model_shape_length) + ] + shape_dict_out[model_output_name] = model_output_shape + dtype_dict_out[model_output_name] = get_tensor_type_str( + subgraph.Tensors(model_input).Type() + ) + + mod, params = relay.frontend.from_tflite( + tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict + ) + + layout_config = relay.transform.LayoutConfig(skip_layers=[]) + desired_layouts = {"nn.conv2d": ["NCHW", "default"]} + seq = tvm.transform.Sequential([relay.transform.ConvertLayout(desired_layouts)]) + with tvm.transform.PassContext(opt_level=3): + mod = seq(mod) + + return mod, params, shape_dict, dtype_dict