From a081d8155c9871be10e13f8a98fa70462de24aec Mon Sep 17 00:00:00 2001 From: dlexplorer Date: Tue, 8 Dec 2020 15:20:30 +0300 Subject: [PATCH] Add default shape initialization if no explicit ones from user If no explicit shapes are passed into relay.frontend.from_onnx, we can use shapes defined in the onnx model itself --- python/tvm/driver/tvmc/frontends.py | 11 +---------- python/tvm/relay/frontend/onnx.py | 6 ++++++ 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py index 976159d0ff5b..bb54b82cceca 100644 --- a/python/tvm/driver/tvmc/frontends.py +++ b/python/tvm/driver/tvmc/frontends.py @@ -161,16 +161,7 @@ def load(self, path): # pylint: disable=E1101 model = onnx.load(path) - # pylint: disable=E1101 - name = model.graph.input[0].name - - # pylint: disable=E1101 - proto_shape = model.graph.input[0].type.tensor_type.shape.dim - shape = [d.dim_value for d in proto_shape] - - shape_dict = {name: shape} - - return relay.frontend.from_onnx(model, shape_dict) + return relay.frontend.from_onnx(model) class TensorflowFrontend(Frontend): diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index d65f5676fb33..377b739650ca 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2737,6 +2737,12 @@ def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=Fals warnings.warn(str(e)) except ImportError: pass + + # if no explicit input's shape came from user, then initialize shape as it is defined in onnx model + if shape is None: + shape = {} + for i in model.graph.input: + shape[i.name] = [dim.dim_value for dim in i.type.tensor_type.shape.dim] g = GraphProto(shape, dtype) graph = model.graph if opset is None: