Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -691,5 +691,28 @@ public void TestSignaturesV2_4_X()
Assert.IsType<RelationalGroupedDataset>(df.Pivot(Col("age"), values));
}
}

/// <summary>
/// Test signatures for APIs introduced in Spark 3.*
/// </summary>
[SkipIfSparkVersionIsLessThan(Versions.V3_0_0)]
public void TestSignaturesV3_X_X()
{
// Validate ToLocalIterator
var data = new List<GenericRow>
{
new GenericRow(new object[] { "Alice", 20}),
new GenericRow(new object[] { "Bob", 30})
};
var schema = new StructType(new List<StructField>()
{
new StructField("Name", new StringType()),
new StructField("Age", new IntegerType())
});
DataFrame df = _spark.CreateDataFrame(data, schema);
IEnumerable<Row> actual = df.ToLocalIterator(true).ToArray();
IEnumerable<Row> expected = data.Select(r => new Row(r.Values, schema));
Assert.Equal(expected, actual);
}
}
}
72 changes: 53 additions & 19 deletions src/csharp/Microsoft.Spark/Sql/DataFrame.cs
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,37 @@ public IEnumerable<Row> Collect()
/// <returns>Row objects</returns>
public IEnumerable<Row> ToLocalIterator()
{
return GetRows("toPythonIterator");
Version version = SparkEnvironment.SparkVersion;
return version.Major switch
{
2 => GetRows("toPythonIterator"),
3 => ToLocalIterator(false),
_ => throw new NotSupportedException($"Spark {version} not supported.")
};
}

/// <summary>
/// Returns an iterator that contains all of the rows in this `DataFrame`.
/// The iterator will consume as much memory as the largest partition in this `DataFrame`.
/// With prefetch it may consume up to the memory of the 2 largest partitions.
/// </summary>
/// <param name="prefetchPartitions">
/// If Spark should pre-fetch the next partition before it is needed.
/// </param>
/// <returns>Row objects</returns>
[Since(Versions.V3_0_0)]
public IEnumerable<Row> ToLocalIterator(bool prefetchPartitions)
{
(int port, string secret, JvmObjectReference server) =
ParseConnectionInfo(
_jvmObject.Invoke("toPythonIterator", prefetchPartitions),
true);
using ISocketWrapper socket = SocketFactory.CreateSocket();
socket.Connect(IPAddress.Loopback, port, secret);
foreach (Row row in new RowCollector().Collect(socket, server))
{
yield return row;
}
}

/// <summary>
Expand Down Expand Up @@ -902,48 +932,52 @@ public DataStreamWriter WriteStream() =>
/// Returns row objects based on the function (either "toPythonIterator" or
/// "collectToPython").
/// </summary>
/// <param name="funcName"></param>
/// <returns></returns>
/// <param name="funcName">
/// The name of the function to call, either "toPythonIterator" or "collectToPython".
/// </param>
/// <returns><see cref="Row"/> objects</returns>
private IEnumerable<Row> GetRows(string funcName)
{
(int port, string secret) = GetConnectionInfo(funcName);
using (ISocketWrapper socket = SocketFactory.CreateSocket())
(int port, string secret, _) = GetConnectionInfo(funcName);
using ISocketWrapper socket = SocketFactory.CreateSocket();
socket.Connect(IPAddress.Loopback, port, secret);
foreach (Row row in new RowCollector().Collect(socket))
{
socket.Connect(IPAddress.Loopback, port, secret);
foreach (Row row in new RowCollector().Collect(socket))
{
yield return row;
}
yield return row;
}
}

/// <summary>
/// Returns a tuple of port number and secret string which are
/// used for connecting with Spark to receive rows for this `DataFrame`.
/// </summary>
/// <returns>A tuple of port number and secret string</returns>
private (int, string) GetConnectionInfo(string funcName)
/// <returns>A tuple of port number, secret string, and JVM socket auth server.</returns>
private (int, string, JvmObjectReference) GetConnectionInfo(string funcName)
{
object result = _jvmObject.Invoke(funcName);
Version version = SparkEnvironment.SparkVersion;
return (version.Major, version.Minor, version.Build) switch
{
// In spark 2.3.0, PythonFunction.serveIterator() returns a port number.
(2, 3, 0) => ((int)result, string.Empty),
(2, 3, 0) => ((int)result, string.Empty, null),
// From spark >= 2.3.1, PythonFunction.serveIterator() returns a pair
// where the first is a port number and the second is the secret
// string to use for the authentication.
(2, 3, _) => ParseConnectionInfo(result),
(2, 4, _) => ParseConnectionInfo(result),
(3, 0, _) => ParseConnectionInfo(result),
(2, 3, _) => ParseConnectionInfo(result, false),
(2, 4, _) => ParseConnectionInfo(result, false),
(3, 0, _) => ParseConnectionInfo(result, false),
_ => throw new NotSupportedException($"Spark {version} not supported.")
};
}

private (int, string) ParseConnectionInfo(object info)
private (int, string, JvmObjectReference) ParseConnectionInfo(
object info,
bool parseServer)
{
var pair = (JvmObjectReference[])info;
return ((int)pair[0].Invoke("intValue"), (string)pair[1].Invoke("toString"));
var infos = (JvmObjectReference[])info;
return ((int)infos[0].Invoke("intValue"),
(string)infos[1].Invoke("toString"),
parseServer ? infos[2] : null);
}

private DataFrame WrapAsDataFrame(object obj) => new DataFrame((JvmObjectReference)obj);
Expand Down
101 changes: 99 additions & 2 deletions src/csharp/Microsoft.Spark/Sql/RowCollector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using Microsoft.Spark.Interop.Ipc;
using Microsoft.Spark.Network;
Expand All @@ -18,8 +20,8 @@ internal sealed class RowCollector
/// <summary>
/// Collects pickled row objects from the given socket.
/// </summary>
/// <param name="socket">Socket the get the stream from</param>
/// <returns>Collection of row objects</returns>
/// <param name="socket">Socket the get the stream from.</param>
/// <returns>Collection of row objects.</returns>
public IEnumerable<Row> Collect(ISocketWrapper socket)
{
Stream inputStream = socket.InputStream;
Expand All @@ -37,5 +39,100 @@ public IEnumerable<Row> Collect(ISocketWrapper socket)
}
}
}

/// <summary>
/// Collects pickled row objects from the given socket. Collects rows in partitions
/// by leveraging <see cref="Collect(ISocketWrapper)"/>.
/// </summary>
/// <param name="socket">Socket the get the stream from.</param>
/// <param name="server">The JVM socket auth server.</param>
/// <returns>Collection of row objects.</returns>
public IEnumerable<Row> Collect(ISocketWrapper socket, JvmObjectReference server) =>
new LocalIteratorFromSocket(socket, server);

/// <summary>
/// LocalIteratorFromSocket creates a synchronous local iterable over
/// a socket.
///
/// Note that the implementation mirrors _local_iterator_from_socket in
/// PySpark: spark/python/pyspark/rdd.py
/// </summary>
private class LocalIteratorFromSocket : IEnumerable<Row>
{
private readonly ISocketWrapper _socket;
private readonly JvmObjectReference _server;

private int _readStatus = 1;
private IEnumerable<Row> _currentPartitionRows = null;

internal LocalIteratorFromSocket(ISocketWrapper socket, JvmObjectReference server)
{
_socket = socket;
_server = server;
}

~LocalIteratorFromSocket()
{
// If iterator is not fully consumed.
if ((_readStatus == 1) && (_currentPartitionRows != null))
{
try
{
// Finish consuming partition data stream.
foreach (Row _ in _currentPartitionRows)
{
}

// Tell Java to stop sending data and close connection.
Stream outputStream = _socket.OutputStream;
SerDe.Write(outputStream, 0);
outputStream.Flush();
}
catch
{
// Ignore any errors, socket may be automatically closed
// when garbage-collected.
}
}
}

public IEnumerator<Row> GetEnumerator()
{
Stream inputStream = _socket.InputStream;
Stream outputStream = _socket.OutputStream;

while (_readStatus == 1)
{
// Request next partition data from Java.
SerDe.Write(outputStream, 1);
outputStream.Flush();

// If response is 1 then there is a partition to read, if 0 then
// fully consumed.
_readStatus = SerDe.ReadInt32(inputStream);
if (_readStatus == 1)
{
// Load the partition data from stream and read each item.
_currentPartitionRows = new RowCollector().Collect(_socket);
foreach (Row row in _currentPartitionRows)
{
yield return row;
}
}
else if (_readStatus == -1)
{
// An error occurred, join serving thread and raise any exceptions from
// the JVM. The exception stack trace will appear in the driver logs.
_server.Invoke("getResult");
}
else
{
Debug.Assert(_readStatus == 0);
}
}
}

IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
}
}
}