diff --git a/csharp/src/Apache.Arrow/Apache.Arrow.csproj b/csharp/src/Apache.Arrow/Apache.Arrow.csproj index a7d2b30772e3..62574029f97c 100644 --- a/csharp/src/Apache.Arrow/Apache.Arrow.csproj +++ b/csharp/src/Apache.Arrow/Apache.Arrow.csproj @@ -1,4 +1,4 @@ - + netstandard1.3;netcoreapp2.1 @@ -37,5 +37,6 @@ + diff --git a/csharp/src/Apache.Arrow/Extensions/StreamExtensions.netcoreapp2.1.cs b/csharp/src/Apache.Arrow/Extensions/StreamExtensions.netcoreapp2.1.cs index f51dc53b0979..efcacdc844d9 100644 --- a/csharp/src/Apache.Arrow/Extensions/StreamExtensions.netcoreapp2.1.cs +++ b/csharp/src/Apache.Arrow/Extensions/StreamExtensions.netcoreapp2.1.cs @@ -25,5 +25,10 @@ public static int Read(this Stream stream, Memory buffer) { return stream.Read(buffer.Span); } + + public static void Write(this Stream stream, ReadOnlyMemory buffer) + { + stream.Write(buffer.Span); + } } } diff --git a/csharp/src/Apache.Arrow/Extensions/StreamExtensions.netstandard.cs b/csharp/src/Apache.Arrow/Extensions/StreamExtensions.netstandard.cs index ce23bd1eb7bd..b983be0fd0da 100644 --- a/csharp/src/Apache.Arrow/Extensions/StreamExtensions.netstandard.cs +++ b/csharp/src/Apache.Arrow/Extensions/StreamExtensions.netstandard.cs @@ -74,6 +74,27 @@ async ValueTask FinishReadAsync(Task readTask, byte[] localBuffer, Mem } } + public static void Write(this Stream stream, ReadOnlyMemory buffer) + { + if (MemoryMarshal.TryGetArray(buffer, out ArraySegment array)) + { + stream.Write(array.Array, array.Offset, array.Count); + } + else + { + byte[] sharedBuffer = ArrayPool.Shared.Rent(buffer.Length); + try + { + buffer.Span.CopyTo(sharedBuffer); + stream.Write(sharedBuffer, 0, buffer.Length); + } + finally + { + ArrayPool.Shared.Return(sharedBuffer); + } + } + } + public static ValueTask WriteAsync(this Stream stream, ReadOnlyMemory buffer, CancellationToken cancellationToken = default) { if (MemoryMarshal.TryGetArray(buffer, out ArraySegment array)) diff --git a/csharp/src/Apache.Arrow/Extensions/TupleExtensions.netstandard.cs b/csharp/src/Apache.Arrow/Extensions/TupleExtensions.netstandard.cs new file mode 100644 index 000000000000..fe42075f14f7 --- /dev/null +++ b/csharp/src/Apache.Arrow/Extensions/TupleExtensions.netstandard.cs @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; + +namespace Apache.Arrow +{ + // Helpers to Deconstruct Tuples on netstandard + internal static partial class TupleExtensions + { + public static void Deconstruct(this Tuple value, out T1 item1, out T2 item2) + { + item1 = value.Item1; + item2 = value.Item2; + } + } +} diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs index 756ebfa3c35c..a1da2925f340 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs @@ -61,6 +61,19 @@ public ArrowFileWriter(Stream stream, Schema schema, bool leaveOpen, IpcOptions RecordBatchBlocks = new List(); } + public override void WriteRecordBatch(RecordBatch recordBatch) + { + // TODO: Compare record batch schema + + if (!HasWrittenHeader) + { + WriteHeader(); + HasWrittenHeader = true; + } + + WriteRecordBatchInternal(recordBatch); + } + public override async Task WriteRecordBatchAsync(RecordBatch recordBatch, CancellationToken cancellationToken = default) { // TODO: Compare record batch schema @@ -104,6 +117,13 @@ private protected override void FinishedWritingRecordBatch(long bodyLength, long _currentRecordBatchOffset = -1; } + private protected override void WriteEndInternal() + { + base.WriteEndInternal(); + + WriteFooter(Schema); + } + private protected override async ValueTask WriteEndInternalAsync(CancellationToken cancellationToken) { await base.WriteEndInternalAsync(cancellationToken); @@ -111,6 +131,14 @@ private protected override async ValueTask WriteEndInternalAsync(CancellationTok await WriteFooterAsync(Schema, cancellationToken); } + private void WriteHeader() + { + // Write magic number and empty padding up to the 8-byte boundary + + WriteMagic(); + WritePadding(CalculatePadding(ArrowFileConstants.Magic.Length)); + } + private async Task WriteHeaderAsync(CancellationToken cancellationToken) { // Write magic number and empty padding up to the 8-byte boundary @@ -120,6 +148,64 @@ await WritePaddingAsync(CalculatePadding(ArrowFileConstants.Magic.Length)) .ConfigureAwait(false); } + private void WriteFooter(Schema schema) + { + Builder.Clear(); + + long offset = BaseStream.Position; + + // Serialize the schema + + FlatBuffers.Offset schemaOffset = SerializeSchema(schema); + + // Serialize all record batches + + Flatbuf.Footer.StartRecordBatchesVector(Builder, RecordBatchBlocks.Count); + + foreach (Block recordBatch in RecordBatchBlocks) + { + Flatbuf.Block.CreateBlock( + Builder, recordBatch.Offset, recordBatch.MetadataLength, recordBatch.BodyLength); + } + + FlatBuffers.VectorOffset recordBatchesVectorOffset = Builder.EndVector(); + + // Serialize all dictionaries + // NOTE: Currently unsupported. + + Flatbuf.Footer.StartDictionariesVector(Builder, 0); + + FlatBuffers.VectorOffset dictionaryBatchesOffset = Builder.EndVector(); + + // Serialize and write the footer flatbuffer + + FlatBuffers.Offset footerOffset = Flatbuf.Footer.CreateFooter(Builder, CurrentMetadataVersion, + schemaOffset, dictionaryBatchesOffset, recordBatchesVectorOffset); + + Builder.Finish(footerOffset.Value); + + WriteFlatBuffer(); + + // Write footer length + + Buffers.RentReturn(4, (buffer) => + { + int footerLength; + checked + { + footerLength = (int)(BaseStream.Position - offset); + } + + BinaryPrimitives.WriteInt32LittleEndian(buffer.Span, footerLength); + + BaseStream.Write(buffer); + }); + + // Write magic + + WriteMagic(); + } + private async Task WriteFooterAsync(Schema schema, CancellationToken cancellationToken) { Builder.Clear(); @@ -182,6 +268,11 @@ await Buffers.RentReturnAsync(4, async (buffer) => await WriteMagicAsync(cancellationToken).ConfigureAwait(false); } + private void WriteMagic() + { + BaseStream.Write(ArrowFileConstants.Magic); + } + private ValueTask WriteMagicAsync(CancellationToken cancellationToken) { return BaseStream.WriteAsync(ArrowFileConstants.Magic, cancellationToken); diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs index cb2b920d61c2..9d81a8ef8bbb 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs @@ -212,53 +212,79 @@ private void CountSelfAndChildrenNodes(IArrowType type, ref int count) count++; } - private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBatch, - CancellationToken cancellationToken = default) + private protected void WriteRecordBatchInternal(RecordBatch recordBatch) { // TODO: Truncate buffers with extraneous padding / unused capacity if (!HasWrittenSchema) { - await WriteSchemaAsync(Schema, cancellationToken).ConfigureAwait(false); + WriteSchema(Schema); HasWrittenSchema = true; } - Builder.Clear(); + (ArrowRecordBatchFlatBufferBuilder recordBatchBuilder, VectorOffset fieldNodesVectorOffset) = + PreparingWritingRecordBatch(recordBatch); - // Serialize field nodes + VectorOffset buffersVectorOffset = Builder.EndVector(); - int fieldCount = Schema.Fields.Count; + // Serialize record batch - Flatbuf.RecordBatch.StartNodesVector(Builder, CountAllNodes()); + StartingWritingRecordBatch(); - // flatbuffer struct vectors have to be created in reverse order - for (int i = fieldCount - 1; i >= 0; i--) - { - CreateSelfAndChildrenFieldNodes(recordBatch.Column(i).Data); - } + Offset recordBatchOffset = Flatbuf.RecordBatch.CreateRecordBatch(Builder, recordBatch.Length, + fieldNodesVectorOffset, + buffersVectorOffset); - VectorOffset fieldNodesVectorOffset = Builder.EndVector(); + long metadataLength = WriteMessage(Flatbuf.MessageHeader.RecordBatch, + recordBatchOffset, recordBatchBuilder.TotalLength); - // Serialize buffers + // Write buffer data - var recordBatchBuilder = new ArrowRecordBatchFlatBufferBuilder(); - for (int i = 0; i < fieldCount; i++) + IReadOnlyList buffers = recordBatchBuilder.Buffers; + + long bodyLength = 0; + + for (int i = 0; i < buffers.Count; i++) { - IArrowArray fieldArray = recordBatch.Column(i); - fieldArray.Accept(recordBatchBuilder); + ArrowBuffer buffer = buffers[i].DataBuffer; + if (buffer.IsEmpty) + continue; + + WriteBuffer(buffer); + + int paddedLength = checked((int)BitUtility.RoundUpToMultipleOf8(buffer.Length)); + int padding = paddedLength - buffer.Length; + if (padding > 0) + { + WritePadding(padding); + } + + bodyLength += paddedLength; } - IReadOnlyList buffers = recordBatchBuilder.Buffers; + // Write padding so the record batch message body length is a multiple of 8 bytes - Flatbuf.RecordBatch.StartBuffersVector(Builder, buffers.Count); + int bodyPaddingLength = CalculatePadding(bodyLength); - // flatbuffer struct vectors have to be created in reverse order - for (int i = buffers.Count - 1; i >= 0; i--) + WritePadding(bodyPaddingLength); + + FinishedWritingRecordBatch(bodyLength + bodyPaddingLength, metadataLength); + } + + private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBatch, + CancellationToken cancellationToken = default) + { + // TODO: Truncate buffers with extraneous padding / unused capacity + + if (!HasWrittenSchema) { - Flatbuf.Buffer.CreateBuffer(Builder, - buffers[i].Offset, buffers[i].DataBuffer.Length); + await WriteSchemaAsync(Schema, cancellationToken).ConfigureAwait(false); + HasWrittenSchema = true; } + (ArrowRecordBatchFlatBufferBuilder recordBatchBuilder, VectorOffset fieldNodesVectorOffset) = + PreparingWritingRecordBatch(recordBatch); + VectorOffset buffersVectorOffset = Builder.EndVector(); // Serialize record batch @@ -275,6 +301,8 @@ private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBat // Write buffer data + IReadOnlyList buffers = recordBatchBuilder.Buffers; + long bodyLength = 0; for (int i = 0; i < buffers.Count; i++) @@ -304,6 +332,52 @@ private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBat FinishedWritingRecordBatch(bodyLength + bodyPaddingLength, metadataLength); } + private Tuple PreparingWritingRecordBatch(RecordBatch recordBatch) + { + Builder.Clear(); + + // Serialize field nodes + + int fieldCount = Schema.Fields.Count; + + Flatbuf.RecordBatch.StartNodesVector(Builder, CountAllNodes()); + + // flatbuffer struct vectors have to be created in reverse order + for (int i = fieldCount - 1; i >= 0; i--) + { + CreateSelfAndChildrenFieldNodes(recordBatch.Column(i).Data); + } + + VectorOffset fieldNodesVectorOffset = Builder.EndVector(); + + // Serialize buffers + + var recordBatchBuilder = new ArrowRecordBatchFlatBufferBuilder(); + for (int i = 0; i < fieldCount; i++) + { + IArrowArray fieldArray = recordBatch.Column(i); + fieldArray.Accept(recordBatchBuilder); + } + + IReadOnlyList buffers = recordBatchBuilder.Buffers; + + Flatbuf.RecordBatch.StartBuffersVector(Builder, buffers.Count); + + // flatbuffer struct vectors have to be created in reverse order + for (int i = buffers.Count - 1; i >= 0; i--) + { + Flatbuf.Buffer.CreateBuffer(Builder, + buffers[i].Offset, buffers[i].DataBuffer.Length); + } + + return Tuple.Create(recordBatchBuilder, fieldNodesVectorOffset); + } + + private protected virtual void WriteEndInternal() + { + WriteIpcMessageLength(length: 0); + } + private protected virtual ValueTask WriteEndInternalAsync(CancellationToken cancellationToken) { return WriteIpcMessageLengthAsync(length: 0, cancellationToken); @@ -317,11 +391,25 @@ private protected virtual void FinishedWritingRecordBatch(long bodyLength, long { } + public virtual void WriteRecordBatch(RecordBatch recordBatch) + { + WriteRecordBatchInternal(recordBatch); + } + public virtual Task WriteRecordBatchAsync(RecordBatch recordBatch, CancellationToken cancellationToken = default) { return WriteRecordBatchInternalAsync(recordBatch, cancellationToken); } + public void WriteEnd() + { + if (!HasWrittenEnd) + { + WriteEndInternal(); + HasWrittenEnd = true; + } + } + public async Task WriteEndAsync(CancellationToken cancellationToken = default) { if (!HasWrittenEnd) @@ -331,6 +419,11 @@ public async Task WriteEndAsync(CancellationToken cancellationToken = default) } } + private void WriteBuffer(ArrowBuffer arrowBuffer) + { + BaseStream.Write(arrowBuffer.Memory); + } + private ValueTask WriteBufferAsync(ArrowBuffer arrowBuffer, CancellationToken cancellationToken = default) { return BaseStream.WriteAsync(arrowBuffer.Memory, cancellationToken); @@ -391,6 +484,21 @@ private ValueTask WriteBufferAsync(ArrowBuffer arrowBuffer, CancellationToken ca return children; } + private Offset WriteSchema(Schema schema) + { + Builder.Clear(); + + // Build schema + + Offset schemaOffset = SerializeSchema(schema); + + // Build message + + WriteMessage(Flatbuf.MessageHeader.Schema, schemaOffset, 0); + + return schemaOffset; + } + private async ValueTask> WriteSchemaAsync(Schema schema, CancellationToken cancellationToken) { Builder.Clear(); @@ -407,6 +515,36 @@ await WriteMessageAsync(Flatbuf.MessageHeader.Schema, schemaOffset, 0, cancellat return schemaOffset; } + /// + /// Writes the message to the . + /// + /// + /// The number of bytes written to the stream. + /// + private long WriteMessage( + Flatbuf.MessageHeader headerType, Offset headerOffset, int bodyLength) + where T : struct + { + Offset messageOffset = Flatbuf.Message.CreateMessage( + Builder, CurrentMetadataVersion, headerType, headerOffset.Value, + bodyLength); + + Builder.Finish(messageOffset.Value); + + ReadOnlyMemory messageData = Builder.DataBuffer.ToReadOnlyMemory(Builder.DataBuffer.Position, Builder.Offset); + int messagePaddingLength = CalculatePadding(_options.SizeOfIpcLength + messageData.Length); + + WriteIpcMessageLength(messageData.Length + messagePaddingLength); + + BaseStream.Write(messageData); + WritePadding(messagePaddingLength); + + checked + { + return _options.SizeOfIpcLength + messageData.Length + messagePaddingLength; + } + } + /// /// Writes the message to the . /// @@ -439,6 +577,13 @@ await WriteIpcMessageLengthAsync(messageData.Length + messagePaddingLength, canc } } + private protected void WriteFlatBuffer() + { + ReadOnlyMemory segment = Builder.DataBuffer.ToReadOnlyMemory(Builder.DataBuffer.Position, Builder.Offset); + + BaseStream.Write(segment); + } + private protected async ValueTask WriteFlatBufferAsync(CancellationToken cancellationToken = default) { ReadOnlyMemory segment = Builder.DataBuffer.ToReadOnlyMemory(Builder.DataBuffer.Position, Builder.Offset); @@ -446,6 +591,23 @@ private protected async ValueTask WriteFlatBufferAsync(CancellationToken cancell await BaseStream.WriteAsync(segment, cancellationToken).ConfigureAwait(false); } + private void WriteIpcMessageLength(int length) + { + Buffers.RentReturn(_options.SizeOfIpcLength, (buffer) => + { + Memory currentBufferPosition = buffer; + if (!_options.WriteLegacyIpcFormat) + { + BinaryPrimitives.WriteInt32LittleEndian( + currentBufferPosition.Span, MessageSerializer.IpcContinuationToken); + currentBufferPosition = currentBufferPosition.Slice(sizeof(int)); + } + + BinaryPrimitives.WriteInt32LittleEndian(currentBufferPosition.Span, length); + BaseStream.Write(buffer); + }); + } + private async ValueTask WriteIpcMessageLengthAsync(int length, CancellationToken cancellationToken) { await Buffers.RentReturnAsync(_options.SizeOfIpcLength, async (buffer) => @@ -472,6 +634,14 @@ protected int CalculatePadding(long offset, int alignment = 8) } } + private protected void WritePadding(int length) + { + if (length > 0) + { + BaseStream.Write(s_padding.AsMemory(0, Math.Min(s_padding.Length, length))); + } + } + private protected ValueTask WritePaddingAsync(int length) { if (length > 0) diff --git a/csharp/test/Apache.Arrow.Tests/ArrowFileWriterTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowFileWriterTests.cs index a8d6f3f95f96..49ea6f9429be 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowFileWriterTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowFileWriterTests.cs @@ -61,6 +61,34 @@ public async Task WritesFooterAlignedMulitpleOf8() { RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + var stream = new MemoryStream(); + var writer = new ArrowFileWriter( + stream, + originalBatch.Schema, + leaveOpen: true, + // use WriteLegacyIpcFormat, which only uses a 4-byte length prefix + // which causes the length prefix to not be 8-byte aligned by default + new IpcOptions() { WriteLegacyIpcFormat = true }); + + writer.WriteRecordBatch(originalBatch); + writer.WriteEnd(); + + stream.Position = 0; + + await ValidateRecordBatchFile(stream, originalBatch); + } + + /// + /// Tests that writing an arrow file will always align the Block lengths + /// to 8 bytes. There are asserts in both the reader and writer which will fail + /// if this isn't the case. + /// + /// + [Fact] + public async Task WritesFooterAlignedMulitpleOf8Async() + { + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + var stream = new MemoryStream(); var writer = new ArrowFileWriter( stream, @@ -75,11 +103,16 @@ public async Task WritesFooterAlignedMulitpleOf8() stream.Position = 0; + await ValidateRecordBatchFile(stream, originalBatch); + } + + private async Task ValidateRecordBatchFile(Stream stream, RecordBatch recordBatch) + { var reader = new ArrowFileReader(stream); int count = await reader.RecordBatchCountAsync(); Assert.Equal(1, count); RecordBatch readBatch = await reader.ReadRecordBatchAsync(0); - ArrowReaderVerifier.CompareBatches(originalBatch, readBatch); + ArrowReaderVerifier.CompareBatches(recordBatch, readBatch); } } } diff --git a/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs index 3387289e97ba..b65d7353186f 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs @@ -56,7 +56,41 @@ public void Ctor_LeaveOpenTrue_StreamValidOnDispose() } [Fact] - public async Task CanWriteToNetworkStream() + public void CanWriteToNetworkStream() + { + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + + const int port = 32153; + TcpListener listener = new TcpListener(IPAddress.Loopback, port); + listener.Start(); + + using (TcpClient sender = new TcpClient()) + { + sender.Connect(IPAddress.Loopback, port); + NetworkStream stream = sender.GetStream(); + + using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema)) + { + writer.WriteRecordBatch(originalBatch); + writer.WriteEnd(); + + stream.Flush(); + } + } + + using (TcpClient receiver = listener.AcceptTcpClient()) + { + NetworkStream stream = receiver.GetStream(); + using (var reader = new ArrowStreamReader(stream)) + { + RecordBatch newBatch = reader.ReadNextRecordBatch(); + ArrowReaderVerifier.CompareBatches(originalBatch, newBatch); + } + } + } + + [Fact] + public async Task CanWriteToNetworkStreamAsync() { RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); @@ -90,15 +124,23 @@ public async Task CanWriteToNetworkStream() } [Fact] - public async Task WriteEmptyBatch() + public void WriteEmptyBatch() + { + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 0); + + TestRoundTripRecordBatch(originalBatch); + } + + [Fact] + public async Task WriteEmptyBatchAsync() { RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 0); - await TestRoundTripRecordBatch(originalBatch); + await TestRoundTripRecordBatchAsync(originalBatch); } [Fact] - public async Task WriteBatchWithNulls() + public void WriteBatchWithNulls() { RecordBatch originalBatch = new RecordBatch.Builder() .Append("Column1", false, col => col.Int32(array => array.AppendRange(Enumerable.Range(0, 10)))) @@ -122,10 +164,59 @@ public async Task WriteBatchWithNulls() offset: 0)) .Build(); - await TestRoundTripRecordBatch(originalBatch); + TestRoundTripRecordBatch(originalBatch); } - private static async Task TestRoundTripRecordBatch(RecordBatch originalBatch, IpcOptions options = null) + [Fact] + public async Task WriteBatchWithNullsAsync() + { + RecordBatch originalBatch = new RecordBatch.Builder() + .Append("Column1", false, col => col.Int32(array => array.AppendRange(Enumerable.Range(0, 10)))) + .Append("Column2", true, new Int32Array( + valueBuffer: new ArrowBuffer.Builder().AppendRange(Enumerable.Range(0, 10)).Build(), + nullBitmapBuffer: new ArrowBuffer.Builder().Append(0xfd).Append(0xff).Build(), + length: 10, + nullCount: 2, + offset: 0)) + .Append("Column3", true, new Int32Array( + valueBuffer: new ArrowBuffer.Builder().AppendRange(Enumerable.Range(0, 10)).Build(), + nullBitmapBuffer: new ArrowBuffer.Builder().Append(0x00).Append(0x00).Build(), + length: 10, + nullCount: 10, + offset: 0)) + .Append("NullableBooleanColumn", true, new BooleanArray( + valueBuffer: new ArrowBuffer.Builder().Append(0xfd).Append(0xff).Build(), + nullBitmapBuffer: new ArrowBuffer.Builder().Append(0xed).Append(0xff).Build(), + length: 10, + nullCount: 3, + offset: 0)) + .Build(); + + await TestRoundTripRecordBatchAsync(originalBatch); + } + + private static void TestRoundTripRecordBatch(RecordBatch originalBatch, IpcOptions options = null) + { + using (MemoryStream stream = new MemoryStream()) + { + using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true, options)) + { + writer.WriteRecordBatch(originalBatch); + writer.WriteEnd(); + } + + stream.Position = 0; + + using (var reader = new ArrowStreamReader(stream)) + { + RecordBatch newBatch = reader.ReadNextRecordBatch(); + ArrowReaderVerifier.CompareBatches(originalBatch, newBatch); + } + } + } + + + private static async Task TestRoundTripRecordBatchAsync(RecordBatch originalBatch, IpcOptions options = null) { using (MemoryStream stream = new MemoryStream()) { @@ -146,7 +237,7 @@ private static async Task TestRoundTripRecordBatch(RecordBatch originalBatch, Ip } [Fact] - public async Task WriteBatchWithCorrectPadding() + public void WriteBatchWithCorrectPadding() { byte value1 = 0x04; byte value2 = 0x14; @@ -172,7 +263,73 @@ public async Task WriteBatchWithCorrectPadding() }, length: 1); - await TestRoundTripRecordBatch(batch); + TestRoundTripRecordBatch(batch); + + using (MemoryStream stream = new MemoryStream()) + { + using (var writer = new ArrowStreamWriter(stream, batch.Schema, leaveOpen: true)) + { + writer.WriteRecordBatch(batch); + writer.WriteEnd(); + } + + byte[] writtenBytes = stream.ToArray(); + + // ensure that the data buffers at the end are 8-byte aligned + Assert.Equal(value1, writtenBytes[writtenBytes.Length - 24]); + Assert.Equal(value1, writtenBytes[writtenBytes.Length - 23]); + for (int i = 22; i > 16; i--) + { + Assert.Equal(0, writtenBytes[writtenBytes.Length - i]); + } + + Assert.Equal(value2, writtenBytes[writtenBytes.Length - 16]); + Assert.Equal(value2, writtenBytes[writtenBytes.Length - 15]); + for (int i = 14; i > 8; i--) + { + Assert.Equal(0, writtenBytes[writtenBytes.Length - i]); + } + + // verify the EOS is written correctly + for (int i = 8; i > 4; i--) + { + Assert.Equal(0xFF, writtenBytes[writtenBytes.Length - i]); + } + for (int i = 4; i > 0; i--) + { + Assert.Equal(0x00, writtenBytes[writtenBytes.Length - i]); + } + } + } + + [Fact] + public async Task WriteBatchWithCorrectPaddingAsync() + { + byte value1 = 0x04; + byte value2 = 0x14; + var batch = new RecordBatch( + new Schema.Builder() + .Field(f => f.Name("age").DataType(Int32Type.Default)) + .Field(f => f.Name("characterCount").DataType(Int32Type.Default)) + .Build(), + new IArrowArray[] + { + new Int32Array( + new ArrowBuffer(new byte[] { value1, value1, 0x00, 0x00 }), + ArrowBuffer.Empty, + length: 1, + nullCount: 0, + offset: 0), + new Int32Array( + new ArrowBuffer(new byte[] { value2, value2, 0x00, 0x00 }), + ArrowBuffer.Empty, + length: 1, + nullCount: 0, + offset: 0) + }, + length: 1); + + await TestRoundTripRecordBatchAsync(batch); using (MemoryStream stream = new MemoryStream()) { @@ -212,16 +369,64 @@ public async Task WriteBatchWithCorrectPadding() } [Fact] - public async Task LegacyIpcFormatRoundTrips() + public void LegacyIpcFormatRoundTrips() + { + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + TestRoundTripRecordBatch(originalBatch, new IpcOptions() { WriteLegacyIpcFormat = true }); + } + + + [Fact] + public async Task LegacyIpcFormatRoundTripsAsync() { RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); - await TestRoundTripRecordBatch(originalBatch, new IpcOptions() { WriteLegacyIpcFormat = true }); + await TestRoundTripRecordBatchAsync(originalBatch, new IpcOptions() { WriteLegacyIpcFormat = true }); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void WriteLegacyIpcFormat(bool writeLegacyIpcFormat) + { + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + var options = new IpcOptions() { WriteLegacyIpcFormat = writeLegacyIpcFormat }; + + using (MemoryStream stream = new MemoryStream()) + { + using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true, options)) + { + writer.WriteRecordBatch(originalBatch); + writer.WriteEnd(); + } + + stream.Position = 0; + + // ensure the continuation is written correctly + byte[] buffer = stream.ToArray(); + int messageLength = BinaryPrimitives.ReadInt32LittleEndian(buffer); + int endOfBuffer1 = BinaryPrimitives.ReadInt32LittleEndian(buffer.AsSpan(buffer.Length - 8)); + int endOfBuffer2 = BinaryPrimitives.ReadInt32LittleEndian(buffer.AsSpan(buffer.Length - 4)); + if (writeLegacyIpcFormat) + { + // the legacy IPC format doesn't have a continuation token at the start + Assert.NotEqual(-1, messageLength); + Assert.NotEqual(-1, endOfBuffer1); + } + else + { + // the latest IPC format has a continuation token at the start + Assert.Equal(-1, messageLength); + Assert.Equal(-1, endOfBuffer1); + } + + Assert.Equal(0, endOfBuffer2); + } } [Theory] [InlineData(true)] [InlineData(false)] - public async Task WriteLegacyIpcFormat(bool writeLegacyIpcFormat) + public async Task WriteLegacyIpcFormatAsync(bool writeLegacyIpcFormat) { RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); var options = new IpcOptions() { WriteLegacyIpcFormat = writeLegacyIpcFormat };