From 31cbf323ddb8a39a021c19c6910de8df2bfd0fb6 Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Tue, 26 Jan 2021 15:46:38 +0000 Subject: [PATCH] [TVMC] Fix PyTorch support A PyTorch model could not be compiled throgh tvmc because the shape of the input tensor could not be deduced from the model after it has been saved. We've added an --input-shape parameter to tvmc compile and tvmc tune that allows the inputs to be specified for PyTorch models. --- python/tvm/driver/tvmc/autotuner.py | 9 +++-- python/tvm/driver/tvmc/common.py | 39 ++++++++++++++++++++ python/tvm/driver/tvmc/compiler.py | 12 ++++++- python/tvm/driver/tvmc/frontends.py | 42 +++++++++++++++------- tests/python/driver/tvmc/test_common.py | 28 +++++++++++++++ tests/python/driver/tvmc/test_frontends.py | 22 +++++++++++- 6 files changed, 135 insertions(+), 17 deletions(-) diff --git a/python/tvm/driver/tvmc/autotuner.py b/python/tvm/driver/tvmc/autotuner.py index 71ccc8546e8b..2a47e41eca04 100644 --- a/python/tvm/driver/tvmc/autotuner.py +++ b/python/tvm/driver/tvmc/autotuner.py @@ -49,7 +49,12 @@ def add_tune_parser(subparsers): type=int, help="minimum number of trials before early stopping", ) - + parser.add_argument( + "--input-shape", + type=common.parse_input_shapes, + metavar="INPUT_SHAPE,[INPUT_SHAPE]...", + help="for PyTorch, e.g. '(1,3,224,224)'", + ) # There is some extra processing required to define the actual default value # for --min-repeat-ms. This is done in `drive_tune`. parser.add_argument( @@ -235,7 +240,7 @@ def drive_tune(args): ) target = common.target_from_cli(args.target) - mod, params = frontends.load_model(args.FILE, args.model_format) + mod, params = frontends.load_model(args.FILE, args.model_format, args.input_shape) # min_repeat_ms should be: # a. the value provided by the user, if any, or diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index 9db22f3f3390..7e357d68f703 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -19,6 +19,8 @@ """ import logging import os.path +import re +import argparse from urllib.parse import urlparse @@ -136,3 +138,40 @@ def tracker_host_port_from_cli(rpc_tracker_str): logger.info("RPC tracker port: %s", rpc_port) return rpc_hostname, rpc_port + + +def parse_input_shapes(xs): + """Turn the string from --input-shape into a list. + + Parameters + ---------- + xs : str + The input shapes, in a form "(1,2,3),(1,4),..." + + Returns + ------- + shapes : list + Input shapes as a list of lists + """ + + shapes = [] + # Split up string into comma seperated sections ignoring commas in ()s + match = re.findall(r"(\(.*?\)|.+?),?", xs) + if match: + for inp in match: + # Test for and remove brackets + shape = re.match(r"\((.*)\)", inp) + if shape and shape.lastindex == 1: + # Remove white space and extract numbers + strshape = shape[1].replace(" ", "").split(",") + try: + shapes.append([int(i) for i in strshape]) + except ValueError: + raise argparse.ArgumentTypeError(f"expected numbers in shape '{shape[1]}'") + else: + raise argparse.ArgumentTypeError( + f"missing brackets around shape '{inp}', example '(1,2,3)'" + ) + else: + raise argparse.ArgumentTypeError(f"unrecognized shapes '{xs}', example '(1,2,3),(1,4),...'") + return shapes diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 90b0aceaa17a..10f690194062 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -59,6 +59,12 @@ def add_compile_parser(subparsers): default="", help="comma separarated list of formats to export, e.g. 'asm,ll,relay' ", ) + parser.add_argument( + "--input-shape", + type=common.parse_input_shapes, + metavar="INPUT_SHAPE,[INPUT_SHAPE]...", + help="for PyTorch, e.g. '(1,3,224,224)'", + ) parser.add_argument( "--model-format", choices=frontends.get_frontend_names(), @@ -108,6 +114,7 @@ def drive_compile(args): args.FILE, args.target, args.dump_code, + args.input_shape, None, args.model_format, args.tuning_records, @@ -125,6 +132,7 @@ def compile_model( path, target, dump_code=None, + input_shape=None, target_host=None, model_format=None, tuning_records=None, @@ -146,6 +154,8 @@ def compile_model( dump_code : list, optional Dump the generated code for the specified source types, on the requested target. + input_shape : list, optional + Shape of the input tensor for PyTorch models target_host : str, optional The target of the host machine if host-side code needs to be generated. @@ -172,7 +182,7 @@ def compile_model( """ dump_code = [x.strip() for x in dump_code.split(",")] if dump_code else None - mod, params = frontends.load_model(path, model_format) + mod, params = frontends.load_model(path, model_format, input_shape) if alter_layout: mod = common.convert_graph_layout(mod, alter_layout) diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py index bb54b82cceca..fc4682010a12 100644 --- a/python/tvm/driver/tvmc/frontends.py +++ b/python/tvm/driver/tvmc/frontends.py @@ -54,13 +54,15 @@ def suffixes(): """File suffixes (extensions) used by this frontend""" @abstractmethod - def load(self, path): + def load(self, path, input_shape): """Load a model from a given path. Parameters ---------- path: str Path to a file + input_shape: list + Shape of the input tensor Returns ------- @@ -99,10 +101,13 @@ def name(): def suffixes(): return ["h5"] - def load(self, path): + def load(self, path, input_shape): # pylint: disable=C0103 tf, keras = import_keras() + if input_shape: + raise TVMCException("--input-shape is not supported for {}".format(self.name())) + # tvm build currently imports keras directly instead of tensorflow.keras try: model = keras.models.load_model(path) @@ -154,10 +159,13 @@ def name(): def suffixes(): return ["onnx"] - def load(self, path): + def load(self, path, input_shape): # pylint: disable=C0415 import onnx + if input_shape: + raise TVMCException("--input-shape is not supported for {}".format(self.name())) + # pylint: disable=E1101 model = onnx.load(path) @@ -175,11 +183,14 @@ def name(): def suffixes(): return ["pb"] - def load(self, path): + def load(self, path, input_shape): # pylint: disable=C0415 import tensorflow as tf import tvm.relay.testing.tf as tf_testing + if input_shape: + raise TVMCException("--input-shape is not supported for {}".format(self.name())) + with tf.io.gfile.GFile(path, "rb") as tf_graph: content = tf_graph.read() @@ -215,10 +226,13 @@ def name(): def suffixes(): return ["tflite"] - def load(self, path): + def load(self, path, input_shape): # pylint: disable=C0415 import tflite.Model as model + if input_shape: + raise TVMCException("--input-shape is not supported for {}".format(self.name())) + with open(path, "rb") as tf_graph: content = tf_graph.read() @@ -285,17 +299,17 @@ def suffixes(): # Torch Script is a zip file, but can be named pth return ["pth", "zip"] - def load(self, path): + def load(self, path, input_shape): # pylint: disable=C0415 import torch - traced_model = torch.jit.load(path) - - inputs = list(traced_model.graph.inputs())[1:] - input_shapes = [inp.type().sizes() for inp in inputs] + if not input_shape: + raise TVMCException("--input-shape must be specified for {}".format(self.name())) + traced_model = torch.jit.load(path) traced_model.eval() # Switch to inference mode - input_shapes = [("input{}".format(idx), shape) for idx, shape in enumerate(shapes)] + + input_shapes = [("input{}".format(idx), shape) for idx, shape in enumerate(input_shape)] logger.debug("parse Torch model and convert into Relay computation graph") return relay.frontend.from_pytorch(traced_model, input_shapes) @@ -378,7 +392,7 @@ def guess_frontend(path): raise TVMCException("failed to infer the model format. Please specify --model-format") -def load_model(path, model_format=None): +def load_model(path, model_format=None, input_shape=None): """Load a model from a supported framework and convert it into an equivalent relay representation. @@ -389,6 +403,8 @@ def load_model(path, model_format=None): model_format : str, optional The underlying framework used to create the model. If not specified, this will be inferred from the file type. + input shape : list, optional + The shape of input tensor for PyTorch models Returns ------- @@ -404,6 +420,6 @@ def load_model(path, model_format=None): else: frontend = guess_frontend(path) - mod, params = frontend.load(path) + mod, params = frontend.load(path, input_shape) return mod, params diff --git a/tests/python/driver/tvmc/test_common.py b/tests/python/driver/tvmc/test_common.py index 5ffbc6fe37dd..5a68ffc00d1b 100644 --- a/tests/python/driver/tvmc/test_common.py +++ b/tests/python/driver/tvmc/test_common.py @@ -149,3 +149,31 @@ def test_tracker_host_port_from_cli__only_hostname__default_port_is_9090(): assert expected_host == actual_host assert expected_port == actual_port + + +def test_parse_input_shapes__no_numbers(): + input_str = "(a,b,c,d)" + + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_input_shapes(input_str) + + +def test_parse_input_shapes__no_brackets(): + input_str = "1, 3, 224, 224" + + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_input_shapes(input_str) + + +def test_parse_input_shapes__bad_input(): + input_str = "[1, 3, 224, 224]" + + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_input_shapes(input_str) + + +def test_parse_input_shapes__turn_into_list(): + input_str = "(1, 3, 224, 224),(1,4)" + output_str = tvmc.common.parse_input_shapes(input_str) + + assert output_str == [[1, 3, 224, 224], [1, 4]] diff --git a/tests/python/driver/tvmc/test_frontends.py b/tests/python/driver/tvmc/test_frontends.py index d77a17addabf..920f1766c270 100644 --- a/tests/python/driver/tvmc/test_frontends.py +++ b/tests/python/driver/tvmc/test_frontends.py @@ -179,4 +179,24 @@ def test_load_model___wrong_language__to_pytorch(tflite_mobilenet_v1_1_quant): pytest.importorskip("torch") with pytest.raises(RuntimeError) as e: - tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, model_format="pytorch") + tvmc.frontends.load_model( + tflite_mobilenet_v1_1_quant, model_format="pytorch", input_shape=[[1, 3, 224, 224]] + ) + + +def test_load_model__pytorch__no_inputs(): + # some CI environments wont offer pytorch, so skip in case it is not present + pytest.importorskip("torch") + + with pytest.raises(TVMCException): + tvmc.frontends.load_model("a_model.pth", model_format="pytorch", input_shape=None) + + +def test_load_model__tflite__with_inputs(): + # some CI environments wont offer pytorch, so skip in case it is not present + pytest.importorskip("tflite") + + with pytest.raises(TVMCException): + tvmc.frontends.load_model( + "a_model.tflite", model_format="tflite", input_shape=[[1, 3, 224, 224]] + )