diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 84a5fc3b8237..a5eba9421f01 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -5629,6 +5629,7 @@ def _parse_graph_initializers(self, graph): def _parse_graph_input(self, graph): for i in graph.input: + i_name_compatible = None # from onnx v0.2, GraphProto.input has type ValueInfoProto, # and the name is 'i.name' i_name, i_shape, d_type, i_shape_name = get_info(i) @@ -5642,9 +5643,11 @@ def _parse_graph_input(self, graph): continue else: self._num_input += 1 - self._input_names.append(i_name) - if i_name in self._shape: - i_shape = self._shape[i_name] + # cleanup input name by replacing `:` in name with `_` + i_name_compatible = i_name.replace(":", "_") + self._input_names.append(i_name_compatible) + if i_name_compatible in self._shape: + i_shape = self._shape[i_name_compatible] else: if "?" in str(i_shape): warning_msg = ( @@ -5654,10 +5657,20 @@ def _parse_graph_input(self, graph): ) warnings.warn(warning_msg) if isinstance(self._dtype, dict): - dtype = self._dtype[i_name] if i_name in self._dtype else d_type + dtype = ( + self._dtype[i_name_compatible] + if i_name_compatible in self._dtype + else d_type + ) else: dtype = d_type - self._nodes[i_name] = new_var(i_name, shape=i_shape, dtype=dtype) + self._nodes[i_name_compatible] = new_var( + i_name_compatible, shape=i_shape, dtype=dtype + ) + + if i_name_compatible: + self._renames[i_name] = i_name_compatible + i_name = i_name_compatible self._inputs[i_name] = self._nodes[i_name] def _check_user_inputs_in_outermost_graph_scope(self):