From fd4b88418aca7af71d5ae3666d2434d9f01e7365 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 1 Apr 2021 00:43:49 +0000 Subject: [PATCH 1/3] Remove popping that interferes with nested loops. --- python/tvm/relay/frontend/onnx.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 624a61efee27..8222e2c968af 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2981,6 +2981,7 @@ def __init__(self, shape, dtype, freeze_params=False): self._num_input = 0 self._num_param = 0 self._shape = shape if shape else {} + self._input_names = [] self._dtype = dtype self.opset = None self._freeze_params = freeze_params @@ -3062,8 +3063,9 @@ def from_onnx(self, graph, opset, get_output_expr=False): continue else: self._num_input += 1 + self._input_names.append(i_name) if i_name in self._shape: - i_shape = self._shape.pop(i_name) + i_shape = self._shape[i_name] else: if "?" in str(i_shape): warning_msg = ( @@ -3078,8 +3080,8 @@ def from_onnx(self, graph, opset, get_output_expr=False): dtype = d_type self._nodes[i_name] = new_var(i_name, shape=i_shape, dtype=dtype) self._inputs[i_name] = self._nodes[i_name] - assert ( - len(self._shape) == 0 + assert all( + [name in self._input_names for name in self._shape.keys()] ), "User specified the shape for inputs that weren't found in the graph: " + str( self._shape ) From 289b6ca3ebc17612d762bb38c8f96bd52015c43b Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 1 Apr 2021 17:14:10 +0000 Subject: [PATCH 2/3] Only check user inputs in the outer-most graph scope. --- python/tvm/relay/frontend/onnx.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 8222e2c968af..307fda1c7056 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3080,11 +3080,13 @@ def from_onnx(self, graph, opset, get_output_expr=False): dtype = d_type self._nodes[i_name] = new_var(i_name, shape=i_shape, dtype=dtype) self._inputs[i_name] = self._nodes[i_name] - assert all( - [name in self._input_names for name in self._shape.keys()] - ), "User specified the shape for inputs that weren't found in the graph: " + str( - self._shape - ) + # Only check user inputs in the outer-most graph scope. + if self._old_manager == None: + assert all( + [name in self._input_names for name in self._shape.keys()] + ), "User specified the shape for inputs that weren't found in the graph: " + str( + self._shape + ) # get list of unsupported ops convert_map = _get_convert_map(opset) unsupported_ops = set() From a765b70943743221ba673be7bbcb4ba5ee2b3949 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 1 Apr 2021 19:25:41 +0000 Subject: [PATCH 3/3] Fix style. --- python/tvm/relay/frontend/onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 307fda1c7056..669eab8cc250 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3081,7 +3081,7 @@ def from_onnx(self, graph, opset, get_output_expr=False): self._nodes[i_name] = new_var(i_name, shape=i_shape, dtype=dtype) self._inputs[i_name] = self._nodes[i_name] # Only check user inputs in the outer-most graph scope. - if self._old_manager == None: + if self._old_manager is None: assert all( [name in self._input_names for name in self._shape.keys()] ), "User specified the shape for inputs that weren't found in the graph: " + str(