From f25afe4f9ec5401affc0ccc6e911ca2905eafad8 Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Tue, 16 Feb 2021 15:18:29 -0700 Subject: [PATCH 01/17] WIP --- tests/python/frontend/onnx/test_forward.py | 43 +++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 5a6216ac705d..a180a7f1983c 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -49,7 +49,7 @@ def get_tvm_output_with_vm( if not isinstance(input_data, list): input_data = [input_data] _, shape_dict = get_input_data_shape_dict(graph_def, input_data) - + print(shape_dict) mod, params = relay.frontend.from_onnx( graph_def, shape_dict, opset=opset, freeze_params=freeze_params ) @@ -4089,6 +4089,47 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): verify_cumsum(data, 1, 0, 1, type="int32") verify_cumsum(data, 1, 1, 1, type="int32") +@tvm.testing.uses_gpu +def test_onnx_nodes(): + from onnx import numpy_helper + f = onnx.__file__ + import glob + tests = sorted(glob.glob("/".join(f.split("/")[0:-1]) + "/backend/test/data/node/*/")) + print(len(tests)) + failures = 0 + for n, test in enumerate(tests): + print(n, test) + try: + onnx_model = onnx.load(test + "/model.onnx") + inputs = [] + outputs = [] + for dataset in glob.glob(test + "/*/"): + tensors = sorted(glob.glob(dataset + "/*.pb")) + for tensor in tensors: + new_tensor = onnx.TensorProto() + with open(tensor, 'rb') as f: + new_tensor.ParseFromString(f.read()) + if "input" in tensor: + inputs.append(numpy_helper.to_array(new_tensor)) + if "output" in tensor: + outputs.append(numpy_helper.to_array(new_tensor)) + tvm_val = get_tvm_output_with_vm(onnx_model, inputs, "llvm", tvm.cpu(0)) + if len(outputs) == 1: + tvm.testing.assert_allclose(outputs[0], tvm_val, rtol=1e-5, atol=1e-5) + else: + for output, val in zip(outputs, tvm_val): + tvm.testing.assert_allclose(output, val, rtol=1e-5, atol=1e-5) + except Exception as e: + if "frontend ONNX" not in str(e): + print("--------------------") + print("Test Number", n) + print("Failure number", failures) + print(test) + print(inputs) + print(e) + failures += 1 + + raise def test_wrong_input(): node = helper.make_node( From bf37601166ccc08cf8208e11b449dca2c51dd103 Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Wed, 17 Feb 2021 15:09:50 -0700 Subject: [PATCH 02/17] some fixes --- python/tvm/relay/frontend/onnx.py | 23 ++++++++++---- tests/python/frontend/onnx/test_forward.py | 37 ++++++++-------------- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index fab4ae889dd7..03dfdb4d0240 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -40,6 +40,8 @@ __all__ = ["from_onnx"] +class ONNXAttrError(Exception): + pass class onnx_input: """ Dual purpose list or dictionary access object.""" @@ -106,7 +108,8 @@ def get_type(elem_type): from onnx import TensorProto except ImportError as e: raise ImportError("Unable to import onnx which is required {}".format(e)) - return TensorProto.DataType.Name(elem_type).lower() + from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE + return str(TENSOR_TYPE_TO_NP_TYPE[elem_type]) def get_info(info_proto): @@ -157,7 +160,7 @@ def revert_caffe2_pad(pads): return pads -def get_pad_pair(input1d, kernel1d, stride1d): +def get_pad_pair(input1d, kernel1d, stride1d, mode): """infer pad size""" if input1d % stride1d == 0: pad = max(kernel1d - stride1d, 0) @@ -165,6 +168,8 @@ def get_pad_pair(input1d, kernel1d, stride1d): pad = max(kernel1d - (input1d % stride1d), 0) pad_before = pad // 2 pad_after = pad - pad_before + if "LOWER" in mode: + return [pad_after, pad_before] return [pad_before, pad_after] @@ -280,9 +285,9 @@ def _impl_v1(cls, inputs, attr, params): pad_tuple = [] for axis in range(len(input_shape) - 2): axis_shape = input_shape[2 + axis] - stride = attr["strides"][axis] + stride = attr.get("strides", [1] * ndim)[axis] kernel = attr["kernel_shape"][axis] - pad = get_pad_pair(axis_shape, kernel, stride) + pad = get_pad_pair(axis_shape, kernel, stride, attr["auto_pad"]) pad_tuple.append(pad) pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair]) attr["pads"] = pad_tuple @@ -1121,7 +1126,6 @@ def _impl_v1(cls, inputs, attr, params): def _impl_v5(cls, inputs, attr, params): try: from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE - attr["to"] = str(TENSOR_TYPE_TO_NP_TYPE[attr["to"]]) except ImportError as e: raise ImportError("Unable to import onnx.mapping which is required {}".format(e)) @@ -1485,6 +1489,8 @@ class ArgMax(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): + if "select_last_index" in attr: + raise ONNXAttrError("select_last_index not supported in ArgMax") axis = attr.get("axis", 0) keepdims = attr.get("keepdims", True) attr = {"axis": axis, "keepdims": keepdims} @@ -1496,6 +1502,8 @@ class ArgMin(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): + if "select_last_index" in attr: + raise ONNXAttrError("select_last_index not supported in ArgMax") axis = attr.get("axis", 0) keepdims = attr.get("keepdims", True) attr = {"axis": axis, "keepdims": keepdims} @@ -2128,7 +2136,8 @@ def _impl_v11(cls, inputs, attr, params): result = inputs[0] for i, op in enumerate([_op.tensor.maximum, _op.tensor.minimum]): if i < len(inputs) - 1: - result = op(result, inputs[i + 1]) + if inputs[i + 1] is not None: + result = op(result, inputs[i + 1]) return result @@ -2958,6 +2967,8 @@ def from_onnx(self, graph, opset, get_output_expr=False): for i in node.input: if i != "": inputs[i] = self._nodes[self._renames.get(i, i)] + else: + inputs[i] = None i_name = self._parse_value_proto(node) node_output = self._fix_outputs(op_name, node.output) attr["tvm_custom"] = {} diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index a180a7f1983c..ea6ad752d259 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -17,12 +17,14 @@ import numpy as np import onnx from onnx import helper, TensorProto, mapping, numpy_helper +from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE import torch import torchvision import pytest import tvm.topi.testing import tvm from tvm import relay +from tvm.relay.frontend.onnx import ONNXAttrError from tvm.contrib import graph_runtime import scipy import tvm.testing @@ -49,14 +51,11 @@ def get_tvm_output_with_vm( if not isinstance(input_data, list): input_data = [input_data] _, shape_dict = get_input_data_shape_dict(graph_def, input_data) - print(shape_dict) mod, params = relay.frontend.from_onnx( graph_def, shape_dict, opset=opset, freeze_params=freeze_params ) - if convert_to_static: mod = relay.transform.DynamicToStatic()(mod) - ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target) result = ex.evaluate()(*input_data, **params) if isinstance(result, tvm.runtime.NDArray): @@ -3479,13 +3478,7 @@ def verify_topk(input_dims, K, axis=-1): @tvm.testing.uses_gpu def test_roi_align(): def verify_roi_align( - input_dims, - num_roi, - output_height, - output_width, - sampling_ratio=0, - spatial_scale=1.0, - mode="avg", + input_dims, num_roi, output_height, output_width, sampling_ratio=0, spatial_scale=1.0 ): output_dims = [num_roi, input_dims[1], output_height, output_width] @@ -3493,7 +3486,7 @@ def verify_roi_align( "RoiAlign", inputs=["X", "rois", "batch_indicies"], outputs=["Y"], - mode=mode, + mode="avg", output_height=output_height, output_width=output_width, sampling_ratio=sampling_ratio, @@ -3538,8 +3531,6 @@ def verify_roi_align( verify_roi_align((5, 4, 16, 14), 32, 7, 7, sampling_ratio=1, spatial_scale=1.0) verify_roi_align((1, 4, 16, 16), 32, 7, 7, sampling_ratio=2, spatial_scale=1.0) - # ONNX implementation of roi_align with max mode is incorrect, so we don't compare outputs here. - # @tvm.testing.uses_gpu def test_non_max_suppression(): @@ -4095,10 +4086,12 @@ def test_onnx_nodes(): f = onnx.__file__ import glob tests = sorted(glob.glob("/".join(f.split("/")[0:-1]) + "/backend/test/data/node/*/")) - print(len(tests)) failures = 0 for n, test in enumerate(tests): print(n, test) + if ("cast" in test and "FLOAT16" in test) or "test_slice_start_out_of_bounds" in test: + print("FAILURE: SKIPPING due to segfault") + continue try: onnx_model = onnx.load(test + "/model.onnx") inputs = [] @@ -4111,8 +4104,11 @@ def test_onnx_nodes(): new_tensor.ParseFromString(f.read()) if "input" in tensor: inputs.append(numpy_helper.to_array(new_tensor)) - if "output" in tensor: + elif "output" in tensor: outputs.append(numpy_helper.to_array(new_tensor)) + else: + print(tensor) + raise tvm_val = get_tvm_output_with_vm(onnx_model, inputs, "llvm", tvm.cpu(0)) if len(outputs) == 1: tvm.testing.assert_allclose(outputs[0], tvm_val, rtol=1e-5, atol=1e-5) @@ -4120,15 +4116,8 @@ def test_onnx_nodes(): for output, val in zip(outputs, tvm_val): tvm.testing.assert_allclose(output, val, rtol=1e-5, atol=1e-5) except Exception as e: - if "frontend ONNX" not in str(e): - print("--------------------") - print("Test Number", n) - print("Failure number", failures) - print(test) - print(inputs) - print(e) - failures += 1 - + print("------------------TEST FAILURE--------------------") + print(e) raise def test_wrong_input(): From 9f3dad3a05031be3062b9127efd970fb419752bd Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Thu, 18 Feb 2021 14:13:32 -0700 Subject: [PATCH 03/17] more fixes --- python/tvm/relay/frontend/onnx.py | 14 ++++++-------- tests/python/frontend/onnx/test_forward.py | 9 ++++++--- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 03dfdb4d0240..f74cef53134f 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -533,13 +533,11 @@ def _impl_v1(cls, inputs, attr, params): if not transB: inputs[1] = _op.transpose(inputs[1], axes=(1, 0)) inputs[0] = _op.nn.batch_flatten(inputs[0]) - if alpha != 1.0: inputs[0] *= _expr.const(alpha) out = _op.nn.dense(inputs[0], inputs[1], units=channels) - if len(inputs) == 3: - return _op.nn.bias_add(out, _expr.const(beta) * inputs[2]) + out = out + _expr.const(beta) * inputs[2] return out @@ -623,7 +621,7 @@ def _impl_v1(cls, inputs, attr, params): # Note: attr['fmod'] determines whether the operator should behave like np.fmod or np.mod. # attr['fmod'] == 0 will behave as np.mod and attr['fmod'] == 1 will force fmod treatment. # The relay equivalent of np.fmod is relay.mod and np.mod is relay.floor_mod - if attr["fmod"] == 0: + if attr.get("fmod", 0) == 0: op_name = "floor_mod" else: op_name = "mod" @@ -1322,8 +1320,8 @@ class Maximum(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2: - raise ValueError("Expect minimum 2 inputs") + if len(inputs) == 1: + return inputs[0] _max = inputs[0] for i in range(1, len(inputs)): _max = AttrCvt("maximum")([_max, inputs[i]], {}) @@ -1348,8 +1346,8 @@ class Mean(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2: - raise ValueError("Expect minimum 2 inputs") + if len(inputs) == 1: + return inputs[0] # avoid overflow concat = _op.concatenate([_op.expand_dims(x, axis=0) for x in inputs], axis=0) return _op.mean(concat, axis=0, keepdims=False) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index ea6ad752d259..d366ec7310ad 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4088,6 +4088,8 @@ def test_onnx_nodes(): tests = sorted(glob.glob("/".join(f.split("/")[0:-1]) + "/backend/test/data/node/*/")) failures = 0 for n, test in enumerate(tests): + #if "cumsum" not in test: + # continue print(n, test) if ("cast" in test and "FLOAT16" in test) or "test_slice_start_out_of_bounds" in test: print("FAILURE: SKIPPING due to segfault") @@ -4102,9 +4104,9 @@ def test_onnx_nodes(): new_tensor = onnx.TensorProto() with open(tensor, 'rb') as f: new_tensor.ParseFromString(f.read()) - if "input" in tensor: + if "input" in tensor.split('/')[-1]: inputs.append(numpy_helper.to_array(new_tensor)) - elif "output" in tensor: + elif "output" in tensor.split('/')[-1]: outputs.append(numpy_helper.to_array(new_tensor)) else: print(tensor) @@ -4118,7 +4120,8 @@ def test_onnx_nodes(): except Exception as e: print("------------------TEST FAILURE--------------------") print(e) - raise + #raise e + #raise def test_wrong_input(): node = helper.make_node( From 357d97dbcc72ed173b263b587038763cab5af275 Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Thu, 18 Feb 2021 14:29:14 -0700 Subject: [PATCH 04/17] fix some conv_transpose tests --- python/tvm/relay/frontend/onnx.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index f74cef53134f..4654f19becf9 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -40,9 +40,11 @@ __all__ = ["from_onnx"] + class ONNXAttrError(Exception): pass + class onnx_input: """ Dual purpose list or dictionary access object.""" @@ -109,6 +111,7 @@ def get_type(elem_type): except ImportError as e: raise ImportError("Unable to import onnx which is required {}".format(e)) from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE + return str(TENSOR_TYPE_TO_NP_TYPE[elem_type]) @@ -449,9 +452,15 @@ class ConvTranspose(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): # get number of channels - channels = infer_channels(inputs[1], True) + out_type = infer_type(inputs[1]) + out_shapes = [get_const_tuple(out_type.checked_type.shape)] + channels = out_shapes[0][1] attr["channels"] = channels groups = attr.get("group", 1) + + if "kernel_shape" not in attr: + attr["kernel_shape"] = out_shapes[0][2:] + attr["groups"] = groups # infer pads for auto_pad data = inputs[0] @@ -1124,6 +1133,7 @@ def _impl_v1(cls, inputs, attr, params): def _impl_v5(cls, inputs, attr, params): try: from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE + attr["to"] = str(TENSOR_TYPE_TO_NP_TYPE[attr["to"]]) except ImportError as e: raise ImportError("Unable to import onnx.mapping which is required {}".format(e)) From a89fdc3fd679a91de7ec7c8ae30b9fa1af52468e Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Thu, 18 Feb 2021 16:04:18 -0700 Subject: [PATCH 05/17] fix out of bounds slice --- python/tvm/relay/op/transform.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 4129b610cb7c..df0ae767460a 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -905,10 +905,13 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"): end = const(list(end)) if isinstance(strides, (tuple, list)): strides = const(list(strides)) - normalized_begin = _make.where( + begin = _make.where( begin < cast_like(const(0), begin), begin + cast_like(shape_of(data), begin), begin ) - return _dyn_make.strided_slice(data, normalized_begin, end, strides, slice_mode) + begin = _make.where( + begin >= cast_like(shape_of(data), begin), cast_like(shape_of(data), begin), begin + ) + return _dyn_make.strided_slice(data, begin, end, strides, slice_mode) return _make.strided_slice(data, begin, end, strides, slice_mode) From e27633026c7d382259c0061585deef61eb155ab0 Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Thu, 18 Feb 2021 17:08:53 -0700 Subject: [PATCH 06/17] fix flatten import --- python/tvm/relay/frontend/onnx.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 4654f19becf9..3e207080e27f 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -861,12 +861,18 @@ class Flatten(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): axis = attr.get("axis", 1) + ishape = _op.shape_of(inputs[0]) + ndim = infer_shape(ishape)[0] + if axis < 0: + axis = axis + ndim + if axis == 1: out = _op.nn.batch_flatten(inputs[0]) else: - newshape = [0] * (axis + 1) - newshape[axis] = -1 - out = _op.reshape(inputs[0], list(newshape)) + pre_shape = _op.prod(_op.strided_slice(ishape, [0], [axis], [1]), keepdims=True) + post_shape = _op.prod(_op.strided_slice(ishape, [axis], [ndim], [1]), keepdims=True) + newshape = _op.concatenate([pre_shape, post_shape], axis=0) + out = _op.reshape(inputs[0], newshape) return out From e689b6911aaf835ba2b8f6b0486eb202c4b2e75e Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Fri, 19 Feb 2021 13:23:02 -0700 Subject: [PATCH 07/17] fix logsoftmax and softmax tests --- python/tvm/relay/frontend/onnx.py | 32 +++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 3e207080e27f..6fe277d5c514 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1532,7 +1532,35 @@ def _impl_v1(cls, inputs, attr, params): # set default value when axis is not set in the model if "axis" not in attr: attr["axis"] = 1 - return AttrCvt("softmax", transforms={"axis": ("axis", 1)})(inputs, attr, params) + axis = attr["axis"] + ndim = len(infer_shape(inputs[0])) + if axis < 0: + axis += ndim + axes = list(range(axis, ndim)) + x = inputs[0] + m = _op.max(x, axes, keepdims=True) + e = _op.exp(x - m) + return e / _op.sum(e, axes, keepdims=True) + + +class LogSoftmax(OnnxOpConverter): + """Operator converter for Softmax.""" + + @classmethod + def _impl_v1(cls, inputs, attr, params): + # set default value when axis is not set in the model + if "axis" not in attr: + attr["axis"] = 1 + axis = attr["axis"] + ndim = len(infer_shape(inputs[0])) + if axis < 0: + axis += ndim + axes = list(range(axis, ndim)) + x = inputs[0] + m = _op.max(x, axes, keepdims=True) + e = _op.exp(x - m) + s = _op.sum(e, axes, keepdims=True) + return x - m - _op.log(s) class OneHot(OnnxOpConverter): @@ -2741,7 +2769,7 @@ def _get_convert_map(opset): "Softplus": Softplus.get_converter(opset), # softmax default axis is different in onnx "Softmax": Softmax.get_converter(opset), - "LogSoftmax": AttrCvt("log_softmax", {"axis": ("axis", 1)}), + "LogSoftmax": LogSoftmax.get_converter(opset), "OneHot": OneHot.get_converter(opset), # 'Hardmax' "Softsign": Softsign.get_converter(opset), From a0cac3be3a3f7a63accfef421b3c49fedd11d225 Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Fri, 19 Feb 2021 13:45:13 -0700 Subject: [PATCH 08/17] fix Error in Upsample --- python/tvm/relay/frontend/onnx.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 6fe277d5c514..66f967af1cb9 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1054,7 +1054,7 @@ def _impl_v9(cls, inputs, attr, params): # in 3d case, we use the purely static op if dims == 5: - if isinstance(scales, _expr.Call): + if isinstance(scales, _expr.Expr): scale_h = _op.take(scales, _op.const(3)) scale_w = _op.take(scales, _op.const(4)) scale_d = _op.take(scales, _op.const(1)) @@ -1070,7 +1070,7 @@ def _impl_v9(cls, inputs, attr, params): ) # in 2d case, use dynamic op else: - if isinstance(scales, _expr.Call): + if isinstance(scales, _expr.Expr): scale_h = _op.take(scales, _op.const(3)) scale_w = _op.take(scales, _op.const(4)) else: From 713e35da92ef9bdf5553cf4868a4959972d702cc Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Fri, 19 Feb 2021 14:04:16 -0700 Subject: [PATCH 09/17] fix onehot --- python/tvm/relay/frontend/onnx.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 66f967af1cb9..5756b0192fb1 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1570,14 +1570,24 @@ class OneHot(OnnxOpConverter): def _impl_v9(cls, inputs, attr, params): # Extract relay one_hot inputs. indices, depth, values = inputs + ndim = len(infer_shape(indices)) # Split onnx on off values into two separate expressions. off_value, on_value = _op.take(values, _op.const(0)), _op.take(values, _op.const(1)) # Extract the datatype of the output from on_value. dtype = infer_type(on_value).checked_type.dtype + ind_dtype = infer_type(indices).checked_type.dtype + # Normalize the indices to a positive range + indices = _op.where( + indices < _op.const(0, ind_dtype), indices + _op.cast(depth, ind_dtype), indices + ) # set default value when axis is not set in the model if "axis" not in attr: attr["axis"] = -1 - return _op.one_hot(indices, on_value, off_value, depth, int(attr["axis"]), dtype=dtype) + axis = attr["axis"] + if axis < 0: + axis += ndim + 1 + + return _op.one_hot(indices, on_value, off_value, depth, axis, dtype=dtype) class ConstantOfShape(OnnxOpConverter): From 9d4c8a62eec536e26d53b476ad8a932106897ec6 Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Fri, 19 Feb 2021 15:37:51 -0700 Subject: [PATCH 10/17] normalize errors --- python/tvm/relay/frontend/onnx.py | 14 +++++++------- tests/python/frontend/onnx/test_forward.py | 6 ++++-- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 5756b0192fb1..cf58aa7b0257 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1349,8 +1349,8 @@ class Minimum(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2: - raise ValueError("Expect minimum 2 inputs") + if len(inputs) == 1: + return inputs[0] _min = inputs[0] for i in range(1, len(inputs)): _min = AttrCvt("minimum")([_min, inputs[i]], {}) @@ -1504,7 +1504,7 @@ class ArgMax(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): if "select_last_index" in attr: - raise ONNXAttrError("select_last_index not supported in ArgMax") + raise NotImplementedError("select_last_index not supported in ArgMax") axis = attr.get("axis", 0) keepdims = attr.get("keepdims", True) attr = {"axis": axis, "keepdims": keepdims} @@ -1517,7 +1517,7 @@ class ArgMin(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): if "select_last_index" in attr: - raise ONNXAttrError("select_last_index not supported in ArgMax") + raise NotImplementedError("select_last_index not supported in ArgMin") axis = attr.get("axis", 0) keepdims = attr.get("keepdims", True) attr = {"axis": axis, "keepdims": keepdims} @@ -1612,7 +1612,7 @@ class Constant(OnnxOpConverter): @classmethod def _impl_v9(cls, inputs, attr, params): if "value" not in attr: - raise "No Value in Constant" + raise tvm.errors.OpAttributeRequired("no value in Constant") np_value = get_numpy(attr.pop("value")) dtype = np_value.dtype.name value = _expr.const(np_value, dtype) @@ -2102,7 +2102,7 @@ def _impl_v1(cls, inputs, attr, params): largest = attr.get("largest", 1) if largest == 0: - raise ValueError("TVM only supports finding TopK largest elements") + raise NotImplementedError("TVM only supports finding TopK largest elements") return _op.topk(inputs[0], inputs[1], axis=axis, dtype="int64") @@ -2147,7 +2147,7 @@ def _impl_v1(cls, inputs, attr, params): batch_indices = inputs[2] mode = attr.get("mode", b"avg") if mode not in (b"avg", b"max"): - raise ValueError("RoiAlign in Relay only uses avg and max modes") + raise NotImplementedError("RoiAlign in Relay only uses avg and max modes") output_height = attr.get("output_height", 1) output_width = attr.get("output_width", 1) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index d366ec7310ad..a6677ea16964 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4088,8 +4088,6 @@ def test_onnx_nodes(): tests = sorted(glob.glob("/".join(f.split("/")[0:-1]) + "/backend/test/data/node/*/")) failures = 0 for n, test in enumerate(tests): - #if "cumsum" not in test: - # continue print(n, test) if ("cast" in test and "FLOAT16" in test) or "test_slice_start_out_of_bounds" in test: print("FAILURE: SKIPPING due to segfault") @@ -4117,6 +4115,10 @@ def test_onnx_nodes(): else: for output, val in zip(outputs, tvm_val): tvm.testing.assert_allclose(output, val, rtol=1e-5, atol=1e-5) + except tvm.error.OpNotImplemented as e: + print("WARNING, missing Op:", e) + except NotImplementedError as e: + print("WARNING, missing implementation:", e) except Exception as e: print("------------------TEST FAILURE--------------------") print(e) From d6634937baad7796b204ca1af60bba9f23f8d3ea Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Fri, 19 Feb 2021 16:02:36 -0700 Subject: [PATCH 11/17] fix gather with negative indices --- python/tvm/relay/frontend/onnx.py | 17 +++++++++++++---- tests/python/frontend/onnx/test_forward.py | 5 +++-- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index cf58aa7b0257..ccfa472e79e7 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1265,7 +1265,13 @@ class Gather(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): axis = attr.get("axis", 0) - return AttrCvt("take", extras={"axis": axis})(inputs, {}) + data = inputs[0] + indices = inputs[1] + ind_dtype = infer_type(indices).checked_type.dtype + # Normalize the indices to a positive range + s = _op.take(_op.shape_of(data, dtype=ind_dtype), _op.const(axis)) + indices = _op.where(indices < _op.const(0, ind_dtype), indices + s, indices) + return _op.take(data, indices, axis) class GatherElements(OnnxOpConverter): @@ -1276,6 +1282,10 @@ def _impl_v1(cls, inputs, attr, params): data = inputs[0] indices = inputs[1] axis = attr.get("axis", 0) + ind_dtype = infer_type(indices).checked_type.dtype + # Normalize the indices to a positive range + s = _op.take(_op.shape_of(data, dtype=ind_dtype), _op.const(axis)) + indices = _op.where(indices < _op.const(0, ind_dtype), indices + s, indices) return _op.gather(data, axis, indices) @@ -2454,9 +2464,8 @@ def _impl_v10(cls, inputs, attr, params): dtype = infer_type(boxes).checked_type.dtype if "center_point_box" in attr: - assert ( - attr["center_point_box"] == 0 - ), "Only support center_point_box = 0 in onnx importer right now" + if attr["center_point_box"] != 0: + raise NotImplementedError("Only support center_point_box = 0 in ONNX NonMaxSuprresion") if iou_threshold is None: iou_threshold = _expr.const(0.0, dtype="float32") diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index a6677ea16964..f2cf3a4f4c66 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4088,6 +4088,8 @@ def test_onnx_nodes(): tests = sorted(glob.glob("/".join(f.split("/")[0:-1]) + "/backend/test/data/node/*/")) failures = 0 for n, test in enumerate(tests): + #if "gather" not in test: + # continue print(n, test) if ("cast" in test and "FLOAT16" in test) or "test_slice_start_out_of_bounds" in test: print("FAILURE: SKIPPING due to segfault") @@ -4118,12 +4120,11 @@ def test_onnx_nodes(): except tvm.error.OpNotImplemented as e: print("WARNING, missing Op:", e) except NotImplementedError as e: - print("WARNING, missing implementation:", e) + print("WARNING, missing implementation in Op:", e) except Exception as e: print("------------------TEST FAILURE--------------------") print(e) #raise e - #raise def test_wrong_input(): node = helper.make_node( From a5aee5f7a17b0485daf40f8f4d467431c20092b2 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 22 Mar 2021 11:03:08 -0600 Subject: [PATCH 12/17] parameterize test --- tests/python/frontend/onnx/test_forward.py | 86 +++++++++++----------- 1 file changed, 43 insertions(+), 43 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index f2cf3a4f4c66..7656859d1a57 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4080,51 +4080,51 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): verify_cumsum(data, 1, 0, 1, type="int32") verify_cumsum(data, 1, 1, 1, type="int32") -@tvm.testing.uses_gpu -def test_onnx_nodes(): - from onnx import numpy_helper - f = onnx.__file__ - import glob - tests = sorted(glob.glob("/".join(f.split("/")[0:-1]) + "/backend/test/data/node/*/")) +from onnx import numpy_helper +f = onnx.__file__ +import glob +onnx_test_folders = sorted(glob.glob("/".join(f.split("/")[0:-1]) + "/backend/test/data/node/*/")) + +@pytest.mark.parametrize("test", onnx_test_folders) +def test_onnx_nodes(test): failures = 0 - for n, test in enumerate(tests): - #if "gather" not in test: - # continue - print(n, test) - if ("cast" in test and "FLOAT16" in test) or "test_slice_start_out_of_bounds" in test: - print("FAILURE: SKIPPING due to segfault") - continue - try: - onnx_model = onnx.load(test + "/model.onnx") - inputs = [] - outputs = [] - for dataset in glob.glob(test + "/*/"): - tensors = sorted(glob.glob(dataset + "/*.pb")) - for tensor in tensors: - new_tensor = onnx.TensorProto() - with open(tensor, 'rb') as f: - new_tensor.ParseFromString(f.read()) - if "input" in tensor.split('/')[-1]: - inputs.append(numpy_helper.to_array(new_tensor)) - elif "output" in tensor.split('/')[-1]: - outputs.append(numpy_helper.to_array(new_tensor)) - else: - print(tensor) - raise - tvm_val = get_tvm_output_with_vm(onnx_model, inputs, "llvm", tvm.cpu(0)) - if len(outputs) == 1: - tvm.testing.assert_allclose(outputs[0], tvm_val, rtol=1e-5, atol=1e-5) + #if "gather" not in test: + # continue + print(test) + if ("cast" in test and "FLOAT16" in test) or "test_slice_start_out_of_bounds" in test: + print("FAILURE: SKIPPING due to segfault") + pytest.skip() + try: + onnx_model = onnx.load(test + "/model.onnx") + inputs = [] + outputs = [] + for dataset in glob.glob(test + "/*/"): + tensors = sorted(glob.glob(dataset + "/*.pb")) + for tensor in tensors: + new_tensor = onnx.TensorProto() + with open(tensor, 'rb') as f: + new_tensor.ParseFromString(f.read()) + if "input" in tensor.split('/')[-1]: + inputs.append(numpy_helper.to_array(new_tensor)) + elif "output" in tensor.split('/')[-1]: + outputs.append(numpy_helper.to_array(new_tensor)) else: - for output, val in zip(outputs, tvm_val): - tvm.testing.assert_allclose(output, val, rtol=1e-5, atol=1e-5) - except tvm.error.OpNotImplemented as e: - print("WARNING, missing Op:", e) - except NotImplementedError as e: - print("WARNING, missing implementation in Op:", e) - except Exception as e: - print("------------------TEST FAILURE--------------------") - print(e) - #raise e + print(tensor) + raise + tvm_val = get_tvm_output_with_vm(onnx_model, inputs, "llvm", tvm.cpu(0)) + if len(outputs) == 1: + tvm.testing.assert_allclose(outputs[0], tvm_val, rtol=1e-5, atol=1e-5) + else: + for output, val in zip(outputs, tvm_val): + tvm.testing.assert_allclose(output, val, rtol=1e-5, atol=1e-5) + except tvm.error.OpNotImplemented as e: + print("WARNING, missing Op:", e) + pytest.skip() + except NotImplementedError as e: + print("WARNING, missing implementation in Op:", e) + pytest.skip() + except Exception as e: + raise e def test_wrong_input(): node = helper.make_node( From 928c12821c832ebdb0345fd71c5df7c1e801158f Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 22 Mar 2021 12:12:41 -0600 Subject: [PATCH 13/17] skip unsupported tests --- tests/python/frontend/onnx/test_forward.py | 193 +++++++++++++++++---- 1 file changed, 156 insertions(+), 37 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 7656859d1a57..4694b56cc504 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4080,51 +4080,170 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): verify_cumsum(data, 1, 0, 1, type="int32") verify_cumsum(data, 1, 1, 1, type="int32") + from onnx import numpy_helper + f = onnx.__file__ import glob + onnx_test_folders = sorted(glob.glob("/".join(f.split("/")[0:-1]) + "/backend/test/data/node/*/")) +unsupported_onnx_tests = [ + "test_basic_convinteger/", + "test_bitshift_left_uint16/", + "test_bitshift_left_uint32/", + "test_bitshift_left_uint64/", + "test_bitshift_left_uint8/", + "test_bitshift_right_uint16/", + "test_bitshift_right_uint32/", + "test_bitshift_right_uint64/", + "test_bitshift_right_uint8/", + "test_cast_DOUBLE_to_FLOAT16/", + "test_cast_FLOAT16_to_DOUBLE/", + "test_cast_FLOAT16_to_FLOAT/", + "test_cast_FLOAT_to_FLOAT16/", + "test_cast_FLOAT_to_STRING/", + "test_cast_STRING_to_FLOAT/", + "test_compress_0/", + "test_compress_1/", + "test_compress_default_axis/", + "test_compress_negative_axis/", + "test_convinteger_with_padding/", + "test_convtranspose_dilations/", + "test_convtranspose_output_shape/", + "test_cumsum_1d/", + "test_cumsum_1d_exclusive/", + "test_cumsum_1d_reverse/", + "test_cumsum_1d_reverse_exclusive/", + "test_cumsum_2d_axis_0/", + "test_cumsum_2d_axis_1/", + "test_cumsum_2d_negative_axis/", + "test_dequantizelinear/", + "test_det_2d/", + "test_det_nd/", + "test_dynamicquantizelinear/", + "test_dynamicquantizelinear_expanded/", + "test_dynamicquantizelinear_max_adjusted/", + "test_dynamicquantizelinear_max_adjusted_expanded/", + "test_dynamicquantizelinear_min_adjusted/", + "test_dynamicquantizelinear_min_adjusted_expanded/", + "test_eyelike_populate_off_main_diagonal/", + "test_eyelike_with_dtype/", + "test_eyelike_without_dtype/", + "test_hardmax_axis_0/", + "test_hardmax_axis_1/", + "test_hardmax_axis_2/", + "test_hardmax_default_axis/", + "test_hardmax_example/", + "test_hardmax_negative_axis/", + "test_hardmax_one_hot/", + "test_isinf_negative/", + "test_isinf_positive/", + "test_lstm_defaults/", + "test_lstm_with_initial_bias/", + "test_lstm_with_peepholes/", + "test_matmulinteger/", + "test_maxpool_2d_dilations/", + "test_maxpool_2d_same_lower/", + "test_maxpool_2d_same_upper/", + "test_maxpool_with_argmax_2d_precomputed_pads/", + "test_maxpool_with_argmax_2d_precomputed_strides/", + "test_maxunpool_export_with_output_shape/", + "test_mvn/", + "test_nonmaxsuppression_center_point_box_format/", + "test_qlinearconv/", + "test_qlinearmatmul_2D/", + "test_qlinearmatmul_3D/", + "test_quantizelinear/", + "test_range_float_type_positive_delta_expanded/", + "test_range_int32_type_negative_delta_expanded/", + "test_resize_downsample_scales_cubic/", + "test_resize_downsample_scales_cubic_A_n0p5_exclude_outside/", + "test_resize_downsample_scales_cubic_align_corners/", + "test_resize_downsample_scales_linear/", + "test_resize_downsample_scales_nearest/", + "test_resize_downsample_sizes_cubic/", + "test_resize_downsample_sizes_linear_pytorch_half_pixel/", + "test_resize_downsample_sizes_nearest/", + "test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn/", + "test_resize_tf_crop_and_resize/", + "test_resize_upsample_scales_cubic/", + "test_resize_upsample_scales_cubic_A_n0p5_exclude_outside/", + "test_resize_upsample_scales_cubic_align_corners/", + "test_resize_upsample_scales_cubic_asymmetric/", + "test_resize_upsample_scales_linear/", + "test_resize_upsample_sizes_cubic/", + "test_resize_upsample_sizes_nearest_ceil_half_pixel/", + "test_resize_upsample_sizes_nearest_floor_align_corners/", + "test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric/", + "test_reversesequence_batch/", + "test_reversesequence_time/", + "test_rnn_seq_length/", + "test_roialign/", + "test_round/", + "test_scan9_sum/", + "test_scan_sum/", + "test_scatternd/", + "test_selu_default/", + "test_shrink_hard/", + "test_shrink_soft/", + "test_simple_rnn_defaults/", + "test_simple_rnn_with_initial_bias/", + "test_slice_neg_steps/", + "test_slice_start_out_of_bounds/", + "test_strnormalizer_export_monday_casesensintive_lower/", + "test_strnormalizer_export_monday_casesensintive_nochangecase/", + "test_strnormalizer_export_monday_casesensintive_upper/", + "test_strnormalizer_export_monday_empty_output/", + "test_strnormalizer_export_monday_insensintive_upper_twodim/", + "test_strnormalizer_nostopwords_nochangecase/", + "test_tfidfvectorizer_tf_batch_onlybigrams_skip0/", + "test_tfidfvectorizer_tf_batch_onlybigrams_skip5/", + "test_tfidfvectorizer_tf_batch_uniandbigrams_skip5/", + "test_tfidfvectorizer_tf_only_bigrams_skip0/", + "test_tfidfvectorizer_tf_onlybigrams_levelempty/", + "test_tfidfvectorizer_tf_onlybigrams_skip5/", + "test_tfidfvectorizer_tf_uniandbigrams_skip5/", + "test_top_k_smallest/", + "test_unique_not_sorted_without_axis/", + "test_unique_sorted_with_axis/", + "test_unique_sorted_with_axis_3d/", + "test_unique_sorted_with_negative_axis/", + "test_unique_sorted_without_axis/", + "test_unsqueeze_unsorted_axes/", + "test_upsample_nearest/", +] + + @pytest.mark.parametrize("test", onnx_test_folders) def test_onnx_nodes(test): - failures = 0 - #if "gather" not in test: - # continue - print(test) - if ("cast" in test and "FLOAT16" in test) or "test_slice_start_out_of_bounds" in test: - print("FAILURE: SKIPPING due to segfault") - pytest.skip() - try: - onnx_model = onnx.load(test + "/model.onnx") - inputs = [] - outputs = [] - for dataset in glob.glob(test + "/*/"): - tensors = sorted(glob.glob(dataset + "/*.pb")) - for tensor in tensors: - new_tensor = onnx.TensorProto() - with open(tensor, 'rb') as f: - new_tensor.ParseFromString(f.read()) - if "input" in tensor.split('/')[-1]: - inputs.append(numpy_helper.to_array(new_tensor)) - elif "output" in tensor.split('/')[-1]: - outputs.append(numpy_helper.to_array(new_tensor)) - else: - print(tensor) - raise - tvm_val = get_tvm_output_with_vm(onnx_model, inputs, "llvm", tvm.cpu(0)) - if len(outputs) == 1: - tvm.testing.assert_allclose(outputs[0], tvm_val, rtol=1e-5, atol=1e-5) + for failure in unsupported_onnx_tests: + if failure in test: + pytest.skip() + break + onnx_model = onnx.load(test + "/model.onnx") + inputs = [] + outputs = [] + for dataset in glob.glob(test + "/*/"): + tensors = sorted(glob.glob(dataset + "/*.pb")) + for tensor in tensors: + new_tensor = onnx.TensorProto() + with open(tensor, "rb") as f: + new_tensor.ParseFromString(f.read()) + if "input" in tensor.split("/")[-1]: + inputs.append(numpy_helper.to_array(new_tensor)) + elif "output" in tensor.split("/")[-1]: + outputs.append(numpy_helper.to_array(new_tensor)) else: - for output, val in zip(outputs, tvm_val): - tvm.testing.assert_allclose(output, val, rtol=1e-5, atol=1e-5) - except tvm.error.OpNotImplemented as e: - print("WARNING, missing Op:", e) - pytest.skip() - except NotImplementedError as e: - print("WARNING, missing implementation in Op:", e) - pytest.skip() - except Exception as e: - raise e + print(tensor) + raise + tvm_val = get_tvm_output_with_vm(onnx_model, inputs, "llvm", tvm.cpu(0)) + if len(outputs) == 1: + tvm.testing.assert_allclose(outputs[0], tvm_val, rtol=1e-5, atol=1e-5) + else: + for output, val in zip(outputs, tvm_val): + tvm.testing.assert_allclose(output, val, rtol=1e-5, atol=1e-5) + def test_wrong_input(): node = helper.make_node( From c39884b67dec87425a90a517b66c1c3452ec62f2 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 22 Mar 2021 12:23:09 -0600 Subject: [PATCH 14/17] clean up --- python/tvm/relay/frontend/onnx.py | 10 ++++------ tests/python/frontend/onnx/test_forward.py | 1 - 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index ccfa472e79e7..2b0da5a68dd8 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -41,10 +41,6 @@ __all__ = ["from_onnx"] -class ONNXAttrError(Exception): - pass - - class onnx_input: """ Dual purpose list or dictionary access object.""" @@ -2464,8 +2460,10 @@ def _impl_v10(cls, inputs, attr, params): dtype = infer_type(boxes).checked_type.dtype if "center_point_box" in attr: - if attr["center_point_box"] != 0: - raise NotImplementedError("Only support center_point_box = 0 in ONNX NonMaxSuprresion") + if attr["center_point_box"] != 0: + raise NotImplementedError( + "Only support center_point_box = 0 in ONNX NonMaxSuprresion" + ) if iou_threshold is None: iou_threshold = _expr.const(0.0, dtype="float32") diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 4694b56cc504..2692f718f642 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -24,7 +24,6 @@ import tvm.topi.testing import tvm from tvm import relay -from tvm.relay.frontend.onnx import ONNXAttrError from tvm.contrib import graph_runtime import scipy import tvm.testing From cd8ee5e21ee982baa9e5bde69de858a04fe594d2 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 22 Mar 2021 13:11:17 -0600 Subject: [PATCH 15/17] fix rebase --- tests/python/frontend/onnx/test_forward.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 2692f718f642..471768e0cab7 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -17,7 +17,6 @@ import numpy as np import onnx from onnx import helper, TensorProto, mapping, numpy_helper -from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE import torch import torchvision import pytest @@ -50,11 +49,14 @@ def get_tvm_output_with_vm( if not isinstance(input_data, list): input_data = [input_data] _, shape_dict = get_input_data_shape_dict(graph_def, input_data) + mod, params = relay.frontend.from_onnx( graph_def, shape_dict, opset=opset, freeze_params=freeze_params ) + if convert_to_static: mod = relay.transform.DynamicToStatic()(mod) + ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target) result = ex.evaluate()(*input_data, **params) if isinstance(result, tvm.runtime.NDArray): @@ -3477,7 +3479,13 @@ def verify_topk(input_dims, K, axis=-1): @tvm.testing.uses_gpu def test_roi_align(): def verify_roi_align( - input_dims, num_roi, output_height, output_width, sampling_ratio=0, spatial_scale=1.0 + input_dims, + num_roi, + output_height, + output_width, + sampling_ratio=0, + spatial_scale=1.0, + mode="avg", ): output_dims = [num_roi, input_dims[1], output_height, output_width] @@ -3485,7 +3493,7 @@ def verify_roi_align( "RoiAlign", inputs=["X", "rois", "batch_indicies"], outputs=["Y"], - mode="avg", + mode=mode, output_height=output_height, output_width=output_width, sampling_ratio=sampling_ratio, @@ -3530,6 +3538,8 @@ def verify_roi_align( verify_roi_align((5, 4, 16, 14), 32, 7, 7, sampling_ratio=1, spatial_scale=1.0) verify_roi_align((1, 4, 16, 16), 32, 7, 7, sampling_ratio=2, spatial_scale=1.0) + # ONNX implementation of roi_align with max mode is incorrect, so we don't compare outputs here. + # @tvm.testing.uses_gpu def test_non_max_suppression(): From 2d81e574f525a0b236d73cc543fffe7c43c32fd5 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 22 Mar 2021 14:56:46 -0600 Subject: [PATCH 16/17] fix lint --- python/tvm/relay/frontend/onnx.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 2b0da5a68dd8..d9fc2ff99a76 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -103,10 +103,9 @@ def get_numpy(tensor_proto): def get_type(elem_type): """Converts onnx integer datatype to numpy datatype""" try: - from onnx import TensorProto + from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE except ImportError as e: raise ImportError("Unable to import onnx which is required {}".format(e)) - from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE return str(TENSOR_TYPE_TO_NP_TYPE[elem_type]) From 5b5493dcc39aea24d5feb6ff7a5e978e46fd93e0 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 22 Mar 2021 15:18:02 -0600 Subject: [PATCH 17/17] add an error message when we find an un-identified tensor --- tests/python/frontend/onnx/test_forward.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 471768e0cab7..ec89a3d844d1 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4244,8 +4244,7 @@ def test_onnx_nodes(test): elif "output" in tensor.split("/")[-1]: outputs.append(numpy_helper.to_array(new_tensor)) else: - print(tensor) - raise + raise ImportError(str(tensor) + " not labeled as an import or an output") tvm_val = get_tvm_output_with_vm(onnx_model, inputs, "llvm", tvm.cpu(0)) if len(outputs) == 1: tvm.testing.assert_allclose(outputs[0], tvm_val, rtol=1e-5, atol=1e-5)