diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs index 265c31f75a7..26ba600da6f 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs @@ -28,7 +28,6 @@ public class ArrowFileWriter: ArrowStreamWriter private long _currentRecordBatchOffset = -1; private bool HasWrittenHeader { get; set; } - private bool HasWrittenFooter { get; set; } private List RecordBatchBlocks { get; } @@ -38,7 +37,12 @@ public ArrowFileWriter(Stream stream, Schema schema) } public ArrowFileWriter(Stream stream, Schema schema, bool leaveOpen) - : base(stream, schema, leaveOpen) + : this(stream, schema, leaveOpen, options: null) + { + } + + public ArrowFileWriter(Stream stream, Schema schema, bool leaveOpen, IpcOptions options) + : base(stream, schema, leaveOpen, options) { if (!stream.CanWrite) { @@ -53,7 +57,6 @@ public ArrowFileWriter(Stream stream, Schema schema, bool leaveOpen) } HasWrittenHeader = false; - HasWrittenFooter = false; RecordBatchBlocks = new List(); } @@ -101,15 +104,11 @@ private protected override void FinishedWritingRecordBatch(long bodyLength, long _currentRecordBatchOffset = -1; } - public async Task WriteFooterAsync(CancellationToken cancellationToken = default) + private protected override async ValueTask WriteEndInternalAsync(CancellationToken cancellationToken) { - if (!HasWrittenFooter) - { - await WriteFooterAsync(Schema, cancellationToken).ConfigureAwait(false); - HasWrittenFooter = true; - } + await base.WriteEndInternalAsync(cancellationToken); - await BaseStream.FlushAsync(cancellationToken).ConfigureAwait(false); + await WriteFooterAsync(Schema, cancellationToken); } private async Task WriteHeaderAsync(CancellationToken cancellationToken) diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs index a39a87bf7f2..c265fa1b0f1 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs @@ -17,6 +17,7 @@ using FlatBuffers; using System; using System.Buffers.Binary; +using System.IO; using System.Threading; using System.Threading.Tasks; @@ -57,6 +58,23 @@ public override RecordBatch ReadNextRecordBatch() //reached the end return null; } + else if (messageLength == MessageSerializer.IpcContinuationToken) + { + // ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length + if (_buffer.Length <= _bufferPosition + sizeof(int)) + { + throw new InvalidDataException("Corrupted IPC message. Received a continuation token at the end of the message."); + } + + messageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition)); + _bufferPosition += sizeof(int); + + if (messageLength == 0) + { + //reached the end + return null; + } + } Message message = Message.GetRootAsMessage( CreateByteBuffer(_buffer.Slice(_bufferPosition, messageLength))); @@ -80,6 +98,18 @@ private void ReadSchema() int schemaMessageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition)); _bufferPosition += sizeof(int); + if (schemaMessageLength == MessageSerializer.IpcContinuationToken) + { + // ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length + if (_buffer.Length <= _bufferPosition + sizeof(int)) + { + throw new InvalidDataException("Corrupted IPC message. Received a continuation token at the end of the message."); + } + + schemaMessageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition)); + _bufferPosition += sizeof(int); + } + ByteBuffer schemaBuffer = CreateByteBuffer(_buffer.Slice(_bufferPosition)); Schema = MessageSerializer.GetSchema(ReadMessage(schemaBuffer)); _bufferPosition += schemaMessageLength; diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs index c2ac4bbfeef..1a6d728fd64 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs @@ -59,18 +59,8 @@ protected async ValueTask ReadRecordBatchAsync(CancellationToken ca { await ReadSchemaAsync().ConfigureAwait(false); - int messageLength = 0; - await ArrayPool.Shared.RentReturnAsync(4, async (lengthBuffer) => - { - // Get Length of record batch for message header. - int bytesRead = await BaseStream.ReadFullBufferAsync(lengthBuffer, cancellationToken) - .ConfigureAwait(false); - - if (bytesRead == 4) - { - messageLength = BitUtility.ReadInt32(lengthBuffer); - } - }).ConfigureAwait(false); + int messageLength = await ReadMessageLengthAsync(throwOnFullRead: false, cancellationToken) + .ConfigureAwait(false); if (messageLength == 0) { @@ -106,16 +96,7 @@ protected RecordBatch ReadRecordBatch() { ReadSchema(); - int messageLength = 0; - ArrayPool.Shared.RentReturn(4, lengthBuffer => - { - int bytesRead = BaseStream.ReadFullBuffer(lengthBuffer); - - if (bytesRead == 4) - { - messageLength = BitUtility.ReadInt32(lengthBuffer); - } - }); + int messageLength = ReadMessageLength(throwOnFullRead: false); if (messageLength == 0) { @@ -153,14 +134,8 @@ protected virtual async ValueTask ReadSchemaAsync() } // Figure out length of schema - int schemaMessageLength = 0; - await ArrayPool.Shared.RentReturnAsync(4, async (lengthBuffer) => - { - int bytesRead = await BaseStream.ReadFullBufferAsync(lengthBuffer).ConfigureAwait(false); - EnsureFullRead(lengthBuffer, bytesRead); - - schemaMessageLength = BitUtility.ReadInt32(lengthBuffer); - }).ConfigureAwait(false); + int schemaMessageLength = await ReadMessageLengthAsync(throwOnFullRead: true) + .ConfigureAwait(false); await ArrayPool.Shared.RentReturnAsync(schemaMessageLength, async (buff) => { @@ -181,14 +156,7 @@ protected virtual void ReadSchema() } // Figure out length of schema - int schemaMessageLength = 0; - ArrayPool.Shared.RentReturn(4, lengthBuffer => - { - int bytesRead = BaseStream.ReadFullBuffer(lengthBuffer); - EnsureFullRead(lengthBuffer, bytesRead); - - schemaMessageLength = BitUtility.ReadInt32(lengthBuffer); - }); + int schemaMessageLength = ReadMessageLength(throwOnFullRead: true); ArrayPool.Shared.RentReturn(schemaMessageLength, buff => { @@ -200,6 +168,84 @@ protected virtual void ReadSchema() }); } + private async ValueTask ReadMessageLengthAsync(bool throwOnFullRead, CancellationToken cancellationToken = default) + { + int messageLength = 0; + await ArrayPool.Shared.RentReturnAsync(4, async (lengthBuffer) => + { + int bytesRead = await BaseStream.ReadFullBufferAsync(lengthBuffer, cancellationToken) + .ConfigureAwait(false); + if (throwOnFullRead) + { + EnsureFullRead(lengthBuffer, bytesRead); + } + else if (bytesRead != 4) + { + return; + } + + messageLength = BitUtility.ReadInt32(lengthBuffer); + + // ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length + if (messageLength == MessageSerializer.IpcContinuationToken) + { + bytesRead = await BaseStream.ReadFullBufferAsync(lengthBuffer, cancellationToken) + .ConfigureAwait(false); + if (throwOnFullRead) + { + EnsureFullRead(lengthBuffer, bytesRead); + } + else if (bytesRead != 4) + { + messageLength = 0; + return; + } + + messageLength = BitUtility.ReadInt32(lengthBuffer); + } + }).ConfigureAwait(false); + + return messageLength; + } + + private int ReadMessageLength(bool throwOnFullRead) + { + int messageLength = 0; + ArrayPool.Shared.RentReturn(4, lengthBuffer => + { + int bytesRead = BaseStream.ReadFullBuffer(lengthBuffer); + if (throwOnFullRead) + { + EnsureFullRead(lengthBuffer, bytesRead); + } + else if (bytesRead != 4) + { + return; + } + + messageLength = BitUtility.ReadInt32(lengthBuffer); + + // ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length + if (messageLength == MessageSerializer.IpcContinuationToken) + { + bytesRead = BaseStream.ReadFullBuffer(lengthBuffer); + if (throwOnFullRead) + { + EnsureFullRead(lengthBuffer, bytesRead); + } + else if (bytesRead != 4) + { + messageLength = 0; + return; + } + + messageLength = BitUtility.ReadInt32(lengthBuffer); + } + }); + + return messageLength; + } + /// /// Ensures the number of bytes read matches the buffer length /// and throws an exception it if doesn't. This ensures we have read diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs index e1da4489ce4..d429a55cc17 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs @@ -142,9 +142,12 @@ public void Visit(IArrowArray array) protected bool HasWrittenSchema { get; set; } + private bool HasWrittenEnd { get; set; } + protected Schema Schema { get; } private readonly bool _leaveOpen; + private readonly IpcOptions _options; private protected const Flatbuf.MetadataVersion CurrentMetadataVersion = Flatbuf.MetadataVersion.V4; @@ -152,11 +155,17 @@ public void Visit(IArrowArray array) private readonly ArrowTypeFlatbufferBuilder _fieldTypeBuilder; - public ArrowStreamWriter(Stream baseStream, Schema schema) : this(baseStream, schema, leaveOpen: false) + public ArrowStreamWriter(Stream baseStream, Schema schema) + : this(baseStream, schema, leaveOpen: false) { } public ArrowStreamWriter(Stream baseStream, Schema schema, bool leaveOpen) + : this(baseStream, schema, leaveOpen, options: null) + { + } + + public ArrowStreamWriter(Stream baseStream, Schema schema, bool leaveOpen, IpcOptions options) { BaseStream = baseStream ?? throw new ArgumentNullException(nameof(baseStream)); Schema = schema ?? throw new ArgumentNullException(nameof(schema)); @@ -167,6 +176,7 @@ public ArrowStreamWriter(Stream baseStream, Schema schema, bool leaveOpen) HasWrittenSchema = false; _fieldTypeBuilder = new ArrowTypeFlatbufferBuilder(Builder); + _options = options ?? IpcOptions.Default; } private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBatch, @@ -262,6 +272,11 @@ private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBat FinishedWritingRecordBatch(bodyLength + bodyPaddingLength, metadataLength); } + private protected virtual ValueTask WriteEndInternalAsync(CancellationToken cancellationToken) + { + return WriteIpcMessageLengthAsync(length: 0, cancellationToken); + } + private protected virtual void StartingWritingRecordBatch() { } @@ -274,6 +289,15 @@ public virtual Task WriteRecordBatchAsync(RecordBatch recordBatch, CancellationT { return WriteRecordBatchInternalAsync(recordBatch, cancellationToken); } + + public async Task WriteEndAsync(CancellationToken cancellationToken = default) + { + if (!HasWrittenEnd) + { + await WriteEndInternalAsync(cancellationToken); + HasWrittenEnd = true; + } + } private ValueTask WriteBufferAsync(ArrowBuffer arrowBuffer, CancellationToken cancellationToken = default) { @@ -348,19 +372,15 @@ private async ValueTask WriteMessageAsync( var messageData = Builder.DataBuffer.ToReadOnlyMemory(Builder.DataBuffer.Position, Builder.Offset); var messagePaddingLength = CalculatePadding(messageData.Length); - await Buffers.RentReturnAsync(4, async (buffer) => - { - var metadataSize = messageData.Length + messagePaddingLength; - BinaryPrimitives.WriteInt32LittleEndian(buffer.Span, metadataSize); - await BaseStream.WriteAsync(buffer, cancellationToken).ConfigureAwait(false); - }).ConfigureAwait(false); + await WriteIpcMessageLengthAsync(messageData.Length + messagePaddingLength, cancellationToken) + .ConfigureAwait(false); await BaseStream.WriteAsync(messageData, cancellationToken).ConfigureAwait(false); await WritePaddingAsync(messagePaddingLength).ConfigureAwait(false); checked { - return 4 + messageData.Length + messagePaddingLength; + return _options.SizeOfIpcLength + messageData.Length + messagePaddingLength; } } @@ -371,6 +391,23 @@ private protected async ValueTask WriteFlatBufferAsync(CancellationToken cancell await BaseStream.WriteAsync(segment, cancellationToken).ConfigureAwait(false); } + private async ValueTask WriteIpcMessageLengthAsync(int length, CancellationToken cancellationToken) + { + await Buffers.RentReturnAsync(_options.SizeOfIpcLength, async (buffer) => + { + Memory currentBufferPosition = buffer; + if (!_options.WriteLegacyIpcFormat) + { + BinaryPrimitives.WriteInt32LittleEndian( + currentBufferPosition.Span, MessageSerializer.IpcContinuationToken); + currentBufferPosition = currentBufferPosition.Slice(sizeof(int)); + } + + BinaryPrimitives.WriteInt32LittleEndian(currentBufferPosition.Span, length); + await BaseStream.WriteAsync(buffer, cancellationToken).ConfigureAwait(false); + }).ConfigureAwait(false); + } + protected int CalculatePadding(long offset, int alignment = 8) { long result = BitUtility.RoundUpToMultiplePowerOfTwo(offset, alignment) - offset; diff --git a/csharp/src/Apache.Arrow/Ipc/IpcOptions.cs b/csharp/src/Apache.Arrow/Ipc/IpcOptions.cs new file mode 100644 index 00000000000..2f43d9800f6 --- /dev/null +++ b/csharp/src/Apache.Arrow/Ipc/IpcOptions.cs @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace Apache.Arrow.Ipc +{ + public class IpcOptions + { + internal static readonly IpcOptions Default = new IpcOptions(); + + /// + /// Write the pre-0.15.0 encapsulated IPC message format + /// consisting of a 4-byte prefix instead of 8 byte. + /// + public bool WriteLegacyIpcFormat { get; set; } + + public IpcOptions() + { + } + + /// + /// Gets the number of bytes used in the IPC message prefix. + /// + internal int SizeOfIpcLength => WriteLegacyIpcFormat ? 4 : 8; + } +} diff --git a/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs b/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs index 7478e00e133..612f14739c0 100644 --- a/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs +++ b/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs @@ -20,6 +20,7 @@ namespace Apache.Arrow.Ipc { internal class MessageSerializer { + public const int IpcContinuationToken = -1; public static Types.NumberType GetNumberType(int bitWidth, bool signed) { diff --git a/csharp/test/Apache.Arrow.Tests/ArrowFileReaderTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowFileReaderTests.cs index 8051bf421a6..ec62ac2627c 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowFileReaderTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowFileReaderTests.cs @@ -59,7 +59,7 @@ public async Task Ctor_MemoryPool_AllocatesFromPool(bool shouldLeaveOpen) { ArrowFileWriter writer = new ArrowFileWriter(stream, originalBatch.Schema); await writer.WriteRecordBatchAsync(originalBatch); - await writer.WriteFooterAsync(); + await writer.WriteEndAsync(); stream.Position = 0; var memoryPool = new TestMemoryAllocator(); @@ -121,7 +121,7 @@ private static async Task TestReadRecordBatchHelper( { ArrowFileWriter writer = new ArrowFileWriter(stream, originalBatch.Schema); await writer.WriteRecordBatchAsync(originalBatch); - await writer.WriteFooterAsync(); + await writer.WriteEndAsync(); stream.Position = 0; ArrowFileReader reader = new ArrowFileReader(stream); @@ -140,7 +140,7 @@ public async Task TestReadMultipleRecordBatchAsync() ArrowFileWriter writer = new ArrowFileWriter(stream, originalBatch1.Schema); await writer.WriteRecordBatchAsync(originalBatch1); await writer.WriteRecordBatchAsync(originalBatch2); - await writer.WriteFooterAsync(); + await writer.WriteEndAsync(); stream.Position = 0; // the recordbatches by index are in reverse order - back to front. diff --git a/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs index 914d91e4ae6..a74a2794188 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs @@ -60,6 +60,7 @@ public async Task Ctor_MemoryPool_AllocatesFromPool(bool shouldLeaveOpen) { ArrowStreamWriter writer = new ArrowStreamWriter(stream, originalBatch.Schema); await writer.WriteRecordBatchAsync(originalBatch); + await writer.WriteEndAsync(); stream.Position = 0; @@ -83,23 +84,29 @@ public async Task Ctor_MemoryPool_AllocatesFromPool(bool shouldLeaveOpen) } } - [Fact] - public async Task ReadRecordBatch_Memory() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ReadRecordBatch_Memory(bool writeEnd) { await TestReaderFromMemory((reader, originalBatch) => { ArrowReaderVerifier.VerifyReader(reader, originalBatch); return Task.CompletedTask; - }); + }, writeEnd); } - [Fact] - public async Task ReadRecordBatchAsync_Memory() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ReadRecordBatchAsync_Memory(bool writeEnd) { - await TestReaderFromMemory(ArrowReaderVerifier.VerifyReaderAsync); + await TestReaderFromMemory(ArrowReaderVerifier.VerifyReaderAsync, writeEnd); } - private static async Task TestReaderFromMemory(Func verificationFunc) + private static async Task TestReaderFromMemory( + Func verificationFunc, + bool writeEnd) { RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); @@ -108,6 +115,10 @@ private static async Task TestReaderFromMemory(Func { ArrowReaderVerifier.VerifyReader(reader, originalBatch); return Task.CompletedTask; - }); + }, writeEnd); } - [Fact] - public async Task ReadRecordBatchAsync_Stream() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ReadRecordBatchAsync_Stream(bool writeEnd) { - await TestReaderFromStream(ArrowReaderVerifier.VerifyReaderAsync); + await TestReaderFromStream(ArrowReaderVerifier.VerifyReaderAsync, writeEnd); } - private static async Task TestReaderFromStream(Func verificationFunc) + private static async Task TestReaderFromStream( + Func verificationFunc, + bool writeEnd) { RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); @@ -139,6 +156,10 @@ private static async Task TestReaderFromStream(Func 8; i--) + Assert.Equal(value1, writtenBytes[writtenBytes.Length - 24]); + Assert.Equal(value1, writtenBytes[writtenBytes.Length - 23]); + for (int i = 22; i > 16; i--) { Assert.Equal(0, writtenBytes[writtenBytes.Length - i]); } - Assert.Equal(value2, writtenBytes[writtenBytes.Length - 8]); - Assert.Equal(value2, writtenBytes[writtenBytes.Length - 7]); - for (int i = 6; i > 0; i--) + Assert.Equal(value2, writtenBytes[writtenBytes.Length - 16]); + Assert.Equal(value2, writtenBytes[writtenBytes.Length - 15]); + for (int i = 14; i > 8; i--) { Assert.Equal(0, writtenBytes[writtenBytes.Length - i]); } + + // verify the EOS is written correctly + for (int i = 8; i > 4; i--) + { + Assert.Equal(0xFF, writtenBytes[writtenBytes.Length - i]); + } + for (int i = 4; i > 0; i--) + { + Assert.Equal(0x00, writtenBytes[writtenBytes.Length - i]); + } + } + } + + [Fact] + public async Task LegacyIpcFormatRoundTrips() + { + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + await TestRoundTripRecordBatch(originalBatch, new IpcOptions() { WriteLegacyIpcFormat = true }); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task WriteLegacyIpcFormat(bool writeLegacyIpcFormat) + { + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + var options = new IpcOptions() { WriteLegacyIpcFormat = writeLegacyIpcFormat }; + + using (MemoryStream stream = new MemoryStream()) + { + using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true, options)) + { + await writer.WriteRecordBatchAsync(originalBatch); + await writer.WriteEndAsync(); + } + + stream.Position = 0; + + // ensure the continuation is written correctly + byte[] buffer = stream.ToArray(); + int messageLength = BinaryPrimitives.ReadInt32LittleEndian(buffer); + int endOfBuffer1 = BinaryPrimitives.ReadInt32LittleEndian(buffer.AsSpan(buffer.Length - 8)); + int endOfBuffer2 = BinaryPrimitives.ReadInt32LittleEndian(buffer.AsSpan(buffer.Length - 4)); + if (writeLegacyIpcFormat) + { + // the legacy IPC format doesn't have a continuation token at the start + Assert.NotEqual(-1, messageLength); + Assert.NotEqual(-1, endOfBuffer1); + } + else + { + // the latest IPC format has a continuation token at the start + Assert.Equal(-1, messageLength); + Assert.Equal(-1, endOfBuffer1); + } + + Assert.Equal(0, endOfBuffer2); } } }