diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index e1a4a7481f6a..57071476b073 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -178,18 +178,18 @@ def compile_model( mod = common.convert_graph_layout(mod, alter_layout) tvm_target = common.target_from_cli(target) - target_host = target_host or "" + target_host = tvm_target if not target_host else target_host if tuning_records and os.path.exists(tuning_records): logger.debug("tuning records file provided: %s", tuning_records) with autotvm.apply_history_best(tuning_records): with tvm.transform.PassContext(opt_level=3): logger.debug("building relay graph with tuning records") - graph_module = relay.build(mod, tvm_target, params=params, target_host=tvm_target) + graph_module = relay.build(mod, tvm_target, params=params, target_host=target_host) else: with tvm.transform.PassContext(opt_level=3): logger.debug("building relay graph (no tuning records provided)") - graph_module = relay.build(mod, tvm_target, params=params, target_host=tvm_target) + graph_module = relay.build(mod, tvm_target, params=params, target_host=target_host) # Generate output dump files with sources dump_code = dump_code or [] diff --git a/tests/python/driver/tvmc/conftest.py b/tests/python/driver/tvmc/conftest.py index 62af34ee7758..882d793ccebd 100644 --- a/tests/python/driver/tvmc/conftest.py +++ b/tests/python/driver/tvmc/conftest.py @@ -148,3 +148,16 @@ def imagenet_cat(tmpdir_factory): np.savez(cat_file_full_path, input=image_data) return cat_file_full_path + + +@pytest.fixture(scope="session") +def tflite_mobilenet_v1_0_25_128(tmpdir_factory): + base_url = "https://storage.googleapis.com/download.tensorflow.org/models" + model_url = "mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz" + model_file = download_and_untar( + "{}/{}".format(base_url, model_url), + "mobilenet_v1_0.25_128.tflite", + temp_dir=tmpdir_factory.mktemp("data"), + ) + + return model_file diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 28a60b19b28e..4bbb6fbf2cf8 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -150,3 +150,21 @@ def test_cross_compile_aarch64_onnx_module(onnx_resnet50): assert type(params) is dict assert type(dumps) is dict assert "asm" in dumps.keys() + + +@tvm.testing.requires_opencl +def test_compile_opencl(tflite_mobilenet_v1_0_25_128): + pytest.importorskip("tflite") + + graph, lib, params, dumps = tvmc.compiler.compile_model( + tflite_mobilenet_v1_0_25_128, + target="opencl", + target_host="llvm", + alter_layout="NCHW", + ) + + # check for output types + assert type(graph) is str + assert type(lib) is tvm.runtime.module.Module + assert type(params) is dict + assert type(dumps) is dict