diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index c0848bb1092c..9a302da72ae6 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -1193,6 +1193,9 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): self._output_shapes[node.name] = \ [tensor_util.TensorShapeProtoToList( \ tensor_value.tensor_shape)] + elif shape and node.name in shape: + # Give priority to user argument. + self._output_shapes[node.name] = [shape[node.name]] elif '_output_shapes' in attr: self._output_shapes[node.name] = \ [tensor_util.TensorShapeProtoToList(tshape) \