From 2b30b176e16647a3f4e32a43f4afea472b175c62 Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Fri, 7 Oct 2022 19:21:29 +0000 Subject: [PATCH] replace `:` in input name --- python/tvm/relay/frontend/onnx.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 84a5fc3b8237..a5eba9421f01 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -5629,6 +5629,7 @@ def _parse_graph_initializers(self, graph): def _parse_graph_input(self, graph): for i in graph.input: + i_name_compatible = None # from onnx v0.2, GraphProto.input has type ValueInfoProto, # and the name is 'i.name' i_name, i_shape, d_type, i_shape_name = get_info(i) @@ -5642,9 +5643,11 @@ def _parse_graph_input(self, graph): continue else: self._num_input += 1 - self._input_names.append(i_name) - if i_name in self._shape: - i_shape = self._shape[i_name] + # cleanup input name by replacing `:` in name with `_` + i_name_compatible = i_name.replace(":", "_") + self._input_names.append(i_name_compatible) + if i_name_compatible in self._shape: + i_shape = self._shape[i_name_compatible] else: if "?" in str(i_shape): warning_msg = ( @@ -5654,10 +5657,20 @@ def _parse_graph_input(self, graph): ) warnings.warn(warning_msg) if isinstance(self._dtype, dict): - dtype = self._dtype[i_name] if i_name in self._dtype else d_type + dtype = ( + self._dtype[i_name_compatible] + if i_name_compatible in self._dtype + else d_type + ) else: dtype = d_type - self._nodes[i_name] = new_var(i_name, shape=i_shape, dtype=dtype) + self._nodes[i_name_compatible] = new_var( + i_name_compatible, shape=i_shape, dtype=dtype + ) + + if i_name_compatible: + self._renames[i_name] = i_name_compatible + i_name = i_name_compatible self._inputs[i_name] = self._nodes[i_name] def _check_user_inputs_in_outermost_graph_scope(self):