diff --git a/python/gen_requirements.py b/python/gen_requirements.py index d6dd094f6a5b..72a974bf2a7e 100755 --- a/python/gen_requirements.py +++ b/python/gen_requirements.py @@ -146,6 +146,7 @@ "future", # Hidden dependency of torch. "onnx", "onnxruntime", + "paddlepaddle", "tensorflow", "tflite", "torch", diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py index 928259e30f0c..21d3d59fb013 100644 --- a/python/tvm/driver/tvmc/frontends.py +++ b/python/tvm/driver/tvmc/frontends.py @@ -265,12 +265,38 @@ def load(self, path, shape_dict=None, **kwargs): return relay.frontend.from_pytorch(traced_model, input_shapes, **kwargs) +class PaddleFrontend(Frontend): + """PaddlePaddle frontend for TVMC""" + + @staticmethod + def name(): + return "paddle" + + @staticmethod + def suffixes(): + return ["pdmodel", "pdiparams"] + + def load(self, path, shape_dict=None, **kwargs): + # pylint: disable=C0415 + import paddle + + paddle.enable_static() + paddle.disable_signal_handler() + + # pylint: disable=E1101 + exe = paddle.static.Executor(paddle.CPUPlace()) + prog, _, _ = paddle.static.load_inference_model(path, exe) + + return relay.frontend.from_paddle(prog, shape_dict=shape_dict, **kwargs) + + ALL_FRONTENDS = [ KerasFrontend, OnnxFrontend, TensorflowFrontend, TFLiteFrontend, PyTorchFrontend, + PaddleFrontend, ] diff --git a/tests/python/driver/tvmc/conftest.py b/tests/python/driver/tvmc/conftest.py index d1e090f40bc5..835b2583c725 100644 --- a/tests/python/driver/tvmc/conftest.py +++ b/tests/python/driver/tvmc/conftest.py @@ -33,7 +33,7 @@ def download_and_untar(model_url, model_sub_path, temp_dir): model_tar_name = os.path.basename(model_url) model_path = download_testdata(model_url, model_tar_name, module=["tvmc"]) - if model_path.endswith("tgz") or model_path.endswith("gz"): + if model_path.endswith("tgz") or model_path.endswith("gz") or model_path.endswith("tar"): tar = tarfile.open(model_path) tar.extractall(path=temp_dir) tar.close() @@ -137,6 +137,18 @@ def onnx_resnet50(): return model_file +@pytest.fixture(scope="session") +def paddle_resnet50(tmpdir_factory): + base_url = "https://bj.bcebos.com/x2paddle/models" + model_url = "paddle_resnet50.tar" + model_file = download_and_untar( + "{}/{}".format(base_url, model_url), + "paddle_resnet50/model", + temp_dir=tmpdir_factory.mktemp("data"), + ) + return model_file + + @pytest.fixture(scope="session") def onnx_mnist(): base_url = "https://github.com/onnx/models/raw/master/vision/classification/mnist/model" diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index defd628c60c9..abc9bd4b3fad 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -273,6 +273,84 @@ def test_cross_compile_options_aarch64_onnx_module(onnx_resnet50): assert os.path.exists(dumps_path) +def verify_compile_paddle_module(model, shape_dict=None): + pytest.importorskip("paddle") + tvmc_model = tvmc.load(model, "paddle", shape_dict=shape_dict) + tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="ll", desired_layout="NCHW") + dumps_path = tvmc_package.package_path + ".ll" + + # check for output types + assert type(tvmc_package) is TVMCPackage + assert type(tvmc_package.graph) is str + assert type(tvmc_package.lib_path) is str + assert type(tvmc_package.params) is bytearray + assert os.path.exists(dumps_path) + + +def test_compile_paddle_module(paddle_resnet50): + # some CI environments wont offer Paddle, so skip in case it is not present + pytest.importorskip("paddle") + # Check default compilation. + verify_compile_paddle_module(paddle_resnet50) + # Check with manual shape override + shape_string = "inputs:[1,3,224,224]" + shape_dict = tvmc.common.parse_shape_string(shape_string) + verify_compile_paddle_module(paddle_resnet50, shape_dict) + + +# This test will be skipped if the AArch64 cross-compilation toolchain is not installed. +@pytest.mark.skipif( + not shutil.which("aarch64-linux-gnu-gcc"), reason="cross-compilation toolchain not installed" +) +def test_cross_compile_aarch64_paddle_module(paddle_resnet50): + # some CI environments wont offer paddle, so skip in case it is not present + pytest.importorskip("paddle") + + tvmc_model = tvmc.load(paddle_resnet50, "paddle") + tvmc_package = tvmc.compile( + tvmc_model, + target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon", + dump_code="asm", + cross="aarch64-linux-gnu-gcc", + ) + dumps_path = tvmc_package.package_path + ".asm" + + # check for output types + assert type(tvmc_package) is TVMCPackage + assert type(tvmc_package.graph) is str + assert type(tvmc_package.lib_path) is str + assert type(tvmc_package.params) is bytearray + assert os.path.exists(dumps_path) + + +# This test will be skipped if the AArch64 cross-compilation toolchain is not installed. +@pytest.mark.skipif( + not shutil.which("aarch64-linux-gnu-gcc"), reason="cross-compilation toolchain not installed" +) +def test_cross_compile_options_aarch64_paddle_module(paddle_resnet50): + # some CI environments wont offer paddle, so skip in case it is not present + pytest.importorskip("paddle") + + fake_sysroot_dir = utils.tempdir().relpath("") + + tvmc_model = tvmc.load(paddle_resnet50, "paddle") + tvmc_package = tvmc.compile( + tvmc_model, + target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon", + dump_code="asm", + cross="aarch64-linux-gnu-gcc", + cross_options="--sysroot=" + fake_sysroot_dir, + ) + dumps_path = tvmc_package.package_path + ".asm" + + # check for output types + assert type(tvmc_package) is TVMCPackage + assert type(tvmc_package.graph) is str + assert type(tvmc_package.lib_path) is str + assert type(tvmc_package.params) is bytearray + assert os.path.exists(dumps_path) + + @tvm.testing.requires_opencl def test_compile_opencl(tflite_mobilenet_v1_0_25_128): pytest.importorskip("tflite") diff --git a/tests/python/driver/tvmc/test_frontends.py b/tests/python/driver/tvmc/test_frontends.py index adf62eb5c7e6..569c42020817 100644 --- a/tests/python/driver/tvmc/test_frontends.py +++ b/tests/python/driver/tvmc/test_frontends.py @@ -84,6 +84,14 @@ def test_guess_frontend_tensorflow(): assert type(sut) is tvmc.frontends.TensorflowFrontend +def test_guess_frontend_paddle(): + # some CI environments wont offer Paddle, so skip in case it is not present + pytest.importorskip("paddle") + + sut = tvmc.frontends.guess_frontend("a_model.pdmodel") + assert type(sut) is tvmc.frontends.PaddleFrontend + + def test_guess_frontend_invalid(): with pytest.raises(TVMCException): tvmc.frontends.guess_frontend("not/a/file.txt") @@ -161,6 +169,16 @@ def test_load_model__pb(pb_mobilenet_v1_1_quant): assert "MobilenetV1/Conv2d_0/weights" in tvmc_model.params.keys() +def test_load_model__paddle(paddle_resnet50): + # some CI environments wont offer Paddle, so skip in case it is not present + pytest.importorskip("paddle") + + tvmc_model = tvmc.load(paddle_resnet50, model_format="paddle") + assert type(tvmc_model) is TVMCModel + assert type(tvmc_model.mod) is IRModule + assert type(tvmc_model.params) is dict + + def test_load_model___wrong_language__to_keras(tflite_mobilenet_v1_1_quant): # some CI environments wont offer TensorFlow/Keras, so skip in case it is not present pytest.importorskip("tensorflow") diff --git a/tests/python/driver/tvmc/test_tvmc_common.py b/tests/python/driver/tvmc/test_tvmc_common.py index 779611a7a345..bdfdb48ce6a0 100644 --- a/tests/python/driver/tvmc/test_tvmc_common.py +++ b/tests/python/driver/tvmc/test_tvmc_common.py @@ -75,6 +75,31 @@ def _is_layout_transform(node): assert any(layout_transform_calls), "Expected 'layout_transform NCWH->NHWC' not found" +def test_compile_paddle_module_nchw_to_nhwc(paddle_resnet50): + # some CI environments wont offer Paddle, so skip in case it is not present + pytest.importorskip("paddle") + + tvmc_model = tvmc.frontends.load_model(paddle_resnet50, "paddle") + before = tvmc_model.mod + + expected_layout = "NHWC" + after = tvmc.common.convert_graph_layout(before, expected_layout) + + layout_transform_calls = [] + + def _is_layout_transform(node): + if isinstance(node, tvm.relay.expr.Call): + layout_transform_calls.append( + node.op.name == "layout_transform" + and node.attrs.src_layout == "NCHW" + and node.attrs.dst_layout == "NHWC" + ) + + tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) + + assert any(layout_transform_calls), "Expected 'layout_transform NCWH->NHWC' not found" + + def test_compile_tflite_module__same_layout__nhwc_to_nhwc(tflite_mobilenet_v1_1_quant): # some CI environments wont offer TFLite, so skip in case it is not present pytest.importorskip("tflite")