diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 09ff6b7de5b5..13505fd0f738 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1475,6 +1475,27 @@ def _impl_v11(cls, inputs, attr, params): ) +class EyeLike(OnnxOpConverter): + """Operator converter for EyeLike.""" + + @classmethod + def _impl_v9(cls, inputs, attr, params): + in_checked_type = infer_type(inputs[0]).checked_type + in_dtype = in_checked_type.dtype + in_shape = list(get_const_tuple(in_checked_type.shape)) + dtype = attr.get("dtype", None) + if dtype is None: + dtype = in_dtype + else: + dtype = get_type(dtype) + zeros = _op.zeros(in_shape, dtype) + dim = in_shape[0] + indices = _op.arange(_op.const(0), _op.const(dim), dtype="int32") + ones = _op.full(_op.const(1), (dim,), dtype=dtype) + k = _op.const(attr.get("k", 0), dtype="int32") + return _op.scatter_nd(zeros, _op.stack([indices, indices + k], axis=0), ones, "update") + + class Greater(OnnxOpConverter): """Operator logical greater.""" @@ -3158,6 +3179,7 @@ def _get_convert_map(opset): "Scatter": Scatter.get_converter(opset), "ScatterElements": Scatter.get_converter(opset), "ScatterND": ScatterND.get_converter(opset), + "EyeLike": EyeLike.get_converter(opset), "Squeeze": AttrCvt("squeeze", {"axes": "axis"}), "Unsqueeze": Unsqueeze.get_converter(opset), "Pad": Pad.get_converter(opset), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 4ac7ff2a81f3..423f031e49a9 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4129,6 +4129,7 @@ def verify_softplus(indata): verify_softplus(input_data) +@tvm.testing.uses_gpu def test_cumsum(): def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): cumsum_node = onnx.helper.make_node( @@ -4205,6 +4206,30 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): verify_cumsum(data, 1, 1, 1, type="int32") +@tvm.testing.uses_gpu +def test_eyelike(): + def verify_eyelike(indata): + node = helper.make_node( + "EyeLike", + inputs=["X"], + outputs=["Y"], + ) + + graph = helper.make_graph( + [node], + "eyelike_test", + inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list(indata.shape))], + outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, list(indata.shape))], + ) + + model = helper.make_model(graph, producer_name="eyelike_test") + + verify_with_ort_with_inputs(model, [indata], dtype="float32", opset=9) + + input_data = np.zeros((5, 5), dtype=np.float32) + verify_eyelike(input_data) + + """ The following parameterized tests loads the tests that ONNX ships as serialized ONNX files, inputs, and outputs. The goal of this test @@ -4241,9 +4266,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): "test_cumsum_2d_negative_axis/", "test_det_2d/", "test_det_nd/", - "test_eyelike_populate_off_main_diagonal/", - "test_eyelike_with_dtype/", - "test_eyelike_without_dtype/", "test_matmulinteger/", "test_maxpool_2d_same_lower/", "test_maxpool_2d_same_upper/", @@ -4680,4 +4702,5 @@ def repeat(N, D): test_wrong_input() test_aten() test_reverse_sequence() + test_eyelike() test_qlinearconv()