diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py index cfe5a4ac7b2e..2da548356446 100644 --- a/python/tvm/driver/tvmc/frontends.py +++ b/python/tvm/driver/tvmc/frontends.py @@ -23,6 +23,7 @@ import logging import os import sys +import re import importlib from abc import ABC from abc import abstractmethod @@ -32,6 +33,7 @@ import numpy as np from tvm import relay +from tvm import parser from tvm.driver.tvmc import TVMCException, TVMCImportError from tvm.driver.tvmc.model import TVMCModel @@ -294,6 +296,73 @@ def load(self, path, shape_dict=None, **kwargs): return relay.frontend.from_paddle(prog, shape_dict=shape_dict, **kwargs) +class RelayFrontend(Frontend): + """Relay frontend for TVMC""" + + @staticmethod + def name(): + return "relay" + + @staticmethod + def suffixes(): + return ["relay"] + + def load(self, path, shape_dict=None, **kwargs): + with open(path, "r", encoding="utf-8") as relay_text: + text = relay_text.read() + if shape_dict is None: + logger.warning( + "Specify --input-shapes to ensure that model inputs " + "will not be considered as constants." + ) + + def _validate_text(text): + """Check the provided file contents. + The relay.txt artifact contained in the MLF is missing the version header and + the metadata which is required to use meta[relay.Constant].""" + + if re.compile(r".*\#\[version\.*").match(text) is None: + raise TVMCException( + "The relay model does not include the required version information." + ) + if re.compile(r".*meta\[.+\].*", re.DOTALL).match(text): + if "#[metadata]" not in text: + raise TVMCException( + "The relay model does not include the required #[metadata] section. " + "Use ir_mod.astext(show_meta_data=True) to export compatible code." + ) + + _validate_text(text) + + ir_mod = parser.fromtext(text) + + if shape_dict: + input_names = shape_dict.keys() + else: + input_names = [] + + def _gen_params(ir_mod, skip_names=None): + """Populate the all the params in the mode with ones.""" + main_func = ir_mod["main"] + shape_dict = {p.name_hint: p.checked_type.concrete_shape for p in main_func.params} + type_dict = {p.name_hint: p.checked_type.dtype for p in main_func.params} + params = {} + for name, shape in shape_dict.items(): + if skip_names and name in skip_names: + continue + + if "int" in type_dict[name]: + data = np.random.randint(128, size=shape, dtype=type_dict[name]) + else: + data = np.random.uniform(-1, 1, size=shape).astype(type_dict[name]) + params[name] = data + return params + + params = _gen_params(ir_mod, skip_names=input_names) + + return ir_mod, params + + ALL_FRONTENDS = [ KerasFrontend, OnnxFrontend, @@ -301,6 +370,7 @@ def load(self, path, shape_dict=None, **kwargs): TFLiteFrontend, PyTorchFrontend, PaddleFrontend, + RelayFrontend, ] diff --git a/tests/python/driver/tvmc/conftest.py b/tests/python/driver/tvmc/conftest.py index fcf079620e25..48b465e507ae 100644 --- a/tests/python/driver/tvmc/conftest.py +++ b/tests/python/driver/tvmc/conftest.py @@ -17,6 +17,7 @@ import os import pytest import tarfile +import textwrap import numpy as np @@ -229,3 +230,41 @@ def tflite_cnn_s_quantized(tmpdir_factory): "{}/{}".format(base_url, file_to_download), file_to_download, module=["tvmc"] ) return model_file + + +@pytest.fixture(scope="session") +def relay_text_conv2d(tmpdir_factory): + file_path = os.path.join(tmpdir_factory.mktemp("model"), "relay.txt") + + RELAY_MODEL = textwrap.dedent( + """\ + #[version = "0.0.5"] + def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(3, 3, 5, 5), int8]) { + %1 = nn.conv2d( + %data, + %weight, + padding=[2, 2], + channels=3, + kernel_size=[5, 5], + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32"); + %2 = cast(nn.max_pool2d(%1, pool_size=[3, 3]), dtype="int8"); + %3 = nn.conv2d( + %2, + %weight, + padding=[2, 2], + channels=3, + kernel_size=[5, 5], + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32"); + %4 = nn.max_pool2d(%3, pool_size=[3, 3]); + %4 + } + """ + ) + + with open(file_path, "w") as relay_text: + relay_text.write(RELAY_MODEL) + return file_path diff --git a/tests/python/driver/tvmc/test_frontends.py b/tests/python/driver/tvmc/test_frontends.py index b76066994cb2..1e6efb4a3b24 100644 --- a/tests/python/driver/tvmc/test_frontends.py +++ b/tests/python/driver/tvmc/test_frontends.py @@ -106,6 +106,12 @@ def test_guess_frontend_paddle(): assert type(sut) is tvmc.frontends.PaddleFrontend +def test_guess_frontend_relay(): + + sut = tvmc.frontends.guess_frontend("relay.relay") + assert type(sut) is tvmc.frontends.RelayFrontend + + def test_guess_frontend_invalid(): with pytest.raises(TVMCException): tvmc.frontends.guess_frontend("not/a/file.txt") @@ -193,6 +199,13 @@ def test_load_model__paddle(paddle_resnet50): assert type(tvmc_model.params) is dict +def test_load_model__relay(relay_text_conv2d): + tvmc_model = tvmc.load(relay_text_conv2d, model_format="relay") + assert type(tvmc_model) is TVMCModel + assert type(tvmc_model.mod) is IRModule + assert type(tvmc_model.params) is dict + + def test_load_model___wrong_language__to_keras(tflite_mobilenet_v1_1_quant): # some CI environments wont offer TensorFlow/Keras, so skip in case it is not present pytest.importorskip("tensorflow")