From 04d5057099be2a7a37480c2ba31de9d9a6d4fd44 Mon Sep 17 00:00:00 2001 From: Leandro Nunes Date: Fri, 28 Aug 2020 11:58:41 +0100 Subject: [PATCH 01/10] [tvmc] command line driver 'compile' (part 2/4) * Add 'compile' subcommand into tvmc (tvm.driver.tvmc) * Add frontends: Keras, ONNX, TensorFlow, tflite, PyTorch * Add tests for the 'compile' subcommand * Enable command line driver tests as part of integration tests * Skip tests if the cross-compilation toolchain is not installed Co-authored-by: Marcus Shawcroft Co-authored-by: Matthew Barrett Co-authored-by: Dmitriy Smirnov Co-authored-by: Luke Hutton Co-authored-by: Giuseppe Rossini Co-authored-by: Matthew Barrett Co-authored-by: Elen Kalda Co-authored-by: Ramana Radhakrishnan Co-authored-by: Jeremy Johnson Co-authored-by: Ina Dobreva --- python/setup.py | 66 ++-- python/tvm/driver/tvmc/__init__.py | 5 + python/tvm/driver/tvmc/__main__.py | 4 +- python/tvm/driver/tvmc/common.py | 68 ++++ python/tvm/driver/tvmc/compiler.py | 307 ++++++++++++++++ python/tvm/driver/tvmc/frontends.py | 391 +++++++++++++++++++++ tests/python/driver/tvmc/conftest.py | 105 ++++++ tests/python/driver/tvmc/test_common.py | 71 ++++ tests/python/driver/tvmc/test_compiler.py | 139 ++++++++ tests/python/driver/tvmc/test_frontends.py | 216 ++++++++++++ tests/scripts/task_python_integration.sh | 3 + 11 files changed, 1344 insertions(+), 31 deletions(-) create mode 100644 python/tvm/driver/tvmc/compiler.py create mode 100644 python/tvm/driver/tvmc/frontends.py create mode 100644 tests/python/driver/tvmc/conftest.py create mode 100644 tests/python/driver/tvmc/test_common.py create mode 100644 tests/python/driver/tvmc/test_compiler.py create mode 100644 tests/python/driver/tvmc/test_frontends.py diff --git a/python/setup.py b/python/setup.py index 402d993820ba..4065b1a239cf 100644 --- a/python/setup.py +++ b/python/setup.py @@ -147,35 +147,43 @@ def is_pure(self): def get_package_data_files(): # Relay standard libraries - return ["relay/std/prelude.rly", "relay/std/core.rly"] - - -setup( - name="tvm", - version=__version__, - description="TVM: An End to End Tensor IR/DSL Stack for Deep Learning Systems", - zip_safe=False, - entry_points={"console_scripts": ["tvmc = tvm.driver.tvmc.main:main"]}, - install_requires=[ - "numpy", - "scipy", - "decorator", - "attrs", - "psutil", - "typed_ast", - ], - extras_require={ - "test": ["pillow<7", "matplotlib"], - "extra_feature": ["tornado", "psutil", "xgboost>=1.1.0", "mypy", "orderedset"], - }, - packages=find_packages(), - package_dir={"tvm": "tvm"}, - package_data={"tvm": get_package_data_files()}, - distclass=BinaryDistribution, - url="https://github.com/apache/incubator-tvm", - ext_modules=config_cython(), - **setup_kwargs, -) + return ['relay/std/prelude.rly', 'relay/std/core.rly'] + + +setup(name='tvm', + version=__version__, + description="TVM: An End to End Tensor IR/DSL Stack for Deep Learning Systems", + zip_safe=False, + entry_points={"console_scripts": ["tvmc = tvm.driver.tvmc.main:main"]}, + install_requires=[ + 'numpy', + 'scipy', + 'decorator', + 'attrs', + 'psutil', + 'typed_ast', + 'tensorflow==2.1.0', + 'tflite==2.1.0', + 'onnx==1.6.0', + 'onnxruntime==1.0.0', + 'torch==1.4.0', + 'torchvision==0.5.0' + ], + extras_require={'test': ['pillow<7', + 'matplotlib'], + 'extra_feature': ['tornado', + 'psutil', + 'xgboost>=1.1.0', + 'mypy', + 'orderedset']}, + + packages=find_packages(), + package_dir={'tvm': 'tvm'}, + package_data={'tvm': get_package_data_files()}, + distclass=BinaryDistribution, + url='https://github.com/apache/incubator-tvm', + ext_modules=config_cython(), + **setup_kwargs) if wheel_include_libs: diff --git a/python/tvm/driver/tvmc/__init__.py b/python/tvm/driver/tvmc/__init__.py index 13a83393a912..cf35f189d2ba 100644 --- a/python/tvm/driver/tvmc/__init__.py +++ b/python/tvm/driver/tvmc/__init__.py @@ -14,3 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +""" +TVMC - TVM driver command-line interface +""" + +from . import compiler diff --git a/python/tvm/driver/tvmc/__main__.py b/python/tvm/driver/tvmc/__main__.py index f72e9f4df3ba..55235a6adfdd 100644 --- a/python/tvm/driver/tvmc/__main__.py +++ b/python/tvm/driver/tvmc/__main__.py @@ -18,7 +18,7 @@ TVMC - TVM driver command-line interface """ -from .main import main +from tvm.driver import tvmc if __name__ == "__main__": - main() + tvmc.main.main() diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index aa53ce7134bb..fd48c28491f7 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -17,7 +17,75 @@ """ Common utility functions shared by TVMC modules. """ +import argparse +import re + +from tvm import relay +from tvm import transform class TVMCException(Exception): """TVMC Exception""" + + +def convert_graph_layout(mod, desired_layout): + """Alter the layout of the input graph. + + Parameters + ---------- + mod : tvm.relay.Module + The relay module to convert. + desired_layout : str + The layout to convert to. + + Returns + ------- + mod : tvm.relay.Module + The converted module. + """ + + # Assume for the time being that graphs only have + # conv2d as heavily-sensitive operators. + desired_layouts = { + "nn.conv2d": [desired_layout, "default"], + "qnn.conv2d": [desired_layout, "default"], + } + + # Convert the layout of the graph where possible. + seq = transform.Sequential( + [ + relay.transform.RemoveUnusedFunctions(), + relay.transform.ConvertLayout(desired_layouts), + ] + ) + with transform.PassContext(opt_level=3): + return seq(mod) + + +def parse_input_shapes(shapes_str): + """ Parsing function for tensor shape syntax. """ + shapes = [] + # Split up string into comma seperated sections ignoring commas in ()s + match = re.findall(r"(\(.*?\)|.+?),?", shapes_str) + if match: + for inp in match: + # Test for and remove brackets + shape = re.match(r"\((.*)\)", inp) + if shape and shape.lastindex == 1: + # Remove white space and extract numbers + strshape = shape[1].replace(" ", "").split(",") + try: + shapes.append([int(i) for i in strshape]) + except ValueError: + raise argparse.ArgumentTypeError( + f"expected numbers in shape '{shape[1]}'" + ) + else: + raise argparse.ArgumentTypeError( + f"missing brackets around shape '{inp}', example '(1,2,3)'" + ) + else: + raise argparse.ArgumentTypeError( + f"unrecognized shapes '{shapes_str}', example '(1,2,3),(1,4),...'" + ) + return shapes diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py new file mode 100644 index 000000000000..1126c26c97c8 --- /dev/null +++ b/python/tvm/driver/tvmc/compiler.py @@ -0,0 +1,307 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Provides support to compile networks both AOT and JIT. +""" +import argparse +import json +import logging +import os.path +import tarfile +from pathlib import Path + +import tvm +from tvm import autotvm +from tvm import relay +from tvm._ffi.runtime_ctypes import TVMContext +from tvm.contrib import cc +from tvm.contrib import util +from tvm.relay.op.contrib import get_pattern_table + +from . import common, frontends +from .main import register_parser + + +@register_parser +def add_compile_parser(subparsers): + """ Include parser for 'compile' subcommand """ + + parser = subparsers.add_parser("compile", help="compile a model") + parser.set_defaults(func=drive_compile) + parser.add_argument( + "--cross-compiler", + default="", + help="the cross compiler to generate target libraries, e.g. 'aarch64-linux-gnu-gcc'", + ) + parser.add_argument( + "--dump-code", + metavar="FORMAT", + default="", + help="comma separarated list of formats to export, e.g. 'asm,ll,relay' " + ) + parser.add_argument( + "--model-format", + choices=frontends.get_frontends(), + help="specify input model format", + ) + parser.add_argument( + "--input-shape", + type=common.parse_input_shapes, + metavar="INPUT_SHAPE,[INPUT_SHAPE]...", + help="for pytorch, e.g. '(1,3,224,224)'", + ) + parser.add_argument( + "-o", + "--output", + default="module.tar", + help="output the compiled module to an archive", + ) + parser.add_argument( + "--target", + help="compilation target as plain string, inline JSON or path to a JSON file", + required=True + ) + parser.add_argument( + "--tuning-records", + metavar="PATH", + default="", + help="path to an auto-tuning log file from AutoTVM" + ) + parser.add_argument( + "--desired-layout", + choices=["NCHW", "NHWC"], + default=None, + help="change the data layout of the whole graph", + ) + parser.add_argument( + "-v", "--verbose", action="count", default=0, help="increase verbosity" + ) + parser.add_argument("FILE") + + +def drive_compile(args): + """ Invoke tvmc.compiler module with command line arguments """ + + graph, lib, params, dumps = compile_model( + args.FILE, + args.target, + args.dump_code, + "", + args.model_format, + args.input_shape, + args.tuning_records, + args.tensor_layout, + ) + + if dumps: + save_dumps(args.output, dumps) + + save_module(args.output, graph, lib, params, args.cross_compiler) + return 0 + + +def compile_model( + path, + target, + dump_sources=None, + target_host=None, + model_format=None, + shapes=None, + tuning_records=None, + alter_layout=None, +): + """Compile a model from a supported framework into a TVM module. + + This function takes a union of the arguments of both frontends.load_model + and compiler.compile_relay. The resulting TVM module can be executed using + the graph runtime. + + Returns + ------- + graph : str + A JSON-serialized TVM execution graph. + lib : tvm.module.Module + A TVM module containing the compiled functions. + params : dict + The parameters (weights) for the TVM module. + dumps : dict + Dictionary containing the dumps specified. + + """ + dump_sources = [x.strip() for x in dump_sources.split(',')] if dump_sources else None + mod, params = frontends.load_model(path, model_format, shapes) + + return compile_relay( + mod, + params, + target, + dump_sources=dump_sources, + target_host=target_host, + tuning_records=tuning_records, + alter_layout=alter_layout, + ) + + +def compile_relay( + mod, + params, + target, + dump_sources=None, + target_host=None, + tuning_records=None, + alter_layout=None, +): + """Compile a relay module to a TVM module for the graph runtime. + + Parameters + ---------- + mod : tvm.relay.Module + The relay module to compile. + params : dict + The parameters (weights) for the relay module. + target : str + The target for which to compile. Can be a plain string or + a path. + dump_sources : list, optional + Dump the generated code for the specified source types, on + the requested target. + target_host : Union[str, tvm.target.Target], optional + The target of the host machine if host-side code + needs to be generated. + tuning_records: str, optional + Name of the file produced by the tuning to be used during + compilation. + alter_layout: str, optional + The layout to convert the graph to. Note, the convert layout + pass doesn't currently guarantee the whole of the graph will + be converted to the chosen layout. + + Returns + ------- + graph : str + A JSON-serialized TVM execution graph. + lib : tvm.module.Module + A TVM module containing the compiled functions. + params : dict + The parameters (weights) for the TVM module. + dumps : dict + Dictionary containing the dumps specified. + + """ + + if alter_layout: + mod = common.convert_graph_layout(mod, alter_layout) + + if os.path.exists(str(target)): + with open(target) as target_file: + logging.info(f"using target input from file: {target}") + target = "".join(target_file.readlines()) + + logging.debug(f"creating target from input: {target}") + tvm_target = tvm.target.create(target) + target_host = "" + + if tuning_records: + logging.debug(f"tuning records file provided: {tuning_records}") + with autotvm.apply_history_best(tuning_records): + with tvm.transform.PassContext(opt_level=3): + logging.debug("building relay graph with tuning records") + graph_module = relay.build(mod, tvm_target, params=params, target_host=tvm_target) + else: + with tvm.transform.PassContext(opt_level=3): + logging.debug("building relay graph (no tuning records provided)") + graph_module = relay.build(mod, tvm_target, params=params, target_host=tvm_target) + + # Generate output dump files with sources + dump_sources = dump_sources or [] + dumps = {} + for source_type in dump_sources: + lib = graph_module.get_lib() + # TODO lib.get_source call here have inconsistent behavior for unsupported + # formats. This is an open discussion (@leandron). + source = str(mod) if source_type == "relay" else lib.get_source(source_type) + dumps[source_type] = source + + return graph_module.get_json(), graph_module.get_lib(), graph_module.get_params(), dumps + + +def save_module(module_path, graph, lib, params, cross=None): + """ + Create a tarball containing the generated TVM graph, + exported library and parameters + + Parameters + ---------- + module_path : str + path to the target tar.gz file to be created, + including the file name + graph : str + A JSON-serialized TVM execution graph. + lib : tvm.module.Module + A TVM module containing the compiled functions. + params : dict + The parameters (weights) for the TVM module. + cross : Union[str, Callable[[str, str, Optional[str]], None]] + Function that performs the actual compilation + + """ + lib_name = "mod.so" + graph_name = "mod.json" + param_name = "mod.params" + temp = util.tempdir() + path_lib = temp.relpath(lib_name) + if not cross: + logging.debug(f"exporting library to {path_lib}") + lib.export_library(path_lib) + else: + logging.debug(f"exporting library to {path_lib}, using cross compiler {cross}") + lib.export_library(path_lib, cc.cross_compiler(cross)) + + with open(temp.relpath(graph_name), "w") as graph_file: + logging.debug(f"writing graph to file to {graph_file.name}") + graph_file.write(graph) + + with open(temp.relpath(param_name), "wb") as params_file: + logging.debug(f"writing params to file to {params_file.name}") + params_file.write(relay.save_param_dict(params)) + + logging.debug(f"saving module as tar file to {module_path}") + with tarfile.open(module_path, "w") as tar: + tar.add(path_lib, lib_name) + tar.add(temp.relpath(graph_name), graph_name) + tar.add(temp.relpath(param_name), param_name) + + +def save_dumps(module_name, dumps, dump_root="."): + """ + Serialize dump files to the disk. + + Parameters + ---------- + module_name : list(Union[str, tvm.target.Target]) + file name, referring to the module that generated + the dump contents + dumps : dict + the output contents to be saved into the files + dump_root : str + path in which dump files will be created + """ + + for dump_format in dumps: + dump_name = module_name + "." + dump_format + with open(Path(dump_root, dump_name), "w") as f: + f.write(dumps[dump_format]) diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py new file mode 100644 index 000000000000..635199dbc07d --- /dev/null +++ b/python/tvm/driver/tvmc/frontends.py @@ -0,0 +1,391 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Provides support to parse models from different frameworks into Relay networks. + +Frontend classes do lazy-loading of modules on purpose, to reduce time spent on +loading the tool. +""" +import logging +import os +import sys +from abc import ABC +from abc import abstractmethod +from pathlib import Path + +from tvm.driver.tvmc.common import TVMCException + + +class Frontend(ABC): + """Abstract class for frontend""" + + @staticmethod + @abstractmethod + def name(): + """Frontend name""" + + @staticmethod + @abstractmethod + def suffixes(): + """File suffixes (extensions) used by this frontend""" + + @abstractmethod + def load(self, path, shapes): + """Load network""" + + +def import_keras(): + """ Lazy import function for Keras""" + # Keras writes the message "Using TensorFlow backend." to stderr + # Redirect stderr during the import to disable this + stderr = sys.stderr + sys.stderr = open(os.devnull, "w") + try: + # pylint: disable=C0415 + import tensorflow as tf + from tensorflow import keras + + return tf, keras + finally: + sys.stderr = stderr + + +class KerasFrontend(Frontend): + """ Keras frontend for TVMC """ + + @staticmethod + def name(): + return "keras" + + @staticmethod + def suffixes(): + return ["h5"] + + def load(self, path, shapes): + # pylint: disable=C0415 + import numpy as np + from tvm import relay + + # pylint: disable=C0103 + tf, keras = import_keras() + + if shapes: + raise TVMCException( + "--input-shape is not supported for {}".format(self.name()) + ) + + # tvm build currently imports keras directly instead of tensorflow.keras + try: + model = keras.models.load_model(path) + except ValueError as err: + raise TVMCException(str(err)) + + # There are two flavours of keras model, sequential and + # functional, TVM expects a functional model, so convert + # if required: + if self.is_sequential_p(model): + model = self.sequential_to_functional(model) + + in_shapes = [] + for layer in model._input_layers: + if tf.executing_eagerly(): + in_shapes.append( + tuple(dim if dim is not None else 1 for dim in layer.input.shape) + ) + else: + in_shapes.append( + tuple( + dim.value if dim.value is not None else 1 + for dim in layer.input.shape + ) + ) + + inputs = [ + np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in in_shapes + ] + shape_dict = {name: x.shape for (name, x) in zip(model.input_names, inputs)} + return relay.frontend.from_keras(model, shape_dict, layout="NHWC") + + def is_sequential_p(self, model): + _, keras = import_keras() + return isinstance(model, keras.models.Sequential) + + def sequential_to_functional(self, model): + _, keras = import_keras() + assert self.is_sequential_p(model) + input_layer = keras.layers.Input(batch_shape=model.layers[0].input_shape) + prev_layer = input_layer + for layer in model.layers: + prev_layer = layer(prev_layer) + model = keras.models.Model([input_layer], [prev_layer]) + return model + + +class OnnxFrontend(Frontend): + """ ONNX frontend for TVMC """ + + @staticmethod + def name(): + return "onnx" + + @staticmethod + def suffixes(): + return ["onnx"] + + def load(self, path, shapes): + # pylint: disable=C0415 + import onnx + from tvm import relay + + if shapes: + raise TVMCException( + "--input-shape is not supported for {}".format(self.name()) + ) + + model = onnx.load(path) + + # Find the name and shape of the first input in the graph + + # pylint: disable=E1101 + name = model.graph.input[0].name + + # pylint: disable=E1101 + proto_shape = model.graph.input[0].type.tensor_type.shape.dim + shape = [d.dim_value for d in proto_shape] + + shape_dict = {name: shape} + + return relay.frontend.from_onnx(model, shape_dict) + + +class TensorflowFrontend(Frontend): + """ TensorFlow frontend for TVMC """ + + @staticmethod + def name(): + return "pb" + + @staticmethod + def suffixes(): + return ["pb"] + + def load(self, path, shapes): + # pylint: disable=C0415 + from tvm import relay + import tensorflow as tf + import tvm.relay.testing.tf as tf_testing + + if shapes: + raise TVMCException( + "--input-shape is not supported for {}".format(self.name()) + ) + + with tf.io.gfile.GFile(path, "rb") as tf_graph: + content = tf_graph.read() + + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(content) + graph_def = tf_testing.ProcessGraphDefParam(graph_def) + + logging.debug("relay.frontend.from_tensorflow") + return relay.frontend.from_tensorflow(graph_def) + + +class TFLiteFrontend(Frontend): + """ TFLite frontend for TVMC """ + + _tflite_m = { + 0: "float32", + 1: "float16", + 2: "int32", + 3: "uint8", + 4: "int64", + 5: "string", + 6: "bool", + 7: "int16", + 8: "complex64", + 9: "int8", + } + + @staticmethod + def name(): + return "tflite" + + @staticmethod + def suffixes(): + return ["tflite"] + + def load(self, path, shapes): + # pylint: disable=C0415 + import tflite.Model as model + from tvm import relay + + if shapes: + raise TVMCException( + "--input-shape is not supported for {}".format(self.name()) + ) + + with open(path, "rb") as tf_graph: + content = tf_graph.read() + + # tflite.Model.Model is tflite.Model in 1.14 and 2.1.0 + try: + tflite_model = model.Model.GetRootAsModel(content, 0) + except AttributeError: + tflite_model = model.GetRootAsModel(content, 0) + + try: + version = tflite_model.Version() + logging.debug("tflite version %s", version) + except Exception: + raise TVMCException("input file not tflite") + + if version != 3: + raise TVMCException("input file not tflite version 3") + + logging.debug("tflite_input_type") + shape_dict, dtype_dict = TFLiteFrontend._input_type(tflite_model) + + # parse TFLite model and convert into Relay computation graph + logging.debug("relay.frontend.from_tflite") + mod, params = relay.frontend.from_tflite( + tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict + ) + return mod, params + + @staticmethod + def _decode_type(n): + return TFLiteFrontend._tflite_m[n] + + @staticmethod + def _input_type(model): + subgraph_count = model.SubgraphsLength() + assert subgraph_count > 0 + shape_dict = {} + dtype_dict = {} + for subgraph_index in range(subgraph_count): + subgraph = model.Subgraphs(subgraph_index) + inputs_count = subgraph.InputsLength() + assert inputs_count >= 1 + for input_index in range(inputs_count): + input_ = subgraph.Inputs(input_index) + assert subgraph.TensorsLength() > input_ + tensor = subgraph.Tensors(input_) + input_shape = tuple(tensor.ShapeAsNumpy()) + tensor_type = tensor.Type() + input_name = tensor.Name().decode("utf8") + shape_dict[input_name] = input_shape + dtype_dict[input_name] = TFLiteFrontend._decode_type(tensor_type) + + return shape_dict, dtype_dict + + +class PyTorchFrontend(Frontend): + """ PyTorch frontend for TVMC """ + + @staticmethod + def name(): + return "pytorch" + + @staticmethod + def suffixes(): + # Torch Script is a zip file, but can be named pth + return ["pth", "zip"] + + def load(self, path, shapes): + # pylint: disable=C0415 + import torch + from tvm import relay + + if not shapes: + raise TVMCException( + "--input-shape must be specified for {}".format(self.name()) + ) + + traced_model = torch.jit.load(path) + traced_model.eval() # Switch to inference mode + input_shapes = [ + ("input{}".format(idx), shape) for idx, shape in enumerate(shapes) + ] + logging.debug("relay.frontend.from_pytorch") + return relay.frontend.from_pytorch(traced_model, input_shapes) + + +ALL_FRONTENDS = [ + KerasFrontend, + OnnxFrontend, + TensorflowFrontend, + TFLiteFrontend, + PyTorchFrontend, +] + + +def get_frontends(): + """Return the names of all supported frontends""" + return [frontend.name() for frontend in ALL_FRONTENDS] + + +def lookup_frontend(name): + for frontend in ALL_FRONTENDS: + if name == frontend.name(): + return frontend() + raise TVMCException("unrecognized frontend") + + +def guess_input_language(path): + suffix = Path(path).suffix.lower() + if suffix.startswith("."): + suffix = suffix[1:] + + for frontend in ALL_FRONTENDS: + if suffix in frontend.suffixes(): + return frontend() + + raise TVMCException("cannot guess input language") + + +def load_model(path, language=None, shapes=None): + """Load a model from a supported framework and convert it + into an equivalent relay representation. + + Parameters + ---------- + path : str + The path to the model file. + language : str, optional + The language of the model file. + If not specified, this will be inferred from the file type. + + Returns + ------- + mod : tvm.relay.Module + The produced relay module. + params : dict + The parameters (weights) for the relay module. + + """ + # pylint: disable=C0415 + import tvm.error + + if language is not None: + frontend = lookup_frontend(language) + else: + frontend = guess_input_language(path) + + mod, params = frontend.load(path, shapes) + + return mod, params diff --git a/tests/python/driver/tvmc/conftest.py b/tests/python/driver/tvmc/conftest.py new file mode 100644 index 000000000000..4efece55011f --- /dev/null +++ b/tests/python/driver/tvmc/conftest.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +import pytest +import tarfile + +import tvm.driver.tvmc.compiler + +from tensorflow.keras.applications.resnet50 import ResNet50 + +from tvm.contrib.download import download_testdata + +from tvm.driver.tvmc.common import convert_graph_layout + +# Support functions + +def download_and_untar(model_url, model_sub_path, temp_dir): + model_tar_name = os.path.basename(model_url) + model_path = download_testdata(model_url, model_tar_name, module=['tvmc']) + + if model_path.endswith("tgz") or model_path.endswith("gz"): + tar = tarfile.open(model_path) + tar.extractall(path=temp_dir) + tar.close() + + return os.path.join(temp_dir, model_sub_path) + + +def get_sample_compiled_module(target_dir): + """Support function that retuns a TFLite compiled module""" + base_url = "https://storage.googleapis.com/download.tensorflow.org/models" + model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz" + model_file = download_and_untar("{}/{}".format(base_url, model_url), + "mobilenet_v1_1.0_224_quant.tflite", + temp_dir=target_dir) + + return tvmc.compiler.compile_model(model_file, targets=["llvm"]) + + +# PyTest fixtures + + +@pytest.fixture(scope="session") +def tflite_mobilenet_v1_1_quant(tmpdir_factory): + base_url = "https://storage.googleapis.com/download.tensorflow.org/models" + model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz" + model_file = download_and_untar("{}/{}".format(base_url, model_url), + "mobilenet_v1_1.0_224_quant.tflite", + temp_dir=tmpdir_factory.mktemp("data")) + + return model_file + + +@pytest.fixture(scope="session") +def pb_mobilenet_v1_1_quant(tmpdir_factory): + base_url = "https://storage.googleapis.com/download.tensorflow.org/models" + model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz" + model_file = download_and_untar("{}/{}".format(base_url, model_url), + "mobilenet_v1_1.0_224_frozen.pb", + temp_dir=tmpdir_factory.mktemp("data")) + + return model_file + + +@pytest.fixture(scope="session") +def keras_resnet50(tmpdir_factory): + model_file_name = "{}/{}".format(tmpdir_factory.mktemp("data"), "resnet50.h5") + model = ResNet50(include_top=True, weights='imagenet', input_shape=(224, 224, 3), classes=1000) + model.save(model_file_name) + + return model_file_name + + +@pytest.fixture(scope="session") +def onnx_resnet50(): + base_url = "https://github.com/onnx/models/raw/master/vision/classification/resnet/model" + file_to_download = "resnet50-v2-7.onnx" + model_file = download_testdata("{}/{}".format(base_url, file_to_download), file_to_download, module=['tvmc']) + + return model_file + + +@pytest.fixture(scope="session") +def tflite_compiled_module_as_tarfile(tmpdir_factory): + target_dir = tmpdir_factory.mktemp("data") + graph, lib, params, _ = get_sample_compiled_module(target_dir) + + module_file = os.path.join(target_dir, 'mock.tar') + tvmc.compiler.save_module(module_file, graph, lib, params) + + return module_file diff --git a/tests/python/driver/tvmc/test_common.py b/tests/python/driver/tvmc/test_common.py new file mode 100644 index 000000000000..a3be9ff51a4f --- /dev/null +++ b/tests/python/driver/tvmc/test_common.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import argparse +import os +from os import path + +import pytest + +from tvm.driver import tvmc + + +def test_parse_input_shapes__list_lengths(): + shape_string = "(1,224,224,3)" + sut = tvmc.common.parse_input_shapes(shape_string) + + # output is a list with a list [[1, 224, 224, 3]] + assert type(sut) is list + assert len(sut) == 1 + assert type(sut[0]) is list + assert len(sut[0]) == 4 + + +def test_parse_input_shapes__lists_match(): + shape_string = "(1,224,224,3)" + sut = tvmc.common.parse_input_shapes(shape_string) + + assert sut[0] == [1, 224, 224, 3] + + +def test_parse_input_shapes__spaces_are_ignored(): + shape_string = "(1, 224, 224, 3)" + sut = tvmc.common.parse_input_shapes(shape_string) + + assert type(sut) is list + assert len(sut) == 1 + assert type(sut[0]) is list + assert len(sut[0]) == 4 + + +def test_parse_input_shapes__missing(): + shape_string = "(1,224,,3)" + with pytest.raises(argparse.ArgumentTypeError) as e: + def f(): + _ = tvmc.common.parse_input_shapes(shape_string) + f() + + assert 'expected numbers in shape' in str(e.value) + + +def test_parse_input_shapes_no_brackets(): + shape_string = "1,224,224,3" + with pytest.raises(argparse.ArgumentTypeError) as e: + def f(): + _ = tvmc.common.parse_input_shapes(shape_string) + f() + + assert 'missing brackets around shape' in str(e.value) diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py new file mode 100644 index 000000000000..a03ce243d858 --- /dev/null +++ b/tests/python/driver/tvmc/test_compiler.py @@ -0,0 +1,139 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import argparse +import os +import shutil +from os import path + +import pytest + +import tvm + +from tvm.driver import tvmc + + +def test_save_dumps(tmpdir_factory): + tmpdir = tmpdir_factory.mktemp("data") + dump_formats = {"relay": "fake relay", "ll": "fake llvm", "asm": "fake asm"} + tvmc.compiler.save_dumps("fake_module", dump_formats, dump_root=tmpdir) + + assert path.exists("{}/{}".format(tmpdir, "fake_module.ll")) + assert path.exists("{}/{}".format(tmpdir, "fake_module.asm")) + assert path.exists("{}/{}".format(tmpdir, "fake_module.relay")) + + +# End to end tests for compilation + + +def test_compile_tflite_module(tflite_mobilenet_v1_1_quant): + graph, lib, params, dumps = tvmc.compiler.compile_model( + tflite_mobilenet_v1_1_quant, + target="llvm", + dump_sources="ll", + alter_layout="NCHW", + ) + + # check for output types + assert type(graph) is str + assert type(lib) is tvm.runtime.module.Module + assert type(params) is dict + assert type(dumps) is dict + + +# This test will be skipped if the AArch64 cross-compilation toolchain is not installed. +@pytest.mark.skipif(not shutil.which("aarch64-linux-gnu-gcc"), + reason="cross-compilation toolchain not installed") +def test_cross_compile_aarch64_tflite_module(tflite_mobilenet_v1_1_quant): + graph, lib, params, dumps = tvmc.compiler.compile_model( + tflite_mobilenet_v1_1_quant, + target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon", + dump_sources="asm", + ) + + # check for output types + assert type(graph) is str + assert type(lib) is tvm.runtime.module.Module + assert type(params) is dict + assert type(dumps) is dict + + +def test_compile_keras__save_module(keras_resnet50, tmpdir_factory): + graph, lib, params, dumps = tvmc.compiler.compile_model( + keras_resnet50, target="llvm", dump_sources="ll" + ) + + expected_temp_dir = tmpdir_factory.mktemp("saved_output") + expected_file_name = "saved.tar" + module_file = os.path.join(expected_temp_dir, expected_file_name) + tvmc.compiler.save_module(module_file, graph, lib, params) + + assert os.path.exists(module_file), "output file {0} should exist".format(module_file) + + +# This test will be skipped if the AArch64 cross-compilation toolchain is not installed. +@pytest.mark.skipif(not shutil.which("aarch64-linux-gnu-gcc"), + reason="cross-compilation toolchain not installed") +def test_cross_compile_aarch64_keras_module(keras_resnet50): + graph, lib, params, dumps = tvmc.compiler.compile_model( + keras_resnet50, + target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon", + dump_sources="asm", + ) + + # check for output types + assert type(graph) is str + assert type(lib) is tvm.runtime.module.Module + assert type(params) is dict + assert type(dumps) is dict + assert "asm" in dumps.keys() + + +def test_compile_onnx_module(onnx_resnet50): + # some CI environments wont offer onnx, so skip in case it is not present + pytest.importorskip('onnx') + + graph, lib, params, dumps = tvmc.compiler.compile_model( + onnx_resnet50, target="llvm", dump_sources="ll" + ) + + # check for output types + assert type(graph) is str + assert type(lib) is tvm.runtime.module.Module + assert type(params) is dict + assert type(dumps) is dict + assert "ll" in dumps.keys() + + +# This test will be skipped if the AArch64 cross-compilation toolchain is not installed. +@pytest.mark.skipif(not shutil.which("aarch64-linux-gnu-gcc"), + reason="cross-compilation toolchain not installed") +def test_cross_compile_aarch64_onnx_module(onnx_resnet50): + # some CI environments wont offer onnx, so skip in case it is not present + pytest.importorskip('onnx') + + graph, lib, params, dumps = tvmc.compiler.compile_model( + onnx_resnet50, + target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon", + dump_sources="asm", + ) + + # check for output types + assert type(graph) is str + assert type(lib) is tvm.runtime.module.Module + assert type(params) is dict + assert type(dumps) is dict + assert "asm" in dumps.keys() diff --git a/tests/python/driver/tvmc/test_frontends.py b/tests/python/driver/tvmc/test_frontends.py new file mode 100644 index 000000000000..f0c83c222525 --- /dev/null +++ b/tests/python/driver/tvmc/test_frontends.py @@ -0,0 +1,216 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +import tarfile + +import pytest + +from tvm.ir.module import IRModule + +from tvm.driver import tvmc +from tvm.driver.tvmc.common import TVMCException + + +def test_get_frontends_is_list(): + sut = tvmc.frontends.get_frontends() + assert type(sut) is list + + +def test_get_frontends_contains_only_strings(): + sut = tvmc.frontends.get_frontends() + assert all([type(x) is str for x in sut]) is True + + +def test_lookup_frontend_valid(): + sut = tvmc.frontends.lookup_frontend("keras") + assert type(sut) is tvmc.frontends.KerasFrontend + + +def test_lookup_frontend_invalid(): + with pytest.raises(TVMCException) as e: + def f(): + tvmc.frontends.lookup_frontend("unsupported_thingy") + f() + assert 'unrecognized frontend' in str(e.value) + + +def test_guess_frontend_tflite(): + sut = tvmc.frontends.guess_input_language("a_model.tflite") + assert type(sut) is tvmc.frontends.TFLiteFrontend + + +def test_guess_frontend_onnx(): + # some CI environments wont offer onnx, so skip in case it is not present + pytest.importorskip('onnx') + + sut = tvmc.frontends.guess_input_language("a_model.onnx") + assert type(sut) is tvmc.frontends.OnnxFrontend + + +def test_guess_frontend_pytorch(): + # some CI environments wont offer pytorch, so skip in case it is not present + pytest.importorskip('torch') + + sut = tvmc.frontends.guess_input_language("a_model.pth") + assert type(sut) is tvmc.frontends.PyTorchFrontend + + +def test_guess_frontend_keras(): + sut = tvmc.frontends.guess_input_language("a_model.h5") + assert type(sut) is tvmc.frontends.KerasFrontend + + +def test_guess_frontend_tensorflow(): + sut = tvmc.frontends.guess_input_language("a_model.pb") + assert type(sut) is tvmc.frontends.TensorflowFrontend + + +def test_guess_frontend_invalid(): + with pytest.raises(TVMCException): + tvmc.frontends.guess_input_language("not/a/file.txt") + + +def test_load_model__invalid_path__no_language(): + with pytest.raises(FileNotFoundError): + tvmc.frontends.load_model("not/a/file.tflite") + + +def test_load_model__invalid_path__with_language(): + # some CI environments wont offer onnx, so skip in case it is not present + pytest.importorskip('onnx') + + with pytest.raises(FileNotFoundError): + tvmc.frontends.load_model("not/a/file.txt", language="onnx") + + +def test_load_model__tflite(tflite_mobilenet_v1_1_quant): + mod, params = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) + assert type(mod) is IRModule + assert type(params) is dict + # check whether one known value is part of the params dict + assert '_param_1' in params.keys() + + +def test_load_model__keras(keras_resnet50): + mod, params = tvmc.frontends.load_model(keras_resnet50) + assert type(mod) is IRModule + assert type(params) is dict + ## check whether one known value is part of the params dict + assert '_param_1' in params.keys() + + +def test_load_model__onnx(onnx_resnet50): + # some CI environments wont offer onnx, so skip in case it is not present + pytest.importorskip('onnx') + + mod, params = tvmc.frontends.load_model(onnx_resnet50) + assert type(mod) is IRModule + assert type(params) is dict + ## check whether one known value is part of the params dict + assert 'resnetv24_batchnorm0_gamma' in params.keys() + + +def test_load_model__pb(pb_mobilenet_v1_1_quant): + mod, params = tvmc.frontends.load_model(pb_mobilenet_v1_1_quant) + assert type(mod) is IRModule + assert type(params) is dict + # check whether one known value is part of the params dict + assert 'MobilenetV1/Conv2d_0/weights' in params.keys() + + +def test_load_model___wrong_language__to_keras(tflite_mobilenet_v1_1_quant): + with pytest.raises(OSError): + tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, language="keras") + + +def test_load_model___wrong_language__to_tflite(keras_resnet50): + with pytest.raises(TVMCException) as e: + def f(): + tvmc.frontends.load_model(keras_resnet50, language="tflite") + f() + assert 'input file not tflite' in str(e.value) + + +def test_load_model___wrong_language__to_onnx(tflite_mobilenet_v1_1_quant): + # some CI environments wont offer onnx, so skip in case it is not present + pytest.importorskip('onnx') + + from google.protobuf.message import DecodeError + + with pytest.raises(DecodeError): + tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, language="onnx") + + +def test_load_model___wrong_language__to_pytorch(tflite_mobilenet_v1_1_quant): + # some CI environments wont offer pytorch, so skip in case it is not present + pytest.importorskip('torch') + + with pytest.raises(RuntimeError) as e: + def f(): + tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, language="pytorch", shapes=(1,1,1,1)) + f() + assert 'PytorchStreamReader' in str(e.value) + + +def test_load_model__pytorch__no_shapes(): + # some CI environments wont offer pytorch, so skip in case it is not present + pytest.importorskip('torch') + + with pytest.raises(TVMCException) as e: + def f(): + tvmc.frontends.load_model("a/fake/path.pth", language="pytorch") + f() + assert '--input-shape must be specified for pytorch' in str(e.value) + + +def test_load_model__keras__with_shapes(): + # sending shapes to Keras should fail (only supported by Torch) + with pytest.raises(TVMCException) as e: + def f(): + tvmc.frontends.load_model("a/fake/path.h5", language="keras", shapes=(1,1,1,1)) + f() + assert '--input-shape is not supported for keras' in str(e.value) + + +def test_load_model__onnx__with_shapes(): + # some CI environments wont offer onnx, so skip in case it is not present + pytest.importorskip('onnx') + + # sending shapes to ONNX should fail (only supported by Torch) + with pytest.raises(TVMCException) as e: + def f(): + tvmc.frontends.load_model("a/fake/path.onnx", language="onnx", shapes=(1,1,1,1)) + f() + assert '--input-shape is not supported for onnx' in str(e.value) + + +def test_load_model__tensorflow__with_shapes(): + # sending shapes to TF should fail (only supported by Torch) + with pytest.raises(TVMCException) as e: + def f(): + tvmc.frontends.load_model("a/fake/path.pb", language="pb", shapes=(1,1,1,1)) + f() + assert '--input-shape is not supported for pb' in str(e.value) + + +def test_load_model__tflite__with_shapes(): + # sending shapes to TFLite should fail (only supported by Torch) + with pytest.raises(TVMCException) as e: + def f(): + tvmc.frontends.load_model("a/fake/path.tflite", language="tflite", shapes=(1,1,1,1)) + f() + assert '--input-shape is not supported for tflite' in str(e.value) diff --git a/tests/scripts/task_python_integration.sh b/tests/scripts/task_python_integration.sh index 35a81e508643..ef86d6917424 100755 --- a/tests/scripts/task_python_integration.sh +++ b/tests/scripts/task_python_integration.sh @@ -59,6 +59,9 @@ TVM_FFI=ctypes python3 -m pytest tests/python/contrib TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm;cuda}" TVM_FFI=ctypes python3 -m pytest tests/python/relay +# Command line driver test +TVM_FFI=ctypes python3 -m pytest tests/python/driver + # Do not enable OpenGL # TVM_FFI=cython python -m pytest tests/webgl # TVM_FFI=ctypes python3 -m pytest tests/webgl From 143f7e03f65b26615fe5602798d6f31373ede357 Mon Sep 17 00:00:00 2001 From: Leandro Nunes Date: Fri, 28 Aug 2020 16:17:33 +0100 Subject: [PATCH 02/10] tvmc: adjust TODOs --- python/tvm/driver/tvmc/compiler.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 1126c26c97c8..8815c2284148 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -211,6 +211,8 @@ def compile_relay( logging.info(f"using target input from file: {target}") target = "".join(target_file.readlines()) + # TODO: We don't have an API to collect a list of supported + # targets yet. (@leandron) logging.debug(f"creating target from input: {target}") tvm_target = tvm.target.create(target) target_host = "" @@ -231,8 +233,8 @@ def compile_relay( dumps = {} for source_type in dump_sources: lib = graph_module.get_lib() - # TODO lib.get_source call here have inconsistent behavior for unsupported - # formats. This is an open discussion (@leandron). + # TODO lib.get_source call have inconsistent behavior for unsupported + # formats (@leandron). source = str(mod) if source_type == "relay" else lib.get_source(source_type) dumps[source_type] = source From 5ec65510ec86b6e001d399ad78a15a05ad0464f0 Mon Sep 17 00:00:00 2001 From: Leandro Nunes Date: Fri, 28 Aug 2020 16:26:13 +0100 Subject: [PATCH 03/10] tvmc: fix linting errors --- python/tvm/driver/tvmc/compiler.py | 22 +++++++++------------- python/tvm/driver/tvmc/frontends.py | 2 -- 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 8815c2284148..dcb15386ff3e 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -17,8 +17,6 @@ """ Provides support to compile networks both AOT and JIT. """ -import argparse -import json import logging import os.path import tarfile @@ -27,10 +25,8 @@ import tvm from tvm import autotvm from tvm import relay -from tvm._ffi.runtime_ctypes import TVMContext from tvm.contrib import cc from tvm.contrib import util -from tvm.relay.op.contrib import get_pattern_table from . import common, frontends from .main import register_parser @@ -208,17 +204,17 @@ def compile_relay( if os.path.exists(str(target)): with open(target) as target_file: - logging.info(f"using target input from file: {target}") + logging.info("using target input from file: %s", target) target = "".join(target_file.readlines()) # TODO: We don't have an API to collect a list of supported # targets yet. (@leandron) - logging.debug(f"creating target from input: {target}") + logging.debug("creating target from input: %s", target) tvm_target = tvm.target.create(target) - target_host = "" + target_host = target_host or "" if tuning_records: - logging.debug(f"tuning records file provided: {tuning_records}") + logging.debug("tuning records file provided: %s", tuning_records) with autotvm.apply_history_best(tuning_records): with tvm.transform.PassContext(opt_level=3): logging.debug("building relay graph with tuning records") @@ -267,21 +263,21 @@ def save_module(module_path, graph, lib, params, cross=None): temp = util.tempdir() path_lib = temp.relpath(lib_name) if not cross: - logging.debug(f"exporting library to {path_lib}") + logging.debug("exporting library to %s", path_lib) lib.export_library(path_lib) else: - logging.debug(f"exporting library to {path_lib}, using cross compiler {cross}") + logging.debug("exporting library to %s , using cross compiler %s", path_lib, cross) lib.export_library(path_lib, cc.cross_compiler(cross)) with open(temp.relpath(graph_name), "w") as graph_file: - logging.debug(f"writing graph to file to {graph_file.name}") + logging.debug("writing graph to file to %s", graph_file.name) graph_file.write(graph) with open(temp.relpath(param_name), "wb") as params_file: - logging.debug(f"writing params to file to {params_file.name}") + logging.debug("writing params to file to %s", params_file.name) params_file.write(relay.save_param_dict(params)) - logging.debug(f"saving module as tar file to {module_path}") + logging.debug("saving module as tar file to %s", module_path) with tarfile.open(module_path, "w") as tar: tar.add(path_lib, lib_name) tar.add(temp.relpath(graph_name), graph_name) diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py index 635199dbc07d..b335963ded94 100644 --- a/python/tvm/driver/tvmc/frontends.py +++ b/python/tvm/driver/tvmc/frontends.py @@ -378,8 +378,6 @@ def load_model(path, language=None, shapes=None): The parameters (weights) for the relay module. """ - # pylint: disable=C0415 - import tvm.error if language is not None: frontend = lookup_frontend(language) From 1e6ef0c5029e702d17f49ef8a54635b461e0dfb0 Mon Sep 17 00:00:00 2001 From: Leandro Nunes Date: Mon, 7 Sep 2020 14:18:29 +0100 Subject: [PATCH 04/10] Address code-review comments --- python/tvm/driver/tvmc/common.py | 39 +----- python/tvm/driver/tvmc/compiler.py | 117 +++++++----------- python/tvm/driver/tvmc/frontends.py | 136 +++++++++++++-------- tests/python/driver/tvmc/test_common.py | 107 +++++++++++----- tests/python/driver/tvmc/test_compiler.py | 22 +++- tests/python/driver/tvmc/test_frontends.py | 86 +++---------- 6 files changed, 242 insertions(+), 265 deletions(-) diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index fd48c28491f7..c8fe5953f76a 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -17,9 +17,6 @@ """ Common utility functions shared by TVMC modules. """ -import argparse -import re - from tvm import relay from tvm import transform @@ -58,34 +55,10 @@ def convert_graph_layout(mod, desired_layout): relay.transform.ConvertLayout(desired_layouts), ] ) - with transform.PassContext(opt_level=3): - return seq(mod) - -def parse_input_shapes(shapes_str): - """ Parsing function for tensor shape syntax. """ - shapes = [] - # Split up string into comma seperated sections ignoring commas in ()s - match = re.findall(r"(\(.*?\)|.+?),?", shapes_str) - if match: - for inp in match: - # Test for and remove brackets - shape = re.match(r"\((.*)\)", inp) - if shape and shape.lastindex == 1: - # Remove white space and extract numbers - strshape = shape[1].replace(" ", "").split(",") - try: - shapes.append([int(i) for i in strshape]) - except ValueError: - raise argparse.ArgumentTypeError( - f"expected numbers in shape '{shape[1]}'" - ) - else: - raise argparse.ArgumentTypeError( - f"missing brackets around shape '{inp}', example '(1,2,3)'" - ) - else: - raise argparse.ArgumentTypeError( - f"unrecognized shapes '{shapes_str}', example '(1,2,3),(1,4),...'" - ) - return shapes + with transform.PassContext(opt_level=3): + try: + return seq(mod) + except Exception as err: + raise TVMCException( + "Error converting layout to {0}: {1}".format(desired_layout, str(err))) diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index dcb15386ff3e..e5a3cb3936c9 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -43,6 +43,12 @@ def add_compile_parser(subparsers): default="", help="the cross compiler to generate target libraries, e.g. 'aarch64-linux-gnu-gcc'", ) + parser.add_argument( + "--desired-layout", + choices=["NCHW", "NHWC"], + default=None, + help="change the data layout of the whole graph", + ) parser.add_argument( "--dump-code", metavar="FORMAT", @@ -51,15 +57,9 @@ def add_compile_parser(subparsers): ) parser.add_argument( "--model-format", - choices=frontends.get_frontends(), + choices=frontends.get_frontend_names(), help="specify input model format", ) - parser.add_argument( - "--input-shape", - type=common.parse_input_shapes, - metavar="INPUT_SHAPE,[INPUT_SHAPE]...", - help="for pytorch, e.g. '(1,3,224,224)'", - ) parser.add_argument( "-o", "--output", @@ -77,20 +77,29 @@ def add_compile_parser(subparsers): default="", help="path to an auto-tuning log file from AutoTVM" ) - parser.add_argument( - "--desired-layout", - choices=["NCHW", "NHWC"], - default=None, - help="change the data layout of the whole graph", - ) parser.add_argument( "-v", "--verbose", action="count", default=0, help="increase verbosity" ) - parser.add_argument("FILE") + #TODO (@leandron) This is a path to a physical file, but + # can be improved in future to add integration with a modelzoo + # or URL, for example. + parser.add_argument("FILE", help="path to the input model file") def drive_compile(args): - """ Invoke tvmc.compiler module with command line arguments """ + """ Invoke tvmc.compiler module with command line arguments + + Parameters + ---------- + args: argparse.Namespace + Arguments from command line parser. + + Returns + -------- + int + Zero if successfully completed + + """ graph, lib, params, dumps = compile_model( args.FILE, @@ -98,7 +107,6 @@ def drive_compile(args): args.dump_code, "", args.model_format, - args.input_shape, args.tuning_records, args.tensor_layout, ) @@ -113,10 +121,9 @@ def drive_compile(args): def compile_model( path, target, - dump_sources=None, + dump_code=None, target_host=None, model_format=None, - shapes=None, tuning_records=None, alter_layout=None, ): @@ -126,58 +133,21 @@ def compile_model( and compiler.compile_relay. The resulting TVM module can be executed using the graph runtime. - Returns - ------- - graph : str - A JSON-serialized TVM execution graph. - lib : tvm.module.Module - A TVM module containing the compiled functions. - params : dict - The parameters (weights) for the TVM module. - dumps : dict - Dictionary containing the dumps specified. - - """ - dump_sources = [x.strip() for x in dump_sources.split(',')] if dump_sources else None - mod, params = frontends.load_model(path, model_format, shapes) - - return compile_relay( - mod, - params, - target, - dump_sources=dump_sources, - target_host=target_host, - tuning_records=tuning_records, - alter_layout=alter_layout, - ) - - -def compile_relay( - mod, - params, - target, - dump_sources=None, - target_host=None, - tuning_records=None, - alter_layout=None, -): - """Compile a relay module to a TVM module for the graph runtime. - Parameters ---------- - mod : tvm.relay.Module - The relay module to compile. - params : dict - The parameters (weights) for the relay module. + path: str + Path to a file target : str The target for which to compile. Can be a plain string or a path. - dump_sources : list, optional + dump_code : list, optional Dump the generated code for the specified source types, on the requested target. - target_host : Union[str, tvm.target.Target], optional + target_host : str, optional The target of the host machine if host-side code needs to be generated. + model_format: str, optional + A string representing a name of a frontend to be used tuning_records: str, optional Name of the file produced by the tuning to be used during compilation. @@ -198,6 +168,8 @@ def compile_relay( Dictionary containing the dumps specified. """ + dump_code = [x.strip() for x in dump_code.split(',')] if dump_code else None + mod, params = frontends.load_model(path, model_format) if alter_layout: mod = common.convert_graph_layout(mod, alter_layout) @@ -207,13 +179,15 @@ def compile_relay( logging.info("using target input from file: %s", target) target = "".join(target_file.readlines()) - # TODO: We don't have an API to collect a list of supported - # targets yet. (@leandron) + # TODO(@leandron) We don't have an API to collect a list of supported + # targets yet logging.debug("creating target from input: %s", target) tvm_target = tvm.target.create(target) target_host = target_host or "" if tuning_records: + # TODO (@leandron) a new PR will introduce the 'tune' subcommand + # the is used to generate the tuning records file logging.debug("tuning records file provided: %s", tuning_records) with autotvm.apply_history_best(tuning_records): with tvm.transform.PassContext(opt_level=3): @@ -225,9 +199,9 @@ def compile_relay( graph_module = relay.build(mod, tvm_target, params=params, target_host=tvm_target) # Generate output dump files with sources - dump_sources = dump_sources or [] + dump_code = dump_code or [] dumps = {} - for source_type in dump_sources: + for source_type in dump_code: lib = graph_module.get_lib() # TODO lib.get_source call have inconsistent behavior for unsupported # formats (@leandron). @@ -253,7 +227,7 @@ def save_module(module_path, graph, lib, params, cross=None): A TVM module containing the compiled functions. params : dict The parameters (weights) for the TVM module. - cross : Union[str, Callable[[str, str, Optional[str]], None]] + cross : str or callable object, optional Function that performs the actual compilation """ @@ -290,13 +264,14 @@ def save_dumps(module_name, dumps, dump_root="."): Parameters ---------- - module_name : list(Union[str, tvm.target.Target]) - file name, referring to the module that generated + module_name : str + File name, referring to the module that generated the dump contents dumps : dict - the output contents to be saved into the files - dump_root : str - path in which dump files will be created + The output contents to be saved into the files + dump_root : str, optional + Path in which dump files will be created + """ for dump_format in dumps: diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py index b335963ded94..6d9842cf9341 100644 --- a/python/tvm/driver/tvmc/frontends.py +++ b/python/tvm/driver/tvmc/frontends.py @@ -27,6 +27,9 @@ from abc import abstractmethod from pathlib import Path +import numpy as np + +from tvm import relay from tvm.driver.tvmc.common import TVMCException @@ -44,8 +47,22 @@ def suffixes(): """File suffixes (extensions) used by this frontend""" @abstractmethod - def load(self, path, shapes): - """Load network""" + def load(self, path): + """Load a model from a given path. + + Parameters + ---------- + path: str + Path to a file + + Returns + ------- + mod : tvm.relay.Module + The produced relay module. + params : dict + The parameters (weights) for the relay module. + + """ def import_keras(): @@ -75,19 +92,10 @@ def name(): def suffixes(): return ["h5"] - def load(self, path, shapes): - # pylint: disable=C0415 - import numpy as np - from tvm import relay - + def load(self, path): # pylint: disable=C0103 tf, keras = import_keras() - if shapes: - raise TVMCException( - "--input-shape is not supported for {}".format(self.name()) - ) - # tvm build currently imports keras directly instead of tensorflow.keras try: model = keras.models.load_model(path) @@ -146,20 +154,12 @@ def name(): def suffixes(): return ["onnx"] - def load(self, path, shapes): + def load(self, path): # pylint: disable=C0415 import onnx - from tvm import relay - - if shapes: - raise TVMCException( - "--input-shape is not supported for {}".format(self.name()) - ) model = onnx.load(path) - # Find the name and shape of the first input in the graph - # pylint: disable=E1101 name = model.graph.input[0].name @@ -183,17 +183,11 @@ def name(): def suffixes(): return ["pb"] - def load(self, path, shapes): + def load(self, path): # pylint: disable=C0415 - from tvm import relay import tensorflow as tf import tvm.relay.testing.tf as tf_testing - if shapes: - raise TVMCException( - "--input-shape is not supported for {}".format(self.name()) - ) - with tf.io.gfile.GFile(path, "rb") as tf_graph: content = tf_graph.read() @@ -229,15 +223,9 @@ def name(): def suffixes(): return ["tflite"] - def load(self, path, shapes): + def load(self, path): # pylint: disable=C0415 import tflite.Model as model - from tvm import relay - - if shapes: - raise TVMCException( - "--input-shape is not supported for {}".format(self.name()) - ) with open(path, "rb") as tf_graph: content = tf_graph.read() @@ -306,17 +294,15 @@ def suffixes(): # Torch Script is a zip file, but can be named pth return ["pth", "zip"] - def load(self, path, shapes): + def load(self, path): # pylint: disable=C0415 import torch - from tvm import relay - - if not shapes: - raise TVMCException( - "--input-shape must be specified for {}".format(self.name()) - ) traced_model = torch.jit.load(path) + + inputs = list(traced_model.graph.inputs())[1:] + input_shapes = [inp.type().sizes() for inp in inputs] + traced_model.eval() # Switch to inference mode input_shapes = [ ("input{}".format(idx), shape) for idx, shape in enumerate(shapes) @@ -334,19 +320,61 @@ def load(self, path, shapes): ] -def get_frontends(): - """Return the names of all supported frontends""" +def get_frontend_names(): + """Return the names of all supported frontends + + Returns + ------- + list : list of str + A list of frontend names as strings + + """ return [frontend.name() for frontend in ALL_FRONTENDS] -def lookup_frontend(name): +def get_frontend_by_name(name): + """ + This function will try to get a frontend instance, based + on the name provided. + + Parameters + ---------- + name : str + the name of a given frontend + + Returns + ------- + frontend : tvm.driver.tvmc.Frontend + An instance of the frontend that matches with + the file extension provided in `path`. + + """ + for frontend in ALL_FRONTENDS: if name == frontend.name(): return frontend() + raise TVMCException("unrecognized frontend") -def guess_input_language(path): +def guess_frontend(path): + """ + This function will try to imply which framework is being used, + based on the extension of the file provided in the path parameter. + + Parameters + ---------- + path : str + The path to the model file. + + Returns + ------- + frontend : tvm.driver.tvmc.Frontend + An instance of the frontend that matches with + the file extension provided in `path`. + + """ + suffix = Path(path).suffix.lower() if suffix.startswith("."): suffix = suffix[1:] @@ -355,10 +383,10 @@ def guess_input_language(path): if suffix in frontend.suffixes(): return frontend() - raise TVMCException("cannot guess input language") + raise TVMCException("cannot guess model format") -def load_model(path, language=None, shapes=None): +def load_model(path, model_format=None): """Load a model from a supported framework and convert it into an equivalent relay representation. @@ -366,8 +394,8 @@ def load_model(path, language=None, shapes=None): ---------- path : str The path to the model file. - language : str, optional - The language of the model file. + model_format : str, optional + The underlying framework used to create the model. If not specified, this will be inferred from the file type. Returns @@ -379,11 +407,11 @@ def load_model(path, language=None, shapes=None): """ - if language is not None: - frontend = lookup_frontend(language) + if model_format is not None: + frontend = get_frontend_by_name(model_format) else: - frontend = guess_input_language(path) + frontend = guess_frontend(path) - mod, params = frontend.load(path, shapes) + mod, params = frontend.load(path) return mod, params diff --git a/tests/python/driver/tvmc/test_common.py b/tests/python/driver/tvmc/test_common.py index a3be9ff51a4f..2a80167d6592 100644 --- a/tests/python/driver/tvmc/test_common.py +++ b/tests/python/driver/tvmc/test_common.py @@ -20,52 +20,93 @@ import pytest +import tvm from tvm.driver import tvmc -def test_parse_input_shapes__list_lengths(): - shape_string = "(1,224,224,3)" - sut = tvmc.common.parse_input_shapes(shape_string) +def test_compile_tflite_module_nhwc_to_nchw(tflite_mobilenet_v1_1_quant): + # some CI environments wont offer TFLite, so skip in case it is not present + pytest.importorskip('tflite') - # output is a list with a list [[1, 224, 224, 3]] - assert type(sut) is list - assert len(sut) == 1 - assert type(sut[0]) is list - assert len(sut[0]) == 4 + before, _ = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) + expected_layout="NCHW" + after = tvmc.common.convert_graph_layout(before, expected_layout) -def test_parse_input_shapes__lists_match(): - shape_string = "(1,224,224,3)" - sut = tvmc.common.parse_input_shapes(shape_string) + layout_transform_calls = [] + def _is_layout_transform(node): + if isinstance(node, tvm.relay.expr.Call): + layout_transform_calls.append( + node.op.name == "layout_transform" \ + and node.attrs.src_layout == 'NHWC' \ + and node.attrs.dst_layout == 'NCHW') - assert sut[0] == [1, 224, 224, 3] + tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) + assert any(layout_transform_calls), "Expected 'layout_transform NHWC->NCHW' not found" -def test_parse_input_shapes__spaces_are_ignored(): - shape_string = "(1, 224, 224, 3)" - sut = tvmc.common.parse_input_shapes(shape_string) - assert type(sut) is list - assert len(sut) == 1 - assert type(sut[0]) is list - assert len(sut[0]) == 4 +def test_compile_onnx_module_nchw_to_nhwc(onnx_resnet50): + # some CI environments wont offer ONNX, so skip in case it is not present + pytest.importorskip('onnx') + before, _ = tvmc.frontends.load_model(onnx_resnet50) -def test_parse_input_shapes__missing(): - shape_string = "(1,224,,3)" - with pytest.raises(argparse.ArgumentTypeError) as e: - def f(): - _ = tvmc.common.parse_input_shapes(shape_string) - f() + expected_layout="NHWC" + after = tvmc.common.convert_graph_layout(before, expected_layout) - assert 'expected numbers in shape' in str(e.value) + layout_transform_calls = [] + def _is_layout_transform(node): + if isinstance(node, tvm.relay.expr.Call): + layout_transform_calls.append( + node.op.name == "layout_transform" \ + and node.attrs.src_layout == 'NCHW' \ + and node.attrs.dst_layout == 'NHWC') + tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) -def test_parse_input_shapes_no_brackets(): - shape_string = "1,224,224,3" - with pytest.raises(argparse.ArgumentTypeError) as e: - def f(): - _ = tvmc.common.parse_input_shapes(shape_string) - f() + assert any(layout_transform_calls), "Expected 'layout_transform NCWH->NHWC' not found" - assert 'missing brackets around shape' in str(e.value) + +def test_compile_tflite_module__same_layout__nhwc_to_nhwc(tflite_mobilenet_v1_1_quant): + # some CI environments wont offer TFLite, so skip in case it is not present + pytest.importorskip('tflite') + + before, _ = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) + + expected_layout="NHWC" + after = tvmc.common.convert_graph_layout(before, expected_layout) + + layout_transform_calls = [] + def _is_layout_transform(node): + if isinstance(node, tvm.relay.expr.Call): + layout_transform_calls.append( + node.op.name == "layout_transform" \ + and node.attrs.src_layout == 'NHWC' \ + and node.attrs.dst_layout == 'NHWC') + + tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) + + assert not any(layout_transform_calls), "Unexpected 'layout_transform' call" + + +def test_compile_onnx_module__same_layout__nchw_to_nchw(onnx_resnet50): + # some CI environments wont offer ONNX, so skip in case it is not present + pytest.importorskip('onnx') + + before, _ = tvmc.frontends.load_model(onnx_resnet50) + + expected_layout="NCHW" + after = tvmc.common.convert_graph_layout(before, expected_layout) + + layout_transform_calls = [] + def _is_layout_transform(node): + if isinstance(node, tvm.relay.expr.Call): + layout_transform_calls.append( + node.op.name == "layout_transform" \ + and node.attrs.src_layout == 'NCHW' \ + and node.attrs.dst_layout == 'NCHW') + + tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) + + assert not any(layout_transform_calls), "Unexpected 'layout_transform' call" \ No newline at end of file diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index a03ce243d858..d9652d1a482a 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -40,10 +40,12 @@ def test_save_dumps(tmpdir_factory): def test_compile_tflite_module(tflite_mobilenet_v1_1_quant): + pytest.importorskip('tflite') + graph, lib, params, dumps = tvmc.compiler.compile_model( tflite_mobilenet_v1_1_quant, target="llvm", - dump_sources="ll", + dump_code="ll", alter_layout="NCHW", ) @@ -58,10 +60,12 @@ def test_compile_tflite_module(tflite_mobilenet_v1_1_quant): @pytest.mark.skipif(not shutil.which("aarch64-linux-gnu-gcc"), reason="cross-compilation toolchain not installed") def test_cross_compile_aarch64_tflite_module(tflite_mobilenet_v1_1_quant): + pytest.importorskip('tflite') + graph, lib, params, dumps = tvmc.compiler.compile_model( tflite_mobilenet_v1_1_quant, target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon", - dump_sources="asm", + dump_code="asm", ) # check for output types @@ -72,8 +76,11 @@ def test_cross_compile_aarch64_tflite_module(tflite_mobilenet_v1_1_quant): def test_compile_keras__save_module(keras_resnet50, tmpdir_factory): + # some CI environments wont offer tensorflow/Keras, so skip in case it is not present + pytest.importorskip('tensorflow') + graph, lib, params, dumps = tvmc.compiler.compile_model( - keras_resnet50, target="llvm", dump_sources="ll" + keras_resnet50, target="llvm", dump_code="ll" ) expected_temp_dir = tmpdir_factory.mktemp("saved_output") @@ -88,10 +95,13 @@ def test_compile_keras__save_module(keras_resnet50, tmpdir_factory): @pytest.mark.skipif(not shutil.which("aarch64-linux-gnu-gcc"), reason="cross-compilation toolchain not installed") def test_cross_compile_aarch64_keras_module(keras_resnet50): + # some CI environments wont offer tensorflow/Keras, so skip in case it is not present + pytest.importorskip('tensorflow') + graph, lib, params, dumps = tvmc.compiler.compile_model( keras_resnet50, target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon", - dump_sources="asm", + dump_code="asm", ) # check for output types @@ -107,7 +117,7 @@ def test_compile_onnx_module(onnx_resnet50): pytest.importorskip('onnx') graph, lib, params, dumps = tvmc.compiler.compile_model( - onnx_resnet50, target="llvm", dump_sources="ll" + onnx_resnet50, target="llvm", dump_code="ll" ) # check for output types @@ -128,7 +138,7 @@ def test_cross_compile_aarch64_onnx_module(onnx_resnet50): graph, lib, params, dumps = tvmc.compiler.compile_model( onnx_resnet50, target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon", - dump_sources="asm", + dump_code="asm", ) # check for output types diff --git a/tests/python/driver/tvmc/test_frontends.py b/tests/python/driver/tvmc/test_frontends.py index f0c83c222525..c9b160135dd1 100644 --- a/tests/python/driver/tvmc/test_frontends.py +++ b/tests/python/driver/tvmc/test_frontends.py @@ -25,31 +25,31 @@ from tvm.driver.tvmc.common import TVMCException -def test_get_frontends_is_list(): - sut = tvmc.frontends.get_frontends() +def test_get_frontend_names_is_list(): + sut = tvmc.frontends.get_frontend_names() assert type(sut) is list def test_get_frontends_contains_only_strings(): - sut = tvmc.frontends.get_frontends() + sut = tvmc.frontends.get_frontend_names() assert all([type(x) is str for x in sut]) is True -def test_lookup_frontend_valid(): - sut = tvmc.frontends.lookup_frontend("keras") +def test_get_frontend_by_name_valid(): + sut = tvmc.frontends.get_frontend_by_name("keras") assert type(sut) is tvmc.frontends.KerasFrontend -def test_lookup_frontend_invalid(): +def test_get_frontend_by_name_invalid(): with pytest.raises(TVMCException) as e: def f(): - tvmc.frontends.lookup_frontend("unsupported_thingy") + tvmc.frontends.get_frontend_by_name("unsupported_thing") f() assert 'unrecognized frontend' in str(e.value) def test_guess_frontend_tflite(): - sut = tvmc.frontends.guess_input_language("a_model.tflite") + sut = tvmc.frontends.guess_frontend("a_model.tflite") assert type(sut) is tvmc.frontends.TFLiteFrontend @@ -57,7 +57,7 @@ def test_guess_frontend_onnx(): # some CI environments wont offer onnx, so skip in case it is not present pytest.importorskip('onnx') - sut = tvmc.frontends.guess_input_language("a_model.onnx") + sut = tvmc.frontends.guess_frontend("a_model.onnx") assert type(sut) is tvmc.frontends.OnnxFrontend @@ -65,23 +65,23 @@ def test_guess_frontend_pytorch(): # some CI environments wont offer pytorch, so skip in case it is not present pytest.importorskip('torch') - sut = tvmc.frontends.guess_input_language("a_model.pth") + sut = tvmc.frontends.guess_frontend("a_model.pth") assert type(sut) is tvmc.frontends.PyTorchFrontend def test_guess_frontend_keras(): - sut = tvmc.frontends.guess_input_language("a_model.h5") + sut = tvmc.frontends.guess_frontend("a_model.h5") assert type(sut) is tvmc.frontends.KerasFrontend def test_guess_frontend_tensorflow(): - sut = tvmc.frontends.guess_input_language("a_model.pb") + sut = tvmc.frontends.guess_frontend("a_model.pb") assert type(sut) is tvmc.frontends.TensorflowFrontend def test_guess_frontend_invalid(): with pytest.raises(TVMCException): - tvmc.frontends.guess_input_language("not/a/file.txt") + tvmc.frontends.guess_frontend("not/a/file.txt") def test_load_model__invalid_path__no_language(): @@ -94,7 +94,7 @@ def test_load_model__invalid_path__with_language(): pytest.importorskip('onnx') with pytest.raises(FileNotFoundError): - tvmc.frontends.load_model("not/a/file.txt", language="onnx") + tvmc.frontends.load_model("not/a/file.txt", model_format="onnx") def test_load_model__tflite(tflite_mobilenet_v1_1_quant): @@ -134,13 +134,13 @@ def test_load_model__pb(pb_mobilenet_v1_1_quant): def test_load_model___wrong_language__to_keras(tflite_mobilenet_v1_1_quant): with pytest.raises(OSError): - tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, language="keras") + tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, model_format="keras") def test_load_model___wrong_language__to_tflite(keras_resnet50): with pytest.raises(TVMCException) as e: def f(): - tvmc.frontends.load_model(keras_resnet50, language="tflite") + tvmc.frontends.load_model(keras_resnet50, model_format="tflite") f() assert 'input file not tflite' in str(e.value) @@ -152,7 +152,7 @@ def test_load_model___wrong_language__to_onnx(tflite_mobilenet_v1_1_quant): from google.protobuf.message import DecodeError with pytest.raises(DecodeError): - tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, language="onnx") + tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, model_format="onnx") def test_load_model___wrong_language__to_pytorch(tflite_mobilenet_v1_1_quant): @@ -161,56 +161,6 @@ def test_load_model___wrong_language__to_pytorch(tflite_mobilenet_v1_1_quant): with pytest.raises(RuntimeError) as e: def f(): - tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, language="pytorch", shapes=(1,1,1,1)) + tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, model_format="pytorch") f() assert 'PytorchStreamReader' in str(e.value) - - -def test_load_model__pytorch__no_shapes(): - # some CI environments wont offer pytorch, so skip in case it is not present - pytest.importorskip('torch') - - with pytest.raises(TVMCException) as e: - def f(): - tvmc.frontends.load_model("a/fake/path.pth", language="pytorch") - f() - assert '--input-shape must be specified for pytorch' in str(e.value) - - -def test_load_model__keras__with_shapes(): - # sending shapes to Keras should fail (only supported by Torch) - with pytest.raises(TVMCException) as e: - def f(): - tvmc.frontends.load_model("a/fake/path.h5", language="keras", shapes=(1,1,1,1)) - f() - assert '--input-shape is not supported for keras' in str(e.value) - - -def test_load_model__onnx__with_shapes(): - # some CI environments wont offer onnx, so skip in case it is not present - pytest.importorskip('onnx') - - # sending shapes to ONNX should fail (only supported by Torch) - with pytest.raises(TVMCException) as e: - def f(): - tvmc.frontends.load_model("a/fake/path.onnx", language="onnx", shapes=(1,1,1,1)) - f() - assert '--input-shape is not supported for onnx' in str(e.value) - - -def test_load_model__tensorflow__with_shapes(): - # sending shapes to TF should fail (only supported by Torch) - with pytest.raises(TVMCException) as e: - def f(): - tvmc.frontends.load_model("a/fake/path.pb", language="pb", shapes=(1,1,1,1)) - f() - assert '--input-shape is not supported for pb' in str(e.value) - - -def test_load_model__tflite__with_shapes(): - # sending shapes to TFLite should fail (only supported by Torch) - with pytest.raises(TVMCException) as e: - def f(): - tvmc.frontends.load_model("a/fake/path.tflite", language="tflite", shapes=(1,1,1,1)) - f() - assert '--input-shape is not supported for tflite' in str(e.value) From 9db5680ec3a3e09c6c0d54c007f749a04f8eaa28 Mon Sep 17 00:00:00 2001 From: Leandro Nunes Date: Mon, 7 Sep 2020 16:21:53 +0100 Subject: [PATCH 05/10] Adjust pytest fixture to not break when there is no tensorflow --- tests/python/driver/tvmc/conftest.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/python/driver/tvmc/conftest.py b/tests/python/driver/tvmc/conftest.py index 4efece55011f..52b4d5adbf6d 100644 --- a/tests/python/driver/tvmc/conftest.py +++ b/tests/python/driver/tvmc/conftest.py @@ -20,8 +20,6 @@ import tvm.driver.tvmc.compiler -from tensorflow.keras.applications.resnet50 import ResNet50 - from tvm.contrib.download import download_testdata from tvm.driver.tvmc.common import convert_graph_layout @@ -78,6 +76,13 @@ def pb_mobilenet_v1_1_quant(tmpdir_factory): @pytest.fixture(scope="session") def keras_resnet50(tmpdir_factory): + try: + from tensorflow.keras.applications.resnet50 import ResNet50 + except ImportError: + # not all environments provide TensorFlow, so skip this fixture + # if that is that case. + return "" + model_file_name = "{}/{}".format(tmpdir_factory.mktemp("data"), "resnet50.h5") model = ResNet50(include_top=True, weights='imagenet', input_shape=(224, 224, 3), classes=1000) model.save(model_file_name) From c217622be4e071329dc492103ac6549fb4bf4d7d Mon Sep 17 00:00:00 2001 From: Leandro Nunes Date: Mon, 7 Sep 2020 19:06:34 +0100 Subject: [PATCH 06/10] Fix frontend tests, to cope with different frameworks in different images --- tests/python/driver/tvmc/test_frontends.py | 30 ++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/python/driver/tvmc/test_frontends.py b/tests/python/driver/tvmc/test_frontends.py index c9b160135dd1..60b975f82fb7 100644 --- a/tests/python/driver/tvmc/test_frontends.py +++ b/tests/python/driver/tvmc/test_frontends.py @@ -36,6 +36,9 @@ def test_get_frontends_contains_only_strings(): def test_get_frontend_by_name_valid(): + # some CI environments wont offer TensorFlow/Keras, so skip in case it is not present + pytest.importorskip('tensorflow') + sut = tvmc.frontends.get_frontend_by_name("keras") assert type(sut) is tvmc.frontends.KerasFrontend @@ -49,6 +52,9 @@ def f(): def test_guess_frontend_tflite(): + # some CI environments wont offer TFLite, so skip in case it is not present + pytest.importorskip('tflite') + sut = tvmc.frontends.guess_frontend("a_model.tflite") assert type(sut) is tvmc.frontends.TFLiteFrontend @@ -70,11 +76,17 @@ def test_guess_frontend_pytorch(): def test_guess_frontend_keras(): + # some CI environments wont offer TensorFlow/Keras, so skip in case it is not present + pytest.importorskip('tensorflow') + sut = tvmc.frontends.guess_frontend("a_model.h5") assert type(sut) is tvmc.frontends.KerasFrontend def test_guess_frontend_tensorflow(): + # some CI environments wont offer TensorFlow, so skip in case it is not present + pytest.importorskip('tensorflow') + sut = tvmc.frontends.guess_frontend("a_model.pb") assert type(sut) is tvmc.frontends.TensorflowFrontend @@ -85,6 +97,9 @@ def test_guess_frontend_invalid(): def test_load_model__invalid_path__no_language(): + # some CI environments wont offer TFLite, so skip in case it is not present + pytest.importorskip('tflite') + with pytest.raises(FileNotFoundError): tvmc.frontends.load_model("not/a/file.tflite") @@ -98,6 +113,9 @@ def test_load_model__invalid_path__with_language(): def test_load_model__tflite(tflite_mobilenet_v1_1_quant): + # some CI environments wont offer TFLite, so skip in case it is not present + pytest.importorskip('tflite') + mod, params = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) assert type(mod) is IRModule assert type(params) is dict @@ -106,6 +124,9 @@ def test_load_model__tflite(tflite_mobilenet_v1_1_quant): def test_load_model__keras(keras_resnet50): + # some CI environments wont offer TensorFlow/Keras, so skip in case it is not present + pytest.importorskip('tensorflow') + mod, params = tvmc.frontends.load_model(keras_resnet50) assert type(mod) is IRModule assert type(params) is dict @@ -125,6 +146,9 @@ def test_load_model__onnx(onnx_resnet50): def test_load_model__pb(pb_mobilenet_v1_1_quant): + # some CI environments wont offer TensorFlow, so skip in case it is not present + pytest.importorskip('tensorflow') + mod, params = tvmc.frontends.load_model(pb_mobilenet_v1_1_quant) assert type(mod) is IRModule assert type(params) is dict @@ -133,11 +157,17 @@ def test_load_model__pb(pb_mobilenet_v1_1_quant): def test_load_model___wrong_language__to_keras(tflite_mobilenet_v1_1_quant): + # some CI environments wont offer TensorFlow/Keras, so skip in case it is not present + pytest.importorskip('tensorflow') + with pytest.raises(OSError): tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, model_format="keras") def test_load_model___wrong_language__to_tflite(keras_resnet50): + # some CI environments wont offer TFLite, so skip in case it is not present + pytest.importorskip('tflite') + with pytest.raises(TVMCException) as e: def f(): tvmc.frontends.load_model(keras_resnet50, model_format="tflite") From 98875507be3e486d691c99ea3a729f012340c86e Mon Sep 17 00:00:00 2001 From: Leandro Nunes Date: Fri, 11 Sep 2020 09:15:49 +0100 Subject: [PATCH 07/10] Apply suggestions from code review Co-authored-by: Cody Yu --- python/tvm/driver/tvmc/compiler.py | 6 +++--- python/tvm/driver/tvmc/frontends.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index e5a3cb3936c9..ca534a8d82cc 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -75,7 +75,7 @@ def add_compile_parser(subparsers): "--tuning-records", metavar="PATH", default="", - help="path to an auto-tuning log file from AutoTVM" + help="path to an auto-tuning log file by AutoTVM. If not presented, the fallback/tophub configs will be used" ) parser.add_argument( "-v", "--verbose", action="count", default=0, help="increase verbosity" @@ -105,7 +105,7 @@ def drive_compile(args): args.FILE, args.target, args.dump_code, - "", + None, args.model_format, args.tuning_records, args.tensor_layout, @@ -182,7 +182,7 @@ def compile_model( # TODO(@leandron) We don't have an API to collect a list of supported # targets yet logging.debug("creating target from input: %s", target) - tvm_target = tvm.target.create(target) + tvm_target = tvm.target.Target(target) target_host = target_host or "" if tuning_records: diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py index 6d9842cf9341..3b8b811ae268 100644 --- a/python/tvm/driver/tvmc/frontends.py +++ b/python/tvm/driver/tvmc/frontends.py @@ -383,7 +383,7 @@ def guess_frontend(path): if suffix in frontend.suffixes(): return frontend() - raise TVMCException("cannot guess model format") + raise TVMCException("failed to infer the model format. Please specify --model-format") def load_model(path, model_format=None): From 7fd5afa55d53df119f552ba881ba1c62d13a36e8 Mon Sep 17 00:00:00 2001 From: Leandro Nunes Date: Fri, 11 Sep 2020 11:04:10 +0100 Subject: [PATCH 08/10] Fix lint and code-review issues --- python/setup.py | 2 +- python/tvm/driver/tvmc/compiler.py | 12 ++++++----- python/tvm/driver/tvmc/frontends.py | 8 +++++-- tests/python/driver/tvmc/test_frontends.py | 25 +++++----------------- 4 files changed, 19 insertions(+), 28 deletions(-) diff --git a/python/setup.py b/python/setup.py index 4065b1a239cf..c19b9363059a 100644 --- a/python/setup.py +++ b/python/setup.py @@ -164,7 +164,7 @@ def get_package_data_files(): 'typed_ast', 'tensorflow==2.1.0', 'tflite==2.1.0', - 'onnx==1.6.0', + 'onnx==1.7.0', 'onnxruntime==1.0.0', 'torch==1.4.0', 'torchvision==0.5.0' diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index ca534a8d82cc..29b3f649a7ad 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -75,7 +75,8 @@ def add_compile_parser(subparsers): "--tuning-records", metavar="PATH", default="", - help="path to an auto-tuning log file by AutoTVM. If not presented, the fallback/tophub configs will be used" + help="path to an auto-tuning log file by AutoTVM. If not presented, " \ + "the fallback/tophub configs will be used" ) parser.add_argument( "-v", "--verbose", action="count", default=0, help="increase verbosity" @@ -149,7 +150,7 @@ def compile_model( model_format: str, optional A string representing a name of a frontend to be used tuning_records: str, optional - Name of the file produced by the tuning to be used during + Path to the file produced by the tuning to be used during compilation. alter_layout: str, optional The layout to convert the graph to. Note, the convert layout @@ -174,7 +175,8 @@ def compile_model( if alter_layout: mod = common.convert_graph_layout(mod, alter_layout) - if os.path.exists(str(target)): + # Handle the case in which target is a path to a JSON file. + if os.path.exists(target): with open(target) as target_file: logging.info("using target input from file: %s", target) target = "".join(target_file.readlines()) @@ -182,10 +184,10 @@ def compile_model( # TODO(@leandron) We don't have an API to collect a list of supported # targets yet logging.debug("creating target from input: %s", target) - tvm_target = tvm.target.Target(target) + tvm_target = tvm.target.create(target) target_host = target_host or "" - if tuning_records: + if tuning_records and os.path.exists(tuning_records): # TODO (@leandron) a new PR will introduce the 'tune' subcommand # the is used to generate the tuning records file logging.debug("tuning records file provided: %s", tuning_records) diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py index 3b8b811ae268..d891bca72b83 100644 --- a/python/tvm/driver/tvmc/frontends.py +++ b/python/tvm/driver/tvmc/frontends.py @@ -34,7 +34,10 @@ class Frontend(ABC): - """Abstract class for frontend""" + """Abstract class for command line driver frontend. + + Provide a unified way to import models (as files), and deal + with any required preprocessing to create a TVM module from it.""" @staticmethod @abstractmethod @@ -354,7 +357,8 @@ def get_frontend_by_name(name): if name == frontend.name(): return frontend() - raise TVMCException("unrecognized frontend") + raise TVMCException( + "unrecognized frontend '{0}'. Choose from: {1}".format(name, get_frontend_names())) def guess_frontend(path): diff --git a/tests/python/driver/tvmc/test_frontends.py b/tests/python/driver/tvmc/test_frontends.py index 60b975f82fb7..8604d7594316 100644 --- a/tests/python/driver/tvmc/test_frontends.py +++ b/tests/python/driver/tvmc/test_frontends.py @@ -25,11 +25,6 @@ from tvm.driver.tvmc.common import TVMCException -def test_get_frontend_names_is_list(): - sut = tvmc.frontends.get_frontend_names() - assert type(sut) is list - - def test_get_frontends_contains_only_strings(): sut = tvmc.frontends.get_frontend_names() assert all([type(x) is str for x in sut]) is True @@ -44,12 +39,8 @@ def test_get_frontend_by_name_valid(): def test_get_frontend_by_name_invalid(): - with pytest.raises(TVMCException) as e: - def f(): - tvmc.frontends.get_frontend_by_name("unsupported_thing") - f() - assert 'unrecognized frontend' in str(e.value) - + with pytest.raises(TVMCException): + tvmc.frontends.get_frontend_by_name("unsupported_thing") def test_guess_frontend_tflite(): # some CI environments wont offer TFLite, so skip in case it is not present @@ -168,11 +159,8 @@ def test_load_model___wrong_language__to_tflite(keras_resnet50): # some CI environments wont offer TFLite, so skip in case it is not present pytest.importorskip('tflite') - with pytest.raises(TVMCException) as e: - def f(): - tvmc.frontends.load_model(keras_resnet50, model_format="tflite") - f() - assert 'input file not tflite' in str(e.value) + with pytest.raises(TVMCException): + tvmc.frontends.load_model(keras_resnet50, model_format="tflite") def test_load_model___wrong_language__to_onnx(tflite_mobilenet_v1_1_quant): @@ -190,7 +178,4 @@ def test_load_model___wrong_language__to_pytorch(tflite_mobilenet_v1_1_quant): pytest.importorskip('torch') with pytest.raises(RuntimeError) as e: - def f(): - tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, model_format="pytorch") - f() - assert 'PytorchStreamReader' in str(e.value) + tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, model_format="pytorch") From dd2af42ab72ead11bfc52317aa3db07cf9233ab9 Mon Sep 17 00:00:00 2001 From: Leandro Nunes Date: Tue, 15 Sep 2020 09:23:43 +0100 Subject: [PATCH 09/10] Re-format with black. --- python/setup.py | 72 +++++++++++----------- python/tvm/driver/tvmc/common.py | 3 +- python/tvm/driver/tvmc/compiler.py | 34 +++++----- python/tvm/driver/tvmc/frontends.py | 20 ++---- tests/python/driver/tvmc/conftest.py | 29 ++++++--- tests/python/driver/tvmc/test_common.py | 50 ++++++++------- tests/python/driver/tvmc/test_compiler.py | 27 ++++---- tests/python/driver/tvmc/test_frontends.py | 41 ++++++------ 8 files changed, 143 insertions(+), 133 deletions(-) diff --git a/python/setup.py b/python/setup.py index c19b9363059a..5be2ca311fa5 100644 --- a/python/setup.py +++ b/python/setup.py @@ -147,43 +147,41 @@ def is_pure(self): def get_package_data_files(): # Relay standard libraries - return ['relay/std/prelude.rly', 'relay/std/core.rly'] - - -setup(name='tvm', - version=__version__, - description="TVM: An End to End Tensor IR/DSL Stack for Deep Learning Systems", - zip_safe=False, - entry_points={"console_scripts": ["tvmc = tvm.driver.tvmc.main:main"]}, - install_requires=[ - 'numpy', - 'scipy', - 'decorator', - 'attrs', - 'psutil', - 'typed_ast', - 'tensorflow==2.1.0', - 'tflite==2.1.0', - 'onnx==1.7.0', - 'onnxruntime==1.0.0', - 'torch==1.4.0', - 'torchvision==0.5.0' - ], - extras_require={'test': ['pillow<7', - 'matplotlib'], - 'extra_feature': ['tornado', - 'psutil', - 'xgboost>=1.1.0', - 'mypy', - 'orderedset']}, - - packages=find_packages(), - package_dir={'tvm': 'tvm'}, - package_data={'tvm': get_package_data_files()}, - distclass=BinaryDistribution, - url='https://github.com/apache/incubator-tvm', - ext_modules=config_cython(), - **setup_kwargs) + return ["relay/std/prelude.rly", "relay/std/core.rly"] + + +setup( + name="tvm", + version=__version__, + description="TVM: An End to End Tensor IR/DSL Stack for Deep Learning Systems", + zip_safe=False, + entry_points={"console_scripts": ["tvmc = tvm.driver.tvmc.main:main"]}, + install_requires=[ + "numpy", + "scipy", + "decorator", + "attrs", + "psutil", + "typed_ast", + "tensorflow==2.1.0", + "tflite==2.1.0", + "onnx==1.7.0", + "onnxruntime==1.0.0", + "torch==1.4.0", + "torchvision==0.5.0", + ], + extras_require={ + "test": ["pillow<7", "matplotlib"], + "extra_feature": ["tornado", "psutil", "xgboost>=1.1.0", "mypy", "orderedset"], + }, + packages=find_packages(), + package_dir={"tvm": "tvm"}, + package_data={"tvm": get_package_data_files()}, + distclass=BinaryDistribution, + url="https://github.com/apache/incubator-tvm", + ext_modules=config_cython(), + **setup_kwargs, +) if wheel_include_libs: diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index c8fe5953f76a..f389c81f0337 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -61,4 +61,5 @@ def convert_graph_layout(mod, desired_layout): return seq(mod) except Exception as err: raise TVMCException( - "Error converting layout to {0}: {1}".format(desired_layout, str(err))) + "Error converting layout to {0}: {1}".format(desired_layout, str(err)) + ) diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 29b3f649a7ad..77703b2d06e1 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -53,7 +53,7 @@ def add_compile_parser(subparsers): "--dump-code", metavar="FORMAT", default="", - help="comma separarated list of formats to export, e.g. 'asm,ll,relay' " + help="comma separarated list of formats to export, e.g. 'asm,ll,relay' ", ) parser.add_argument( "--model-format", @@ -69,26 +69,24 @@ def add_compile_parser(subparsers): parser.add_argument( "--target", help="compilation target as plain string, inline JSON or path to a JSON file", - required=True + required=True, ) parser.add_argument( "--tuning-records", metavar="PATH", default="", - help="path to an auto-tuning log file by AutoTVM. If not presented, " \ - "the fallback/tophub configs will be used" + help="path to an auto-tuning log file by AutoTVM. If not presented, " + "the fallback/tophub configs will be used", ) - parser.add_argument( - "-v", "--verbose", action="count", default=0, help="increase verbosity" - ) - #TODO (@leandron) This is a path to a physical file, but + parser.add_argument("-v", "--verbose", action="count", default=0, help="increase verbosity") + # TODO (@leandron) This is a path to a physical file, but # can be improved in future to add integration with a modelzoo # or URL, for example. parser.add_argument("FILE", help="path to the input model file") def drive_compile(args): - """ Invoke tvmc.compiler module with command line arguments + """Invoke tvmc.compiler module with command line arguments Parameters ---------- @@ -120,13 +118,13 @@ def drive_compile(args): def compile_model( - path, - target, - dump_code=None, - target_host=None, - model_format=None, - tuning_records=None, - alter_layout=None, + path, + target, + dump_code=None, + target_host=None, + model_format=None, + tuning_records=None, + alter_layout=None, ): """Compile a model from a supported framework into a TVM module. @@ -169,7 +167,7 @@ def compile_model( Dictionary containing the dumps specified. """ - dump_code = [x.strip() for x in dump_code.split(',')] if dump_code else None + dump_code = [x.strip() for x in dump_code.split(",")] if dump_code else None mod, params = frontends.load_model(path, model_format) if alter_layout: @@ -184,7 +182,7 @@ def compile_model( # TODO(@leandron) We don't have an API to collect a list of supported # targets yet logging.debug("creating target from input: %s", target) - tvm_target = tvm.target.create(target) + tvm_target = tvm.target.Target(target) target_host = target_host or "" if tuning_records and os.path.exists(tuning_records): diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py index d891bca72b83..6275f779f778 100644 --- a/python/tvm/driver/tvmc/frontends.py +++ b/python/tvm/driver/tvmc/frontends.py @@ -114,20 +114,13 @@ def load(self, path): in_shapes = [] for layer in model._input_layers: if tf.executing_eagerly(): - in_shapes.append( - tuple(dim if dim is not None else 1 for dim in layer.input.shape) - ) + in_shapes.append(tuple(dim if dim is not None else 1 for dim in layer.input.shape)) else: in_shapes.append( - tuple( - dim.value if dim.value is not None else 1 - for dim in layer.input.shape - ) + tuple(dim.value if dim.value is not None else 1 for dim in layer.input.shape) ) - inputs = [ - np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in in_shapes - ] + inputs = [np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in in_shapes] shape_dict = {name: x.shape for (name, x) in zip(model.input_names, inputs)} return relay.frontend.from_keras(model, shape_dict, layout="NHWC") @@ -307,9 +300,7 @@ def load(self, path): input_shapes = [inp.type().sizes() for inp in inputs] traced_model.eval() # Switch to inference mode - input_shapes = [ - ("input{}".format(idx), shape) for idx, shape in enumerate(shapes) - ] + input_shapes = [("input{}".format(idx), shape) for idx, shape in enumerate(shapes)] logging.debug("relay.frontend.from_pytorch") return relay.frontend.from_pytorch(traced_model, input_shapes) @@ -358,7 +349,8 @@ def get_frontend_by_name(name): return frontend() raise TVMCException( - "unrecognized frontend '{0}'. Choose from: {1}".format(name, get_frontend_names())) + "unrecognized frontend '{0}'. Choose from: {1}".format(name, get_frontend_names()) + ) def guess_frontend(path): diff --git a/tests/python/driver/tvmc/conftest.py b/tests/python/driver/tvmc/conftest.py index 52b4d5adbf6d..ee67cc904aac 100644 --- a/tests/python/driver/tvmc/conftest.py +++ b/tests/python/driver/tvmc/conftest.py @@ -26,9 +26,10 @@ # Support functions + def download_and_untar(model_url, model_sub_path, temp_dir): model_tar_name = os.path.basename(model_url) - model_path = download_testdata(model_url, model_tar_name, module=['tvmc']) + model_path = download_testdata(model_url, model_tar_name, module=["tvmc"]) if model_path.endswith("tgz") or model_path.endswith("gz"): tar = tarfile.open(model_path) @@ -42,9 +43,11 @@ def get_sample_compiled_module(target_dir): """Support function that retuns a TFLite compiled module""" base_url = "https://storage.googleapis.com/download.tensorflow.org/models" model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz" - model_file = download_and_untar("{}/{}".format(base_url, model_url), + model_file = download_and_untar( + "{}/{}".format(base_url, model_url), "mobilenet_v1_1.0_224_quant.tflite", - temp_dir=target_dir) + temp_dir=target_dir, + ) return tvmc.compiler.compile_model(model_file, targets=["llvm"]) @@ -56,9 +59,11 @@ def get_sample_compiled_module(target_dir): def tflite_mobilenet_v1_1_quant(tmpdir_factory): base_url = "https://storage.googleapis.com/download.tensorflow.org/models" model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz" - model_file = download_and_untar("{}/{}".format(base_url, model_url), + model_file = download_and_untar( + "{}/{}".format(base_url, model_url), "mobilenet_v1_1.0_224_quant.tflite", - temp_dir=tmpdir_factory.mktemp("data")) + temp_dir=tmpdir_factory.mktemp("data"), + ) return model_file @@ -67,9 +72,11 @@ def tflite_mobilenet_v1_1_quant(tmpdir_factory): def pb_mobilenet_v1_1_quant(tmpdir_factory): base_url = "https://storage.googleapis.com/download.tensorflow.org/models" model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz" - model_file = download_and_untar("{}/{}".format(base_url, model_url), + model_file = download_and_untar( + "{}/{}".format(base_url, model_url), "mobilenet_v1_1.0_224_frozen.pb", - temp_dir=tmpdir_factory.mktemp("data")) + temp_dir=tmpdir_factory.mktemp("data"), + ) return model_file @@ -84,7 +91,7 @@ def keras_resnet50(tmpdir_factory): return "" model_file_name = "{}/{}".format(tmpdir_factory.mktemp("data"), "resnet50.h5") - model = ResNet50(include_top=True, weights='imagenet', input_shape=(224, 224, 3), classes=1000) + model = ResNet50(include_top=True, weights="imagenet", input_shape=(224, 224, 3), classes=1000) model.save(model_file_name) return model_file_name @@ -94,7 +101,9 @@ def keras_resnet50(tmpdir_factory): def onnx_resnet50(): base_url = "https://github.com/onnx/models/raw/master/vision/classification/resnet/model" file_to_download = "resnet50-v2-7.onnx" - model_file = download_testdata("{}/{}".format(base_url, file_to_download), file_to_download, module=['tvmc']) + model_file = download_testdata( + "{}/{}".format(base_url, file_to_download), file_to_download, module=["tvmc"] + ) return model_file @@ -104,7 +113,7 @@ def tflite_compiled_module_as_tarfile(tmpdir_factory): target_dir = tmpdir_factory.mktemp("data") graph, lib, params, _ = get_sample_compiled_module(target_dir) - module_file = os.path.join(target_dir, 'mock.tar') + module_file = os.path.join(target_dir, "mock.tar") tvmc.compiler.save_module(module_file, graph, lib, params) return module_file diff --git a/tests/python/driver/tvmc/test_common.py b/tests/python/driver/tvmc/test_common.py index 2a80167d6592..a9a62c5ef874 100644 --- a/tests/python/driver/tvmc/test_common.py +++ b/tests/python/driver/tvmc/test_common.py @@ -26,20 +26,22 @@ def test_compile_tflite_module_nhwc_to_nchw(tflite_mobilenet_v1_1_quant): # some CI environments wont offer TFLite, so skip in case it is not present - pytest.importorskip('tflite') + pytest.importorskip("tflite") before, _ = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) - expected_layout="NCHW" + expected_layout = "NCHW" after = tvmc.common.convert_graph_layout(before, expected_layout) layout_transform_calls = [] + def _is_layout_transform(node): if isinstance(node, tvm.relay.expr.Call): layout_transform_calls.append( - node.op.name == "layout_transform" \ - and node.attrs.src_layout == 'NHWC' \ - and node.attrs.dst_layout == 'NCHW') + node.op.name == "layout_transform" + and node.attrs.src_layout == "NHWC" + and node.attrs.dst_layout == "NCHW" + ) tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) @@ -48,20 +50,22 @@ def _is_layout_transform(node): def test_compile_onnx_module_nchw_to_nhwc(onnx_resnet50): # some CI environments wont offer ONNX, so skip in case it is not present - pytest.importorskip('onnx') + pytest.importorskip("onnx") before, _ = tvmc.frontends.load_model(onnx_resnet50) - expected_layout="NHWC" + expected_layout = "NHWC" after = tvmc.common.convert_graph_layout(before, expected_layout) layout_transform_calls = [] + def _is_layout_transform(node): if isinstance(node, tvm.relay.expr.Call): layout_transform_calls.append( - node.op.name == "layout_transform" \ - and node.attrs.src_layout == 'NCHW' \ - and node.attrs.dst_layout == 'NHWC') + node.op.name == "layout_transform" + and node.attrs.src_layout == "NCHW" + and node.attrs.dst_layout == "NHWC" + ) tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) @@ -70,20 +74,22 @@ def _is_layout_transform(node): def test_compile_tflite_module__same_layout__nhwc_to_nhwc(tflite_mobilenet_v1_1_quant): # some CI environments wont offer TFLite, so skip in case it is not present - pytest.importorskip('tflite') + pytest.importorskip("tflite") before, _ = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) - expected_layout="NHWC" + expected_layout = "NHWC" after = tvmc.common.convert_graph_layout(before, expected_layout) layout_transform_calls = [] + def _is_layout_transform(node): if isinstance(node, tvm.relay.expr.Call): layout_transform_calls.append( - node.op.name == "layout_transform" \ - and node.attrs.src_layout == 'NHWC' \ - and node.attrs.dst_layout == 'NHWC') + node.op.name == "layout_transform" + and node.attrs.src_layout == "NHWC" + and node.attrs.dst_layout == "NHWC" + ) tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) @@ -92,21 +98,23 @@ def _is_layout_transform(node): def test_compile_onnx_module__same_layout__nchw_to_nchw(onnx_resnet50): # some CI environments wont offer ONNX, so skip in case it is not present - pytest.importorskip('onnx') + pytest.importorskip("onnx") before, _ = tvmc.frontends.load_model(onnx_resnet50) - expected_layout="NCHW" + expected_layout = "NCHW" after = tvmc.common.convert_graph_layout(before, expected_layout) layout_transform_calls = [] + def _is_layout_transform(node): if isinstance(node, tvm.relay.expr.Call): layout_transform_calls.append( - node.op.name == "layout_transform" \ - and node.attrs.src_layout == 'NCHW' \ - and node.attrs.dst_layout == 'NCHW') + node.op.name == "layout_transform" + and node.attrs.src_layout == "NCHW" + and node.attrs.dst_layout == "NCHW" + ) tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) - assert not any(layout_transform_calls), "Unexpected 'layout_transform' call" \ No newline at end of file + assert not any(layout_transform_calls), "Unexpected 'layout_transform' call" diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index d9652d1a482a..28a60b19b28e 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -40,7 +40,7 @@ def test_save_dumps(tmpdir_factory): def test_compile_tflite_module(tflite_mobilenet_v1_1_quant): - pytest.importorskip('tflite') + pytest.importorskip("tflite") graph, lib, params, dumps = tvmc.compiler.compile_model( tflite_mobilenet_v1_1_quant, @@ -57,10 +57,11 @@ def test_compile_tflite_module(tflite_mobilenet_v1_1_quant): # This test will be skipped if the AArch64 cross-compilation toolchain is not installed. -@pytest.mark.skipif(not shutil.which("aarch64-linux-gnu-gcc"), - reason="cross-compilation toolchain not installed") +@pytest.mark.skipif( + not shutil.which("aarch64-linux-gnu-gcc"), reason="cross-compilation toolchain not installed" +) def test_cross_compile_aarch64_tflite_module(tflite_mobilenet_v1_1_quant): - pytest.importorskip('tflite') + pytest.importorskip("tflite") graph, lib, params, dumps = tvmc.compiler.compile_model( tflite_mobilenet_v1_1_quant, @@ -77,7 +78,7 @@ def test_cross_compile_aarch64_tflite_module(tflite_mobilenet_v1_1_quant): def test_compile_keras__save_module(keras_resnet50, tmpdir_factory): # some CI environments wont offer tensorflow/Keras, so skip in case it is not present - pytest.importorskip('tensorflow') + pytest.importorskip("tensorflow") graph, lib, params, dumps = tvmc.compiler.compile_model( keras_resnet50, target="llvm", dump_code="ll" @@ -92,11 +93,12 @@ def test_compile_keras__save_module(keras_resnet50, tmpdir_factory): # This test will be skipped if the AArch64 cross-compilation toolchain is not installed. -@pytest.mark.skipif(not shutil.which("aarch64-linux-gnu-gcc"), - reason="cross-compilation toolchain not installed") +@pytest.mark.skipif( + not shutil.which("aarch64-linux-gnu-gcc"), reason="cross-compilation toolchain not installed" +) def test_cross_compile_aarch64_keras_module(keras_resnet50): # some CI environments wont offer tensorflow/Keras, so skip in case it is not present - pytest.importorskip('tensorflow') + pytest.importorskip("tensorflow") graph, lib, params, dumps = tvmc.compiler.compile_model( keras_resnet50, @@ -114,7 +116,7 @@ def test_cross_compile_aarch64_keras_module(keras_resnet50): def test_compile_onnx_module(onnx_resnet50): # some CI environments wont offer onnx, so skip in case it is not present - pytest.importorskip('onnx') + pytest.importorskip("onnx") graph, lib, params, dumps = tvmc.compiler.compile_model( onnx_resnet50, target="llvm", dump_code="ll" @@ -129,11 +131,12 @@ def test_compile_onnx_module(onnx_resnet50): # This test will be skipped if the AArch64 cross-compilation toolchain is not installed. -@pytest.mark.skipif(not shutil.which("aarch64-linux-gnu-gcc"), - reason="cross-compilation toolchain not installed") +@pytest.mark.skipif( + not shutil.which("aarch64-linux-gnu-gcc"), reason="cross-compilation toolchain not installed" +) def test_cross_compile_aarch64_onnx_module(onnx_resnet50): # some CI environments wont offer onnx, so skip in case it is not present - pytest.importorskip('onnx') + pytest.importorskip("onnx") graph, lib, params, dumps = tvmc.compiler.compile_model( onnx_resnet50, diff --git a/tests/python/driver/tvmc/test_frontends.py b/tests/python/driver/tvmc/test_frontends.py index 8604d7594316..d77a17addabf 100644 --- a/tests/python/driver/tvmc/test_frontends.py +++ b/tests/python/driver/tvmc/test_frontends.py @@ -32,7 +32,7 @@ def test_get_frontends_contains_only_strings(): def test_get_frontend_by_name_valid(): # some CI environments wont offer TensorFlow/Keras, so skip in case it is not present - pytest.importorskip('tensorflow') + pytest.importorskip("tensorflow") sut = tvmc.frontends.get_frontend_by_name("keras") assert type(sut) is tvmc.frontends.KerasFrontend @@ -42,9 +42,10 @@ def test_get_frontend_by_name_invalid(): with pytest.raises(TVMCException): tvmc.frontends.get_frontend_by_name("unsupported_thing") + def test_guess_frontend_tflite(): # some CI environments wont offer TFLite, so skip in case it is not present - pytest.importorskip('tflite') + pytest.importorskip("tflite") sut = tvmc.frontends.guess_frontend("a_model.tflite") assert type(sut) is tvmc.frontends.TFLiteFrontend @@ -52,7 +53,7 @@ def test_guess_frontend_tflite(): def test_guess_frontend_onnx(): # some CI environments wont offer onnx, so skip in case it is not present - pytest.importorskip('onnx') + pytest.importorskip("onnx") sut = tvmc.frontends.guess_frontend("a_model.onnx") assert type(sut) is tvmc.frontends.OnnxFrontend @@ -60,7 +61,7 @@ def test_guess_frontend_onnx(): def test_guess_frontend_pytorch(): # some CI environments wont offer pytorch, so skip in case it is not present - pytest.importorskip('torch') + pytest.importorskip("torch") sut = tvmc.frontends.guess_frontend("a_model.pth") assert type(sut) is tvmc.frontends.PyTorchFrontend @@ -68,7 +69,7 @@ def test_guess_frontend_pytorch(): def test_guess_frontend_keras(): # some CI environments wont offer TensorFlow/Keras, so skip in case it is not present - pytest.importorskip('tensorflow') + pytest.importorskip("tensorflow") sut = tvmc.frontends.guess_frontend("a_model.h5") assert type(sut) is tvmc.frontends.KerasFrontend @@ -76,7 +77,7 @@ def test_guess_frontend_keras(): def test_guess_frontend_tensorflow(): # some CI environments wont offer TensorFlow, so skip in case it is not present - pytest.importorskip('tensorflow') + pytest.importorskip("tensorflow") sut = tvmc.frontends.guess_frontend("a_model.pb") assert type(sut) is tvmc.frontends.TensorflowFrontend @@ -89,7 +90,7 @@ def test_guess_frontend_invalid(): def test_load_model__invalid_path__no_language(): # some CI environments wont offer TFLite, so skip in case it is not present - pytest.importorskip('tflite') + pytest.importorskip("tflite") with pytest.raises(FileNotFoundError): tvmc.frontends.load_model("not/a/file.tflite") @@ -97,7 +98,7 @@ def test_load_model__invalid_path__no_language(): def test_load_model__invalid_path__with_language(): # some CI environments wont offer onnx, so skip in case it is not present - pytest.importorskip('onnx') + pytest.importorskip("onnx") with pytest.raises(FileNotFoundError): tvmc.frontends.load_model("not/a/file.txt", model_format="onnx") @@ -105,51 +106,51 @@ def test_load_model__invalid_path__with_language(): def test_load_model__tflite(tflite_mobilenet_v1_1_quant): # some CI environments wont offer TFLite, so skip in case it is not present - pytest.importorskip('tflite') + pytest.importorskip("tflite") mod, params = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) assert type(mod) is IRModule assert type(params) is dict # check whether one known value is part of the params dict - assert '_param_1' in params.keys() + assert "_param_1" in params.keys() def test_load_model__keras(keras_resnet50): # some CI environments wont offer TensorFlow/Keras, so skip in case it is not present - pytest.importorskip('tensorflow') + pytest.importorskip("tensorflow") mod, params = tvmc.frontends.load_model(keras_resnet50) assert type(mod) is IRModule assert type(params) is dict ## check whether one known value is part of the params dict - assert '_param_1' in params.keys() + assert "_param_1" in params.keys() def test_load_model__onnx(onnx_resnet50): # some CI environments wont offer onnx, so skip in case it is not present - pytest.importorskip('onnx') + pytest.importorskip("onnx") mod, params = tvmc.frontends.load_model(onnx_resnet50) assert type(mod) is IRModule assert type(params) is dict ## check whether one known value is part of the params dict - assert 'resnetv24_batchnorm0_gamma' in params.keys() + assert "resnetv24_batchnorm0_gamma" in params.keys() def test_load_model__pb(pb_mobilenet_v1_1_quant): # some CI environments wont offer TensorFlow, so skip in case it is not present - pytest.importorskip('tensorflow') + pytest.importorskip("tensorflow") mod, params = tvmc.frontends.load_model(pb_mobilenet_v1_1_quant) assert type(mod) is IRModule assert type(params) is dict # check whether one known value is part of the params dict - assert 'MobilenetV1/Conv2d_0/weights' in params.keys() + assert "MobilenetV1/Conv2d_0/weights" in params.keys() def test_load_model___wrong_language__to_keras(tflite_mobilenet_v1_1_quant): # some CI environments wont offer TensorFlow/Keras, so skip in case it is not present - pytest.importorskip('tensorflow') + pytest.importorskip("tensorflow") with pytest.raises(OSError): tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, model_format="keras") @@ -157,7 +158,7 @@ def test_load_model___wrong_language__to_keras(tflite_mobilenet_v1_1_quant): def test_load_model___wrong_language__to_tflite(keras_resnet50): # some CI environments wont offer TFLite, so skip in case it is not present - pytest.importorskip('tflite') + pytest.importorskip("tflite") with pytest.raises(TVMCException): tvmc.frontends.load_model(keras_resnet50, model_format="tflite") @@ -165,7 +166,7 @@ def test_load_model___wrong_language__to_tflite(keras_resnet50): def test_load_model___wrong_language__to_onnx(tflite_mobilenet_v1_1_quant): # some CI environments wont offer onnx, so skip in case it is not present - pytest.importorskip('onnx') + pytest.importorskip("onnx") from google.protobuf.message import DecodeError @@ -175,7 +176,7 @@ def test_load_model___wrong_language__to_onnx(tflite_mobilenet_v1_1_quant): def test_load_model___wrong_language__to_pytorch(tflite_mobilenet_v1_1_quant): # some CI environments wont offer pytorch, so skip in case it is not present - pytest.importorskip('torch') + pytest.importorskip("torch") with pytest.raises(RuntimeError) as e: tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, model_format="pytorch") From c55f8f29a29a52b0cc16442ff19c1ea6b32f5b27 Mon Sep 17 00:00:00 2001 From: Leandro Nunes Date: Thu, 17 Sep 2020 11:26:10 +0100 Subject: [PATCH 10/10] tvmc: Move dependencies to extras_requires --- python/setup.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/python/setup.py b/python/setup.py index 5be2ca311fa5..fff7a0ed3bb1 100644 --- a/python/setup.py +++ b/python/setup.py @@ -163,16 +163,24 @@ def get_package_data_files(): "attrs", "psutil", "typed_ast", - "tensorflow==2.1.0", - "tflite==2.1.0", - "onnx==1.7.0", - "onnxruntime==1.0.0", - "torch==1.4.0", - "torchvision==0.5.0", ], extras_require={ "test": ["pillow<7", "matplotlib"], - "extra_feature": ["tornado", "psutil", "xgboost>=1.1.0", "mypy", "orderedset"], + "extra_feature": [ + "tornado", + "psutil", + "xgboost>=1.1.0", + "mypy", + "orderedset", + ], + "tvmc": [ + "tensorflow>=2.1.0", + "tflite>=2.1.0", + "onnx>=1.7.0", + "onnxruntime>=1.0.0", + "torch>=1.4.0", + "torchvision>=0.5.0", + ], }, packages=find_packages(), package_dir={"tvm": "tvm"},