diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py index 976159d0ff5b..bb54b82cceca 100644 --- a/python/tvm/driver/tvmc/frontends.py +++ b/python/tvm/driver/tvmc/frontends.py @@ -161,16 +161,7 @@ def load(self, path): # pylint: disable=E1101 model = onnx.load(path) - # pylint: disable=E1101 - name = model.graph.input[0].name - - # pylint: disable=E1101 - proto_shape = model.graph.input[0].type.tensor_type.shape.dim - shape = [d.dim_value for d in proto_shape] - - shape_dict = {name: shape} - - return relay.frontend.from_onnx(model, shape_dict) + return relay.frontend.from_onnx(model) class TensorflowFrontend(Frontend): diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index d65f5676fb33..377b739650ca 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2737,6 +2737,12 @@ def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=Fals warnings.warn(str(e)) except ImportError: pass + + # if no explicit input's shape came from user, then initialize shape as it is defined in onnx model + if shape is None: + shape = {} + for i in model.graph.input: + shape[i.name] = [dim.dim_value for dim in i.type.tensor_type.shape.dim] g = GraphProto(shape, dtype) graph = model.graph if opset is None: