-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[TVMC] Fix PyTorch support #7359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,8 @@ | |
| """ | ||
| import logging | ||
| import os.path | ||
| import re | ||
| import argparse | ||
|
|
||
| from urllib.parse import urlparse | ||
|
|
||
|
|
@@ -136,3 +138,40 @@ def tracker_host_port_from_cli(rpc_tracker_str): | |
| logger.info("RPC tracker port: %s", rpc_port) | ||
|
|
||
| return rpc_hostname, rpc_port | ||
|
|
||
|
|
||
| def parse_input_shapes(xs): | ||
| """Turn the string from --input-shape into a list. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| xs : str | ||
| The input shapes, in a form "(1,2,3),(1,4),..." | ||
|
|
||
| Returns | ||
| ------- | ||
| shapes : list | ||
| Input shapes as a list of lists | ||
| """ | ||
|
|
||
| shapes = [] | ||
| # Split up string into comma seperated sections ignoring commas in ()s | ||
| match = re.findall(r"(\(.*?\)|.+?),?", xs) | ||
| if match: | ||
| for inp in match: | ||
| # Test for and remove brackets | ||
| shape = re.match(r"\((.*)\)", inp) | ||
| if shape and shape.lastindex == 1: | ||
| # Remove white space and extract numbers | ||
| strshape = shape[1].replace(" ", "").split(",") | ||
|
Comment on lines
+165
to
+166
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be safer and easier to remove all spaces in |
||
| try: | ||
| shapes.append([int(i) for i in strshape]) | ||
| except ValueError: | ||
| raise argparse.ArgumentTypeError(f"expected numbers in shape '{shape[1]}'") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider the following two input shapes:
Accordingly, I guess your intention is
Either way is fine for me, and please update the help message and make sure you have a unit test to cover corner cases. |
||
| else: | ||
| raise argparse.ArgumentTypeError( | ||
| f"missing brackets around shape '{inp}', example '(1,2,3)'" | ||
| ) | ||
| else: | ||
| raise argparse.ArgumentTypeError(f"unrecognized shapes '{xs}', example '(1,2,3),(1,4),...'") | ||
| return shapes | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -59,6 +59,12 @@ def add_compile_parser(subparsers): | |
| default="", | ||
| help="comma separarated list of formats to export, e.g. 'asm,ll,relay' ", | ||
| ) | ||
| parser.add_argument( | ||
| "--input-shape", | ||
| type=common.parse_input_shapes, | ||
| metavar="INPUT_SHAPE,[INPUT_SHAPE]...", | ||
| help="for PyTorch, e.g. '(1,3,224,224)'", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe clarify that it is in fact mandatory for PyTorch.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree. It's confusing to see such a general option only for PyTorch. I would suggest the following changes:
|
||
| ) | ||
| parser.add_argument( | ||
| "--model-format", | ||
| choices=frontends.get_frontend_names(), | ||
|
|
@@ -108,6 +114,7 @@ def drive_compile(args): | |
| args.FILE, | ||
| args.target, | ||
| args.dump_code, | ||
| args.input_shape, | ||
| None, | ||
| args.model_format, | ||
| args.tuning_records, | ||
|
|
@@ -125,6 +132,7 @@ def compile_model( | |
| path, | ||
| target, | ||
| dump_code=None, | ||
| input_shape=None, | ||
| target_host=None, | ||
| model_format=None, | ||
| tuning_records=None, | ||
|
|
@@ -146,6 +154,8 @@ def compile_model( | |
| dump_code : list, optional | ||
| Dump the generated code for the specified source types, on | ||
| the requested target. | ||
| input_shape : list, optional | ||
| Shape of the input tensor for PyTorch models | ||
| target_host : str, optional | ||
| The target of the host machine if host-side code | ||
| needs to be generated. | ||
|
|
@@ -172,7 +182,7 @@ def compile_model( | |
|
|
||
| """ | ||
| dump_code = [x.strip() for x in dump_code.split(",")] if dump_code else None | ||
| mod, params = frontends.load_model(path, model_format) | ||
| mod, params = frontends.load_model(path, model_format, input_shape) | ||
|
|
||
| if alter_layout: | ||
| mod = common.convert_graph_layout(mod, alter_layout) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -54,13 +54,15 @@ def suffixes(): | |
| """File suffixes (extensions) used by this frontend""" | ||
|
|
||
| @abstractmethod | ||
| def load(self, path): | ||
| def load(self, path, input_shape): | ||
| """Load a model from a given path. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| path: str | ||
| Path to a file | ||
| input_shape: list | ||
| Shape of the input tensor | ||
|
|
||
| Returns | ||
| ------- | ||
|
|
@@ -99,10 +101,13 @@ def name(): | |
| def suffixes(): | ||
| return ["h5"] | ||
|
|
||
| def load(self, path): | ||
| def load(self, path, input_shape): | ||
| # pylint: disable=C0103 | ||
| tf, keras = import_keras() | ||
|
|
||
| if input_shape: | ||
| raise TVMCException("--input-shape is not supported for {}".format(self.name())) | ||
|
|
||
|
Comment on lines
+108
to
+110
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is definitely too ad hoc |
||
| # tvm build currently imports keras directly instead of tensorflow.keras | ||
| try: | ||
| model = keras.models.load_model(path) | ||
|
|
@@ -154,10 +159,13 @@ def name(): | |
| def suffixes(): | ||
| return ["onnx"] | ||
|
|
||
| def load(self, path): | ||
| def load(self, path, input_shape): | ||
| # pylint: disable=C0415 | ||
| import onnx | ||
|
|
||
| if input_shape: | ||
| raise TVMCException("--input-shape is not supported for {}".format(self.name())) | ||
|
|
||
| # pylint: disable=E1101 | ||
| model = onnx.load(path) | ||
|
|
||
|
|
@@ -175,11 +183,14 @@ def name(): | |
| def suffixes(): | ||
| return ["pb"] | ||
|
|
||
| def load(self, path): | ||
| def load(self, path, input_shape): | ||
| # pylint: disable=C0415 | ||
| import tensorflow as tf | ||
| import tvm.relay.testing.tf as tf_testing | ||
|
|
||
| if input_shape: | ||
| raise TVMCException("--input-shape is not supported for {}".format(self.name())) | ||
|
|
||
| with tf.io.gfile.GFile(path, "rb") as tf_graph: | ||
| content = tf_graph.read() | ||
|
|
||
|
|
@@ -215,10 +226,13 @@ def name(): | |
| def suffixes(): | ||
| return ["tflite"] | ||
|
|
||
| def load(self, path): | ||
| def load(self, path, input_shape): | ||
| # pylint: disable=C0415 | ||
| import tflite.Model as model | ||
|
|
||
| if input_shape: | ||
| raise TVMCException("--input-shape is not supported for {}".format(self.name())) | ||
|
|
||
| with open(path, "rb") as tf_graph: | ||
| content = tf_graph.read() | ||
|
|
||
|
|
@@ -285,17 +299,17 @@ def suffixes(): | |
| # Torch Script is a zip file, but can be named pth | ||
| return ["pth", "zip"] | ||
|
|
||
| def load(self, path): | ||
| def load(self, path, input_shape): | ||
| # pylint: disable=C0415 | ||
| import torch | ||
|
|
||
| traced_model = torch.jit.load(path) | ||
|
|
||
| inputs = list(traced_model.graph.inputs())[1:] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this approach not working at all? If it works for some cases, we should still use it first when
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I looked into this and I didn't find a way to extract inputs from the model after it has been saved and loaded. I asked on the PyTorch forum as well (https://discuss.pytorch.org/t/input-size-disappears-between-torch-jit-save-and-torch-jit-load/108955) and since I received a grand total of zero responses, I suspect it is a deliberate design decision. If there was a way, it would be good to keep it, of course, but in that form it doesn't work any more. |
||
| input_shapes = [inp.type().sizes() for inp in inputs] | ||
| if not input_shape: | ||
| raise TVMCException("--input-shape must be specified for {}".format(self.name())) | ||
|
|
||
| traced_model = torch.jit.load(path) | ||
| traced_model.eval() # Switch to inference mode | ||
| input_shapes = [("input{}".format(idx), shape) for idx, shape in enumerate(shapes)] | ||
|
|
||
| input_shapes = [("input{}".format(idx), shape) for idx, shape in enumerate(input_shape)] | ||
|
|
||
| logger.debug("parse Torch model and convert into Relay computation graph") | ||
| return relay.frontend.from_pytorch(traced_model, input_shapes) | ||
|
|
@@ -378,7 +392,7 @@ def guess_frontend(path): | |
| raise TVMCException("failed to infer the model format. Please specify --model-format") | ||
|
|
||
|
|
||
| def load_model(path, model_format=None): | ||
| def load_model(path, model_format=None, input_shape=None): | ||
| """Load a model from a supported framework and convert it | ||
| into an equivalent relay representation. | ||
|
|
||
|
|
@@ -389,6 +403,8 @@ def load_model(path, model_format=None): | |
| model_format : str, optional | ||
| The underlying framework used to create the model. | ||
| If not specified, this will be inferred from the file type. | ||
| input shape : list, optional | ||
| The shape of input tensor for PyTorch models | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto. make it general instead of only for PyTorch. |
||
|
|
||
| Returns | ||
| ------- | ||
|
|
@@ -404,6 +420,6 @@ def load_model(path, model_format=None): | |
| else: | ||
| frontend = guess_frontend(path) | ||
|
|
||
| mod, params = frontend.load(path) | ||
| mod, params = frontend.load(path, input_shape) | ||
|
|
||
| return mod, params | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to have an example here, that describes the input format and expected output format, similar to what you have on
test_parse_input_shapes__turn_into_list.