Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions python/tvm/driver/tvmc/frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import logging
import os
import sys
import re
import importlib
from abc import ABC
from abc import abstractmethod
Expand All @@ -32,6 +33,7 @@
import numpy as np

from tvm import relay
from tvm import parser
from tvm.driver.tvmc import TVMCException, TVMCImportError
from tvm.driver.tvmc.model import TVMCModel

Expand Down Expand Up @@ -294,13 +296,81 @@ def load(self, path, shape_dict=None, **kwargs):
return relay.frontend.from_paddle(prog, shape_dict=shape_dict, **kwargs)


class RelayFrontend(Frontend):
"""Relay frontend for TVMC"""

@staticmethod
def name():
return "relay"

@staticmethod
def suffixes():
return ["relay"]

def load(self, path, shape_dict=None, **kwargs):
with open(path, "r", encoding="utf-8") as relay_text:
text = relay_text.read()
if shape_dict is None:
logger.warning(
"Specify --input-shapes to ensure that model inputs "
"will not be considered as constants."
)

def _validate_text(text):
"""Check the provided file contents.
The relay.txt artifact contained in the MLF is missing the version header and
the metadata which is required to use meta[relay.Constant]."""

if re.compile(r".*\#\[version\.*").match(text) is None:
raise TVMCException(
"The relay model does not include the required version information."
)
if re.compile(r".*meta\[.+\].*", re.DOTALL).match(text):
if "#[metadata]" not in text:
raise TVMCException(
"The relay model does not include the required #[metadata] section. "
"Use ir_mod.astext(show_meta_data=True) to export compatible code."
)

_validate_text(text)

ir_mod = parser.fromtext(text)

if shape_dict:
input_names = shape_dict.keys()
else:
input_names = []

def _gen_params(ir_mod, skip_names=None):
"""Populate the all the params in the mode with ones."""
main_func = ir_mod["main"]
shape_dict = {p.name_hint: p.checked_type.concrete_shape for p in main_func.params}
type_dict = {p.name_hint: p.checked_type.dtype for p in main_func.params}
params = {}
for name, shape in shape_dict.items():
if skip_names and name in skip_names:
continue

if "int" in type_dict[name]:
data = np.random.randint(128, size=shape, dtype=type_dict[name])
else:
data = np.random.uniform(-1, 1, size=shape).astype(type_dict[name])
params[name] = data
return params

params = _gen_params(ir_mod, skip_names=input_names)

return ir_mod, params


ALL_FRONTENDS = [
KerasFrontend,
OnnxFrontend,
TensorflowFrontend,
TFLiteFrontend,
PyTorchFrontend,
PaddleFrontend,
RelayFrontend,
]


Expand Down
39 changes: 39 additions & 0 deletions tests/python/driver/tvmc/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import pytest
import tarfile
import textwrap

import numpy as np

Expand Down Expand Up @@ -229,3 +230,41 @@ def tflite_cnn_s_quantized(tmpdir_factory):
"{}/{}".format(base_url, file_to_download), file_to_download, module=["tvmc"]
)
return model_file


@pytest.fixture(scope="session")
def relay_text_conv2d(tmpdir_factory):
file_path = os.path.join(tmpdir_factory.mktemp("model"), "relay.txt")

RELAY_MODEL = textwrap.dedent(
"""\
#[version = "0.0.5"]
def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(3, 3, 5, 5), int8]) {
%1 = nn.conv2d(
%data,
%weight,
padding=[2, 2],
channels=3,
kernel_size=[5, 5],
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32");
%2 = cast(nn.max_pool2d(%1, pool_size=[3, 3]), dtype="int8");
%3 = nn.conv2d(
%2,
%weight,
padding=[2, 2],
channels=3,
kernel_size=[5, 5],
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32");
%4 = nn.max_pool2d(%3, pool_size=[3, 3]);
%4
}
"""
)

with open(file_path, "w") as relay_text:
relay_text.write(RELAY_MODEL)
return file_path
13 changes: 13 additions & 0 deletions tests/python/driver/tvmc/test_frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ def test_guess_frontend_paddle():
assert type(sut) is tvmc.frontends.PaddleFrontend


def test_guess_frontend_relay():

sut = tvmc.frontends.guess_frontend("relay.relay")
assert type(sut) is tvmc.frontends.RelayFrontend


def test_guess_frontend_invalid():
with pytest.raises(TVMCException):
tvmc.frontends.guess_frontend("not/a/file.txt")
Expand Down Expand Up @@ -193,6 +199,13 @@ def test_load_model__paddle(paddle_resnet50):
assert type(tvmc_model.params) is dict


def test_load_model__relay(relay_text_conv2d):
tvmc_model = tvmc.load(relay_text_conv2d, model_format="relay")
assert type(tvmc_model) is TVMCModel
assert type(tvmc_model.mod) is IRModule
assert type(tvmc_model.params) is dict


def test_load_model___wrong_language__to_keras(tflite_mobilenet_v1_1_quant):
# some CI environments wont offer TensorFlow/Keras, so skip in case it is not present
pytest.importorskip("tensorflow")
Expand Down