Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ public class ArrowFileWriter: ArrowStreamWriter
private long _currentRecordBatchOffset = -1;

private bool HasWrittenHeader { get; set; }
private bool HasWrittenFooter { get; set; }

private List<Block> RecordBatchBlocks { get; }

Expand All @@ -38,7 +37,12 @@ public ArrowFileWriter(Stream stream, Schema schema)
}

public ArrowFileWriter(Stream stream, Schema schema, bool leaveOpen)
: base(stream, schema, leaveOpen)
: this(stream, schema, leaveOpen, options: null)
{
}

public ArrowFileWriter(Stream stream, Schema schema, bool leaveOpen, IpcOptions options)
: base(stream, schema, leaveOpen, options)
{
if (!stream.CanWrite)
{
Expand All @@ -53,7 +57,6 @@ public ArrowFileWriter(Stream stream, Schema schema, bool leaveOpen)
}

HasWrittenHeader = false;
HasWrittenFooter = false;

RecordBatchBlocks = new List<Block>();
}
Expand Down Expand Up @@ -101,15 +104,11 @@ private protected override void FinishedWritingRecordBatch(long bodyLength, long
_currentRecordBatchOffset = -1;
}

public async Task WriteFooterAsync(CancellationToken cancellationToken = default)
private protected override async ValueTask WriteEndInternalAsync(CancellationToken cancellationToken)
{
if (!HasWrittenFooter)
{
await WriteFooterAsync(Schema, cancellationToken).ConfigureAwait(false);
HasWrittenFooter = true;
}
await base.WriteEndInternalAsync(cancellationToken);

await BaseStream.FlushAsync(cancellationToken).ConfigureAwait(false);
await WriteFooterAsync(Schema, cancellationToken);
}

private async Task WriteHeaderAsync(CancellationToken cancellationToken)
Expand Down
30 changes: 30 additions & 0 deletions csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
using FlatBuffers;
using System;
using System.Buffers.Binary;
using System.IO;
using System.Threading;
using System.Threading.Tasks;

Expand Down Expand Up @@ -57,6 +58,23 @@ public override RecordBatch ReadNextRecordBatch()
//reached the end
return null;
}
else if (messageLength == MessageSerializer.IpcContinuationToken)
{
// ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length
if (_buffer.Length <= _bufferPosition + sizeof(int))
{
throw new InvalidDataException("Corrupted IPC message. Received a continuation token at the end of the message.");
}

messageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition));
_bufferPosition += sizeof(int);

if (messageLength == 0)
{
//reached the end
return null;
}
}

Message message = Message.GetRootAsMessage(
CreateByteBuffer(_buffer.Slice(_bufferPosition, messageLength)));
Expand All @@ -80,6 +98,18 @@ private void ReadSchema()
int schemaMessageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition));
_bufferPosition += sizeof(int);

if (schemaMessageLength == MessageSerializer.IpcContinuationToken)
{
// ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length
if (_buffer.Length <= _bufferPosition + sizeof(int))
{
throw new InvalidDataException("Corrupted IPC message. Received a continuation token at the end of the message.");
}

schemaMessageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition));
_bufferPosition += sizeof(int);
}

