From c494a4b48c081e4113b5b9586663469e0b7d27b3 Mon Sep 17 00:00:00 2001 From: Eric Erhardt Date: Wed, 4 Sep 2019 17:57:22 -0500 Subject: [PATCH 1/2] ARROW-6314: [C#] Implement IPC message format alignment changes, provide backwards compatibility and "legacy" option to emit old message format --- .../src/Apache.Arrow/Ipc/ArrowFileWriter.cs | 7 +- .../Ipc/ArrowMemoryReaderImplementation.cs | 30 +++++ .../Ipc/ArrowStreamReaderImplementation.cs | 122 ++++++++++++------ .../src/Apache.Arrow/Ipc/ArrowStreamWriter.cs | 26 +++- csharp/src/Apache.Arrow/Ipc/IpcOptions.cs | 32 +++++ .../src/Apache.Arrow/Ipc/MessageSerializer.cs | 1 + .../ArrowStreamWriterTests.cs | 45 ++++++- 7 files changed, 218 insertions(+), 45 deletions(-) create mode 100644 csharp/src/Apache.Arrow/Ipc/IpcOptions.cs diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs index 265c31f75a7..7c3c698492f 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs @@ -38,7 +38,12 @@ public ArrowFileWriter(Stream stream, Schema schema) } public ArrowFileWriter(Stream stream, Schema schema, bool leaveOpen) - : base(stream, schema, leaveOpen) + : this(stream, schema, leaveOpen, options: null) + { + } + + public ArrowFileWriter(Stream stream, Schema schema, bool leaveOpen, IpcOptions options) + : base(stream, schema, leaveOpen, options) { if (!stream.CanWrite) { diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs index a39a87bf7f2..c265fa1b0f1 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs @@ -17,6 +17,7 @@ using FlatBuffers; using System; using System.Buffers.Binary; +using System.IO; using System.Threading; using System.Threading.Tasks; @@ -57,6 +58,23 @@ public override RecordBatch ReadNextRecordBatch() //reached the end return null; } + else if (messageLength == MessageSerializer.IpcContinuationToken) + { + // ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length + if (_buffer.Length <= _bufferPosition + sizeof(int)) + { + throw new InvalidDataException("Corrupted IPC message. Received a continuation token at the end of the message."); + } + + messageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition)); + _bufferPosition += sizeof(int); + + if (messageLength == 0) + { + //reached the end + return null; + } + } Message message = Message.GetRootAsMessage( CreateByteBuffer(_buffer.Slice(_bufferPosition, messageLength))); @@ -80,6 +98,18 @@ private void ReadSchema() int schemaMessageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition)); _bufferPosition += sizeof(int); + if (schemaMessageLength == MessageSerializer.IpcContinuationToken) + { + // ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length + if (_buffer.Length <= _bufferPosition + sizeof(int)) + { + throw new InvalidDataException("Corrupted IPC message. Received a continuation token at the end of the message."); + } + + schemaMessageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition)); + _bufferPosition += sizeof(int); + } + ByteBuffer schemaBuffer = CreateByteBuffer(_buffer.Slice(_bufferPosition)); Schema = MessageSerializer.GetSchema(ReadMessage(schemaBuffer)); _bufferPosition += schemaMessageLength; diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs index c2ac4bbfeef..1a6d728fd64 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs @@ -59,18 +59,8 @@ protected async ValueTask ReadRecordBatchAsync(CancellationToken ca { await ReadSchemaAsync().ConfigureAwait(false); - int messageLength = 0; - await ArrayPool.Shared.RentReturnAsync(4, async (lengthBuffer) => - { - // Get Length of record batch for message header. - int bytesRead = await BaseStream.ReadFullBufferAsync(lengthBuffer, cancellationToken) - .ConfigureAwait(false); - - if (bytesRead == 4) - { - messageLength = BitUtility.ReadInt32(lengthBuffer); - } - }).ConfigureAwait(false); + int messageLength = await ReadMessageLengthAsync(throwOnFullRead: false, cancellationToken) + .ConfigureAwait(false); if (messageLength == 0) { @@ -106,16 +96,7 @@ protected RecordBatch ReadRecordBatch() { ReadSchema(); - int messageLength = 0; - ArrayPool.Shared.RentReturn(4, lengthBuffer => - { - int bytesRead = BaseStream.ReadFullBuffer(lengthBuffer); - - if (bytesRead == 4) - { - messageLength = BitUtility.ReadInt32(lengthBuffer); - } - }); + int messageLength = ReadMessageLength(throwOnFullRead: false); if (messageLength == 0) { @@ -153,14 +134,8 @@ protected virtual async ValueTask ReadSchemaAsync() } // Figure out length of schema - int schemaMessageLength = 0; - await ArrayPool.Shared.RentReturnAsync(4, async (lengthBuffer) => - { - int bytesRead = await BaseStream.ReadFullBufferAsync(lengthBuffer).ConfigureAwait(false); - EnsureFullRead(lengthBuffer, bytesRead); - - schemaMessageLength = BitUtility.ReadInt32(lengthBuffer); - }).ConfigureAwait(false); + int schemaMessageLength = await ReadMessageLengthAsync(throwOnFullRead: true) + .ConfigureAwait(false); await ArrayPool.Shared.RentReturnAsync(schemaMessageLength, async (buff) => { @@ -181,14 +156,7 @@ protected virtual void ReadSchema() } // Figure out length of schema - int schemaMessageLength = 0; - ArrayPool.Shared.RentReturn(4, lengthBuffer => - { - int bytesRead = BaseStream.ReadFullBuffer(lengthBuffer); - EnsureFullRead(lengthBuffer, bytesRead); - - schemaMessageLength = BitUtility.ReadInt32(lengthBuffer); - }); + int schemaMessageLength = ReadMessageLength(throwOnFullRead: true); ArrayPool.Shared.RentReturn(schemaMessageLength, buff => { @@ -200,6 +168,84 @@ protected virtual void ReadSchema() }); } + private async ValueTask ReadMessageLengthAsync(bool throwOnFullRead, CancellationToken cancellationToken = default) + { + int messageLength = 0; + await ArrayPool.Shared.RentReturnAsync(4, async (lengthBuffer) => + { + int bytesRead = await BaseStream.ReadFullBufferAsync(lengthBuffer, cancellationToken) + .ConfigureAwait(false); + if (throwOnFullRead) + { + EnsureFullRead(lengthBuffer, bytesRead); + } + else if (bytesRead != 4) + { + return; + } + + messageLength = BitUtility.ReadInt32(lengthBuffer); + + // ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length + if (messageLength == MessageSerializer.IpcContinuationToken) + { + bytesRead = await BaseStream.ReadFullBufferAsync(lengthBuffer, cancellationToken) + .ConfigureAwait(false); + if (throwOnFullRead) + { + EnsureFullRead(lengthBuffer, bytesRead); + } + else if (bytesRead != 4) + { + messageLength = 0; + return; + } + + messageLength = BitUtility.ReadInt32(lengthBuffer); + } + }).ConfigureAwait(false); + + return messageLength; + } + + private int ReadMessageLength(bool throwOnFullRead) + { + int messageLength = 0; + ArrayPool.Shared.RentReturn(4, lengthBuffer => + { + int bytesRead = BaseStream.ReadFullBuffer(lengthBuffer); + if (throwOnFullRead) + { + EnsureFullRead(lengthBuffer, bytesRead); + } + else if (bytesRead != 4) + { + return; + } + + messageLength = BitUtility.ReadInt32(lengthBuffer); + + // ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length + if (messageLength == MessageSerializer.IpcContinuationToken) + { + bytesRead = BaseStream.ReadFullBuffer(lengthBuffer); + if (throwOnFullRead) + { + EnsureFullRead(lengthBuffer, bytesRead); + } + else if (bytesRead != 4) + { + messageLength = 0; + return; + } + + messageLength = BitUtility.ReadInt32(lengthBuffer); + } + }); + + return messageLength; + } + /// /// Ensures the number of bytes read matches the buffer length /// and throws an exception it if doesn't. This ensures we have read diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs index e1da4489ce4..5306ba70d7d 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs @@ -145,6 +145,7 @@ public void Visit(IArrowArray array) protected Schema Schema { get; } private readonly bool _leaveOpen; + private readonly IpcOptions _options; private protected const Flatbuf.MetadataVersion CurrentMetadataVersion = Flatbuf.MetadataVersion.V4; @@ -152,11 +153,17 @@ public void Visit(IArrowArray array) private readonly ArrowTypeFlatbufferBuilder _fieldTypeBuilder; - public ArrowStreamWriter(Stream baseStream, Schema schema) : this(baseStream, schema, leaveOpen: false) + public ArrowStreamWriter(Stream baseStream, Schema schema) + : this(baseStream, schema, leaveOpen: false) { } public ArrowStreamWriter(Stream baseStream, Schema schema, bool leaveOpen) + : this(baseStream, schema, leaveOpen, options: null) + { + } + + public ArrowStreamWriter(Stream baseStream, Schema schema, bool leaveOpen, IpcOptions options) { BaseStream = baseStream ?? throw new ArgumentNullException(nameof(baseStream)); Schema = schema ?? throw new ArgumentNullException(nameof(schema)); @@ -167,6 +174,7 @@ public ArrowStreamWriter(Stream baseStream, Schema schema, bool leaveOpen) HasWrittenSchema = false; _fieldTypeBuilder = new ArrowTypeFlatbufferBuilder(Builder); + _options = options ?? IpcOptions.Default; } private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBatch, @@ -348,10 +356,20 @@ private async ValueTask WriteMessageAsync( var messageData = Builder.DataBuffer.ToReadOnlyMemory(Builder.DataBuffer.Position, Builder.Offset); var messagePaddingLength = CalculatePadding(messageData.Length); - await Buffers.RentReturnAsync(4, async (buffer) => + int prefixSize = _options.WriteLegacyIpcFormat ? 4 : 8; + + await Buffers.RentReturnAsync(prefixSize, async (buffer) => { + Memory currentBufferPosition = buffer; + if (!_options.WriteLegacyIpcFormat) + { + BinaryPrimitives.WriteInt32LittleEndian( + currentBufferPosition.Span, MessageSerializer.IpcContinuationToken); + currentBufferPosition = currentBufferPosition.Slice(4); + } + var metadataSize = messageData.Length + messagePaddingLength; - BinaryPrimitives.WriteInt32LittleEndian(buffer.Span, metadataSize); + BinaryPrimitives.WriteInt32LittleEndian(currentBufferPosition.Span, metadataSize); await BaseStream.WriteAsync(buffer, cancellationToken).ConfigureAwait(false); }).ConfigureAwait(false); @@ -360,7 +378,7 @@ await Buffers.RentReturnAsync(4, async (buffer) => checked { - return 4 + messageData.Length + messagePaddingLength; + return prefixSize + messageData.Length + messagePaddingLength; } } diff --git a/csharp/src/Apache.Arrow/Ipc/IpcOptions.cs b/csharp/src/Apache.Arrow/Ipc/IpcOptions.cs new file mode 100644 index 00000000000..7f37ab148f9 --- /dev/null +++ b/csharp/src/Apache.Arrow/Ipc/IpcOptions.cs @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace Apache.Arrow.Ipc +{ + public class IpcOptions + { + internal static readonly IpcOptions Default = new IpcOptions(); + + /// + /// Write the pre-0.15.0 encapsulated IPC message format + /// consisting of a 4-byte prefix instead of 8 byte. + /// + public bool WriteLegacyIpcFormat { get; set; } + + public IpcOptions() + { + } + } +} diff --git a/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs b/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs index 7478e00e133..612f14739c0 100644 --- a/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs +++ b/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs @@ -20,6 +20,7 @@ namespace Apache.Arrow.Ipc { internal class MessageSerializer { + public const int IpcContinuationToken = -1; public static Types.NumberType GetNumberType(int bitWidth, bool signed) { diff --git a/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs index 06be8bd6504..170bea8ab1e 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs @@ -16,6 +16,7 @@ using Apache.Arrow.Ipc; using Apache.Arrow.Types; using System; +using System.Buffers.Binary; using System.IO; using System.Linq; using System.Net; @@ -122,11 +123,11 @@ public async Task WriteBatchWithNulls() await TestRoundTripRecordBatch(originalBatch); } - private static async Task TestRoundTripRecordBatch(RecordBatch originalBatch) + private static async Task TestRoundTripRecordBatch(RecordBatch originalBatch, IpcOptions options = null) { using (MemoryStream stream = new MemoryStream()) { - using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true)) + using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true, options)) { await writer.WriteRecordBatchAsync(originalBatch); } @@ -195,5 +196,45 @@ public async Task WriteBatchWithCorrectPadding() } } } + + [Fact] + public async Task LegacyIpcFormatRoundTrips() + { + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + await TestRoundTripRecordBatch(originalBatch, new IpcOptions() { WriteLegacyIpcFormat = true }); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task WriteLegacyIpcFormat(bool writeLegacyIpcFormat) + { + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + var options = new IpcOptions() { WriteLegacyIpcFormat = writeLegacyIpcFormat }; + + using (MemoryStream stream = new MemoryStream()) + { + using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true, options)) + { + await writer.WriteRecordBatchAsync(originalBatch); + } + + stream.Position = 0; + + // ensure the continuation is written correctly + byte[] buffer = stream.GetBuffer(); + int messageLength = BinaryPrimitives.ReadInt32LittleEndian(buffer); + if (writeLegacyIpcFormat) + { + // the legacy IPC format doesn't have a continuation token at the start + Assert.NotEqual(-1, messageLength); + } + else + { + // the latest IPC format has a continuation token at the start + Assert.Equal(-1, messageLength); + } + } + } } } From 231e90cbe1710f7b6b2fb878ef27f84546cfe041 Mon Sep 17 00:00:00 2001 From: Eric Erhardt Date: Wed, 11 Sep 2019 16:32:34 -0500 Subject: [PATCH 2/2] Implement WriteEndAsync on ArrowStreamWriter to write the EOS signal. Remove WriteFooterAsync on ArrowFileWriter and instead use WriteEndAsync from the base class. This basically renames WriteFooterAsync to WriteEndAsync for the file writer. Fix a bug in ArrowFileWriter - we now write the EOS signal before the footer, which is specified in https://arrow.apache.org/docs/format/IPC.html. --- .../src/Apache.Arrow/Ipc/ArrowFileWriter.cs | 12 ++--- .../src/Apache.Arrow/Ipc/ArrowStreamWriter.cs | 53 +++++++++++++------ csharp/src/Apache.Arrow/Ipc/IpcOptions.cs | 5 ++ .../ArrowFileReaderTests.cs | 6 +-- .../ArrowStreamReaderTests.cs | 50 ++++++++++++----- .../ArrowStreamWriterTests.cs | 35 +++++++++--- 6 files changed, 111 insertions(+), 50 deletions(-) diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs index 7c3c698492f..26ba600da6f 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs @@ -28,7 +28,6 @@ public class ArrowFileWriter: ArrowStreamWriter private long _currentRecordBatchOffset = -1; private bool HasWrittenHeader { get; set; } - private bool HasWrittenFooter { get; set; } private List RecordBatchBlocks { get; } @@ -58,7 +57,6 @@ public ArrowFileWriter(Stream stream, Schema schema, bool leaveOpen, IpcOptions } HasWrittenHeader = false; - HasWrittenFooter = false; RecordBatchBlocks = new List(); } @@ -106,15 +104,11 @@ private protected override void FinishedWritingRecordBatch(long bodyLength, long _currentRecordBatchOffset = -1; } - public async Task WriteFooterAsync(CancellationToken cancellationToken = default) + private protected override async ValueTask WriteEndInternalAsync(CancellationToken cancellationToken) { - if (!HasWrittenFooter) - { - await WriteFooterAsync(Schema, cancellationToken).ConfigureAwait(false); - HasWrittenFooter = true; - } + await base.WriteEndInternalAsync(cancellationToken); - await BaseStream.FlushAsync(cancellationToken).ConfigureAwait(false); + await WriteFooterAsync(Schema, cancellationToken); } private async Task WriteHeaderAsync(CancellationToken cancellationToken) diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs index 5306ba70d7d..d429a55cc17 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs @@ -142,6 +142,8 @@ public void Visit(IArrowArray array) protected bool HasWrittenSchema { get; set; } + private bool HasWrittenEnd { get; set; } + protected Schema Schema { get; } private readonly bool _leaveOpen; @@ -270,6 +272,11 @@ private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBat FinishedWritingRecordBatch(bodyLength + bodyPaddingLength, metadataLength); } + private protected virtual ValueTask WriteEndInternalAsync(CancellationToken cancellationToken) + { + return WriteIpcMessageLengthAsync(length: 0, cancellationToken); + } + private protected virtual void StartingWritingRecordBatch() { } @@ -282,6 +289,15 @@ public virtual Task WriteRecordBatchAsync(RecordBatch recordBatch, CancellationT { return WriteRecordBatchInternalAsync(recordBatch, cancellationToken); } + + public async Task WriteEndAsync(CancellationToken cancellationToken = default) + { + if (!HasWrittenEnd) + { + await WriteEndInternalAsync(cancellationToken); + HasWrittenEnd = true; + } + } private ValueTask WriteBufferAsync(ArrowBuffer arrowBuffer, CancellationToken cancellationToken = default) { @@ -356,29 +372,15 @@ private async ValueTask WriteMessageAsync( var messageData = Builder.DataBuffer.ToReadOnlyMemory(Builder.DataBuffer.Position, Builder.Offset); var messagePaddingLength = CalculatePadding(messageData.Length); - int prefixSize = _options.WriteLegacyIpcFormat ? 4 : 8; - - await Buffers.RentReturnAsync(prefixSize, async (buffer) => - { - Memory currentBufferPosition = buffer; - if (!_options.WriteLegacyIpcFormat) - { - BinaryPrimitives.WriteInt32LittleEndian( - currentBufferPosition.Span, MessageSerializer.IpcContinuationToken); - currentBufferPosition = currentBufferPosition.Slice(4); - } - - var metadataSize = messageData.Length + messagePaddingLength; - BinaryPrimitives.WriteInt32LittleEndian(currentBufferPosition.Span, metadataSize); - await BaseStream.WriteAsync(buffer, cancellationToken).ConfigureAwait(false); - }).ConfigureAwait(false); + await WriteIpcMessageLengthAsync(messageData.Length + messagePaddingLength, cancellationToken) + .ConfigureAwait(false); await BaseStream.WriteAsync(messageData, cancellationToken).ConfigureAwait(false); await WritePaddingAsync(messagePaddingLength).ConfigureAwait(false); checked { - return prefixSize + messageData.Length + messagePaddingLength; + return _options.SizeOfIpcLength + messageData.Length + messagePaddingLength; } } @@ -389,6 +391,23 @@ private protected async ValueTask WriteFlatBufferAsync(CancellationToken cancell await BaseStream.WriteAsync(segment, cancellationToken).ConfigureAwait(false); } + private async ValueTask WriteIpcMessageLengthAsync(int length, CancellationToken cancellationToken) + { + await Buffers.RentReturnAsync(_options.SizeOfIpcLength, async (buffer) => + { + Memory currentBufferPosition = buffer; + if (!_options.WriteLegacyIpcFormat) + { + BinaryPrimitives.WriteInt32LittleEndian( + currentBufferPosition.Span, MessageSerializer.IpcContinuationToken); + currentBufferPosition = currentBufferPosition.Slice(sizeof(int)); + } + + BinaryPrimitives.WriteInt32LittleEndian(currentBufferPosition.Span, length); + await BaseStream.WriteAsync(buffer, cancellationToken).ConfigureAwait(false); + }).ConfigureAwait(false); + } + protected int CalculatePadding(long offset, int alignment = 8) { long result = BitUtility.RoundUpToMultiplePowerOfTwo(offset, alignment) - offset; diff --git a/csharp/src/Apache.Arrow/Ipc/IpcOptions.cs b/csharp/src/Apache.Arrow/Ipc/IpcOptions.cs index 7f37ab148f9..2f43d9800f6 100644 --- a/csharp/src/Apache.Arrow/Ipc/IpcOptions.cs +++ b/csharp/src/Apache.Arrow/Ipc/IpcOptions.cs @@ -28,5 +28,10 @@ public class IpcOptions public IpcOptions() { } + + /// + /// Gets the number of bytes used in the IPC message prefix. + /// + internal int SizeOfIpcLength => WriteLegacyIpcFormat ? 4 : 8; } } diff --git a/csharp/test/Apache.Arrow.Tests/ArrowFileReaderTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowFileReaderTests.cs index 8051bf421a6..ec62ac2627c 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowFileReaderTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowFileReaderTests.cs @@ -59,7 +59,7 @@ public async Task Ctor_MemoryPool_AllocatesFromPool(bool shouldLeaveOpen) { ArrowFileWriter writer = new ArrowFileWriter(stream, originalBatch.Schema); await writer.WriteRecordBatchAsync(originalBatch); - await writer.WriteFooterAsync(); + await writer.WriteEndAsync(); stream.Position = 0; var memoryPool = new TestMemoryAllocator(); @@ -121,7 +121,7 @@ private static async Task TestReadRecordBatchHelper( { ArrowFileWriter writer = new ArrowFileWriter(stream, originalBatch.Schema); await writer.WriteRecordBatchAsync(originalBatch); - await writer.WriteFooterAsync(); + await writer.WriteEndAsync(); stream.Position = 0; ArrowFileReader reader = new ArrowFileReader(stream); @@ -140,7 +140,7 @@ public async Task TestReadMultipleRecordBatchAsync() ArrowFileWriter writer = new ArrowFileWriter(stream, originalBatch1.Schema); await writer.WriteRecordBatchAsync(originalBatch1); await writer.WriteRecordBatchAsync(originalBatch2); - await writer.WriteFooterAsync(); + await writer.WriteEndAsync(); stream.Position = 0; // the recordbatches by index are in reverse order - back to front. diff --git a/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs index 914d91e4ae6..a74a2794188 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs @@ -60,6 +60,7 @@ public async Task Ctor_MemoryPool_AllocatesFromPool(bool shouldLeaveOpen) { ArrowStreamWriter writer = new ArrowStreamWriter(stream, originalBatch.Schema); await writer.WriteRecordBatchAsync(originalBatch); + await writer.WriteEndAsync(); stream.Position = 0; @@ -83,23 +84,29 @@ public async Task Ctor_MemoryPool_AllocatesFromPool(bool shouldLeaveOpen) } } - [Fact] - public async Task ReadRecordBatch_Memory() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ReadRecordBatch_Memory(bool writeEnd) { await TestReaderFromMemory((reader, originalBatch) => { ArrowReaderVerifier.VerifyReader(reader, originalBatch); return Task.CompletedTask; - }); + }, writeEnd); } - [Fact] - public async Task ReadRecordBatchAsync_Memory() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ReadRecordBatchAsync_Memory(bool writeEnd) { - await TestReaderFromMemory(ArrowReaderVerifier.VerifyReaderAsync); + await TestReaderFromMemory(ArrowReaderVerifier.VerifyReaderAsync, writeEnd); } - private static async Task TestReaderFromMemory(Func verificationFunc) + private static async Task TestReaderFromMemory( + Func verificationFunc, + bool writeEnd) { RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); @@ -108,6 +115,10 @@ private static async Task TestReaderFromMemory(Func { ArrowReaderVerifier.VerifyReader(reader, originalBatch); return Task.CompletedTask; - }); + }, writeEnd); } - [Fact] - public async Task ReadRecordBatchAsync_Stream() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ReadRecordBatchAsync_Stream(bool writeEnd) { - await TestReaderFromStream(ArrowReaderVerifier.VerifyReaderAsync); + await TestReaderFromStream(ArrowReaderVerifier.VerifyReaderAsync, writeEnd); } - private static async Task TestReaderFromStream(Func verificationFunc) + private static async Task TestReaderFromStream( + Func verificationFunc, + bool writeEnd) { RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); @@ -139,6 +156,10 @@ private static async Task TestReaderFromStream(Func 8; i--) + Assert.Equal(value1, writtenBytes[writtenBytes.Length - 24]); + Assert.Equal(value1, writtenBytes[writtenBytes.Length - 23]); + for (int i = 22; i > 16; i--) { Assert.Equal(0, writtenBytes[writtenBytes.Length - i]); } - Assert.Equal(value2, writtenBytes[writtenBytes.Length - 8]); - Assert.Equal(value2, writtenBytes[writtenBytes.Length - 7]); - for (int i = 6; i > 0; i--) + Assert.Equal(value2, writtenBytes[writtenBytes.Length - 16]); + Assert.Equal(value2, writtenBytes[writtenBytes.Length - 15]); + for (int i = 14; i > 8; i--) { Assert.Equal(0, writtenBytes[writtenBytes.Length - i]); } + + // verify the EOS is written correctly + for (int i = 8; i > 4; i--) + { + Assert.Equal(0xFF, writtenBytes[writtenBytes.Length - i]); + } + for (int i = 4; i > 0; i--) + { + Assert.Equal(0x00, writtenBytes[writtenBytes.Length - i]); + } } } @@ -217,23 +231,30 @@ public async Task WriteLegacyIpcFormat(bool writeLegacyIpcFormat) using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true, options)) { await writer.WriteRecordBatchAsync(originalBatch); + await writer.WriteEndAsync(); } stream.Position = 0; // ensure the continuation is written correctly - byte[] buffer = stream.GetBuffer(); + byte[] buffer = stream.ToArray(); int messageLength = BinaryPrimitives.ReadInt32LittleEndian(buffer); + int endOfBuffer1 = BinaryPrimitives.ReadInt32LittleEndian(buffer.AsSpan(buffer.Length - 8)); + int endOfBuffer2 = BinaryPrimitives.ReadInt32LittleEndian(buffer.AsSpan(buffer.Length - 4)); if (writeLegacyIpcFormat) { // the legacy IPC format doesn't have a continuation token at the start Assert.NotEqual(-1, messageLength); + Assert.NotEqual(-1, endOfBuffer1); } else { // the latest IPC format has a continuation token at the start Assert.Equal(-1, messageLength); + Assert.Equal(-1, endOfBuffer1); } + + Assert.Equal(0, endOfBuffer2); } } }