Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions python/tvm/driver/tvmc/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@ def add_tune_parser(subparsers):
type=int,
help="minimum number of trials before early stopping",
)

parser.add_argument(
"--input-shape",
type=common.parse_input_shapes,
metavar="INPUT_SHAPE,[INPUT_SHAPE]...",
help="for PyTorch, e.g. '(1,3,224,224)'",
)
# There is some extra processing required to define the actual default value
# for --min-repeat-ms. This is done in `drive_tune`.
parser.add_argument(
Expand Down Expand Up @@ -235,7 +240,7 @@ def drive_tune(args):
)

target = common.target_from_cli(args.target)
mod, params = frontends.load_model(args.FILE, args.model_format)
mod, params = frontends.load_model(args.FILE, args.model_format, args.input_shape)

# min_repeat_ms should be:
# a. the value provided by the user, if any, or
Expand Down
39 changes: 39 additions & 0 deletions python/tvm/driver/tvmc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
"""
import logging
import os.path
import re
import argparse

from urllib.parse import urlparse

Expand Down Expand Up @@ -136,3 +138,40 @@ def tracker_host_port_from_cli(rpc_tracker_str):
logger.info("RPC tracker port: %s", rpc_port)

return rpc_hostname, rpc_port


def parse_input_shapes(xs):
"""Turn the string from --input-shape into a list.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to have an example here, that describes the input format and expected output format, similar to what you have on test_parse_input_shapes__turn_into_list.

Parameters
----------
xs : str
The input shapes, in a form "(1,2,3),(1,4),..."

Returns
-------
shapes : list
Input shapes as a list of lists
"""

shapes = []
# Split up string into comma seperated sections ignoring commas in ()s
match = re.findall(r"(\(.*?\)|.+?),?", xs)
if match:
for inp in match:
# Test for and remove brackets
shape = re.match(r"\((.*)\)", inp)
if shape and shape.lastindex == 1:
# Remove white space and extract numbers
strshape = shape[1].replace(" ", "").split(",")
Comment on lines +165 to +166
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be safer and easier to remove all spaces in xs in the beginning of this function.

try:
shapes.append([int(i) for i in strshape])
except ValueError:
raise argparse.ArgumentTypeError(f"expected numbers in shape '{shape[1]}'")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider the following two input shapes:

  • (8): shapes=[8]
  • (8,): Value error because strshape would be [8, ""].

Accordingly, I guess your intention is (8) instead of (8,). However, this is inconsistent with the Python syntax so it might confuse people. I have two proposals to deal with this:

  1. Use list syntax instead of tuple, so that the semantic is clear, and we can simply use JSON loader to deal with all variants (e.g., spaces):
    xs = "[1,3,224,224], [32]"
    shapes = json.loads(xs) # [[1,3,224,224],[32]]
  2. Follow Python syntax to only accept (8,) and throw an error for (8), which is treated as an integer instead of a tuple because buckets will be simplified in Python. In this case, I would suggest using eval to deal with all variants.
    xs = "(1,3,224,224), (32,)"
    shapes = eval(xs, {}, {}) # Remember to disable all local and global symbols to isolate this expression.
    # shapes=[(1,3,224,224),(32,)]

Either way is fine for me, and please update the help message and make sure you have a unit test to cover corner cases.

else:
raise argparse.ArgumentTypeError(
f"missing brackets around shape '{inp}', example '(1,2,3)'"
)
else:
raise argparse.ArgumentTypeError(f"unrecognized shapes '{xs}', example '(1,2,3),(1,4),...'")
return shapes
12 changes: 11 additions & 1 deletion python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ def add_compile_parser(subparsers):
default="",
help="comma separarated list of formats to export, e.g. 'asm,ll,relay' ",
)
parser.add_argument(
"--input-shape",
type=common.parse_input_shapes,
metavar="INPUT_SHAPE,[INPUT_SHAPE]...",
help="for PyTorch, e.g. '(1,3,224,224)'",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe clarify that it is in fact mandatory for PyTorch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. It's confusing to see such a general option only for PyTorch. I would suggest the following changes:

  1. Make --input-shape as a general option for all frontends. If present, we skip the input shape inference.
  2. --input-shape is optional by default. However, if users want to process a PyTorch model but don't specify --input-shape, we throw out an error in the PyTorch frontend.

)
parser.add_argument(
"--model-format",
choices=frontends.get_frontend_names(),
Expand Down Expand Up @@ -108,6 +114,7 @@ def drive_compile(args):
args.FILE,
args.target,
args.dump_code,
args.input_shape,
None,
args.model_format,
args.tuning_records,
Expand All @@ -125,6 +132,7 @@ def compile_model(
path,
target,
dump_code=None,
input_shape=None,
target_host=None,
model_format=None,
tuning_records=None,
Expand All @@ -146,6 +154,8 @@ def compile_model(
dump_code : list, optional
Dump the generated code for the specified source types, on
the requested target.
input_shape : list, optional
Shape of the input tensor for PyTorch models
target_host : str, optional
The target of the host machine if host-side code
needs to be generated.
Expand All @@ -172,7 +182,7 @@ def compile_model(

"""
dump_code = [x.strip() for x in dump_code.split(",")] if dump_code else None
mod, params = frontends.load_model(path, model_format)
mod, params = frontends.load_model(path, model_format, input_shape)