ByteBuffer schemaBuffer = CreateByteBuffer(_buffer.Slice(_bufferPosition));
Schema = MessageSerializer.GetSchema(ReadMessage<Flatbuf.Schema>(schemaBuffer));
_bufferPosition += schemaMessageLength;
Expand Down
122 changes: 84 additions & 38 deletions csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,8 @@ protected async ValueTask<RecordBatch> ReadRecordBatchAsync(CancellationToken ca
{
await ReadSchemaAsync().ConfigureAwait(false);

int messageLength = 0;
await ArrayPool<byte>.Shared.RentReturnAsync(4, async (lengthBuffer) =>
{
// Get Length of record batch for message header.
int bytesRead = await BaseStream.ReadFullBufferAsync(lengthBuffer, cancellationToken)
.ConfigureAwait(false);

if (bytesRead == 4)
{
messageLength = BitUtility.ReadInt32(lengthBuffer);
}
}).ConfigureAwait(false);
int messageLength = await ReadMessageLengthAsync(throwOnFullRead: false, cancellationToken)
.ConfigureAwait(false);

if (messageLength == 0)
{
Expand Down Expand Up @@ -106,16 +96,7 @@ protected RecordBatch ReadRecordBatch()
{
ReadSchema();

int messageLength = 0;
ArrayPool<byte>.Shared.RentReturn(4, lengthBuffer =>
{
int bytesRead = BaseStream.ReadFullBuffer(lengthBuffer);

if (bytesRead == 4)
{
messageLength = BitUtility.ReadInt32(lengthBuffer);
}
});
int messageLength = ReadMessageLength(throwOnFullRead: false);

if (messageLength == 0)
{
Expand Down Expand Up @@ -153,14 +134,8 @@ protected virtual async ValueTask ReadSchemaAsync()
}

// Figure out length of schema
int schemaMessageLength = 0;
await ArrayPool<byte>.Shared.RentReturnAsync(4, async (lengthBuffer) =>
{
int bytesRead = await BaseStream.ReadFullBufferAsync(lengthBuffer).ConfigureAwait(false);
EnsureFullRead(lengthBuffer, bytesRead);

schemaMessageLength = BitUtility.ReadInt32(lengthBuffer);
}).ConfigureAwait(false);
int schemaMessageLength = await ReadMessageLengthAsync(throwOnFullRead: true)
.ConfigureAwait(false);

await ArrayPool<byte>.Shared.RentReturnAsync(schemaMessageLength, async (buff) =>
{
Expand All @@ -181,14 +156,7 @@ protected virtual void ReadSchema()
}

// Figure out length of schema
int schemaMessageLength = 0;
ArrayPool<byte>.Shared.RentReturn(4, lengthBuffer =>
{
int bytesRead = BaseStream.ReadFullBuffer(lengthBuffer);
EnsureFullRead(lengthBuffer, bytesRead);

schemaMessageLength = BitUtility.ReadInt32(lengthBuffer);
});
int schemaMessageLength = ReadMessageLength(throwOnFullRead: true);

ArrayPool<byte>.Shared.RentReturn(schemaMessageLength, buff =>
{
Expand All @@ -200,6 +168,84 @@ protected virtual void ReadSchema()
});
}

private async ValueTask<int> ReadMessageLengthAsync(bool throwOnFullRead, CancellationToken cancellationToken = default)
{
int messageLength = 0;
await ArrayPool<byte>.Shared.RentReturnAsync(4, async (lengthBuffer) =>
{
int bytesRead = await BaseStream.ReadFullBufferAsync(lengthBuffer, cancellationToken)
.ConfigureAwait(false);
if (throwOnFullRead)
{
EnsureFullRead(lengthBuffer, bytesRead);
}
else if (bytesRead != 4)
{
return;
}

messageLength = BitUtility.ReadInt32(lengthBuffer);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears to perform an unaligned read from the specified buffer assuming a native byte ordering; shouldn't this be using BinaryPrimitives.ReadInt32LittleEndian?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code was refactored from the original code, which you can see here:

// Get Length of record batch for message header.
lengthBuffer = Buffers.Rent(4);
bytesRead += await BaseStream.ReadAsync(lengthBuffer, 0, 4, cancellationToken);
var messageLength = BitConverter.ToInt32(lengthBuffer, 0);

Originally, we always had a byte[] and would call BitConverter.ToInt32. However, with the changes to allow for Memory and Span, I needed to make the same call, only with a Span instead of byte[]. This API exists in .NET, but it is not available in netstandard. So I needed to copy the little bit of code out into the BitUtility class.

https://source.dot.net/#System.Private.CoreLib/shared/System/BitConverter.cs,269

You can see the BitConverter.ToInt32(byte[]) does the same operation.

return Unsafe.ReadUnaligned<int>(ref value[startIndex]);

From what I can tell, the C++ implementation does the same thing:

(master branch)

int32_t flatbuffer_size = *reinterpret_cast<const int32_t*>(buffer->data());

(ARROW-6313-flatbuffer-alignment branch)

inline typename std::enable_if<std::is_integral<T>::value, T>::type SafeLoadAs(
const uint8_t* unaligned) {
typename std::remove_const<T>::type ret;
std::memcpy(&ret, unaligned, sizeof(T));
return ret;

I was never sure on this, and the spec doesn't 100% specify if these length numbers are big or little endian, or machine dependent. So that's why I've never changed this code, and left it doing what it has always been doing.

https://arrow.apache.org/docs/format/Layout.html#byte-order-endianness

The Arrow format is little endian by default. The Schema metadata has an endianness field indicating endianness of RecordBatches. Typically this is the endianness of the system where the RecordBatch was generated.

Having the endianness inside of the schema doesn't help when you need to know what endian the schema length is in, in order to read the schema itself.

I see we are always writing little-endian numbers for these lengths, so maybe changing it here can be justified that way.

Thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe since this issue has existed in this code since its inception, it would be best to open a JIRA issue for this.

https://issues.apache.org/jira/browse/ARROW-6553 - "[C#] Decide how to read message lengths - little-endian or machine dependent"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eerhardt I'll continue the discussion in that JIRA issue; I interpreted the "little-endian by default" section to mean that the IPC protocol is always little-endian, but that array primitives have a byte order corresponding to the (optional) schema metadata value. If the protocol specification does not specify byte ordering or a mechanism for determining byte ordering, I would think to view that as an oversight; however, it can also just mean the C++ code is presently non-compliant or does not support such endian-awareness.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The C++ implementation is not big-endian compliant. Even finding environments to do big endian testing nowadays is a major challenge.


// ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length
if (messageLength == MessageSerializer.IpcContinuationToken)
{
bytesRead = await BaseStream.ReadFullBufferAsync(lengthBuffer, cancellationToken)
.ConfigureAwait(false);
if (throwOnFullRead)
{
EnsureFullRead(lengthBuffer, bytesRead);
}
else if (bytesRead != 4)
{
messageLength = 0;
return;
}

messageLength = BitUtility.ReadInt32(lengthBuffer);
}
}).ConfigureAwait(false);

return messageLength;
}

private int ReadMessageLength(bool throwOnFullRead)
{
int messageLength = 0;
ArrayPool<byte>.Shared.RentReturn(4, lengthBuffer =>
{
int bytesRead = BaseStream.ReadFullBuffer(lengthBuffer);
if (throwOnFullRead)
{
EnsureFullRead(lengthBuffer, bytesRead);
}
else if (bytesRead != 4)
{
return;
}

messageLength = BitUtility.ReadInt32(lengthBuffer);

// ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length
if (messageLength == MessageSerializer.IpcContinuationToken)
{
bytesRead = BaseStream.ReadFullBuffer(lengthBuffer);
if (throwOnFullRead)
{
EnsureFullRead(lengthBuffer, bytesRead);
}
else if (bytesRead != 4)
{
messageLength = 0;
return;
}

messageLength = BitUtility.ReadInt32(lengthBuffer);
}
});

return messageLength;
}

/// <summary>
/// Ensures the number of bytes read matches the buffer length
/// and throws an exception it if doesn't. This ensures we have read
Expand Down
53 changes: 45 additions & 8 deletions csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -142,21 +142,30 @@ public void Visit(IArrowArray array)

protected bool HasWrittenSchema { get; set; }

private bool HasWrittenEnd { get; set; }

protected Schema Schema { get; }

private readonly bool _leaveOpen;
private readonly IpcOptions _options;

private protected const Flatbuf.MetadataVersion CurrentMetadataVersion = Flatbuf.MetadataVersion.V4;

private static readonly byte[] Padding = new byte[64];

private readonly ArrowTypeFlatbufferBuilder _fieldTypeBuilder;

public ArrowStreamWriter(Stream baseStream, Schema schema) : this(baseStream, schema, leaveOpen: false)
public ArrowStreamWriter(Stream baseStream, Schema schema)
: this(baseStream, schema, leaveOpen: false)
{
}

public ArrowStreamWriter(Stream baseStream, Schema schema, bool leaveOpen)
: this(baseStream, schema, leaveOpen, options: null)
{
}

public ArrowStreamWriter(Stream baseStream, Schema schema, bool leaveOpen, IpcOptions options)
{
BaseStream = baseStream ?? throw new ArgumentNullException(nameof(baseStream));
Schema = schema ?? throw new ArgumentNullException(nameof(schema));
Expand All @@ -167,6 +176,7 @@ public ArrowStreamWriter(Stream baseStream, Schema schema, bool leaveOpen)
HasWrittenSchema = false;

_fieldTypeBuilder = new ArrowTypeFlatbufferBuilder(Builder);
_options = options ?? IpcOptions.Default;
}

private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBatch,
Expand Down Expand Up @@ -262,6 +272,11 @@ private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBat
FinishedWritingRecordBatch(bodyLength + bodyPaddingLength, metadataLength);
}

private protected virtual ValueTask WriteEndInternalAsync(CancellationToken cancellationToken)
{
return WriteIpcMessageLengthAsync(length: 0, cancellationToken);
}

private protected virtual void StartingWritingRecordBatch()
{
}
Expand All @@ -274,6 +289,15 @@ public virtual Task WriteRecordBatchAsync(RecordBatch recordBatch, CancellationT
{
return WriteRecordBatchInternalAsync(recordBatch, cancellationToken);
}

public async Task WriteEndAsync(CancellationToken cancellationToken = default)
{
if (!HasWrittenEnd)
{
await WriteEndInternalAsync(cancellationToken);
HasWrittenEnd = true;
}
}

private ValueTask WriteBufferAsync(ArrowBuffer arrowBuffer, CancellationToken cancellationToken = default)
{
Expand Down Expand Up @@ -348,19 +372,15 @@ private async ValueTask<long> WriteMessageAsync<T>(
var messageData = Builder.DataBuffer.ToReadOnlyMemory(Builder.DataBuffer.Position, Builder.Offset);
var messagePaddingLength = CalculatePadding(messageData.Length);

await Buffers.RentReturnAsync(4, async (buffer) =>
{
var metadataSize = messageData.Length + messagePaddingLength;
BinaryPrimitives.WriteInt32LittleEndian(buffer.Span, metadataSize);
await BaseStream.WriteAsync(buffer, cancellationToken).ConfigureAwait(false);
}).ConfigureAwait(false);
await WriteIpcMessageLengthAsync(messageData.Length + messagePaddingLength, cancellationToken)
.ConfigureAwait(false);

await BaseStream.WriteAsync(messageData, cancellationToken).ConfigureAwait(false);
await WritePaddingAsync(messagePaddingLength).ConfigureAwait(false);

checked
{
return 4 + messageData.Length + messagePaddingLength;
return _options.SizeOfIpcLength + messageData.Length + messagePaddingLength;
}
}

Expand All @@ -371,6 +391,23 @@ private protected async ValueTask WriteFlatBufferAsync(CancellationToken cancell
await BaseStream.WriteAsync(segment, cancellationToken).ConfigureAwait(false);
}

private async ValueTask WriteIpcMessageLengthAsync(int length, CancellationToken cancellationToken)
{
await Buffers.RentReturnAsync(_options.SizeOfIpcLength, async (buffer) =>
{
Memory<byte> currentBufferPosition = buffer;
if (!_options.WriteLegacyIpcFormat)
{
BinaryPrimitives.WriteInt32LittleEndian(
currentBufferPosition.Span, MessageSerializer.IpcContinuationToken);
currentBufferPosition = currentBufferPosition.Slice(sizeof(int));
}

BinaryPrimitives.WriteInt32LittleEndian(currentBufferPosition.Span, length);
await BaseStream.WriteAsync(buffer, cancellationToken).ConfigureAwait(false);
}).ConfigureAwait(false);
}

protected int CalculatePadding(long offset, int alignment = 8)
{
long result = BitUtility.RoundUpToMultiplePowerOfTwo(offset, alignment) - offset;
Expand Down
Loading