diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 8269ebfb198d..5f31724ec3ea 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1050,6 +1050,17 @@ def _impl_v1(cls, inputs, attr, params): return AttrCvt("take", extras={"axis": axis})(inputs, {}) +class GatherElements(OnnxOpConverter): + """Operator converter for GatherElements.""" + + @classmethod + def _impl_v1(cls, inputs, attr, params): + data = inputs[0] + indices = inputs[1] + axis = attr.get("axis", 0) + return _op.gather(data, axis, indices) + + class GatherND(OnnxOpConverter): """Operator converter for GatherND.""" @@ -2014,6 +2025,7 @@ def _get_convert_map(opset): "DepthToSpace": DepthToSpace.get_converter(opset), "SpaceToDepth": SpaceToDepth.get_converter(opset), "Gather": Gather.get_converter(opset), + "GatherElements": GatherElements.get_converter(opset), "GatherND": GatherND.get_converter(opset), "Scatter": Scatter.get_converter(opset), "ScatterElements": Scatter.get_converter(opset), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 6f63cbf0413a..894a6b6d40ce 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -435,6 +435,44 @@ def test_gather(): verify_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, "float32") +def verify_gatherelements(in_shape, indices, axis): + x = np.random.uniform(size=in_shape).astype("float32") + indices = np.array(indices, dtype="int32") + + y = helper.make_node("GatherElements", ["data", "indices"], ["output"], axis=axis) + graph = helper.make_graph( + [y], + "gather_elements_test", + inputs=[ + helper.make_tensor_value_info("data", TensorProto.FLOAT, list(in_shape)), + helper.make_tensor_value_info("indices", TensorProto.INT32, list(indices.shape)), + ], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, list(in_shape))], + ) + model = helper.make_model(graph, producer_name="gather_elements_test") + onnx_out = get_onnxruntime_output(model, [x, indices]) + + for target, ctx in tvm.testing.enabled_targets(): + tvm_out = get_tvm_output(model, [x, indices], target, ctx, onnx_out[0].shape) + tvm.testing.assert_allclose(onnx_out[0], tvm_out) + + +@tvm.testing.uses_gpu +def test_gatherelements(): + verify_gatherelements((4,), [3, 0, 2, 1], 0) + verify_gatherelements((2, 2), [[1, 0], [0, 1]], 0) + verify_gatherelements((2, 2), [[0, 0], [1, 0]], 1) + verify_gatherelements((2, 2), [[1, 0], [0, 1]], 1) + + indices = [ + [[1, 0, 0], [1, 0, 1], [0, 1, 1]], + [[1, 1, 1], [1, 2, 1], [1, 0, 1]], + [[1, 2, 1], [1, 2, 1], [1, 2, 1]], + ] + + verify_gatherelements((3, 3, 3), indices, 2) + + def verify_scatter(in_shape, indices, axis): x = np.random.uniform(size=in_shape).astype("float32") indices = np.array(indices, dtype="int32") @@ -3384,6 +3422,7 @@ def verify_roi_align( test_matmul() test_batch_matmul() test_gather() + test_gatherelements() test_gather_nd() test_scatter() test_lrn()