if alter_layout:
mod = common.convert_graph_layout(mod, alter_layout)
Expand Down
42 changes: 29 additions & 13 deletions python/tvm/driver/tvmc/frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,15 @@ def suffixes():
"""File suffixes (extensions) used by this frontend"""

@abstractmethod
def load(self, path):
def load(self, path, input_shape):
"""Load a model from a given path.

Parameters
----------
path: str
Path to a file
input_shape: list
Shape of the input tensor

Returns
-------
Expand Down Expand Up @@ -99,10 +101,13 @@ def name():
def suffixes():
return ["h5"]

def load(self, path):
def load(self, path, input_shape):
# pylint: disable=C0103
tf, keras = import_keras()

if input_shape:
raise TVMCException("--input-shape is not supported for {}".format(self.name()))

Comment on lines +108 to +110
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is definitely too ad hoc

# tvm build currently imports keras directly instead of tensorflow.keras
try:
model = keras.models.load_model(path)
Expand Down Expand Up @@ -154,10 +159,13 @@ def name():
def suffixes():
return ["onnx"]

def load(self, path):
def load(self, path, input_shape):
# pylint: disable=C0415
import onnx

if input_shape:
raise TVMCException("--input-shape is not supported for {}".format(self.name()))

# pylint: disable=E1101
model = onnx.load(path)

Expand All @@ -175,11 +183,14 @@ def name():
def suffixes():
return ["pb"]

def load(self, path):
def load(self, path, input_shape):
# pylint: disable=C0415
import tensorflow as tf
import tvm.relay.testing.tf as tf_testing

if input_shape:
raise TVMCException("--input-shape is not supported for {}".format(self.name()))

with tf.io.gfile.GFile(path, "rb") as tf_graph:
content = tf_graph.read()

Expand Down Expand Up @@ -215,10 +226,13 @@ def name():
def suffixes():
return ["tflite"]

def load(self, path):
def load(self, path, input_shape):
# pylint: disable=C0415
import tflite.Model as model

if input_shape:
raise TVMCException("--input-shape is not supported for {}".format(self.name()))

with open(path, "rb") as tf_graph:
content = tf_graph.read()

Expand Down Expand Up @@ -285,17 +299,17 @@ def suffixes():
# Torch Script is a zip file, but can be named pth
return ["pth", "zip"]

def load(self, path):
def load(self, path, input_shape):
# pylint: disable=C0415
import torch

traced_model = torch.jit.load(path)

inputs = list(traced_model.graph.inputs())[1:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this approach not working at all? If it works for some cases, we should still use it first when --input-shape is missing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked into this and I didn't find a way to extract inputs from the model after it has been saved and loaded. I asked on the PyTorch forum as well (https://discuss.pytorch.org/t/input-size-disappears-between-torch-jit-save-and-torch-jit-load/108955) and since I received a grand total of zero responses, I suspect it is a deliberate design decision. If there was a way, it would be good to keep it, of course, but in that form it doesn't work any more.

input_shapes = [inp.type().sizes() for inp in inputs]
if not input_shape:
raise TVMCException("--input-shape must be specified for {}".format(self.name()))

traced_model = torch.jit.load(path)
traced_model.eval() # Switch to inference mode
input_shapes = [("input{}".format(idx), shape) for idx, shape in enumerate(shapes)]

input_shapes = [("input{}".format(idx), shape) for idx, shape in enumerate(input_shape)]

logger.debug("parse Torch model and convert into Relay computation graph")
return relay.frontend.from_pytorch(traced_model, input_shapes)
Expand Down Expand Up @@ -378,7 +392,7 @@ def guess_frontend(path):
raise TVMCException("failed to infer the model format. Please specify --model-format")


def load_model(path, model_format=None):
def load_model(path, model_format=None, input_shape=None):
"""Load a model from a supported framework and convert it
into an equivalent relay representation.

