diff --git a/python/tvm/driver/tvmc/autotuner.py b/python/tvm/driver/tvmc/autotuner.py index 71ccc8546e8b..fe5bebcabcbc 100644 --- a/python/tvm/driver/tvmc/autotuner.py +++ b/python/tvm/driver/tvmc/autotuner.py @@ -210,6 +210,13 @@ def add_tune_parser(subparsers): # can be improved in future to add integration with a modelzoo # or URL, for example. parser.add_argument("FILE", help="path to the input model file") + parser.add_argument( + "--input-shapes", + help="specify non-generic shapes for model to run, format is " + '"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]"', + type=common.parse_shape_string, + default=None, + ) def drive_tune(args): @@ -235,7 +242,7 @@ def drive_tune(args): ) target = common.target_from_cli(args.target) - mod, params = frontends.load_model(args.FILE, args.model_format) + mod, params = frontends.load_model(args.FILE, args.model_format, shape_dict=args.input_shapes) # min_repeat_ms should be: # a. the value provided by the user, if any, or diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index 9db22f3f3390..1845915bcbd1 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -17,8 +17,10 @@ """ Common utility functions shared by TVMC modules. """ +import re import logging import os.path +import argparse from urllib.parse import urlparse @@ -136,3 +138,40 @@ def tracker_host_port_from_cli(rpc_tracker_str): logger.info("RPC tracker port: %s", rpc_port) return rpc_hostname, rpc_port + + +def parse_shape_string(inputs_string): + """Parse an input shape dictionary string to a usable dictionary. + + Parameters + ---------- + inputs_string: str + A string of the form "input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]" that + indicates the desired shape for specific model inputs. + + Returns + ------- + shape_dict: dict + A dictionary mapping input names to their shape for use in relay frontend converters. + """ + + # Create a regex pattern that extracts each separate input mapping. + pattern = r"\w+\:\s*\[\-?\d+(?:\,\s*\-?\d+)*\]" + input_mappings = re.findall(pattern, inputs_string) + if not input_mappings: + raise argparse.ArgumentTypeError( + "--input-shapes argument must be of the form " + '"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]"' + ) + shape_dict = {} + for mapping in input_mappings: + # Remove whitespace. + mapping = mapping.replace(" ", "") + # Split mapping into name and shape. + name, shape_string = mapping.split(":") + # Convert shape string into a list of integers or Anys if negative. + shape = [int(x) if int(x) > 0 else relay.Any() for x in shape_string.strip("][").split(",")] + # Add parsed mapping to shape dictionary. + shape_dict[name] = shape + + return shape_dict diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 90b0aceaa17a..282ae6a76b56 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -87,6 +87,13 @@ def add_compile_parser(subparsers): # can be improved in future to add integration with a modelzoo # or URL, for example. parser.add_argument("FILE", help="path to the input model file") + parser.add_argument( + "--input-shapes", + help="specify non-generic shapes for model to run, format is " + '"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]"', + type=common.parse_shape_string, + default=None, + ) def drive_compile(args): @@ -98,7 +105,7 @@ def drive_compile(args): Arguments from command line parser. Returns - -------- + ------- int Zero if successfully completed @@ -112,6 +119,7 @@ def drive_compile(args): args.model_format, args.tuning_records, args.desired_layout, + args.input_shapes, ) if dumps: @@ -129,6 +137,7 @@ def compile_model( model_format=None, tuning_records=None, alter_layout=None, + shape_dict=None, ): """Compile a model from a supported framework into a TVM module. @@ -158,6 +167,9 @@ def compile_model( The layout to convert the graph to. Note, the convert layout pass doesn't currently guarantee the whole of the graph will be converted to the chosen layout. + shape_dict: dict, optional + A mapping from input names to their shape. When present, + the default shapes in the model will be overwritten. Returns ------- @@ -172,7 +184,7 @@ def compile_model( """ dump_code = [x.strip() for x in dump_code.split(",")] if dump_code else None - mod, params = frontends.load_model(path, model_format) + mod, params = frontends.load_model(path, model_format, shape_dict) if alter_layout: mod = common.convert_graph_layout(mod, alter_layout) diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py index bb54b82cceca..53fbed66c8fc 100644 --- a/python/tvm/driver/tvmc/frontends.py +++ b/python/tvm/driver/tvmc/frontends.py @@ -54,13 +54,15 @@ def suffixes(): """File suffixes (extensions) used by this frontend""" @abstractmethod - def load(self, path): + def load(self, path, shape_dict=None): """Load a model from a given path. Parameters ---------- path: str Path to a file + shape_dict: dict, optional + Mapping from input names to their shapes. Returns ------- @@ -99,7 +101,7 @@ def name(): def suffixes(): return ["h5"] - def load(self, path): + def load(self, path, shape_dict=None): # pylint: disable=C0103 tf, keras = import_keras() @@ -125,8 +127,10 @@ def load(self, path): ) inputs = [np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in in_shapes] - shape_dict = {name: x.shape for (name, x) in zip(model.input_names, inputs)} - return relay.frontend.from_keras(model, shape_dict, layout="NHWC") + input_shapes = {name: x.shape for (name, x) in zip(model.input_names, inputs)} + if shape_dict is not None: + input_shapes.update(shape_dict) + return relay.frontend.from_keras(model, input_shapes, layout="NHWC") def is_sequential_p(self, model): _, keras = import_keras() @@ -154,14 +158,14 @@ def name(): def suffixes(): return ["onnx"] - def load(self, path): + def load(self, path, shape_dict=None): # pylint: disable=C0415 import onnx # pylint: disable=E1101 model = onnx.load(path) - return relay.frontend.from_onnx(model) + return relay.frontend.from_onnx(model, shape=shape_dict) class TensorflowFrontend(Frontend): @@ -175,7 +179,7 @@ def name(): def suffixes(): return ["pb"] - def load(self, path): + def load(self, path, shape_dict=None): # pylint: disable=C0415 import tensorflow as tf import tvm.relay.testing.tf as tf_testing @@ -188,7 +192,7 @@ def load(self, path): graph_def = tf_testing.ProcessGraphDefParam(graph_def) logger.debug("parse TensorFlow model and convert into Relay computation graph") - return relay.frontend.from_tensorflow(graph_def) + return relay.frontend.from_tensorflow(graph_def, shape=shape_dict) class TFLiteFrontend(Frontend): @@ -215,7 +219,7 @@ def name(): def suffixes(): return ["tflite"] - def load(self, path): + def load(self, path, shape_dict=None): # pylint: disable=C0415 import tflite.Model as model @@ -238,11 +242,13 @@ def load(self, path): raise TVMCException("input file not tflite version 3") logger.debug("tflite_input_type") - shape_dict, dtype_dict = TFLiteFrontend._input_type(tflite_model) + input_shapes, dtype_dict = TFLiteFrontend._input_type(tflite_model) + if shape_dict is not None: + input_shapes.update(shape_dict) logger.debug("parse TFLite model and convert into Relay computation graph") mod, params = relay.frontend.from_tflite( - tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict + tflite_model, shape_dict=input_shapes, dtype_dict=dtype_dict ) return mod, params @@ -285,17 +291,18 @@ def suffixes(): # Torch Script is a zip file, but can be named pth return ["pth", "zip"] - def load(self, path): + def load(self, path, shape_dict=None): # pylint: disable=C0415 import torch - traced_model = torch.jit.load(path) - - inputs = list(traced_model.graph.inputs())[1:] - input_shapes = [inp.type().sizes() for inp in inputs] + if shape_dict is None: + raise TVMCException("--input-shapes must be specified for %s" % self.name()) + traced_model = torch.jit.load(path) traced_model.eval() # Switch to inference mode - input_shapes = [("input{}".format(idx), shape) for idx, shape in enumerate(shapes)] + + # Convert shape dictionary to list for Pytorch frontend compatibility + input_shapes = list(shape_dict.items()) logger.debug("parse Torch model and convert into Relay computation graph") return relay.frontend.from_pytorch(traced_model, input_shapes) @@ -378,7 +385,7 @@ def guess_frontend(path): raise TVMCException("failed to infer the model format. Please specify --model-format") -def load_model(path, model_format=None): +def load_model(path, model_format=None, shape_dict=None): """Load a model from a supported framework and convert it into an equivalent relay representation. @@ -389,6 +396,8 @@ def load_model(path, model_format=None): model_format : str, optional The underlying framework used to create the model. If not specified, this will be inferred from the file type. + shape_dict : dict, optional + Mapping from input names to their shapes. Returns ------- @@ -404,6 +413,6 @@ def load_model(path, model_format=None): else: frontend = guess_frontend(path) - mod, params = frontend.load(path) + mod, params = frontend.load(path, shape_dict) return mod, params diff --git a/tests/python/driver/tvmc/conftest.py b/tests/python/driver/tvmc/conftest.py index 882d793ccebd..534953deecbc 100644 --- a/tests/python/driver/tvmc/conftest.py +++ b/tests/python/driver/tvmc/conftest.py @@ -99,6 +99,23 @@ def keras_resnet50(tmpdir_factory): return model_file_name +@pytest.fixture(scope="session") +def pytorch_resnet18(tmpdir_factory): + try: + import torch + import torchvision.models as models + except ImportError: + # Not all environments provide Pytorch, so skip if that's the case. + return "" + model = models.resnet18() + model_file_name = "{}/{}".format(tmpdir_factory.mktemp("data"), "resnet18.pth") + # Trace model into torchscript. + traced_cpu = torch.jit.trace(model, torch.randn(1, 3, 224, 224)) + torch.jit.save(traced_cpu, model_file_name) + + return model_file_name + + @pytest.fixture(scope="session") def onnx_resnet50(): base_url = "https://github.com/onnx/models/raw/master/vision/classification/resnet/model" diff --git a/tests/python/driver/tvmc/test_common.py b/tests/python/driver/tvmc/test_common.py index 5ffbc6fe37dd..f30949b54497 100644 --- a/tests/python/driver/tvmc/test_common.py +++ b/tests/python/driver/tvmc/test_common.py @@ -21,6 +21,7 @@ import pytest import tvm +from tvm import relay from tvm.driver import tvmc @@ -149,3 +150,35 @@ def test_tracker_host_port_from_cli__only_hostname__default_port_is_9090(): assert expected_host == actual_host assert expected_port == actual_port + + +def test_shape_parser(): + # Check that a valid input is parsed correctly + shape_string = "input:[10,10,10]" + shape_dict = tvmc.common.parse_shape_string(shape_string) + assert shape_dict == {"input": [10, 10, 10]} + # Check that multiple valid input shapes are parse correctly + shape_string = "input:[10,10,10] input2:[20,20,20,20]" + shape_dict = tvmc.common.parse_shape_string(shape_string) + assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]} + # Check that alternate syntax parses correctly + shape_string = "input: [10, 10, 10] input2: [20, 20, 20, 20]" + shape_dict = tvmc.common.parse_shape_string(shape_string) + assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]} + shape_string = "input:[10,10,10],input2:[20,20,20,20]" + shape_dict = tvmc.common.parse_shape_string(shape_string) + assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]} + # Check that negative dimensions parse to Any correctly. + shape_string = "input:[-1,3,224,224]" + shape_dict = tvmc.common.parse_shape_string(shape_string) + # Convert to strings to allow comparison with Any. + assert str(shape_dict) == "{'input': [?, 3, 224, 224]}" + + # Check that invalid pattern raises expected error. + shape_string = "input:[a,10]" + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) + # Check that input with invalid separators raises error. + shape_string = "input:5,10 input2:10,10" + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 4bbb6fbf2cf8..4cb342c2e967 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -39,14 +39,11 @@ def test_save_dumps(tmpdir_factory): # End to end tests for compilation -def test_compile_tflite_module(tflite_mobilenet_v1_1_quant): +def verify_compile_tflite_module(model, shape_dict=None): pytest.importorskip("tflite") graph, lib, params, dumps = tvmc.compiler.compile_model( - tflite_mobilenet_v1_1_quant, - target="llvm", - dump_code="ll", - alter_layout="NCHW", + model, target="llvm", dump_code="ll", alter_layout="NCHW", shape_dict=shape_dict ) # check for output types @@ -56,6 +53,17 @@ def test_compile_tflite_module(tflite_mobilenet_v1_1_quant): assert type(dumps) is dict +def test_compile_tflite_module(tflite_mobilenet_v1_1_quant): + # some CI environments wont offer tflite, so skip in case it is not present + pytest.importorskip("tflite") + # Check default compilation. + verify_compile_tflite_module(tflite_mobilenet_v1_1_quant) + # Check with manual shape override + shape_string = "input:[1,224,224,3]" + shape_dict = tvmc.common.parse_shape_string(shape_string) + verify_compile_tflite_module(tflite_mobilenet_v1_1_quant, shape_dict) + + # This test will be skipped if the AArch64 cross-compilation toolchain is not installed. @pytest.mark.skipif( not shutil.which("aarch64-linux-gnu-gcc"), reason="cross-compilation toolchain not installed" @@ -114,12 +122,12 @@ def test_cross_compile_aarch64_keras_module(keras_resnet50): assert "asm" in dumps.keys() -def test_compile_onnx_module(onnx_resnet50): +def verify_compile_onnx_module(model, shape_dict=None): # some CI environments wont offer onnx, so skip in case it is not present pytest.importorskip("onnx") graph, lib, params, dumps = tvmc.compiler.compile_model( - onnx_resnet50, target="llvm", dump_code="ll" + model, target="llvm", dump_code="ll", shape_dict=shape_dict ) # check for output types @@ -130,6 +138,15 @@ def test_compile_onnx_module(onnx_resnet50): assert "ll" in dumps.keys() +def test_compile_onnx_module(onnx_resnet50): + # Test default compilation + verify_compile_onnx_module(onnx_resnet50) + # Test with manual shape dict + shape_string = "data:[1,3,200,200]" + shape_dict = tvmc.common.parse_shape_string(shape_string) + verify_compile_onnx_module(onnx_resnet50, shape_dict) + + # This test will be skipped if the AArch64 cross-compilation toolchain is not installed. @pytest.mark.skipif( not shutil.which("aarch64-linux-gnu-gcc"), reason="cross-compilation toolchain not installed" diff --git a/tests/python/driver/tvmc/test_frontends.py b/tests/python/driver/tvmc/test_frontends.py index d77a17addabf..04c85b1eb8f3 100644 --- a/tests/python/driver/tvmc/test_frontends.py +++ b/tests/python/driver/tvmc/test_frontends.py @@ -174,9 +174,27 @@ def test_load_model___wrong_language__to_onnx(tflite_mobilenet_v1_1_quant): tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, model_format="onnx") +def test_load_model__pth(pytorch_resnet18): + # some CI environments wont offer torch, so skip in case it is not present + pytest.importorskip("torch") + pytest.importorskip("torchvision") + + mod, params = tvmc.frontends.load_model( + pytorch_resnet18, shape_dict={"input": [1, 3, 224, 224]} + ) + assert type(mod) is IRModule + assert type(params) is dict + # check whether one known value is part of the params dict + assert "layer1.0.conv1.weight" in params.keys() + + def test_load_model___wrong_language__to_pytorch(tflite_mobilenet_v1_1_quant): # some CI environments wont offer pytorch, so skip in case it is not present pytest.importorskip("torch") with pytest.raises(RuntimeError) as e: - tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, model_format="pytorch") + tvmc.frontends.load_model( + tflite_mobilenet_v1_1_quant, + model_format="pytorch", + shape_dict={"input": [1, 3, 224, 224]}, + ) diff --git a/tests/scripts/setup-pytest-env.sh b/tests/scripts/setup-pytest-env.sh index b77d3f37cd3e..5f108e9355fc 100755 --- a/tests/scripts/setup-pytest-env.sh +++ b/tests/scripts/setup-pytest-env.sh @@ -20,9 +20,9 @@ set +u if [[ ! -z $CI_PYTEST_ADD_OPTIONS ]]; then - export PYTEST_ADDOPTS="-v $CI_PYTEST_ADD_OPTIONS $PYTEST_ADDOPTS" + export PYTEST_ADDOPTS="-s -v $CI_PYTEST_ADD_OPTIONS $PYTEST_ADDOPTS" else - export PYTEST_ADDOPTS="-v $PYTEST_ADDOPTS" + export PYTEST_ADDOPTS="-s -v $PYTEST_ADDOPTS" fi set -u