From 42d19098905e1dab38ff928fcf2eaf4fc85362b8 Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Wed, 27 Jan 2021 16:23:48 -0500 Subject: [PATCH 01/13] add ability to optionally overide tvm shapes --- python/tvm/driver/tvmc/compiler.py | 28 +++++++++++++- python/tvm/driver/tvmc/frontends.py | 45 ++++++++++++++++------- tests/python/driver/tvmc/test_compiler.py | 26 +++++++++++-- 3 files changed, 80 insertions(+), 19 deletions(-) diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 90b0aceaa17a..bbc84373533d 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -35,6 +35,21 @@ # pylint: disable=invalid-name logger = logging.getLogger("TVMC") +#turn data:32x3x224x224 to {'data':[32,3,224,224]} +def parse_shape(inputs): + d = {} #final dictionary + inputs = inputs.split(",") #multiple data inputs + for string in inputs: + string = string.split(":") #seperate name from ints + shapelist = [] + string[1] = string[1].split("x") #make list + for x in string[1]: + x = int(x) #make int list + if x < 0: #negative ints dynamic + x = relay.Any() + shapelist.append(x) + d[string[0]] = shapelist + return d @register_parser def add_compile_parser(subparsers): @@ -87,6 +102,12 @@ def add_compile_parser(subparsers): # can be improved in future to add integration with a modelzoo # or URL, for example. parser.add_argument("FILE", help="path to the input model file") + parser.add_argument( + "--shapes", + help="", + type=parse_shape, + default=None, + ) def drive_compile(args): @@ -112,6 +133,7 @@ def drive_compile(args): args.model_format, args.tuning_records, args.desired_layout, + args.shapes, ) if dumps: @@ -129,6 +151,7 @@ def compile_model( model_format=None, tuning_records=None, alter_layout=None, + shape_dict=None ): """Compile a model from a supported framework into a TVM module. @@ -158,6 +181,9 @@ def compile_model( The layout to convert the graph to. Note, the convert layout pass doesn't currently guarantee the whole of the graph will be converted to the chosen layout. + shape_dict: dict, optional + A mapping between input names and their shape. This is useful + to override the default values in a model if needed. Returns ------- @@ -172,7 +198,7 @@ def compile_model( """ dump_code = [x.strip() for x in dump_code.split(",")] if dump_code else None - mod, params = frontends.load_model(path, model_format) + mod, params = frontends.load_model(path, model_format, shape_dict) if alter_layout: mod = common.convert_graph_layout(mod, alter_layout) diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py index bb54b82cceca..9369ad771476 100644 --- a/python/tvm/driver/tvmc/frontends.py +++ b/python/tvm/driver/tvmc/frontends.py @@ -54,13 +54,15 @@ def suffixes(): """File suffixes (extensions) used by this frontend""" @abstractmethod - def load(self, path): + def load(self, path, shape_dict=None): """Load a model from a given path. Parameters ---------- path: str Path to a file + shape_dict: dict, optional + A dictionary mapping input names to shapes. Returns ------- @@ -99,7 +101,7 @@ def name(): def suffixes(): return ["h5"] - def load(self, path): + def load(self, path, shape_dict=None): # pylint: disable=C0103 tf, keras = import_keras() @@ -125,8 +127,10 @@ def load(self, path): ) inputs = [np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in in_shapes] - shape_dict = {name: x.shape for (name, x) in zip(model.input_names, inputs)} - return relay.frontend.from_keras(model, shape_dict, layout="NHWC") + input_shapes = {name: x.shape for (name, x) in zip(model.input_names, inputs)} + if shape_dict is not None: + input_shapes.update(shape_dict) + return relay.frontend.from_keras(model, input_shapes, layout="NHWC") def is_sequential_p(self, model): _, keras = import_keras() @@ -154,14 +158,14 @@ def name(): def suffixes(): return ["onnx"] - def load(self, path): + def load(self, path, shape_dict = None): # pylint: disable=C0415 import onnx # pylint: disable=E1101 model = onnx.load(path) - return relay.frontend.from_onnx(model) + return relay.frontend.from_onnx(model, shape = shape_dict) class TensorflowFrontend(Frontend): @@ -175,7 +179,7 @@ def name(): def suffixes(): return ["pb"] - def load(self, path): + def load(self, path, shape_dict=None): # pylint: disable=C0415 import tensorflow as tf import tvm.relay.testing.tf as tf_testing @@ -188,7 +192,7 @@ def load(self, path): graph_def = tf_testing.ProcessGraphDefParam(graph_def) logger.debug("parse TensorFlow model and convert into Relay computation graph") - return relay.frontend.from_tensorflow(graph_def) + return relay.frontend.from_tensorflow(graph_def, shape=shape_dict) class TFLiteFrontend(Frontend): @@ -215,7 +219,7 @@ def name(): def suffixes(): return ["tflite"] - def load(self, path): + def load(self, path, shape_dict=None): # pylint: disable=C0415 import tflite.Model as model @@ -238,11 +242,13 @@ def load(self, path): raise TVMCException("input file not tflite version 3") logger.debug("tflite_input_type") - shape_dict, dtype_dict = TFLiteFrontend._input_type(tflite_model) + input_shapes, dtype_dict = TFLiteFrontend._input_type(tflite_model) + if shape_dict is not None: + input_shapes.update(shape_dict) logger.debug("parse TFLite model and convert into Relay computation graph") mod, params = relay.frontend.from_tflite( - tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict + tflite_model, shape_dict=input_shapes, dtype_dict=dtype_dict ) return mod, params @@ -285,7 +291,7 @@ def suffixes(): # Torch Script is a zip file, but can be named pth return ["pth", "zip"] - def load(self, path): + def load(self, path, shape_dict=None): # pylint: disable=C0415 import torch @@ -297,6 +303,15 @@ def load(self, path): traced_model.eval() # Switch to inference mode input_shapes = [("input{}".format(idx), shape) for idx, shape in enumerate(shapes)] + # Update input shapes with manual override and prevent duplication. + if shape_dict is not None: + input_shape_dict = {} + for name, shape in input_shapes: + input_shape_dict[name] = shape + input_shape_dict.update(shape_dict) + # Convert back to list for torch importer. + input_shapes = list(input_shape_dict.items()) + logger.debug("parse Torch model and convert into Relay computation graph") return relay.frontend.from_pytorch(traced_model, input_shapes) @@ -378,7 +393,7 @@ def guess_frontend(path): raise TVMCException("failed to infer the model format. Please specify --model-format") -def load_model(path, model_format=None): +def load_model(path, model_format=None, shape_dict = None): """Load a model from a supported framework and convert it into an equivalent relay representation. @@ -389,6 +404,8 @@ def load_model(path, model_format=None): model_format : str, optional The underlying framework used to create the model. If not specified, this will be inferred from the file type. + shape_dict : dict, optional + A mapping between input names and their desired shape. Returns ------- @@ -404,6 +421,6 @@ def load_model(path, model_format=None): else: frontend = guess_frontend(path) - mod, params = frontend.load(path) + mod, params = frontend.load(path, shape_dict) return mod, params diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 4bbb6fbf2cf8..bf88a6616858 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -39,14 +39,15 @@ def test_save_dumps(tmpdir_factory): # End to end tests for compilation -def test_compile_tflite_module(tflite_mobilenet_v1_1_quant): +def verify_compile_tflite_module(model, shape_dict=None): pytest.importorskip("tflite") graph, lib, params, dumps = tvmc.compiler.compile_model( - tflite_mobilenet_v1_1_quant, + model, target="llvm", dump_code="ll", alter_layout="NCHW", + shape_dict=shape_dict ) # check for output types @@ -56,6 +57,14 @@ def test_compile_tflite_module(tflite_mobilenet_v1_1_quant): assert type(dumps) is dict +def test_compile_tflite_module(tflite_mobilenet_v1_1_quant): + # Check default compilation. + verify_compile_tflite_module(tflite_mobilenet_v1_1_quant) + # Check with manual shape override + shape_string = "input:1x224x224x3" + shape_dict = tvmc.compiler.parse_shape(shape_string) + verify_compile_onnx_module(tflite_mobilenet_v1_1_quant, shape_dict) + # This test will be skipped if the AArch64 cross-compilation toolchain is not installed. @pytest.mark.skipif( not shutil.which("aarch64-linux-gnu-gcc"), reason="cross-compilation toolchain not installed" @@ -114,12 +123,12 @@ def test_cross_compile_aarch64_keras_module(keras_resnet50): assert "asm" in dumps.keys() -def test_compile_onnx_module(onnx_resnet50): +def verify_compile_onnx_module(model, shape_dict=None): # some CI environments wont offer onnx, so skip in case it is not present pytest.importorskip("onnx") graph, lib, params, dumps = tvmc.compiler.compile_model( - onnx_resnet50, target="llvm", dump_code="ll" + model, target="llvm", dump_code="ll", shape_dict=shape_dict ) # check for output types @@ -130,6 +139,15 @@ def test_compile_onnx_module(onnx_resnet50): assert "ll" in dumps.keys() +def test_compile_onnx_module(onnx_resnet50): + # Test default compilation + verify_compile_onnx_module(onnx_resnet50) + # Test with manual shape dict + shape_string = "data:1x3x200x200" + shape_dict = tvmc.compiler.parse_shape(shape_string) + verify_compile_onnx_module(onnx_resnet50, shape_dict) + + # This test will be skipped if the AArch64 cross-compilation toolchain is not installed. @pytest.mark.skipif( not shutil.which("aarch64-linux-gnu-gcc"), reason="cross-compilation toolchain not installed" From d547ac12ddd179b57312ca856c157ee070de52e0 Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Thu, 28 Jan 2021 14:41:57 -0500 Subject: [PATCH 02/13] add help documentation for --shapes --- python/tvm/driver/tvmc/compiler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index bbc84373533d..724bdc74815a 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -104,7 +104,8 @@ def add_compile_parser(subparsers): parser.add_argument("FILE", help="path to the input model file") parser.add_argument( "--shapes", - help="", + help="specify non-generic shapes for model to run, format is" + "name:num1xnum2xnum3,name2:num1xnum2xnum3", type=parse_shape, default=None, ) From 1ec3b48faeddb71edf00e4187e8073a9464deeb1 Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Thu, 28 Jan 2021 16:59:05 -0500 Subject: [PATCH 03/13] improve documentation --- python/tvm/driver/tvmc/compiler.py | 40 +++++++++++++++++++++-------- python/tvm/driver/tvmc/frontends.py | 6 ++--- 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 724bdc74815a..f50f6e8e9ba8 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -35,22 +35,42 @@ # pylint: disable=invalid-name logger = logging.getLogger("TVMC") -#turn data:32x3x224x224 to {'data':[32,3,224,224]} + def parse_shape(inputs): - d = {} #final dictionary - inputs = inputs.split(",") #multiple data inputs + """Parse an input shape dictionary string to a usable dictionary. + + Parameters + ---------- + inputs: str + A string of the form "name:num1xnum2x...xnumN,name2:num1xnum2xnum3" that indicates + the desired shape for specific model inputs. + + Returns + ------- + shape_dict: dict + A dictionary mapping input names to their shape for use in relay frontend converters. + """ + d = {} + # Break apart each specific input string + inputs = inputs.split(",") for string in inputs: - string = string.split(":") #seperate name from ints + # Split name from shape string. + string = string.split(":") shapelist = [] - string[1] = string[1].split("x") #make list + # Separate each dimension in the shape. + string[1] = string[1].split("x") + # Parse each dimension into an integer. for x in string[1]: - x = int(x) #make int list - if x < 0: #negative ints dynamic + x = int(x) + # Negative numbers are converted to dynamic axes. + if x < 0: x = relay.Any() shapelist.append(x) + # Assign dictionary key value pair. d[string[0]] = shapelist return d + @register_parser def add_compile_parser(subparsers): """ Include parser for 'compile' subcommand """ @@ -105,7 +125,7 @@ def add_compile_parser(subparsers): parser.add_argument( "--shapes", help="specify non-generic shapes for model to run, format is" - "name:num1xnum2xnum3,name2:num1xnum2xnum3", + "name:num1xnum2x...xnumN,name2:num1xnum2xnum3", type=parse_shape, default=None, ) @@ -120,7 +140,7 @@ def drive_compile(args): Arguments from command line parser. Returns - -------- + ------- int Zero if successfully completed @@ -152,7 +172,7 @@ def compile_model( model_format=None, tuning_records=None, alter_layout=None, - shape_dict=None + shape_dict=None, ): """Compile a model from a supported framework into a TVM module. diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py index 9369ad771476..66843622ea2b 100644 --- a/python/tvm/driver/tvmc/frontends.py +++ b/python/tvm/driver/tvmc/frontends.py @@ -158,14 +158,14 @@ def name(): def suffixes(): return ["onnx"] - def load(self, path, shape_dict = None): + def load(self, path, shape_dict=None): # pylint: disable=C0415 import onnx # pylint: disable=E1101 model = onnx.load(path) - return relay.frontend.from_onnx(model, shape = shape_dict) + return relay.frontend.from_onnx(model, shape=shape_dict) class TensorflowFrontend(Frontend): @@ -393,7 +393,7 @@ def guess_frontend(path): raise TVMCException("failed to infer the model format. Please specify --model-format") -def load_model(path, model_format=None, shape_dict = None): +def load_model(path, model_format=None, shape_dict=None): """Load a model from a supported framework and convert it into an equivalent relay representation. From d35e1fad8bb86e1256ae2f7901ca837b9976096c Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Thu, 28 Jan 2021 20:18:05 -0500 Subject: [PATCH 04/13] reformat test_compiler using black --- tests/python/driver/tvmc/test_compiler.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index bf88a6616858..d353acf23342 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -43,11 +43,7 @@ def verify_compile_tflite_module(model, shape_dict=None): pytest.importorskip("tflite") graph, lib, params, dumps = tvmc.compiler.compile_model( - model, - target="llvm", - dump_code="ll", - alter_layout="NCHW", - shape_dict=shape_dict + model, target="llvm", dump_code="ll", alter_layout="NCHW", shape_dict=shape_dict ) # check for output types @@ -65,6 +61,7 @@ def test_compile_tflite_module(tflite_mobilenet_v1_1_quant): shape_dict = tvmc.compiler.parse_shape(shape_string) verify_compile_onnx_module(tflite_mobilenet_v1_1_quant, shape_dict) + # This test will be skipped if the AArch64 cross-compilation toolchain is not installed. @pytest.mark.skipif( not shutil.which("aarch64-linux-gnu-gcc"), reason="cross-compilation toolchain not installed" From d1f74d55bd5fbd1b63a59ba989759beb9d5c5758 Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Fri, 29 Jan 2021 19:14:03 -0500 Subject: [PATCH 05/13] Incorporate feedback from ekalda for better pytorch support and testing. --- python/tvm/driver/tvmc/autotuner.py | 9 ++++- python/tvm/driver/tvmc/common.py | 45 ++++++++++++++++++++++ python/tvm/driver/tvmc/compiler.py | 37 +----------------- python/tvm/driver/tvmc/frontends.py | 18 +++------ tests/python/driver/tvmc/conftest.py | 17 ++++++++ tests/python/driver/tvmc/test_common.py | 24 ++++++++++++ tests/python/driver/tvmc/test_compiler.py | 6 +-- tests/python/driver/tvmc/test_frontends.py | 19 ++++++++- 8 files changed, 121 insertions(+), 54 deletions(-) diff --git a/python/tvm/driver/tvmc/autotuner.py b/python/tvm/driver/tvmc/autotuner.py index 71ccc8546e8b..e824b33cc897 100644 --- a/python/tvm/driver/tvmc/autotuner.py +++ b/python/tvm/driver/tvmc/autotuner.py @@ -210,6 +210,13 @@ def add_tune_parser(subparsers): # can be improved in future to add integration with a modelzoo # or URL, for example. parser.add_argument("FILE", help="path to the input model file") + parser.add_argument( + "--shapes", + help="specify non-generic shapes for model to run, format is" + "name:num1xnum2x...xnumN,name2:num1xnum2xnum3", + type=common.parse_shape_string, + default=None, + ) def drive_tune(args): @@ -235,7 +242,7 @@ def drive_tune(args): ) target = common.target_from_cli(args.target) - mod, params = frontends.load_model(args.FILE, args.model_format) + mod, params = frontends.load_model(args.FILE, args.model_format, shape_dict=args.shapes) # min_repeat_ms should be: # a. the value provided by the user, if any, or diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index 9db22f3f3390..d9f61ac74a70 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -17,8 +17,10 @@ """ Common utility functions shared by TVMC modules. """ +import re import logging import os.path +import argparse from urllib.parse import urlparse @@ -136,3 +138,46 @@ def tracker_host_port_from_cli(rpc_tracker_str): logger.info("RPC tracker port: %s", rpc_port) return rpc_hostname, rpc_port + + +def parse_shape_string(inputs): + """Parse an input shape dictionary string to a usable dictionary. + + Parameters + ---------- + inputs: str + A string of the form "name:num1xnum2x...xnumN,name2:num1xnum2xnum3" that indicates + the desired shape for specific model inputs. + + Returns + ------- + shape_dict: dict + A dictionary mapping input names to their shape for use in relay frontend converters. + """ + inputs = inputs.replace(" ", "") + # Check if the passed input is in the proper format. + valid_pattern = re.compile("(\w+:(\d+(x|X))*(\d)+)(,(\w+:(\d+(x|X))*(\d)+))*") + result = re.fullmatch(valid_pattern, inputs) + if result is None: + raise argparse.ArgumentTypeError( + "--shapes argument must be of the form 'input_name:dim1xdim2x...xdimN,input_name2:dim1xdim2" + ) + d = {} + # Break apart each specific input string + inputs = inputs.split(",") + for string in inputs: + # Split name from shape string. + string = string.split(":") + shapelist = [] + # Separate each dimension in the shape. + string[1] = string[1].lower().split("x") + # Parse each dimension into an integer. + for x in string[1]: + x = int(x) + # Negative numbers are converted to dynamic axes. + if x < 0: + x = relay.Any() + shapelist.append(x) + # Assign dictionary key value pair. + d[string[0]] = shapelist + return d diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index f50f6e8e9ba8..6c9360d2edd3 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -36,41 +36,6 @@ logger = logging.getLogger("TVMC") -def parse_shape(inputs): - """Parse an input shape dictionary string to a usable dictionary. - - Parameters - ---------- - inputs: str - A string of the form "name:num1xnum2x...xnumN,name2:num1xnum2xnum3" that indicates - the desired shape for specific model inputs. - - Returns - ------- - shape_dict: dict - A dictionary mapping input names to their shape for use in relay frontend converters. - """ - d = {} - # Break apart each specific input string - inputs = inputs.split(",") - for string in inputs: - # Split name from shape string. - string = string.split(":") - shapelist = [] - # Separate each dimension in the shape. - string[1] = string[1].split("x") - # Parse each dimension into an integer. - for x in string[1]: - x = int(x) - # Negative numbers are converted to dynamic axes. - if x < 0: - x = relay.Any() - shapelist.append(x) - # Assign dictionary key value pair. - d[string[0]] = shapelist - return d - - @register_parser def add_compile_parser(subparsers): """ Include parser for 'compile' subcommand """ @@ -126,7 +91,7 @@ def add_compile_parser(subparsers): "--shapes", help="specify non-generic shapes for model to run, format is" "name:num1xnum2x...xnumN,name2:num1xnum2xnum3", - type=parse_shape, + type=common.parse_shape_string, default=None, ) diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py index 66843622ea2b..930a21c99d1e 100644 --- a/python/tvm/driver/tvmc/frontends.py +++ b/python/tvm/driver/tvmc/frontends.py @@ -296,21 +296,13 @@ def load(self, path, shape_dict=None): import torch traced_model = torch.jit.load(path) - - inputs = list(traced_model.graph.inputs())[1:] - input_shapes = [inp.type().sizes() for inp in inputs] - traced_model.eval() # Switch to inference mode - input_shapes = [("input{}".format(idx), shape) for idx, shape in enumerate(shapes)] - # Update input shapes with manual override and prevent duplication. - if shape_dict is not None: - input_shape_dict = {} - for name, shape in input_shapes: - input_shape_dict[name] = shape - input_shape_dict.update(shape_dict) - # Convert back to list for torch importer. - input_shapes = list(input_shape_dict.items()) + if shape_dict is None: + raise TVMCException("--shapes must be specified for {}".format(self.name())) + + # Convert shape dictionary to list for Pytorch frontend compatibility + input_shapes = list(shape_dict.items()) logger.debug("parse Torch model and convert into Relay computation graph") return relay.frontend.from_pytorch(traced_model, input_shapes) diff --git a/tests/python/driver/tvmc/conftest.py b/tests/python/driver/tvmc/conftest.py index 882d793ccebd..534953deecbc 100644 --- a/tests/python/driver/tvmc/conftest.py +++ b/tests/python/driver/tvmc/conftest.py @@ -99,6 +99,23 @@ def keras_resnet50(tmpdir_factory): return model_file_name +@pytest.fixture(scope="session") +def pytorch_resnet18(tmpdir_factory): + try: + import torch + import torchvision.models as models + except ImportError: + # Not all environments provide Pytorch, so skip if that's the case. + return "" + model = models.resnet18() + model_file_name = "{}/{}".format(tmpdir_factory.mktemp("data"), "resnet18.pth") + # Trace model into torchscript. + traced_cpu = torch.jit.trace(model, torch.randn(1, 3, 224, 224)) + torch.jit.save(traced_cpu, model_file_name) + + return model_file_name + + @pytest.fixture(scope="session") def onnx_resnet50(): base_url = "https://github.com/onnx/models/raw/master/vision/classification/resnet/model" diff --git a/tests/python/driver/tvmc/test_common.py b/tests/python/driver/tvmc/test_common.py index 5ffbc6fe37dd..e43b3ece05c4 100644 --- a/tests/python/driver/tvmc/test_common.py +++ b/tests/python/driver/tvmc/test_common.py @@ -149,3 +149,27 @@ def test_tracker_host_port_from_cli__only_hostname__default_port_is_9090(): assert expected_host == actual_host assert expected_port == actual_port + + +def test_shape_parser(): + # Check that a valid input is parsed correctly + shape_string = "input:10x10x10" + shape_dict = tvmc.common.parse_shape_string(shape_string) + assert shape_dict == {"input": [10, 10, 10]} + # Check that multiple valid input shapes are parse correctly + shape_string = "input:10x10x10,input2:20x20x20x20" + shape_dict = tvmc.common.parse_shape_string(shape_string) + assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]} + # Check that alternate syntax parses correctly + shape_string = "input:10X10X10, input2:20X20X20X20" + shape_dict = tvmc.common.parse_shape_string(shape_string) + assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]} + + # Check that invalid pattern raises expected error. + shape_string = "input:ax10" + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) + # Check that input with invalid separators raises error. + shape_string = "input:5,10 input2:10,10" + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index d353acf23342..d47c6bd33e40 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -58,8 +58,8 @@ def test_compile_tflite_module(tflite_mobilenet_v1_1_quant): verify_compile_tflite_module(tflite_mobilenet_v1_1_quant) # Check with manual shape override shape_string = "input:1x224x224x3" - shape_dict = tvmc.compiler.parse_shape(shape_string) - verify_compile_onnx_module(tflite_mobilenet_v1_1_quant, shape_dict) + shape_dict = tvmc.common.parse_shape_string(shape_string) + verify_compile_tflite_module(tflite_mobilenet_v1_1_quant, shape_dict) # This test will be skipped if the AArch64 cross-compilation toolchain is not installed. @@ -141,7 +141,7 @@ def test_compile_onnx_module(onnx_resnet50): verify_compile_onnx_module(onnx_resnet50) # Test with manual shape dict shape_string = "data:1x3x200x200" - shape_dict = tvmc.compiler.parse_shape(shape_string) + shape_dict = tvmc.common.parse_shape_string(shape_string) verify_compile_onnx_module(onnx_resnet50, shape_dict) diff --git a/tests/python/driver/tvmc/test_frontends.py b/tests/python/driver/tvmc/test_frontends.py index d77a17addabf..54f6f0cd50db 100644 --- a/tests/python/driver/tvmc/test_frontends.py +++ b/tests/python/driver/tvmc/test_frontends.py @@ -174,9 +174,26 @@ def test_load_model___wrong_language__to_onnx(tflite_mobilenet_v1_1_quant): tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, model_format="onnx") +def test_load_model__pth(pytorch_resnet18): + # some CI environments wont offer torch, so skip in case it is not present + pytest.importorskip("torch") + + mod, params = tvmc.frontends.load_model( + pytorch_resnet18, shape_dict={"input": [1, 3, 224, 224]} + ) + assert type(mod) is IRModule + assert type(params) is dict + # check whether one known value is part of the params dict + assert "layer1.0.conv1.weight" in params.keys() + + def test_load_model___wrong_language__to_pytorch(tflite_mobilenet_v1_1_quant): # some CI environments wont offer pytorch, so skip in case it is not present pytest.importorskip("torch") with pytest.raises(RuntimeError) as e: - tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, model_format="pytorch") + tvmc.frontends.load_model( + tflite_mobilenet_v1_1_quant, + model_format="pytorch", + shape_dict={"input": [1, 3, 224, 224]}, + ) From 7fdf9d32b8a252b7e24a280cc9147e1567a14f26 Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Mon, 1 Feb 2021 15:21:56 -0500 Subject: [PATCH 06/13] address feedback --- python/tvm/driver/tvmc/autotuner.py | 2 +- python/tvm/driver/tvmc/common.py | 38 ++++++++++++----------- python/tvm/driver/tvmc/compiler.py | 6 ++-- python/tvm/driver/tvmc/frontends.py | 6 ++-- tests/python/driver/tvmc/test_common.py | 6 ++++ tests/python/driver/tvmc/test_compiler.py | 2 ++ 6 files changed, 35 insertions(+), 25 deletions(-) diff --git a/python/tvm/driver/tvmc/autotuner.py b/python/tvm/driver/tvmc/autotuner.py index e824b33cc897..83feb1fb1d76 100644 --- a/python/tvm/driver/tvmc/autotuner.py +++ b/python/tvm/driver/tvmc/autotuner.py @@ -211,7 +211,7 @@ def add_tune_parser(subparsers): # or URL, for example. parser.add_argument("FILE", help="path to the input model file") parser.add_argument( - "--shapes", + "--input-shapes", help="specify non-generic shapes for model to run, format is" "name:num1xnum2x...xnumN,name2:num1xnum2xnum3", type=common.parse_shape_string, diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index d9f61ac74a70..180be1df545b 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -140,12 +140,12 @@ def tracker_host_port_from_cli(rpc_tracker_str): return rpc_hostname, rpc_port -def parse_shape_string(inputs): +def parse_shape_string(inputs_string): """Parse an input shape dictionary string to a usable dictionary. Parameters ---------- - inputs: str + inputs_string: str A string of the form "name:num1xnum2x...xnumN,name2:num1xnum2xnum3" that indicates the desired shape for specific model inputs. @@ -154,30 +154,32 @@ def parse_shape_string(inputs): shape_dict: dict A dictionary mapping input names to their shape for use in relay frontend converters. """ - inputs = inputs.replace(" ", "") + # Simplify passed input string by removing spaces. + inputs_string = inputs_string.replace(" ", "") # Check if the passed input is in the proper format. - valid_pattern = re.compile("(\w+:(\d+(x|X))*(\d)+)(,(\w+:(\d+(x|X))*(\d)+))*") - result = re.fullmatch(valid_pattern, inputs) + valid_pattern = re.compile( + r"(\w+:(\-{0,1}\d+(x|X))*\-{0,1}(\d)+)(,(\w+:(\-{0,1}\d+(x|X))*\-{0,1}(\d)+))*" + ) + result = re.fullmatch(valid_pattern, inputs_string) if result is None: raise argparse.ArgumentTypeError( - "--shapes argument must be of the form 'input_name:dim1xdim2x...xdimN,input_name2:dim1xdim2" + "--input-shapes argument must be of the form " + "input_name:dim1xdim2x...xdimN,input_name2:dim1xdim2" ) - d = {} + shape_dict = {} # Break apart each specific input string - inputs = inputs.split(",") - for string in inputs: + inputs_list = inputs_string.split(",") + for shape_mapping in inputs_list: # Split name from shape string. - string = string.split(":") - shapelist = [] + input_name, input_shape_string = shape_mapping.split(":") # Separate each dimension in the shape. - string[1] = string[1].lower().split("x") + input_shape_chars = input_shape_string.lower().split("x") # Parse each dimension into an integer. - for x in string[1]: + input_shape = [] + for x in input_shape_chars: x = int(x) # Negative numbers are converted to dynamic axes. - if x < 0: - x = relay.Any() - shapelist.append(x) + input_shape.append(x if x >= 0 else relay.Any()) # Assign dictionary key value pair. - d[string[0]] = shapelist - return d + shape_dict[input_name] = input_shape + return shape_dict diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 6c9360d2edd3..879e6788afc3 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -88,7 +88,7 @@ def add_compile_parser(subparsers): # or URL, for example. parser.add_argument("FILE", help="path to the input model file") parser.add_argument( - "--shapes", + "--input-shapes", help="specify non-generic shapes for model to run, format is" "name:num1xnum2x...xnumN,name2:num1xnum2xnum3", type=common.parse_shape_string, @@ -168,8 +168,8 @@ def compile_model( pass doesn't currently guarantee the whole of the graph will be converted to the chosen layout. shape_dict: dict, optional - A mapping between input names and their shape. This is useful - to override the default values in a model if needed. + A mapping from input names to their shape. When present, + the default shapes in the model will be overwritten. Returns ------- diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py index 930a21c99d1e..9a3ed13028fe 100644 --- a/python/tvm/driver/tvmc/frontends.py +++ b/python/tvm/driver/tvmc/frontends.py @@ -62,7 +62,7 @@ def load(self, path, shape_dict=None): path: str Path to a file shape_dict: dict, optional - A dictionary mapping input names to shapes. + Mapping from input names to their shapes. Returns ------- @@ -299,7 +299,7 @@ def load(self, path, shape_dict=None): traced_model.eval() # Switch to inference mode if shape_dict is None: - raise TVMCException("--shapes must be specified for {}".format(self.name())) + raise TVMCException("--input-shapes must be specified for %s" % self.name()) # Convert shape dictionary to list for Pytorch frontend compatibility input_shapes = list(shape_dict.items()) @@ -397,7 +397,7 @@ def load_model(path, model_format=None, shape_dict=None): The underlying framework used to create the model. If not specified, this will be inferred from the file type. shape_dict : dict, optional - A mapping between input names and their desired shape. + Mapping from input names to their shapes. Returns ------- diff --git a/tests/python/driver/tvmc/test_common.py b/tests/python/driver/tvmc/test_common.py index e43b3ece05c4..7dc221cb161a 100644 --- a/tests/python/driver/tvmc/test_common.py +++ b/tests/python/driver/tvmc/test_common.py @@ -21,6 +21,7 @@ import pytest import tvm +from tvm import relay from tvm.driver import tvmc @@ -164,6 +165,11 @@ def test_shape_parser(): shape_string = "input:10X10X10, input2:20X20X20X20" shape_dict = tvmc.common.parse_shape_string(shape_string) assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]} + # Check that negative dimensions parse to Any correctly. + shape_string = "input:-1x3x224x224" + shape_dict = tvmc.common.parse_shape_string(shape_string) + # Convert to strings to allow comparison with Any. + assert str(shape_dict) == "{'input': [?, 3, 224, 224]}" # Check that invalid pattern raises expected error. shape_string = "input:ax10" diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index d47c6bd33e40..a08da4dfeab6 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -54,6 +54,8 @@ def verify_compile_tflite_module(model, shape_dict=None): def test_compile_tflite_module(tflite_mobilenet_v1_1_quant): + # some CI environments wont offer flute, so skip in case it is not present + pytest.importorskip("tflite") # Check default compilation. verify_compile_tflite_module(tflite_mobilenet_v1_1_quant) # Check with manual shape override From 150b6fb5fc4356f30a6aa03eb69ce89e48d48796 Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Mon, 1 Feb 2021 16:43:27 -0500 Subject: [PATCH 07/13] switch input shape syntax to be more pythonic --- python/tvm/driver/tvmc/autotuner.py | 6 ++-- python/tvm/driver/tvmc/common.py | 40 +++++++++-------------- python/tvm/driver/tvmc/compiler.py | 6 ++-- tests/python/driver/tvmc/test_common.py | 13 +++++--- tests/python/driver/tvmc/test_compiler.py | 6 ++-- 5 files changed, 33 insertions(+), 38 deletions(-) diff --git a/python/tvm/driver/tvmc/autotuner.py b/python/tvm/driver/tvmc/autotuner.py index 83feb1fb1d76..fe5bebcabcbc 100644 --- a/python/tvm/driver/tvmc/autotuner.py +++ b/python/tvm/driver/tvmc/autotuner.py @@ -212,8 +212,8 @@ def add_tune_parser(subparsers): parser.add_argument("FILE", help="path to the input model file") parser.add_argument( "--input-shapes", - help="specify non-generic shapes for model to run, format is" - "name:num1xnum2x...xnumN,name2:num1xnum2xnum3", + help="specify non-generic shapes for model to run, format is " + '"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]"', type=common.parse_shape_string, default=None, ) @@ -242,7 +242,7 @@ def drive_tune(args): ) target = common.target_from_cli(args.target) - mod, params = frontends.load_model(args.FILE, args.model_format, shape_dict=args.shapes) + mod, params = frontends.load_model(args.FILE, args.model_format, shape_dict=args.input_shapes) # min_repeat_ms should be: # a. the value provided by the user, if any, or diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index 180be1df545b..f00350cc9f1c 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -154,32 +154,24 @@ def parse_shape_string(inputs_string): shape_dict: dict A dictionary mapping input names to their shape for use in relay frontend converters. """ - # Simplify passed input string by removing spaces. - inputs_string = inputs_string.replace(" ", "") - # Check if the passed input is in the proper format. - valid_pattern = re.compile( - r"(\w+:(\-{0,1}\d+(x|X))*\-{0,1}(\d)+)(,(\w+:(\-{0,1}\d+(x|X))*\-{0,1}(\d)+))*" - ) - result = re.fullmatch(valid_pattern, inputs_string) - if result is None: + + # Create a regex pattern that extracts each separate input mapping. + pattern = r"\w+\:\s*\[\-?\d+(?:\,\s*\-?\d+)*\]" + input_mappings = re.findall(pattern, inputs_string) + if not input_mappings: raise argparse.ArgumentTypeError( "--input-shapes argument must be of the form " - "input_name:dim1xdim2x...xdimN,input_name2:dim1xdim2" + '"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]"' ) shape_dict = {} - # Break apart each specific input string - inputs_list = inputs_string.split(",") - for shape_mapping in inputs_list: - # Split name from shape string. - input_name, input_shape_string = shape_mapping.split(":") - # Separate each dimension in the shape. - input_shape_chars = input_shape_string.lower().split("x") - # Parse each dimension into an integer. - input_shape = [] - for x in input_shape_chars: - x = int(x) - # Negative numbers are converted to dynamic axes. - input_shape.append(x if x >= 0 else relay.Any()) - # Assign dictionary key value pair. - shape_dict[input_name] = input_shape + for mapping in input_mappings: + # Remove whitespace. + mapping = mapping.replace(" ", "") + # Split mapping into name and shape. + name, shape_string = mapping.split(":") + # Convert shape string into a list of integers or Anys if negative. + shape = [int(x) if int(x) > 0 else relay.Any() for x in shape_string.strip("][").split(",")] + # Add parsed mapping to shape dictionary. + shape_dict[name] = shape + return shape_dict diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 879e6788afc3..282ae6a76b56 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -89,8 +89,8 @@ def add_compile_parser(subparsers): parser.add_argument("FILE", help="path to the input model file") parser.add_argument( "--input-shapes", - help="specify non-generic shapes for model to run, format is" - "name:num1xnum2x...xnumN,name2:num1xnum2xnum3", + help="specify non-generic shapes for model to run, format is " + '"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]"', type=common.parse_shape_string, default=None, ) @@ -119,7 +119,7 @@ def drive_compile(args): args.model_format, args.tuning_records, args.desired_layout, - args.shapes, + args.input_shapes, ) if dumps: diff --git a/tests/python/driver/tvmc/test_common.py b/tests/python/driver/tvmc/test_common.py index 7dc221cb161a..f30949b54497 100644 --- a/tests/python/driver/tvmc/test_common.py +++ b/tests/python/driver/tvmc/test_common.py @@ -154,25 +154,28 @@ def test_tracker_host_port_from_cli__only_hostname__default_port_is_9090(): def test_shape_parser(): # Check that a valid input is parsed correctly - shape_string = "input:10x10x10" + shape_string = "input:[10,10,10]" shape_dict = tvmc.common.parse_shape_string(shape_string) assert shape_dict == {"input": [10, 10, 10]} # Check that multiple valid input shapes are parse correctly - shape_string = "input:10x10x10,input2:20x20x20x20" + shape_string = "input:[10,10,10] input2:[20,20,20,20]" shape_dict = tvmc.common.parse_shape_string(shape_string) assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]} # Check that alternate syntax parses correctly - shape_string = "input:10X10X10, input2:20X20X20X20" + shape_string = "input: [10, 10, 10] input2: [20, 20, 20, 20]" + shape_dict = tvmc.common.parse_shape_string(shape_string) + assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]} + shape_string = "input:[10,10,10],input2:[20,20,20,20]" shape_dict = tvmc.common.parse_shape_string(shape_string) assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]} # Check that negative dimensions parse to Any correctly. - shape_string = "input:-1x3x224x224" + shape_string = "input:[-1,3,224,224]" shape_dict = tvmc.common.parse_shape_string(shape_string) # Convert to strings to allow comparison with Any. assert str(shape_dict) == "{'input': [?, 3, 224, 224]}" # Check that invalid pattern raises expected error. - shape_string = "input:ax10" + shape_string = "input:[a,10]" with pytest.raises(argparse.ArgumentTypeError): tvmc.common.parse_shape_string(shape_string) # Check that input with invalid separators raises error. diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index a08da4dfeab6..4cb342c2e967 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -54,12 +54,12 @@ def verify_compile_tflite_module(model, shape_dict=None): def test_compile_tflite_module(tflite_mobilenet_v1_1_quant): - # some CI environments wont offer flute, so skip in case it is not present + # some CI environments wont offer tflite, so skip in case it is not present pytest.importorskip("tflite") # Check default compilation. verify_compile_tflite_module(tflite_mobilenet_v1_1_quant) # Check with manual shape override - shape_string = "input:1x224x224x3" + shape_string = "input:[1,224,224,3]" shape_dict = tvmc.common.parse_shape_string(shape_string) verify_compile_tflite_module(tflite_mobilenet_v1_1_quant, shape_dict) @@ -142,7 +142,7 @@ def test_compile_onnx_module(onnx_resnet50): # Test default compilation verify_compile_onnx_module(onnx_resnet50) # Test with manual shape dict - shape_string = "data:1x3x200x200" + shape_string = "data:[1,3,200,200]" shape_dict = tvmc.common.parse_shape_string(shape_string) verify_compile_onnx_module(onnx_resnet50, shape_dict) From 636dacfe8f3bd3bdc38305727bb26d08742f8251 Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Tue, 2 Feb 2021 15:45:55 -0500 Subject: [PATCH 08/13] add commentary --- python/tvm/driver/tvmc/common.py | 8 ++++---- python/tvm/driver/tvmc/frontends.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index f00350cc9f1c..52b3ba3ea9a9 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -146,7 +146,7 @@ def parse_shape_string(inputs_string): Parameters ---------- inputs_string: str - A string of the form "name:num1xnum2x...xnumN,name2:num1xnum2xnum3" that indicates + A string of the form "input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]" that indicates the desired shape for specific model inputs. Returns @@ -161,16 +161,16 @@ def parse_shape_string(inputs_string): if not input_mappings: raise argparse.ArgumentTypeError( "--input-shapes argument must be of the form " - '"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]"' + "\"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]\"" ) shape_dict = {} for mapping in input_mappings: # Remove whitespace. mapping = mapping.replace(" ", "") # Split mapping into name and shape. - name, shape_string = mapping.split(":") + name, shape_string = mapping.split(':') # Convert shape string into a list of integers or Anys if negative. - shape = [int(x) if int(x) > 0 else relay.Any() for x in shape_string.strip("][").split(",")] + shape = [int(x) if int(x) > 0 else relay.Any() for x in shape_string.strip('][').split(',')] # Add parsed mapping to shape dictionary. shape_dict[name] = shape diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py index 9a3ed13028fe..53fbed66c8fc 100644 --- a/python/tvm/driver/tvmc/frontends.py +++ b/python/tvm/driver/tvmc/frontends.py @@ -295,12 +295,12 @@ def load(self, path, shape_dict=None): # pylint: disable=C0415 import torch - traced_model = torch.jit.load(path) - traced_model.eval() # Switch to inference mode - if shape_dict is None: raise TVMCException("--input-shapes must be specified for %s" % self.name()) + traced_model = torch.jit.load(path) + traced_model.eval() # Switch to inference mode + # Convert shape dictionary to list for Pytorch frontend compatibility input_shapes = list(shape_dict.items()) From 2ff50ee4da742e5e724ac761bbe5ccf638270a73 Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Tue, 2 Feb 2021 18:18:06 -0500 Subject: [PATCH 09/13] reformat common.py --- python/tvm/driver/tvmc/common.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index 52b3ba3ea9a9..1f29298110f4 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -161,16 +161,16 @@ def parse_shape_string(inputs_string): if not input_mappings: raise argparse.ArgumentTypeError( "--input-shapes argument must be of the form " - "\"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]\"" + '"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]"' ) shape_dict = {} for mapping in input_mappings: # Remove whitespace. mapping = mapping.replace(" ", "") # Split mapping into name and shape. - name, shape_string = mapping.split(':') + name, shape_string = mapping.split(":") # Convert shape string into a list of integers or Anys if negative. - shape = [int(x) if int(x) > 0 else relay.Any() for x in shape_string.strip('][').split(',')] + shape = [int(x) if int(x) > 0 else relay.Any() for x in shape_string.strip("][").split(",")] # Add parsed mapping to shape dictionary. shape_dict[name] = shape From 2829753d08e1685f3832b6e93e73437f58fa871b Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Tue, 2 Feb 2021 19:13:19 -0500 Subject: [PATCH 10/13] fix lint issue --- python/tvm/driver/tvmc/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index 1f29298110f4..b6b670618572 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -146,8 +146,8 @@ def parse_shape_string(inputs_string): Parameters ---------- inputs_string: str - A string of the form "input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]" that indicates - the desired shape for specific model inputs. + A string of the form "input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]" that + indicates the desired shape for specific model inputs. Returns ------- From d358593850c03984209403b92bbe28543ed4586a Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Wed, 3 Feb 2021 15:30:37 -0500 Subject: [PATCH 11/13] format common.py with black --- python/tvm/driver/tvmc/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index b6b670618572..1845915bcbd1 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -146,7 +146,7 @@ def parse_shape_string(inputs_string): Parameters ---------- inputs_string: str - A string of the form "input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]" that + A string of the form "input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]" that indicates the desired shape for specific model inputs. Returns From b218fce020671e9119e4b39aa8d54d3578f47b3b Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Fri, 5 Feb 2021 18:08:20 -0500 Subject: [PATCH 12/13] torch/pytorch test hiccup --- tests/python/driver/tvmc/test_frontends.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/driver/tvmc/test_frontends.py b/tests/python/driver/tvmc/test_frontends.py index 54f6f0cd50db..04c85b1eb8f3 100644 --- a/tests/python/driver/tvmc/test_frontends.py +++ b/tests/python/driver/tvmc/test_frontends.py @@ -177,6 +177,7 @@ def test_load_model___wrong_language__to_onnx(tflite_mobilenet_v1_1_quant): def test_load_model__pth(pytorch_resnet18): # some CI environments wont offer torch, so skip in case it is not present pytest.importorskip("torch") + pytest.importorskip("torchvision") mod, params = tvmc.frontends.load_model( pytorch_resnet18, shape_dict={"input": [1, 3, 224, 224]} From ac2dbd6dfd29a81c44e6db081b535d10e7036b29 Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Mon, 8 Feb 2021 19:06:35 -0500 Subject: [PATCH 13/13] add -s to setup-pytest-env.sh for clearer error msgs --- tests/scripts/setup-pytest-env.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/scripts/setup-pytest-env.sh b/tests/scripts/setup-pytest-env.sh index b77d3f37cd3e..5f108e9355fc 100755 --- a/tests/scripts/setup-pytest-env.sh +++ b/tests/scripts/setup-pytest-env.sh @@ -20,9 +20,9 @@ set +u if [[ ! -z $CI_PYTEST_ADD_OPTIONS ]]; then - export PYTEST_ADDOPTS="-v $CI_PYTEST_ADD_OPTIONS $PYTEST_ADDOPTS" + export PYTEST_ADDOPTS="-s -v $CI_PYTEST_ADD_OPTIONS $PYTEST_ADDOPTS" else - export PYTEST_ADDOPTS="-v $PYTEST_ADDOPTS" + export PYTEST_ADDOPTS="-s -v $PYTEST_ADDOPTS" fi set -u