Expand All @@ -389,6 +403,8 @@ def load_model(path, model_format=None):
model_format : str, optional
The underlying framework used to create the model.
If not specified, this will be inferred from the file type.
input shape : list, optional
The shape of input tensor for PyTorch models
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto. make it general instead of only for PyTorch.


Returns
-------
Expand All @@ -404,6 +420,6 @@ def load_model(path, model_format=None):
else:
frontend = guess_frontend(path)

mod, params = frontend.load(path)
mod, params = frontend.load(path, input_shape)

return mod, params
28 changes: 28 additions & 0 deletions tests/python/driver/tvmc/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,31 @@ def test_tracker_host_port_from_cli__only_hostname__default_port_is_9090():

assert expected_host == actual_host
assert expected_port == actual_port


def test_parse_input_shapes__no_numbers():
input_str = "(a,b,c,d)"

with pytest.raises(argparse.ArgumentTypeError):
tvmc.common.parse_input_shapes(input_str)


def test_parse_input_shapes__no_brackets():
input_str = "1, 3, 224, 224"

with pytest.raises(argparse.ArgumentTypeError):
tvmc.common.parse_input_shapes(input_str)


def test_parse_input_shapes__bad_input():
input_str = "[1, 3, 224, 224]"

with pytest.raises(argparse.ArgumentTypeError):
tvmc.common.parse_input_shapes(input_str)


def test_parse_input_shapes__turn_into_list():
input_str = "(1, 3, 224, 224),(1,4)"
output_str = tvmc.common.parse_input_shapes(input_str)

assert output_str == [[1, 3, 224, 224], [1, 4]]
22 changes: 21 additions & 1 deletion tests/python/driver/tvmc/test_frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,4 +179,24 @@ def test_load_model___wrong_language__to_pytorch(tflite_mobilenet_v1_1_quant):
pytest.importorskip("torch")

with pytest.raises(RuntimeError) as e:
tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, model_format="pytorch")
tvmc.frontends.load_model(
tflite_mobilenet_v1_1_quant, model_format="pytorch", input_shape=[[1, 3, 224, 224]]
)


def test_load_model__pytorch__no_inputs():
# some CI environments wont offer pytorch, so skip in case it is not present
pytest.importorskip("torch")

with pytest.raises(TVMCException):
tvmc.frontends.load_model("a_model.pth", model_format="pytorch", input_shape=None)


def test_load_model__tflite__with_inputs():
# some CI environments wont offer pytorch, so skip in case it is not present
pytest.importorskip("tflite")

with pytest.raises(TVMCException):
tvmc.frontends.load_model(
"a_model.tflite", model_format="tflite", input_shape=[[1, 3, 224, 224]]
)