diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameTests.cs index 4fa9904e3..f036ad346 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameTests.cs @@ -691,5 +691,28 @@ public void TestSignaturesV2_4_X() Assert.IsType(df.Pivot(Col("age"), values)); } } + + /// + /// Test signatures for APIs introduced in Spark 3.* + /// + [SkipIfSparkVersionIsLessThan(Versions.V3_0_0)] + public void TestSignaturesV3_X_X() + { + // Validate ToLocalIterator + var data = new List + { + new GenericRow(new object[] { "Alice", 20}), + new GenericRow(new object[] { "Bob", 30}) + }; + var schema = new StructType(new List() + { + new StructField("Name", new StringType()), + new StructField("Age", new IntegerType()) + }); + DataFrame df = _spark.CreateDataFrame(data, schema); + IEnumerable actual = df.ToLocalIterator(true).ToArray(); + IEnumerable expected = data.Select(r => new Row(r.Values, schema)); + Assert.Equal(expected, actual); + } } } diff --git a/src/csharp/Microsoft.Spark/Sql/DataFrame.cs b/src/csharp/Microsoft.Spark/Sql/DataFrame.cs index ed2052143..1c4d1de8d 100644 --- a/src/csharp/Microsoft.Spark/Sql/DataFrame.cs +++ b/src/csharp/Microsoft.Spark/Sql/DataFrame.cs @@ -722,7 +722,37 @@ public IEnumerable Collect() /// Row objects public IEnumerable ToLocalIterator() { - return GetRows("toPythonIterator"); + Version version = SparkEnvironment.SparkVersion; + return version.Major switch + { + 2 => GetRows("toPythonIterator"), + 3 => ToLocalIterator(false), + _ => throw new NotSupportedException($"Spark {version} not supported.") + }; + } + + /// + /// Returns an iterator that contains all of the rows in this `DataFrame`. + /// The iterator will consume as much memory as the largest partition in this `DataFrame`. + /// With prefetch it may consume up to the memory of the 2 largest partitions. + /// + /// + /// If Spark should pre-fetch the next partition before it is needed. + /// + /// Row objects + [Since(Versions.V3_0_0)] + public IEnumerable ToLocalIterator(bool prefetchPartitions) + { + (int port, string secret, JvmObjectReference server) = + ParseConnectionInfo( + _jvmObject.Invoke("toPythonIterator", prefetchPartitions), + true); + using ISocketWrapper socket = SocketFactory.CreateSocket(); + socket.Connect(IPAddress.Loopback, port, secret); + foreach (Row row in new RowCollector().Collect(socket, server)) + { + yield return row; + } } /// @@ -902,18 +932,18 @@ public DataStreamWriter WriteStream() => /// Returns row objects based on the function (either "toPythonIterator" or /// "collectToPython"). /// - /// - /// + /// + /// The name of the function to call, either "toPythonIterator" or "collectToPython". + /// + /// objects private IEnumerable GetRows(string funcName) { - (int port, string secret) = GetConnectionInfo(funcName); - using (ISocketWrapper socket = SocketFactory.CreateSocket()) + (int port, string secret, _) = GetConnectionInfo(funcName); + using ISocketWrapper socket = SocketFactory.CreateSocket(); + socket.Connect(IPAddress.Loopback, port, secret); + foreach (Row row in new RowCollector().Collect(socket)) { - socket.Connect(IPAddress.Loopback, port, secret); - foreach (Row row in new RowCollector().Collect(socket)) - { - yield return row; - } + yield return row; } } @@ -921,29 +951,33 @@ private IEnumerable GetRows(string funcName) /// Returns a tuple of port number and secret string which are /// used for connecting with Spark to receive rows for this `DataFrame`. /// - /// A tuple of port number and secret string - private (int, string) GetConnectionInfo(string funcName) + /// A tuple of port number, secret string, and JVM socket auth server. + private (int, string, JvmObjectReference) GetConnectionInfo(string funcName) { object result = _jvmObject.Invoke(funcName); Version version = SparkEnvironment.SparkVersion; return (version.Major, version.Minor, version.Build) switch { // In spark 2.3.0, PythonFunction.serveIterator() returns a port number. - (2, 3, 0) => ((int)result, string.Empty), + (2, 3, 0) => ((int)result, string.Empty, null), // From spark >= 2.3.1, PythonFunction.serveIterator() returns a pair // where the first is a port number and the second is the secret // string to use for the authentication. - (2, 3, _) => ParseConnectionInfo(result), - (2, 4, _) => ParseConnectionInfo(result), - (3, 0, _) => ParseConnectionInfo(result), + (2, 3, _) => ParseConnectionInfo(result, false), + (2, 4, _) => ParseConnectionInfo(result, false), + (3, 0, _) => ParseConnectionInfo(result, false), _ => throw new NotSupportedException($"Spark {version} not supported.") }; } - private (int, string) ParseConnectionInfo(object info) + private (int, string, JvmObjectReference) ParseConnectionInfo( + object info, + bool parseServer) { - var pair = (JvmObjectReference[])info; - return ((int)pair[0].Invoke("intValue"), (string)pair[1].Invoke("toString")); + var infos = (JvmObjectReference[])info; + return ((int)infos[0].Invoke("intValue"), + (string)infos[1].Invoke("toString"), + parseServer ? infos[2] : null); } private DataFrame WrapAsDataFrame(object obj) => new DataFrame((JvmObjectReference)obj); diff --git a/src/csharp/Microsoft.Spark/Sql/RowCollector.cs b/src/csharp/Microsoft.Spark/Sql/RowCollector.cs index f48545ea6..431a5ea3c 100644 --- a/src/csharp/Microsoft.Spark/Sql/RowCollector.cs +++ b/src/csharp/Microsoft.Spark/Sql/RowCollector.cs @@ -2,7 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Collections; using System.Collections.Generic; +using System.Diagnostics; using System.IO; using Microsoft.Spark.Interop.Ipc; using Microsoft.Spark.Network; @@ -18,8 +20,8 @@ internal sealed class RowCollector /// /// Collects pickled row objects from the given socket. /// - /// Socket the get the stream from - /// Collection of row objects + /// Socket the get the stream from. + /// Collection of row objects. public IEnumerable Collect(ISocketWrapper socket) { Stream inputStream = socket.InputStream; @@ -37,5 +39,100 @@ public IEnumerable Collect(ISocketWrapper socket) } } } + + /// + /// Collects pickled row objects from the given socket. Collects rows in partitions + /// by leveraging . + /// + /// Socket the get the stream from. + /// The JVM socket auth server. + /// Collection of row objects. + public IEnumerable Collect(ISocketWrapper socket, JvmObjectReference server) => + new LocalIteratorFromSocket(socket, server); + + /// + /// LocalIteratorFromSocket creates a synchronous local iterable over + /// a socket. + /// + /// Note that the implementation mirrors _local_iterator_from_socket in + /// PySpark: spark/python/pyspark/rdd.py + /// + private class LocalIteratorFromSocket : IEnumerable + { + private readonly ISocketWrapper _socket; + private readonly JvmObjectReference _server; + + private int _readStatus = 1; + private IEnumerable _currentPartitionRows = null; + + internal LocalIteratorFromSocket(ISocketWrapper socket, JvmObjectReference server) + { + _socket = socket; + _server = server; + } + + ~LocalIteratorFromSocket() + { + // If iterator is not fully consumed. + if ((_readStatus == 1) && (_currentPartitionRows != null)) + { + try + { + // Finish consuming partition data stream. + foreach (Row _ in _currentPartitionRows) + { + } + + // Tell Java to stop sending data and close connection. + Stream outputStream = _socket.OutputStream; + SerDe.Write(outputStream, 0); + outputStream.Flush(); + } + catch + { + // Ignore any errors, socket may be automatically closed + // when garbage-collected. + } + } + } + + public IEnumerator GetEnumerator() + { + Stream inputStream = _socket.InputStream; + Stream outputStream = _socket.OutputStream; + + while (_readStatus == 1) + { + // Request next partition data from Java. + SerDe.Write(outputStream, 1); + outputStream.Flush(); + + // If response is 1 then there is a partition to read, if 0 then + // fully consumed. + _readStatus = SerDe.ReadInt32(inputStream); + if (_readStatus == 1) + { + // Load the partition data from stream and read each item. + _currentPartitionRows = new RowCollector().Collect(_socket); + foreach (Row row in _currentPartitionRows) + { + yield return row; + } + } + else if (_readStatus == -1) + { + // An error occurred, join serving thread and raise any exceptions from + // the JVM. The exception stack trace will appear in the driver logs. + _server.Invoke("getResult"); + } + else + { + Debug.Assert(_readStatus == 0); + } + } + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } } }