From 3e8d322e8cf9bfc7a7977856ef5ac5b3fde6975e Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 8 Jul 2020 10:52:00 -0400 Subject: [PATCH 1/4] ARROW-9362: [Java] increment default metadata version to V5 --- .../arrow/vector/ipc/ArrowFileWriter.java | 6 + .../arrow/vector/ipc/message/IpcOption.java | 5 + .../vector/ipc/message/MessageSerializer.java | 47 +- .../arrow/vector/types/MetadataVersion.java | 63 ++ .../vector/ipc/MessageSerializerTest.java | 33 +- .../arrow/vector/ipc/TestArrowFile.java | 778 ------------------ .../arrow/vector/ipc/TestRoundTrip.java | 603 ++++++++++++++ 7 files changed, 746 insertions(+), 789 deletions(-) create mode 100644 java/vector/src/main/java/org/apache/arrow/vector/types/MetadataVersion.java create mode 100644 java/vector/src/test/java/org/apache/arrow/vector/ipc/TestRoundTrip.java diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java index fb1ca000df5..6a572c6c17e 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java @@ -63,6 +63,12 @@ public ArrowFileWriter(VectorSchemaRoot root, DictionaryProvider provider, Writa super(root, provider, out, option); } + public ArrowFileWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out, + Map metaData, IpcOption option) { + super(root, provider, out, option); + this.metaData = metaData; + } + @Override protected void startInternal(WriteChannel out) throws IOException { ArrowMagic.writeMagic(out, true); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/IpcOption.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/IpcOption.java index 81a0603fd93..c1a93dcdd63 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/IpcOption.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/IpcOption.java @@ -17,6 +17,8 @@ package org.apache.arrow.vector.ipc.message; +import org.apache.arrow.vector.types.MetadataVersion; + /** * IPC options, now only use for write. */ @@ -25,4 +27,7 @@ public class IpcOption { // Write the pre-0.15.0 encapsulated IPC message format // consisting of a 4-byte prefix instead of 8 byte public boolean write_legacy_ipc_format = false; + + // The metadata version. Defaults to V4. + public MetadataVersion metadataVersion = MetadataVersion.V5; } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageSerializer.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageSerializer.java index 59000317e62..8679088b1f4 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageSerializer.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageSerializer.java @@ -161,7 +161,7 @@ public static long serialize(WriteChannel out, Schema schema, IpcOption option) long start = out.getCurrentPosition(); Preconditions.checkArgument(start % 8 == 0, "out is not aligned"); - ByteBuffer serializedMessage = serializeMetadata(schema); + ByteBuffer serializedMessage = serializeMetadata(schema, option); int messageLength = serializedMessage.remaining(); @@ -173,10 +173,19 @@ public static long serialize(WriteChannel out, Schema schema, IpcOption option) /** * Returns the serialized flatbuffer bytes of the schema wrapped in a message table. */ + @Deprecated public static ByteBuffer serializeMetadata(Schema schema) { + return serializeMetadata(schema, new IpcOption()); + } + + /** + * Returns the serialized flatbuffer bytes of the schema wrapped in a message table. + */ + public static ByteBuffer serializeMetadata(Schema schema, IpcOption writeOption) { FlatBufferBuilder builder = new FlatBufferBuilder(); int schemaOffset = schema.getSchema(builder); - return MessageSerializer.serializeMessage(builder, org.apache.arrow.flatbuf.MessageHeader.Schema, schemaOffset, 0); + return MessageSerializer.serializeMessage(builder, org.apache.arrow.flatbuf.MessageHeader.Schema, schemaOffset, 0, + writeOption); } /** @@ -241,7 +250,7 @@ public static ArrowBlock serialize(WriteChannel out, ArrowRecordBatch batch, Ipc long bodyLength = batch.computeBodyLength(); Preconditions.checkArgument(bodyLength % 8 == 0, "batch is not aligned"); - ByteBuffer serializedMessage = serializeMetadata(batch); + ByteBuffer serializedMessage = serializeMetadata(batch, option); int metadataLength = serializedMessage.remaining(); @@ -303,11 +312,19 @@ public static long writeBatchBuffers(WriteChannel out, ArrowRecordBatch batch) t /** * Returns the serialized form of {@link RecordBatch} wrapped in a {@link org.apache.arrow.flatbuf.Message}. */ + @Deprecated public static ByteBuffer serializeMetadata(ArrowMessage message) { + return serializeMetadata(message, new IpcOption()); + } + + /** + * Returns the serialized form of {@link RecordBatch} wrapped in a {@link org.apache.arrow.flatbuf.Message}. + */ + public static ByteBuffer serializeMetadata(ArrowMessage message, IpcOption writeOption) { FlatBufferBuilder builder = new FlatBufferBuilder(); int batchOffset = message.writeTo(builder); return serializeMessage(builder, message.getMessageType(), batchOffset, - message.computeBodyLength()); + message.computeBodyLength(), writeOption); } /** @@ -446,7 +463,7 @@ public static ArrowBlock serialize(WriteChannel out, ArrowDictionaryBatch batch, long bodyLength = batch.computeBodyLength(); Preconditions.checkArgument(bodyLength % 8 == 0, "batch is not aligned"); - ByteBuffer serializedMessage = serializeMetadata(batch); + ByteBuffer serializedMessage = serializeMetadata(batch, option); int metadataLength = serializedMessage.remaining(); @@ -582,8 +599,9 @@ public static ArrowMessage deserializeMessageBatch(MessageChannelReader reader) throw new IOException("Cannot currently deserialize record batches over 2GB"); } - if (result.getMessage().version() != MetadataVersion.V4) { - throw new IOException("Received metadata with an incompatible version number"); + if (result.getMessage().version() != MetadataVersion.V4 && + result.getMessage().version() != MetadataVersion.V5) { + throw new IOException("Received metadata with an incompatible version number: " + result.getMessage().version()); } switch (result.getMessage().headerType()) { @@ -608,6 +626,15 @@ public static ArrowMessage deserializeMessageBatch(ReadChannel in, BufferAllocat return deserializeMessageBatch(new MessageChannelReader(in, alloc)); } + @Deprecated + public static ByteBuffer serializeMessage( + FlatBufferBuilder builder, + byte headerType, + int headerOffset, + long bodyLength) { + return serializeMessage(builder, headerType, headerOffset, bodyLength, new IpcOption()); + } + /** * Serializes a message header. * @@ -615,17 +642,19 @@ public static ArrowMessage deserializeMessageBatch(ReadChannel in, BufferAllocat * @param headerType headerType field * @param headerOffset header offset field * @param bodyLength body length field + * @param writeOption IPC write options * @return the corresponding ByteBuffer */ public static ByteBuffer serializeMessage( FlatBufferBuilder builder, byte headerType, int headerOffset, - long bodyLength) { + long bodyLength, + IpcOption writeOption) { Message.startMessage(builder); Message.addHeaderType(builder, headerType); Message.addHeader(builder, headerOffset); - Message.addVersion(builder, MetadataVersion.V4); + Message.addVersion(builder, writeOption.metadataVersion.toFlatbufID()); Message.addBodyLength(builder, bodyLength); builder.finish(Message.endMessage(builder)); return builder.dataBuffer(); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/MetadataVersion.java b/java/vector/src/main/java/org/apache/arrow/vector/types/MetadataVersion.java new file mode 100644 index 00000000000..9e1894052d0 --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/types/MetadataVersion.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector.types; + +/** + * Metadata version for Arrow metadata. + */ +public enum MetadataVersion { + /// 0.1.0 + V1(org.apache.arrow.flatbuf.MetadataVersion.V1), + + /// 0.2.0 + V2(org.apache.arrow.flatbuf.MetadataVersion.V2), + + /// 0.3.0 to 0.7.1 + V3(org.apache.arrow.flatbuf.MetadataVersion.V3), + + /// 0.8.0 to 0.17.1 + V4(org.apache.arrow.flatbuf.MetadataVersion.V4), + + /// >= 1.0.0 + V5(org.apache.arrow.flatbuf.MetadataVersion.V5), + + ; + + private static final MetadataVersion[] valuesByFlatbufId = + new MetadataVersion[MetadataVersion.values().length]; + + static { + for (MetadataVersion v : MetadataVersion.values()) { + valuesByFlatbufId[v.flatbufID] = v; + } + } + + private final short flatbufID; + + MetadataVersion(short flatbufID) { + this.flatbufID = flatbufID; + } + + public short toFlatbufID() { + return flatbufID; + } + + public static MetadataVersion fromFlatbufID(short id) { + return valuesByFlatbufId[id]; + } +} diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/MessageSerializerTest.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/MessageSerializerTest.java index 028e52e8c0b..a26a0ac62ad 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/MessageSerializerTest.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/MessageSerializerTest.java @@ -38,7 +38,9 @@ import org.apache.arrow.vector.ipc.message.ArrowFieldNode; import org.apache.arrow.vector.ipc.message.ArrowMessage; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.ipc.message.IpcOption; import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.MetadataVersion; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.apache.arrow.vector.types.pojo.Field; @@ -148,7 +150,7 @@ public void testSchemaDictionaryMessageSerialization() throws IOException { public ExpectedException expectedEx = ExpectedException.none(); @Test - public void testSerializeRecordBatch() throws IOException { + public void testSerializeRecordBatchV4() throws IOException { byte[] validity = new byte[] {(byte) 255, 0}; // second half is "undefined" byte[] values = new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; @@ -160,8 +162,35 @@ public void testSerializeRecordBatch() throws IOException { ArrowRecordBatch batch = new ArrowRecordBatch( 16, asList(new ArrowFieldNode(16, 8)), asList(validityb, valuesb)); + IpcOption option = new IpcOption(); + option.metadataVersion = MetadataVersion.V4; ByteArrayOutputStream out = new ByteArrayOutputStream(); - MessageSerializer.serialize(new WriteChannel(Channels.newChannel(out)), batch); + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(out)), batch, option); + + ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); + ReadChannel channel = new ReadChannel(Channels.newChannel(in)); + ArrowMessage deserialized = MessageSerializer.deserializeMessageBatch(channel, alloc); + assertEquals(ArrowRecordBatch.class, deserialized.getClass()); + verifyBatch((ArrowRecordBatch) deserialized, validity, values); + } + + @Test + public void testSerializeRecordBatchV5() throws IOException { + byte[] validity = new byte[] {(byte) 255, 0}; + // second half is "undefined" + byte[] values = new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + + BufferAllocator alloc = new RootAllocator(Long.MAX_VALUE); + ArrowBuf validityb = buf(alloc, validity); + ArrowBuf valuesb = buf(alloc, values); + + ArrowRecordBatch batch = new ArrowRecordBatch( + 16, asList(new ArrowFieldNode(16, 8)), asList(validityb, valuesb)); + + IpcOption option = new IpcOption(); + option.metadataVersion = MetadataVersion.V5; + ByteArrayOutputStream out = new ByteArrayOutputStream(); + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(out)), batch, option); ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); ReadChannel channel = new ReadChannel(Channels.newChannel(in)); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowFile.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowFile.java index 958cde9c754..4fb58227860 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowFile.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowFile.java @@ -19,51 +19,26 @@ import static java.nio.channels.Channels.newChannel; import static org.apache.arrow.vector.TestUtils.newVarCharVector; -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.File; -import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; import java.nio.charset.StandardCharsets; -import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; import java.util.List; -import java.util.Map; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.Collections2; import org.apache.arrow.vector.FieldVector; -import org.apache.arrow.vector.FixedSizeBinaryVector; -import org.apache.arrow.vector.Float4Vector; -import org.apache.arrow.vector.IntVector; -import org.apache.arrow.vector.TinyIntVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.VectorUnloader; -import org.apache.arrow.vector.complex.FixedSizeListVector; import org.apache.arrow.vector.complex.StructVector; -import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider; -import org.apache.arrow.vector.ipc.message.ArrowBlock; -import org.apache.arrow.vector.ipc.message.ArrowBuffer; -import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; -import org.apache.arrow.vector.types.FloatingPointPrecision; -import org.apache.arrow.vector.types.Types.MinorType; -import org.apache.arrow.vector.types.pojo.ArrowType; -import org.apache.arrow.vector.types.pojo.ArrowType.FixedSizeBinary; -import org.apache.arrow.vector.types.pojo.ArrowType.FixedSizeList; -import org.apache.arrow.vector.types.pojo.ArrowType.Int; import org.apache.arrow.vector.types.pojo.Field; -import org.apache.arrow.vector.types.pojo.FieldType; -import org.apache.arrow.vector.types.pojo.Schema; -import org.junit.Assert; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -97,759 +72,6 @@ public void testWriteComplex() throws IOException { } } - @Test - public void testWriteRead() throws IOException { - File file = new File("target/mytest.arrow"); - ByteArrayOutputStream stream = new ByteArrayOutputStream(); - int count = COUNT; - - // write - try (BufferAllocator originalVectorAllocator = - allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE); - StructVector parent = StructVector.empty("parent", originalVectorAllocator)) { - writeData(count, parent); - write(parent.getChild("root"), file, stream); - } - - // read - try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - FileInputStream fileInputStream = new FileInputStream(file); - ArrowFileReader arrowReader = new ArrowFileReader(fileInputStream.getChannel(), readerAllocator)) { - - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - VectorUnloader unloader = new VectorUnloader(root); - Schema schema = root.getSchema(); - LOGGER.debug("reading schema: " + schema); - for (ArrowBlock rbBlock : arrowReader.getRecordBlocks()) { - arrowReader.loadRecordBatch(rbBlock); - assertEquals(count, root.getRowCount()); - ArrowRecordBatch batch = unloader.getRecordBatch(); - List buffersLayout = batch.getBuffersLayout(); - for (ArrowBuffer arrowBuffer : buffersLayout) { - assertEquals(0, arrowBuffer.getOffset() % 8); - } - validateContent(count, root); - batch.close(); - } - } - - // Read from stream. - try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); - ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator)) { - - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - VectorUnloader unloader = new VectorUnloader(root); - Schema schema = root.getSchema(); - LOGGER.debug("reading schema: " + schema); - Assert.assertTrue(arrowReader.loadNextBatch()); - ArrowRecordBatch batch = unloader.getRecordBatch(); - List buffersLayout = batch.getBuffersLayout(); - for (ArrowBuffer arrowBuffer : buffersLayout) { - assertEquals(0, arrowBuffer.getOffset() % 8); - } - batch.close(); - assertEquals(count, root.getRowCount()); - validateContent(count, root); - } - } - - @Test - public void testWriteReadComplex() throws IOException { - File file = new File("target/mytest_complex.arrow"); - ByteArrayOutputStream stream = new ByteArrayOutputStream(); - int count = COUNT; - - // write - try (BufferAllocator originalVectorAllocator = - allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE); - StructVector parent = StructVector.empty("parent", originalVectorAllocator)) { - writeComplexData(count, parent); - write(parent.getChild("root"), file, stream); - } - - // read - try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - FileInputStream fileInputStream = new FileInputStream(file); - ArrowFileReader arrowReader = new ArrowFileReader(fileInputStream.getChannel(), readerAllocator)) { - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - Schema schema = root.getSchema(); - LOGGER.debug("reading schema: " + schema); - - for (ArrowBlock rbBlock : arrowReader.getRecordBlocks()) { - arrowReader.loadRecordBatch(rbBlock); - assertEquals(count, root.getRowCount()); - validateComplexContent(count, root); - } - } - - // Read from stream. - try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); - ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator)) { - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - Schema schema = root.getSchema(); - LOGGER.debug("reading schema: " + schema); - Assert.assertTrue(arrowReader.loadNextBatch()); - assertEquals(count, root.getRowCount()); - validateComplexContent(count, root); - } - } - - @Test - public void testWriteReadMultipleRBs() throws IOException { - File file = new File("target/mytest_multiple.arrow"); - ByteArrayOutputStream stream = new ByteArrayOutputStream(); - int[] counts = {10, 5}; - - // write - try (BufferAllocator originalVectorAllocator = - allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE); - StructVector parent = StructVector.empty("parent", originalVectorAllocator); - FileOutputStream fileOutputStream = new FileOutputStream(file)) { - writeData(counts[0], parent); - VectorSchemaRoot root = new VectorSchemaRoot(parent.getChild("root")); - - try (ArrowFileWriter fileWriter = new ArrowFileWriter(root, null, fileOutputStream.getChannel()); - ArrowStreamWriter streamWriter = new ArrowStreamWriter(root, null, stream)) { - fileWriter.start(); - streamWriter.start(); - - fileWriter.writeBatch(); - streamWriter.writeBatch(); - - parent.allocateNew(); - // if we write the same data we don't catch that the metadata is stored in the wrong order. - writeData(counts[1], parent); - root.setRowCount(counts[1]); - - fileWriter.writeBatch(); - streamWriter.writeBatch(); - - fileWriter.end(); - streamWriter.end(); - } - } - - // read file - try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - FileInputStream fileInputStream = new FileInputStream(file); - ArrowFileReader arrowReader = new ArrowFileReader(fileInputStream.getChannel(), readerAllocator)) { - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - Schema schema = root.getSchema(); - LOGGER.debug("reading schema: " + schema); - int i = 0; - List recordBatches = arrowReader.getRecordBlocks(); - assertEquals(2, recordBatches.size()); - long previousOffset = 0; - for (ArrowBlock rbBlock : recordBatches) { - Assert.assertTrue(rbBlock.getOffset() + " > " + previousOffset, rbBlock.getOffset() > previousOffset); - previousOffset = rbBlock.getOffset(); - arrowReader.loadRecordBatch(rbBlock); - assertEquals("RB #" + i, counts[i], root.getRowCount()); - validateContent(counts[i], root); - ++i; - } - } - - // read stream - try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); - ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator)) { - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - Schema schema = root.getSchema(); - LOGGER.debug("reading schema: " + schema); - int i = 0; - - for (int n = 0; n < 2; n++) { - Assert.assertTrue(arrowReader.loadNextBatch()); - assertEquals("RB #" + i, counts[i], root.getRowCount()); - validateContent(counts[i], root); - ++i; - } - Assert.assertFalse(arrowReader.loadNextBatch()); - } - } - - @Test - public void testWriteReadUnion() throws IOException { - File file = new File("target/mytest_write_union.arrow"); - ByteArrayOutputStream stream = new ByteArrayOutputStream(); - int count = COUNT; - - // write - try (BufferAllocator vectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE); - StructVector parent = StructVector.empty("parent", vectorAllocator)) { - writeUnionData(count, parent); - validateUnionData(count, new VectorSchemaRoot(parent.getChild("root"))); - write(parent.getChild("root"), file, stream); - } - - // read file - try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - FileInputStream fileInputStream = new FileInputStream(file); - ArrowFileReader arrowReader = new ArrowFileReader(fileInputStream.getChannel(), readerAllocator)) { - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - Schema schema = root.getSchema(); - LOGGER.debug("reading schema: " + schema); - Assert.assertTrue(arrowReader.loadNextBatch()); - validateUnionData(count, root); - } - - // Read from stream. - try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); - ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator)) { - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - Schema schema = root.getSchema(); - LOGGER.debug("reading schema: " + schema); - Assert.assertTrue(arrowReader.loadNextBatch()); - validateUnionData(count, root); - } - } - - @Test - public void testWriteReadTiny() throws IOException { - File file = new File("target/mytest_write_tiny.arrow"); - ByteArrayOutputStream stream = new ByteArrayOutputStream(); - - try (VectorSchemaRoot root = VectorSchemaRoot.create(MessageSerializerTest.testSchema(), allocator)) { - root.getFieldVectors().get(0).allocateNew(); - TinyIntVector vector = (TinyIntVector) root.getFieldVectors().get(0); - for (int i = 0; i < 16; i++) { - vector.set(i, i < 8 ? 1 : 0, (byte) (i + 1)); - } - vector.setValueCount(16); - root.setRowCount(16); - - // write file - try (FileOutputStream fileOutputStream = new FileOutputStream(file); - ArrowFileWriter arrowWriter = new ArrowFileWriter(root, null, fileOutputStream.getChannel())) { - LOGGER.debug("writing schema: " + root.getSchema()); - arrowWriter.start(); - arrowWriter.writeBatch(); - arrowWriter.end(); - } - // write stream - try (ArrowStreamWriter arrowWriter = new ArrowStreamWriter(root, null, stream)) { - arrowWriter.start(); - arrowWriter.writeBatch(); - arrowWriter.end(); - } - } - - // read file - try (BufferAllocator readerAllocator = allocator.newChildAllocator("fileReader", 0, Integer.MAX_VALUE); - FileInputStream fileInputStream = new FileInputStream(file); - ArrowFileReader arrowReader = new ArrowFileReader(fileInputStream.getChannel(), readerAllocator)) { - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - Schema schema = root.getSchema(); - LOGGER.debug("reading schema: " + schema); - Assert.assertTrue(arrowReader.loadNextBatch()); - validateTinyData(root); - } - - // Read from stream. - try (BufferAllocator readerAllocator = allocator.newChildAllocator("streamReader", 0, Integer.MAX_VALUE); - ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); - ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator)) { - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - Schema schema = root.getSchema(); - LOGGER.debug("reading schema: " + schema); - Assert.assertTrue(arrowReader.loadNextBatch()); - validateTinyData(root); - } - } - - private void validateTinyData(VectorSchemaRoot root) { - assertEquals(16, root.getRowCount()); - TinyIntVector vector = (TinyIntVector) root.getFieldVectors().get(0); - for (int i = 0; i < 16; i++) { - if (i < 8) { - assertEquals((byte) (i + 1), vector.get(i)); - } else { - Assert.assertTrue(vector.isNull(i)); - } - } - } - - @Test - public void testWriteReadMetadata() throws IOException { - File file = new File("target/mytest_metadata.arrow"); - ByteArrayOutputStream stream = new ByteArrayOutputStream(); - - List childFields = new ArrayList<>(); - childFields.add(new Field("varchar-child", new FieldType(true, ArrowType.Utf8.INSTANCE, null, metadata(1)), null)); - childFields.add(new Field("float-child", - new FieldType(true, new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE), null, metadata(2)), null)); - childFields.add(new Field("int-child", new FieldType(false, new ArrowType.Int(32, true), null, metadata(3)), null)); - childFields.add(new Field("list-child", new FieldType(true, ArrowType.List.INSTANCE, null, metadata(4)), - Collections2.asImmutableList(new Field("l1", FieldType.nullable(new ArrowType.Int(16, true)), null)))); - Field field = new Field("meta", new FieldType(true, ArrowType.Struct.INSTANCE, null, metadata(0)), childFields); - Map metadata = new HashMap<>(); - metadata.put("s1", "v1"); - metadata.put("s2", "v2"); - Schema originalSchema = new Schema(Collections2.asImmutableList(field), metadata); - assertEquals(metadata, originalSchema.getCustomMetadata()); - - // write - try (BufferAllocator originalVectorAllocator = - allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE); - StructVector vector = (StructVector) field.createVector(originalVectorAllocator)) { - vector.allocateNewSafe(); - vector.setValueCount(0); - - List vectors = Collections2.asImmutableList(vector); - VectorSchemaRoot root = new VectorSchemaRoot(originalSchema, vectors, 0); - - try (FileOutputStream fileOutputStream = new FileOutputStream(file); - ArrowFileWriter fileWriter = new ArrowFileWriter(root, null, fileOutputStream.getChannel()); - ArrowStreamWriter streamWriter = new ArrowStreamWriter(root, null, stream)) { - LOGGER.debug("writing schema: " + root.getSchema()); - fileWriter.start(); - streamWriter.start(); - fileWriter.writeBatch(); - streamWriter.writeBatch(); - fileWriter.end(); - streamWriter.end(); - } - } - - // read from file - try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - FileInputStream fileInputStream = new FileInputStream(file); - ArrowFileReader arrowReader = new ArrowFileReader(fileInputStream.getChannel(), readerAllocator)) { - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - Schema schema = root.getSchema(); - LOGGER.debug("reading schema: " + schema); - assertEquals(originalSchema, schema); - assertEquals(originalSchema.getCustomMetadata(), schema.getCustomMetadata()); - Field top = schema.getFields().get(0); - assertEquals(metadata(0), top.getMetadata()); - for (int i = 0; i < 4; i++) { - assertEquals(metadata(i + 1), top.getChildren().get(i).getMetadata()); - } - } - - // Read from stream - try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); - ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator)) { - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - Schema schema = root.getSchema(); - LOGGER.debug("reading schema: " + schema); - assertEquals(originalSchema, schema); - assertEquals(originalSchema.getCustomMetadata(), schema.getCustomMetadata()); - Field top = schema.getFields().get(0); - assertEquals(metadata(0), top.getMetadata()); - for (int i = 0; i < 4; i++) { - assertEquals(metadata(i + 1), top.getChildren().get(i).getMetadata()); - } - } - } - - private Map metadata(int i) { - Map map = new HashMap<>(); - map.put("k_" + i, "v_" + i); - map.put("k2_" + i, "v2_" + i); - return Collections.unmodifiableMap(map); - } - - @Test - public void testWriteReadDictionary() throws IOException { - File file = new File("target/mytest_dict.arrow"); - ByteArrayOutputStream stream = new ByteArrayOutputStream(); - int numDictionaryBlocksWritten = 0; - - // write - try (BufferAllocator originalVectorAllocator = - allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE)) { - - MapDictionaryProvider provider = new MapDictionaryProvider(); - - try (VectorSchemaRoot root = writeFlatDictionaryData(originalVectorAllocator, provider); - FileOutputStream fileOutputStream = new FileOutputStream(file); - ArrowFileWriter fileWriter = new ArrowFileWriter(root, provider, fileOutputStream.getChannel()); - ArrowStreamWriter streamWriter = new ArrowStreamWriter(root, provider, stream)) { - LOGGER.debug("writing schema: " + root.getSchema()); - fileWriter.start(); - streamWriter.start(); - fileWriter.writeBatch(); - streamWriter.writeBatch(); - fileWriter.end(); - streamWriter.end(); - numDictionaryBlocksWritten = fileWriter.getDictionaryBlocks().size(); - } - - // Need to close dictionary vectors - for (long id : provider.getDictionaryIds()) { - provider.lookup(id).getVector().close(); - } - } - - // read from file - try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - FileInputStream fileInputStream = new FileInputStream(file); - ArrowFileReader arrowReader = new ArrowFileReader(fileInputStream.getChannel(), readerAllocator)) { - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - Schema schema = root.getSchema(); - LOGGER.debug("reading schema: " + schema); - Assert.assertTrue(arrowReader.loadNextBatch()); - validateFlatDictionary(root, arrowReader); - assertEquals(numDictionaryBlocksWritten, arrowReader.getDictionaryBlocks().size()); - } - - // Read from stream - try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); - ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator)) { - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - Schema schema = root.getSchema(); - LOGGER.debug("reading schema: " + schema); - Assert.assertTrue(arrowReader.loadNextBatch()); - validateFlatDictionary(root, arrowReader); - } - } - - @Test - public void testWriteReadNestedDictionary() throws IOException { - File file = new File("target/mytest_dict_nested.arrow"); - ByteArrayOutputStream stream = new ByteArrayOutputStream(); - - // data being written: - // [['foo', 'bar'], ['foo'], ['bar']] -> [[0, 1], [0], [1]] - - // write - try ( - BufferAllocator vectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE) - ) { - MapDictionaryProvider provider = new MapDictionaryProvider(); - - try (VectorSchemaRoot root = writeNestedDictionaryData(vectorAllocator, provider); - FileOutputStream fileOutputStream = new FileOutputStream(file); - ArrowFileWriter fileWriter = new ArrowFileWriter(root, provider, fileOutputStream.getChannel()); - ArrowStreamWriter streamWriter = new ArrowStreamWriter(root, provider, stream)) { - - validateNestedDictionary(root, provider); - - LOGGER.debug("writing schema: " + root.getSchema()); - fileWriter.start(); - streamWriter.start(); - fileWriter.writeBatch(); - streamWriter.writeBatch(); - fileWriter.end(); - streamWriter.end(); - } - - // Need to close dictionary vectors - for (long id : provider.getDictionaryIds()) { - provider.lookup(id).getVector().close(); - } - } - - // read from file - try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - FileInputStream fileInputStream = new FileInputStream(file); - ArrowFileReader arrowReader = new ArrowFileReader(fileInputStream.getChannel(), readerAllocator)) { - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - Schema schema = root.getSchema(); - LOGGER.debug("reading schema: " + schema); - Assert.assertTrue(arrowReader.loadNextBatch()); - validateNestedDictionary(root, arrowReader); - } - - // Read from stream - try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); - ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator)) { - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - Schema schema = root.getSchema(); - LOGGER.debug("reading schema: " + schema); - Assert.assertTrue(arrowReader.loadNextBatch()); - validateNestedDictionary(root, arrowReader); - } - } - - @Test - public void testWriteReadFixedSizeBinary() throws IOException { - File file = new File("target/mytest_fixed_size_binary.arrow"); - ByteArrayOutputStream stream = new ByteArrayOutputStream(); - final int numValues = 10; - final int typeWidth = 11; - byte[][] byteValues = new byte[numValues][typeWidth]; - for (int i = 0; i < numValues; i++) { - for (int j = 0; j < typeWidth; j++) { - byteValues[i][j] = ((byte) i); - } - } - - // write - try (BufferAllocator originalVectorAllocator = - allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE); - StructVector parent = StructVector.empty("parent", originalVectorAllocator)) { - FixedSizeBinaryVector fixedSizeBinaryVector = parent.addOrGet("fixed-binary", - FieldType.nullable(new FixedSizeBinary(typeWidth)), FixedSizeBinaryVector.class); - parent.allocateNew(); - for (int i = 0; i < numValues; i++) { - fixedSizeBinaryVector.set(i, byteValues[i]); - } - parent.setValueCount(numValues); - write(parent, file, stream); - } - - // read - try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - FileInputStream fileInputStream = new FileInputStream(file); - ArrowFileReader arrowReader = new ArrowFileReader(fileInputStream.getChannel(), readerAllocator)) { - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - Schema schema = root.getSchema(); - LOGGER.debug("reading schema: " + schema); - - for (ArrowBlock rbBlock : arrowReader.getRecordBlocks()) { - arrowReader.loadRecordBatch(rbBlock); - assertEquals(numValues, root.getRowCount()); - for (int i = 0; i < numValues; i++) { - Assert.assertArrayEquals(byteValues[i], ((byte[]) root.getVector("fixed-binary").getObject(i))); - } - } - } - - // read from stream - try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); - ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator)) { - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - Schema schema = root.getSchema(); - LOGGER.debug("reading schema: " + schema); - arrowReader.loadNextBatch(); - assertEquals(numValues, root.getRowCount()); - for (int i = 0; i < numValues; i++) { - Assert.assertArrayEquals(byteValues[i], ((byte[]) root.getVector("fixed-binary").getObject(i))); - } - } - } - - @Test - public void testWriteReadFixedSizeList() throws IOException { - File file = new File("target/mytest_fixed_list.arrow"); - ByteArrayOutputStream stream = new ByteArrayOutputStream(); - int count = COUNT; - - // write - try (BufferAllocator originalVectorAllocator = - allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE); - StructVector parent = StructVector.empty("parent", originalVectorAllocator)) { - FixedSizeListVector tuples = parent.addOrGet("float-pairs", - FieldType.nullable(new FixedSizeList(2)), FixedSizeListVector.class); - Float4Vector floats = (Float4Vector) tuples.addOrGetVector(FieldType.nullable(MinorType.FLOAT4.getType())) - .getVector(); - IntVector ints = parent.addOrGet("ints", FieldType.nullable(new Int(32, true)), IntVector.class); - parent.allocateNew(); - - for (int i = 0; i < 10; i++) { - tuples.setNotNull(i); - floats.set(i * 2, i + 0.1f); - floats.set(i * 2 + 1, i + 10.1f); - ints.set(i, i); - } - - parent.setValueCount(10); - write(parent, file, stream); - } - - // read - try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - FileInputStream fileInputStream = new FileInputStream(file); - ArrowFileReader arrowReader = new ArrowFileReader(fileInputStream.getChannel(), readerAllocator)) { - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - Schema schema = root.getSchema(); - LOGGER.debug("reading schema: " + schema); - - for (ArrowBlock rbBlock : arrowReader.getRecordBlocks()) { - arrowReader.loadRecordBatch(rbBlock); - assertEquals(count, root.getRowCount()); - for (int i = 0; i < 10; i++) { - assertEquals(Collections2.asImmutableList(i + 0.1f, i + 10.1f), root.getVector("float-pairs") - .getObject(i)); - assertEquals(i, root.getVector("ints").getObject(i)); - } - } - } - - // read from stream - try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); - ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator)) { - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - Schema schema = root.getSchema(); - LOGGER.debug("reading schema: " + schema); - arrowReader.loadNextBatch(); - assertEquals(count, root.getRowCount()); - for (int i = 0; i < 10; i++) { - assertEquals(Collections2.asImmutableList(i + 0.1f, i + 10.1f), root.getVector("float-pairs") - .getObject(i)); - assertEquals(i, root.getVector("ints").getObject(i)); - } - } - } - - @Test - public void testWriteReadVarBin() throws IOException { - File file = new File("target/mytest_varbin.arrow"); - ByteArrayOutputStream stream = new ByteArrayOutputStream(); - int count = COUNT; - - // write - try ( - BufferAllocator vectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE); - StructVector parent = StructVector.empty("parent", vectorAllocator)) { - writeVarBinaryData(count, parent); - VectorSchemaRoot root = new VectorSchemaRoot(parent.getChild("root")); - validateVarBinary(count, root); - write(parent.getChild("root"), file, stream); - } - - // read from file - try ( - BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - FileInputStream fileInputStream = new FileInputStream(file); - ArrowFileReader arrowReader = new ArrowFileReader(fileInputStream.getChannel(), readerAllocator)) { - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - Schema schema = root.getSchema(); - LOGGER.debug("reading schema: " + schema); - Assert.assertTrue(arrowReader.loadNextBatch()); - validateVarBinary(count, root); - } - - // read from stream - try ( - BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); - ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator)) { - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - Schema schema = root.getSchema(); - LOGGER.debug("reading schema: " + schema); - Assert.assertTrue(arrowReader.loadNextBatch()); - validateVarBinary(count, root); - } - } - - @Test - public void testReadWriteMultipleBatches() throws IOException { - File file = new File("target/mytest_nulls_multibatch.arrow"); - int numBlocksWritten = 0; - - try (IntVector vector = new IntVector("foo", allocator);) { - Schema schema = new Schema(Collections.singletonList(vector.getField())); - try (FileOutputStream fileOutputStream = new FileOutputStream(file); - VectorSchemaRoot root = - new VectorSchemaRoot(schema, Collections.singletonList((FieldVector) vector), vector.getValueCount()); - ArrowFileWriter writer = new ArrowFileWriter(root, null, fileOutputStream.getChannel());) { - writeBatchData(writer, vector, root); - numBlocksWritten = writer.getRecordBlocks().size(); - } - } - - try (FileInputStream fileInputStream = new FileInputStream(file); - ArrowFileReader reader = new ArrowFileReader(fileInputStream.getChannel(), allocator);) { - IntVector vector = (IntVector) reader.getVectorSchemaRoot().getFieldVectors().get(0); - validateBatchData(reader, vector); - assertEquals(numBlocksWritten, reader.getRecordBlocks().size()); - } - } - - @Test - public void testWriteReadMapVector() throws IOException { - File file = new File("target/mytest_map.arrow"); - ByteArrayOutputStream stream = new ByteArrayOutputStream(); - - // write - try (BufferAllocator originalVectorAllocator = - allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE)) { - - try (VectorSchemaRoot root = writeMapData(originalVectorAllocator); - FileOutputStream fileOutputStream = new FileOutputStream(file); - ArrowFileWriter fileWriter = new ArrowFileWriter(root, null, fileOutputStream.getChannel()); - ArrowStreamWriter streamWriter = new ArrowStreamWriter(root, null, stream)) { - LOGGER.debug("writing schema: " + root.getSchema()); - fileWriter.start(); - streamWriter.start(); - fileWriter.writeBatch(); - streamWriter.writeBatch(); - fileWriter.end(); - streamWriter.end(); - } - } - - // read from file - try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - FileInputStream fileInputStream = new FileInputStream(file); - ArrowFileReader arrowReader = new ArrowFileReader(fileInputStream.getChannel(), readerAllocator)) { - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - Schema schema = root.getSchema(); - LOGGER.debug("reading schema: " + schema); - Assert.assertTrue(arrowReader.loadNextBatch()); - validateMapData(root); - } - - // Read from stream - try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); - ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator)) { - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - Schema schema = root.getSchema(); - LOGGER.debug("reading schema: " + schema); - Assert.assertTrue(arrowReader.loadNextBatch()); - validateMapData(root); - } - } - - @Test - public void testWriteReadListAsMap() throws IOException { - File file = new File("target/mytest_list_as_map.arrow"); - ByteArrayOutputStream stream = new ByteArrayOutputStream(); - - // write - try (BufferAllocator originalVectorAllocator = - allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE)) { - - try (VectorSchemaRoot root = writeListAsMapData(originalVectorAllocator); - FileOutputStream fileOutputStream = new FileOutputStream(file); - ArrowFileWriter fileWriter = new ArrowFileWriter(root, null, fileOutputStream.getChannel()); - ArrowStreamWriter streamWriter = new ArrowStreamWriter(root, null, stream)) { - LOGGER.debug("writing schema: " + root.getSchema()); - fileWriter.start(); - streamWriter.start(); - fileWriter.writeBatch(); - streamWriter.writeBatch(); - fileWriter.end(); - streamWriter.end(); - } - } - - // read from file - try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - FileInputStream fileInputStream = new FileInputStream(file); - ArrowFileReader arrowReader = new ArrowFileReader(fileInputStream.getChannel(), readerAllocator)) { - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - Schema schema = root.getSchema(); - LOGGER.debug("reading schema: " + schema); - Assert.assertTrue(arrowReader.loadNextBatch()); - validateListAsMapData(root); - } - - // Read from stream - try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); - ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator)) { - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - Schema schema = root.getSchema(); - LOGGER.debug("reading schema: " + schema); - Assert.assertTrue(arrowReader.loadNextBatch()); - validateListAsMapData(root); - } - } - /** * Writes the contents of parents to file. If outStream is non-null, also writes it * to outStream in the streaming serialized format. diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestRoundTrip.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestRoundTrip.java new file mode 100644 index 00000000000..3baf949f8b0 --- /dev/null +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestRoundTrip.java @@ -0,0 +1,603 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector.ipc; + +import static org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.channels.Channels; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiConsumer; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.Collections2; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.FixedSizeBinaryVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.message.ArrowBlock; +import org.apache.arrow.vector.ipc.message.ArrowBuffer; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.ipc.message.IpcOption; +import org.apache.arrow.vector.ipc.message.MessageMetadataResult; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.MetadataVersion; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.AfterClass; +import org.junit.Assume; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@RunWith(Parameterized.class) +public class TestRoundTrip extends BaseFileTest { + private static final Logger LOGGER = LoggerFactory.getLogger(TestRoundTrip.class); + private static BufferAllocator allocator; + private final String name; + private final IpcOption writeOption; + + public TestRoundTrip(String name, IpcOption writeOption) { + this.name = name; + this.writeOption = writeOption; + } + + @Parameterized.Parameters(name = "options = {0}") + public static Collection getWriteOption() { + final IpcOption legacy = new IpcOption(); + legacy.metadataVersion = MetadataVersion.V4; + legacy.write_legacy_ipc_format = true; + final IpcOption version4 = new IpcOption(); + version4.metadataVersion = MetadataVersion.V4; + return Arrays.asList( + new Object[] {"V4Legacy", legacy}, + new Object[] {"V4", version4}, + new Object[] {"V5", new IpcOption()} + ); + } + + @BeforeClass + public static void setUpClass() { + allocator = new RootAllocator(Integer.MAX_VALUE); + } + + @AfterClass + public static void tearDownClass() { + allocator.close(); + } + + @Test + public void testStruct() throws Exception { + try (final BufferAllocator originalVectorAllocator = + allocator.newChildAllocator("original vectors", 0, allocator.getLimit()); + final StructVector parent = StructVector.empty("parent", originalVectorAllocator)) { + writeData(COUNT, parent); + roundTrip( + new VectorSchemaRoot(parent.getChild("root")), + /* dictionaryProvider */null, + TestRoundTrip::writeSingleBatch, + validateFileBatches(new int[] {COUNT}, this::validateContent), + validateStreamBatches(new int[] {COUNT}, this::validateContent)); + } + } + + @Test + public void testComplex() throws Exception { + try (final BufferAllocator originalVectorAllocator = + allocator.newChildAllocator("original vectors", 0, allocator.getLimit()); + final StructVector parent = StructVector.empty("parent", originalVectorAllocator)) { + writeComplexData(COUNT, parent); + roundTrip( + new VectorSchemaRoot(parent.getChild("root")), + /* dictionaryProvider */null, + TestRoundTrip::writeSingleBatch, + validateFileBatches(new int[] {COUNT}, this::validateComplexContent), + validateStreamBatches(new int[] {COUNT}, this::validateComplexContent)); + } + } + + @Test + public void testMultipleRecordBatches() throws Exception { + int[] counts = {10, 5}; + try (final BufferAllocator originalVectorAllocator = + allocator.newChildAllocator("original vectors", 0, allocator.getLimit()); + final StructVector parent = StructVector.empty("parent", originalVectorAllocator)) { + writeData(counts[0], parent); + roundTrip( + new VectorSchemaRoot(parent.getChild("root")), + /* dictionaryProvider */null, + (root, writer) -> { + writer.start(); + parent.allocateNew(); + writeData(counts[0], parent); + root.setRowCount(counts[0]); + writer.writeBatch(); + + parent.allocateNew(); + // if we write the same data we don't catch that the metadata is stored in the wrong order. + writeData(counts[1], parent); + root.setRowCount(counts[1]); + writer.writeBatch(); + + writer.end(); + }, + validateFileBatches(counts, this::validateContent), + validateStreamBatches(counts, this::validateContent)); + } + } + + @Test + public void testUnion() throws Exception { + Assume.assumeTrue(writeOption.metadataVersion == MetadataVersion.V5); + try (final BufferAllocator originalVectorAllocator = + allocator.newChildAllocator("original vectors", 0, allocator.getLimit()); + final StructVector parent = StructVector.empty("parent", originalVectorAllocator)) { + writeUnionData(COUNT, parent); + VectorSchemaRoot root = new VectorSchemaRoot(parent.getChild("root")); + validateUnionData(COUNT, root); + roundTrip( + root, + /* dictionaryProvider */null, + TestRoundTrip::writeSingleBatch, + validateFileBatches(new int[] {COUNT}, this::validateUnionData), + validateStreamBatches(new int[] {COUNT}, this::validateUnionData)); + } + } + + @Test + public void testTiny() throws Exception { + try (final VectorSchemaRoot root = VectorSchemaRoot.create(MessageSerializerTest.testSchema(), allocator)) { + root.getFieldVectors().get(0).allocateNew(); + int count = 16; + TinyIntVector vector = (TinyIntVector) root.getFieldVectors().get(0); + for (int i = 0; i < count; i++) { + vector.set(i, i < 8 ? 1 : 0, (byte) (i + 1)); + } + vector.setValueCount(count); + root.setRowCount(count); + + roundTrip( + root, + /* dictionaryProvider */null, + TestRoundTrip::writeSingleBatch, + validateFileBatches(new int[] {count}, this::validateTinyData), + validateStreamBatches(new int[] {count}, this::validateTinyData)); + } + } + + private void validateTinyData(int count, VectorSchemaRoot root) { + assertEquals(count, root.getRowCount()); + TinyIntVector vector = (TinyIntVector) root.getFieldVectors().get(0); + for (int i = 0; i < count; i++) { + if (i < 8) { + assertEquals((byte) (i + 1), vector.get(i)); + } else { + assertTrue(vector.isNull(i)); + } + } + } + + @Test + public void testMetadata() throws Exception { + List childFields = new ArrayList<>(); + childFields.add(new Field("varchar-child", new FieldType(true, ArrowType.Utf8.INSTANCE, null, metadata(1)), null)); + childFields.add(new Field("float-child", + new FieldType(true, new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE), null, metadata(2)), null)); + childFields.add(new Field("int-child", new FieldType(false, new ArrowType.Int(32, true), null, metadata(3)), null)); + childFields.add(new Field("list-child", new FieldType(true, ArrowType.List.INSTANCE, null, metadata(4)), + Collections2.asImmutableList(new Field("l1", FieldType.nullable(new ArrowType.Int(16, true)), null)))); + Field field = new Field("meta", new FieldType(true, ArrowType.Struct.INSTANCE, null, metadata(0)), childFields); + Map metadata = new HashMap<>(); + metadata.put("s1", "v1"); + metadata.put("s2", "v2"); + Schema originalSchema = new Schema(Collections2.asImmutableList(field), metadata); + assertEquals(metadata, originalSchema.getCustomMetadata()); + + try (final BufferAllocator originalVectorAllocator = + allocator.newChildAllocator("original vectors", 0, allocator.getLimit()); + final StructVector vector = (StructVector) field.createVector(originalVectorAllocator)) { + vector.allocateNewSafe(); + vector.setValueCount(0); + + List vectors = Collections2.asImmutableList(vector); + VectorSchemaRoot root = new VectorSchemaRoot(originalSchema, vectors, 0); + + BiConsumer validate = (count, readRoot) -> { + Schema schema = readRoot.getSchema(); + assertEquals(originalSchema, schema); + assertEquals(originalSchema.getCustomMetadata(), schema.getCustomMetadata()); + Field top = schema.getFields().get(0); + assertEquals(metadata(0), top.getMetadata()); + for (int i = 0; i < 4; i++) { + assertEquals(metadata(i + 1), top.getChildren().get(i).getMetadata()); + } + }; + roundTrip( + root, + /* dictionaryProvider */null, + TestRoundTrip::writeSingleBatch, + validateFileBatches(new int[] {0}, validate), + validateStreamBatches(new int[] {0}, validate)); + } + } + + private Map metadata(int i) { + Map map = new HashMap<>(); + map.put("k_" + i, "v_" + i); + map.put("k2_" + i, "v2_" + i); + return Collections.unmodifiableMap(map); + } + + @Test + public void testFlatDictionary() throws Exception { + AtomicInteger numDictionaryBlocksWritten = new AtomicInteger(); + MapDictionaryProvider provider = new MapDictionaryProvider(); + try (final BufferAllocator originalVectorAllocator = + allocator.newChildAllocator("original vectors", 0, allocator.getLimit()); + final VectorSchemaRoot root = writeFlatDictionaryData(originalVectorAllocator, provider)) { + roundTrip( + root, + provider, + (ignored, writer) -> { + writer.start(); + writer.writeBatch(); + writer.end(); + if (writer instanceof ArrowFileWriter) { + numDictionaryBlocksWritten.set(((ArrowFileWriter) writer).getDictionaryBlocks().size()); + } + }, + (fileReader) -> { + VectorSchemaRoot readRoot = fileReader.getVectorSchemaRoot(); + Schema schema = readRoot.getSchema(); + LOGGER.debug("reading schema: " + schema); + assertTrue(fileReader.loadNextBatch()); + validateFlatDictionary(readRoot, fileReader); + assertEquals(numDictionaryBlocksWritten.get(), fileReader.getDictionaryBlocks().size()); + }, + (streamReader) -> { + VectorSchemaRoot readRoot = streamReader.getVectorSchemaRoot(); + Schema schema = readRoot.getSchema(); + LOGGER.debug("reading schema: " + schema); + assertTrue(streamReader.loadNextBatch()); + validateFlatDictionary(readRoot, streamReader); + }); + + // Need to close dictionary vectors + for (long id : provider.getDictionaryIds()) { + provider.lookup(id).getVector().close(); + } + } + } + + @Test + public void testNestedDictionary() throws Exception { + AtomicInteger numDictionaryBlocksWritten = new AtomicInteger(); + MapDictionaryProvider provider = new MapDictionaryProvider(); + // data being written: + // [['foo', 'bar'], ['foo'], ['bar']] -> [[0, 1], [0], [1]] + try (final BufferAllocator originalVectorAllocator = + allocator.newChildAllocator("original vectors", 0, allocator.getLimit()); + final VectorSchemaRoot root = writeNestedDictionaryData(originalVectorAllocator, provider)) { + CheckedConsumer validateDictionary = (streamReader) -> { + VectorSchemaRoot readRoot = streamReader.getVectorSchemaRoot(); + Schema schema = readRoot.getSchema(); + LOGGER.debug("reading schema: " + schema); + assertTrue(streamReader.loadNextBatch()); + validateNestedDictionary(readRoot, streamReader); + }; + roundTrip( + root, + provider, + (ignored, writer) -> { + writer.start(); + writer.writeBatch(); + writer.end(); + if (writer instanceof ArrowFileWriter) { + numDictionaryBlocksWritten.set(((ArrowFileWriter) writer).getDictionaryBlocks().size()); + } + }, + validateDictionary, + validateDictionary); + + // Need to close dictionary vectors + for (long id : provider.getDictionaryIds()) { + provider.lookup(id).getVector().close(); + } + } + } + + @Test + public void testFixedSizeBinary() throws Exception { + final int count = 10; + final int typeWidth = 11; + byte[][] byteValues = new byte[count][typeWidth]; + for (int i = 0; i < count; i++) { + for (int j = 0; j < typeWidth; j++) { + byteValues[i][j] = ((byte) i); + } + } + + BiConsumer validator = (expectedCount, root) -> { + for (int i = 0; i < expectedCount; i++) { + assertArrayEquals(byteValues[i], ((byte[]) root.getVector("fixed-binary").getObject(i))); + } + }; + + try (final BufferAllocator originalVectorAllocator = + allocator.newChildAllocator("original vectors", 0, allocator.getLimit()); + final StructVector parent = StructVector.empty("parent", originalVectorAllocator)) { + FixedSizeBinaryVector fixedSizeBinaryVector = parent.addOrGet("fixed-binary", + FieldType.nullable(new ArrowType.FixedSizeBinary(typeWidth)), FixedSizeBinaryVector.class); + parent.allocateNew(); + for (int i = 0; i < count; i++) { + fixedSizeBinaryVector.set(i, byteValues[i]); + } + parent.setValueCount(count); + + roundTrip( + new VectorSchemaRoot(parent), + /* dictionaryProvider */null, + TestRoundTrip::writeSingleBatch, + validateFileBatches(new int[] {count}, validator), + validateStreamBatches(new int[] {count}, validator)); + } + } + + @Test + public void testFixedSizeList() throws Exception { + BiConsumer validator = (expectedCount, root) -> { + for (int i = 0; i < expectedCount; i++) { + assertEquals(Collections2.asImmutableList(i + 0.1f, i + 10.1f), root.getVector("float-pairs") + .getObject(i)); + assertEquals(i, root.getVector("ints").getObject(i)); + } + }; + + try (final BufferAllocator originalVectorAllocator = + allocator.newChildAllocator("original vectors", 0, allocator.getLimit()); + final StructVector parent = StructVector.empty("parent", originalVectorAllocator)) { + FixedSizeListVector tuples = parent.addOrGet("float-pairs", + FieldType.nullable(new ArrowType.FixedSizeList(2)), FixedSizeListVector.class); + Float4Vector floats = (Float4Vector) tuples.addOrGetVector(FieldType.nullable(Types.MinorType.FLOAT4.getType())) + .getVector(); + IntVector ints = parent.addOrGet("ints", FieldType.nullable(new ArrowType.Int(32, true)), IntVector.class); + parent.allocateNew(); + for (int i = 0; i < COUNT; i++) { + tuples.setNotNull(i); + floats.set(i * 2, i + 0.1f); + floats.set(i * 2 + 1, i + 10.1f); + ints.set(i, i); + } + parent.setValueCount(COUNT); + + roundTrip( + new VectorSchemaRoot(parent), + /* dictionaryProvider */null, + TestRoundTrip::writeSingleBatch, + validateFileBatches(new int[] {COUNT}, validator), + validateStreamBatches(new int[] {COUNT}, validator)); + } + } + + @Test + public void testVarBinary() throws Exception { + try (final BufferAllocator originalVectorAllocator = + allocator.newChildAllocator("original vectors", 0, allocator.getLimit()); + final StructVector parent = StructVector.empty("parent", originalVectorAllocator)) { + writeVarBinaryData(COUNT, parent); + VectorSchemaRoot root = new VectorSchemaRoot(parent.getChild("root")); + validateVarBinary(COUNT, root); + + roundTrip( + root, + /* dictionaryProvider */null, + TestRoundTrip::writeSingleBatch, + validateFileBatches(new int[]{COUNT}, this::validateVarBinary), + validateStreamBatches(new int[]{COUNT}, this::validateVarBinary)); + } + } + + @Test + public void testReadWriteMultipleBatches() throws IOException { + File file = new File("target/mytest_nulls_multibatch.arrow"); + int numBlocksWritten = 0; + + try (IntVector vector = new IntVector("foo", allocator);) { + Schema schema = new Schema(Collections.singletonList(vector.getField())); + try (FileOutputStream fileOutputStream = new FileOutputStream(file); + VectorSchemaRoot root = + new VectorSchemaRoot(schema, Collections.singletonList((FieldVector) vector), vector.getValueCount()); + ArrowFileWriter writer = new ArrowFileWriter(root, null, fileOutputStream.getChannel(), writeOption)) { + writeBatchData(writer, vector, root); + numBlocksWritten = writer.getRecordBlocks().size(); + } + } + + try (FileInputStream fileInputStream = new FileInputStream(file); + ArrowFileReader reader = new ArrowFileReader(fileInputStream.getChannel(), allocator);) { + IntVector vector = (IntVector) reader.getVectorSchemaRoot().getFieldVectors().get(0); + validateBatchData(reader, vector); + assertEquals(numBlocksWritten, reader.getRecordBlocks().size()); + } + } + + @Test + public void testMap() throws Exception { + try (final BufferAllocator originalVectorAllocator = + allocator.newChildAllocator("original vectors", 0, allocator.getLimit()); + final VectorSchemaRoot root = writeMapData(originalVectorAllocator)) { + roundTrip( + root, + /* dictionaryProvider */null, + TestRoundTrip::writeSingleBatch, + validateFileBatches(new int[]{root.getRowCount()}, (count, readRoot) -> validateMapData(readRoot)), + validateStreamBatches(new int[]{root.getRowCount()}, (count, readRoot) -> validateMapData(readRoot))); + } + } + + @Test + public void testListAsMap() throws Exception { + try (final BufferAllocator originalVectorAllocator = + allocator.newChildAllocator("original vectors", 0, allocator.getLimit()); + final VectorSchemaRoot root = writeListAsMapData(originalVectorAllocator)) { + roundTrip( + root, + /* dictionaryProvider */null, + TestRoundTrip::writeSingleBatch, + validateFileBatches(new int[]{root.getRowCount()}, (count, readRoot) -> validateListAsMapData(readRoot)), + validateStreamBatches(new int[]{root.getRowCount()}, (count, readRoot) -> validateListAsMapData(readRoot))); + } + } + + // Generic test helpers + + private static void writeSingleBatch(VectorSchemaRoot root, ArrowWriter writer) throws IOException { + writer.start(); + writer.writeBatch(); + writer.end(); + } + + private CheckedConsumer validateFileBatches( + int[] counts, BiConsumer validator) { + return (arrowReader) -> { + VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); + VectorUnloader unloader = new VectorUnloader(root); + Schema schema = root.getSchema(); + LOGGER.debug("reading schema: " + schema); + int i = 0; + List recordBatches = arrowReader.getRecordBlocks(); + assertEquals(counts.length, recordBatches.size()); + long previousOffset = 0; + for (ArrowBlock rbBlock : recordBatches) { + assertTrue(rbBlock.getOffset() + " > " + previousOffset, rbBlock.getOffset() > previousOffset); + previousOffset = rbBlock.getOffset(); + arrowReader.loadRecordBatch(rbBlock); + assertEquals("RB #" + i, counts[i], root.getRowCount()); + validator.accept(counts[i], root); + try (final ArrowRecordBatch batch = unloader.getRecordBatch()) { + List buffersLayout = batch.getBuffersLayout(); + for (ArrowBuffer arrowBuffer : buffersLayout) { + assertEquals(0, arrowBuffer.getOffset() % 8); + } + } + ++i; + } + }; + } + + private CheckedConsumer validateStreamBatches( + int[] counts, BiConsumer validator) { + return (arrowReader) -> { + VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); + VectorUnloader unloader = new VectorUnloader(root); + Schema schema = root.getSchema(); + LOGGER.debug("reading schema: " + schema); + int i = 0; + + for (int n = 0; n < counts.length; n++) { + assertTrue(arrowReader.loadNextBatch()); + assertEquals("RB #" + i, counts[i], root.getRowCount()); + validator.accept(counts[i], root); + try (final ArrowRecordBatch batch = unloader.getRecordBatch()) { + final List buffersLayout = batch.getBuffersLayout(); + for (ArrowBuffer arrowBuffer : buffersLayout) { + assertEquals(0, arrowBuffer.getOffset() % 8); + } + } + ++i; + } + assertFalse(arrowReader.loadNextBatch()); + }; + } + + @FunctionalInterface + interface CheckedConsumer { + void accept(T t) throws Exception; + } + + @FunctionalInterface + interface CheckedBiConsumer { + void accept(T t, U u) throws Exception; + } + + private void roundTrip(VectorSchemaRoot root, DictionaryProvider provider, + CheckedBiConsumer writer, + CheckedConsumer fileValidator, + CheckedConsumer streamValidator) throws Exception { + final File temp = File.createTempFile("arrow-test-" + name + "-", ".arrow"); + temp.deleteOnExit(); + final ByteArrayOutputStream memoryStream = new ByteArrayOutputStream(); + final Map metadata = new HashMap<>(); + metadata.put("foo", "bar"); + try (final FileOutputStream fileStream = new FileOutputStream(temp); + final ArrowFileWriter fileWriter = + new ArrowFileWriter(root, provider, fileStream.getChannel(), metadata, writeOption); + final ArrowStreamWriter streamWriter = + new ArrowStreamWriter(root, provider, Channels.newChannel(memoryStream), writeOption)) { + writer.accept(root, fileWriter); + writer.accept(root, streamWriter); + } + + MessageMetadataResult metadataResult = MessageSerializer.readMessage( + new ReadChannel(Channels.newChannel(new ByteArrayInputStream(memoryStream.toByteArray())))); + assertNotNull(metadataResult); + assertEquals(writeOption.metadataVersion.toFlatbufID(), metadataResult.getMessage().version()); + + try ( + BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, allocator.getLimit()); + FileInputStream fileInputStream = new FileInputStream(temp); + ByteArrayInputStream inputStream = new ByteArrayInputStream(memoryStream.toByteArray()); + ArrowFileReader fileReader = new ArrowFileReader(fileInputStream.getChannel(), readerAllocator); + ArrowStreamReader streamReader = new ArrowStreamReader(inputStream, readerAllocator)) { + fileValidator.accept(fileReader); + streamValidator.accept(streamReader); + assertEquals(metadata, fileReader.getMetaData()); + } + } +} From a930940ae1ae9e47eeff4a5204b776039f08d29e Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 8 Jul 2020 12:57:15 -0400 Subject: [PATCH 2/4] ARROW-9362: [Java][FlightRPC] increment default metadata version to V5 --- java/flight/flight-core/pom.xml | 5 + .../org/apache/arrow/flight/ArrowMessage.java | 35 ++- .../apache/arrow/flight/DictionaryUtils.java | 10 +- .../org/apache/arrow/flight/FlightClient.java | 22 +- .../org/apache/arrow/flight/FlightInfo.java | 28 +- .../org/apache/arrow/flight/FlightStream.java | 20 +- .../arrow/flight/OutboundStreamListener.java | 18 +- .../flight/OutboundStreamListenerImpl.java | 15 +- .../org/apache/arrow/flight/SchemaResult.java | 10 +- .../arrow/flight/TestMetadataVersion.java | 277 ++++++++++++++++++ 10 files changed, 403 insertions(+), 37 deletions(-) create mode 100644 java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestMetadataVersion.java diff --git a/java/flight/flight-core/pom.xml b/java/flight/flight-core/pom.xml index bced914c8b0..08f9f821033 100644 --- a/java/flight/flight-core/pom.xml +++ b/java/flight/flight-core/pom.xml @@ -30,6 +30,11 @@ + + org.apache.arrow + arrow-format + ${project.version} + org.apache.arrow arrow-vector diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java index 032107d85bf..917d8435e06 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java @@ -37,8 +37,10 @@ import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.ipc.message.IpcOption; import org.apache.arrow.vector.ipc.message.MessageMetadataResult; import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.MetadataVersion; import org.apache.arrow.vector.types.pojo.Schema; import com.google.common.collect.ImmutableList; @@ -50,7 +52,6 @@ import com.google.protobuf.WireFormat; import io.grpc.Drainable; -import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor.Marshaller; import io.grpc.protobuf.ProtoUtils; import io.netty.buffer.ByteBuf; @@ -75,7 +76,8 @@ class ArrowMessage implements AutoCloseable { private static final int APP_METADATA_TAG = (FlightData.APP_METADATA_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED; - private static Marshaller NO_BODY_MARSHALLER = ProtoUtils.marshaller(FlightData.getDefaultInstance()); + private static final Marshaller NO_BODY_MARSHALLER = + ProtoUtils.marshaller(FlightData.getDefaultInstance()); /** Get the application-specific metadata in this message. The ArrowMessage retains ownership of the buffer. */ public ArrowBuf getApplicationMetadata() { @@ -106,7 +108,7 @@ public static HeaderType getHeader(byte b) { } // Pre-allocated buffers for padding serialized ArrowMessages. - private static List PADDING_BUFFERS = Arrays.asList( + private static final List PADDING_BUFFERS = Arrays.asList( null, Unpooled.copiedBuffer(new byte[] { 0 }), Unpooled.copiedBuffer(new byte[] { 0, 0 }), @@ -117,13 +119,15 @@ public static HeaderType getHeader(byte b) { Unpooled.copiedBuffer(new byte[] { 0, 0, 0, 0, 0, 0, 0 }) ); + private final IpcOption writeOption; private final FlightDescriptor descriptor; private final MessageMetadataResult message; private final ArrowBuf appMetadata; private final List bufs; - public ArrowMessage(FlightDescriptor descriptor, Schema schema) { - ByteBuffer serializedMessage = MessageSerializer.serializeMetadata(schema); + public ArrowMessage(FlightDescriptor descriptor, Schema schema, IpcOption option) { + this.writeOption = option; + ByteBuffer serializedMessage = MessageSerializer.serializeMetadata(schema, writeOption); this.message = MessageMetadataResult.create(serializedMessage.slice(), serializedMessage.remaining()); bufs = ImmutableList.of(); @@ -136,16 +140,18 @@ public ArrowMessage(FlightDescriptor descriptor, Schema schema) { * @param batch The record batch. * @param appMetadata The app metadata. May be null. Takes ownership of the buffer otherwise. */ - public ArrowMessage(ArrowRecordBatch batch, ArrowBuf appMetadata) { - ByteBuffer serializedMessage = MessageSerializer.serializeMetadata(batch); + public ArrowMessage(ArrowRecordBatch batch, ArrowBuf appMetadata, IpcOption option) { + this.writeOption = option; + ByteBuffer serializedMessage = MessageSerializer.serializeMetadata(batch, writeOption); this.message = MessageMetadataResult.create(serializedMessage.slice(), serializedMessage.remaining()); this.bufs = ImmutableList.copyOf(batch.getBuffers()); this.descriptor = null; this.appMetadata = appMetadata; } - public ArrowMessage(ArrowDictionaryBatch batch) { - ByteBuffer serializedMessage = MessageSerializer.serializeMetadata(batch); + public ArrowMessage(ArrowDictionaryBatch batch, IpcOption option) { + this.writeOption = new IpcOption(); + ByteBuffer serializedMessage = MessageSerializer.serializeMetadata(batch, writeOption); serializedMessage = serializedMessage.slice(); this.message = MessageMetadataResult.create(serializedMessage, serializedMessage.remaining()); // asInputStream will free the buffers implicitly, so increment the reference count @@ -160,6 +166,8 @@ public ArrowMessage(ArrowDictionaryBatch batch) { * @param appMetadata The application-provided metadata buffer. */ public ArrowMessage(ArrowBuf appMetadata) { + // No need to take IpcOption as it's not used to serialize this kind of message. + this.writeOption = new IpcOption(); this.message = null; this.bufs = ImmutableList.of(); this.descriptor = null; @@ -167,6 +175,8 @@ public ArrowMessage(ArrowBuf appMetadata) { } public ArrowMessage(FlightDescriptor descriptor) { + // No need to take IpcOption as it's not used to serialize this kind of message. + this.writeOption = new IpcOption(); this.message = null; this.bufs = ImmutableList.of(); this.descriptor = descriptor; @@ -175,6 +185,11 @@ public ArrowMessage(FlightDescriptor descriptor) { private ArrowMessage(FlightDescriptor descriptor, MessageMetadataResult message, ArrowBuf appMetadata, ArrowBuf buf) { + // No need to take IpcOption as this is used for deserialized ArrowMessage coming from the wire. + this.writeOption = new IpcOption(); + if (message != null) { + this.writeOption.metadataVersion = MetadataVersion.fromFlatbufID(message.getMessage().version()); + } this.message = message; this.descriptor = descriptor; this.appMetadata = appMetadata; @@ -404,7 +419,7 @@ public static Marshaller createMarshaller(BufferAllocator allocato return new ArrowMessageHolderMarshaller(allocator); } - private static class ArrowMessageHolderMarshaller implements MethodDescriptor.Marshaller { + private static class ArrowMessageHolderMarshaller implements Marshaller { private final BufferAllocator allocator; diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java index dd9bc516134..fa15bee4dce 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java @@ -25,6 +25,7 @@ import java.util.function.Consumer; import java.util.stream.Collectors; +import org.apache.arrow.flight.impl.Flight; import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VectorSchemaRoot; @@ -32,6 +33,7 @@ import org.apache.arrow.vector.dictionary.Dictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; +import org.apache.arrow.vector.ipc.message.IpcOption; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.arrow.vector.util.DictionaryUtility; @@ -51,11 +53,13 @@ private DictionaryUtils() { * @throws Exception if there was an error closing {@link ArrowMessage} objects. This is not generally expected. */ static Schema generateSchemaMessages(final Schema originalSchema, final FlightDescriptor descriptor, - final DictionaryProvider provider, final Consumer messageCallback) throws Exception { + final DictionaryProvider provider, final IpcOption option, + final Consumer messageCallback) throws Exception { final Set dictionaryIds = new HashSet<>(); final Schema schema = generateSchema(originalSchema, provider, dictionaryIds); // Send the schema message - try (final ArrowMessage message = new ArrowMessage(descriptor == null ? null : descriptor.toProtocol(), schema)) { + final Flight.FlightDescriptor protoDescriptor = descriptor == null ? null : descriptor.toProtocol(); + try (final ArrowMessage message = new ArrowMessage(protoDescriptor, schema, option)) { messageCallback.accept(message); } // Create and write dictionary batches @@ -71,7 +75,7 @@ static Schema generateSchemaMessages(final Schema originalSchema, final FlightDe final VectorUnloader unloader = new VectorUnloader(dictRoot); try (final ArrowDictionaryBatch dictionaryBatch = new ArrowDictionaryBatch( id, unloader.getRecordBatch()); - final ArrowMessage message = new ArrowMessage(dictionaryBatch)) { + final ArrowMessage message = new ArrowMessage(dictionaryBatch, option)) { messageCallback.accept(message); } } diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java index fe9cfe23ae1..f477fa01c29 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java @@ -201,9 +201,24 @@ public ClientStreamListener startPut(FlightDescriptor descriptor, VectorSchemaRo */ public ClientStreamListener startPut(FlightDescriptor descriptor, VectorSchemaRoot root, DictionaryProvider provider, PutListener metadataListener, CallOption... options) { - Preconditions.checkNotNull(descriptor, "descriptor must not be null"); Preconditions.checkNotNull(root, "root must not be null"); Preconditions.checkNotNull(provider, "provider must not be null"); + final ClientStreamListener writer = startPut(descriptor, metadataListener, options); + writer.start(root, provider); + return writer; + } + + /** + * Create or append a descriptor with another stream. + * @param descriptor FlightDescriptor the descriptor for the data + * @param metadataListener A handler for metadata messages from the server. + * @param options RPC-layer hints for this call. + * @return ClientStreamListener an interface to control uploading data. + * {@link ClientStreamListener#start(VectorSchemaRoot, DictionaryProvider)} will NOT already have been called. + */ + public ClientStreamListener startPut(FlightDescriptor descriptor, PutListener metadataListener, + CallOption... options) { + Preconditions.checkNotNull(descriptor, "descriptor must not be null"); Preconditions.checkNotNull(metadataListener, "metadataListener must not be null"); final io.grpc.CallOptions callOptions = CallOptions.wrapStub(asyncStub, options).getCallOptions(); @@ -212,11 +227,8 @@ public ClientStreamListener startPut(FlightDescriptor descriptor, VectorSchemaRo ClientCallStreamObserver observer = (ClientCallStreamObserver) ClientCalls.asyncBidiStreamingCall( interceptedChannel.newCall(doPutDescriptor, callOptions), resultObserver); - final ClientStreamListener writer = new PutObserver( + return new PutObserver( descriptor, observer, metadataListener::isCancelled, metadataListener::getResult); - // Send the schema to start. - writer.start(root, provider); - return writer; } catch (StatusRuntimeException sre) { throw StatusUtils.fromGrpcRuntimeException(sre); } diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java index c15b0cdc727..7452ba83d18 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java @@ -30,6 +30,7 @@ import org.apache.arrow.flight.impl.Flight; import org.apache.arrow.vector.ipc.ReadChannel; import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.IpcOption; import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.pojo.Schema; @@ -41,11 +42,12 @@ * A POJO representation of a FlightInfo, metadata associated with a set of data records. */ public class FlightInfo { - private Schema schema; - private FlightDescriptor descriptor; - private List endpoints; + private final Schema schema; + private final FlightDescriptor descriptor; + private final List endpoints; private final long bytes; private final long records; + private final IpcOption option; /** * Constructs a new instance. @@ -58,7 +60,21 @@ public class FlightInfo { */ public FlightInfo(Schema schema, FlightDescriptor descriptor, List endpoints, long bytes, long records) { - super(); + this(schema, descriptor, endpoints, bytes, records, new IpcOption()); + } + + /** + * Constructs a new instance. + * + * @param schema The schema of the Flight + * @param descriptor An identifier for the Flight. + * @param endpoints A list of endpoints that have the flight available. + * @param bytes The number of bytes in the flight + * @param records The number of records in the flight. + * @param option IPC write options. + */ + public FlightInfo(Schema schema, FlightDescriptor descriptor, List endpoints, long bytes, + long records, IpcOption option) { Objects.requireNonNull(schema); Objects.requireNonNull(descriptor); Objects.requireNonNull(endpoints); @@ -67,6 +83,7 @@ public FlightInfo(Schema schema, FlightDescriptor descriptor, ListThis method must be called before all others, except {@link #putMetadata(ArrowBuf)}. */ - void start(VectorSchemaRoot root); + default void start(VectorSchemaRoot root) { + start(root, null, new IpcOption()); + } /** * Start sending data, using the schema of the given {@link VectorSchemaRoot}. * - *

This method must be called before all others. + *

This method must be called before all others, except {@link #putMetadata(ArrowBuf)}. + */ + default void start(VectorSchemaRoot root, DictionaryProvider dictionaries) { + start(root, dictionaries, new IpcOption()); + } + + /** + * Start sending data, using the schema of the given {@link VectorSchemaRoot}. + * + *

This method must be called before all others, except {@link #putMetadata(ArrowBuf)}. */ - void start(VectorSchemaRoot root, DictionaryProvider dictionaries); + void start(VectorSchemaRoot root, DictionaryProvider dictionaries, IpcOption option); /** * Send the current contents of the associated {@link VectorSchemaRoot}. diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListenerImpl.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListenerImpl.java index c826c8507f3..b9bd626c130 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListenerImpl.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListenerImpl.java @@ -23,6 +23,7 @@ import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.VectorUnloader; import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.message.IpcOption; import io.grpc.stub.CallStreamObserver; @@ -33,6 +34,7 @@ abstract class OutboundStreamListenerImpl implements OutboundStreamListener { private final FlightDescriptor descriptor; // nullable protected final CallStreamObserver responseObserver; protected volatile VectorUnloader unloader; // null until stream started + protected IpcOption option; // null until stream started OutboundStreamListenerImpl(FlightDescriptor descriptor, CallStreamObserver responseObserver) { Preconditions.checkNotNull(responseObserver, "responseObserver must be provided"); @@ -47,14 +49,11 @@ public boolean isReady() { } @Override - public void start(VectorSchemaRoot root) { - start(root, new DictionaryProvider.MapDictionaryProvider()); - } - - @Override - public void start(VectorSchemaRoot root, DictionaryProvider dictionaries) { + public void start(VectorSchemaRoot root, DictionaryProvider dictionaries, IpcOption option) { + this.option = option; try { - DictionaryUtils.generateSchemaMessages(root.getSchema(), descriptor, dictionaries, responseObserver::onNext); + DictionaryUtils.generateSchemaMessages(root.getSchema(), descriptor, dictionaries, option, + responseObserver::onNext); } catch (Exception e) { // Only happens if closing buffers somehow fails - indicates application is an unknown state so propagate // the exception @@ -86,7 +85,7 @@ public void putNext(ArrowBuf metadata) { // close is a no-op if the message has been written to gRPC, otherwise frees the associated buffers // in some code paths (e.g. if the call is cancelled), gRPC does not write the message, so we need to clean up // ourselves. Normally, writing the ArrowMessage will transfer ownership of the data to gRPC/Netty. - try (final ArrowMessage message = new ArrowMessage(unloader.getRecordBatch(), metadata)) { + try (final ArrowMessage message = new ArrowMessage(unloader.getRecordBatch(), metadata, option)) { responseObserver.onNext(message); } catch (Exception e) { // This exception comes from ArrowMessage#close, not responseObserver#onNext. diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SchemaResult.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SchemaResult.java index 764f4c70f33..0ef3cbb789a 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SchemaResult.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SchemaResult.java @@ -25,6 +25,7 @@ import org.apache.arrow.flight.impl.Flight; import org.apache.arrow.vector.ipc.ReadChannel; import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.IpcOption; import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.pojo.Schema; @@ -40,11 +41,16 @@ public class SchemaResult { private final Schema schema; + private final IpcOption option; public SchemaResult(Schema schema) { - this.schema = schema; + this(schema, new IpcOption()); } + public SchemaResult(Schema schema, IpcOption option) { + this.schema = schema; + this.option = option; + } public Schema getSchema() { return schema; @@ -57,7 +63,7 @@ Flight.SchemaResult toProtocol() { // Encode schema in a Message payload ByteArrayOutputStream baos = new ByteArrayOutputStream(); try { - MessageSerializer.serialize(new WriteChannel(Channels.newChannel(baos)), schema); + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(baos)), schema, option); } catch (IOException e) { throw new RuntimeException(e); } diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestMetadataVersion.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestMetadataVersion.java new file mode 100644 index 00000000000..ad8fda65b82 --- /dev/null +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestMetadataVersion.java @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import java.util.Collections; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.message.IpcOption; +import org.apache.arrow.vector.types.MetadataVersion; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +/** + * Test clients/servers with different metadata versions. + */ +public class TestMetadataVersion { + private static BufferAllocator allocator; + private static Schema schema; + private static IpcOption optionV4; + private static IpcOption optionV5; + + @BeforeClass + public static void setUpClass() { + allocator = new RootAllocator(Integer.MAX_VALUE); + schema = new Schema(Collections.singletonList(Field.nullable("foo", new ArrowType.Int(32, true)))); + optionV4 = new IpcOption(); + optionV4.metadataVersion = MetadataVersion.V4; + optionV5 = new IpcOption(); + } + + @AfterClass + public static void tearDownClass() { + allocator.close(); + } + + @Test + public void testGetFlightInfoV4() throws Exception { + try (final FlightServer server = startServer(optionV4); + final FlightClient client = connect(server)) { + final FlightInfo result = client.getInfo(FlightDescriptor.command(new byte[0])); + assertEquals(schema, result.getSchema()); + } + } + + @Test + public void testGetSchemaV4() throws Exception { + try (final FlightServer server = startServer(optionV4); + final FlightClient client = connect(server)) { + final SchemaResult result = client.getSchema(FlightDescriptor.command(new byte[0])); + assertEquals(schema, result.getSchema()); + } + } + + @Test + public void testPutV4() throws Exception { + try (final FlightServer server = startServer(optionV4); + final FlightClient client = connect(server); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + generateData(root); + final FlightDescriptor descriptor = FlightDescriptor.command(new byte[0]); + final SyncPutListener reader = new SyncPutListener(); + final FlightClient.ClientStreamListener listener = client.startPut(descriptor, reader); + listener.start(root, null, optionV4); + listener.putNext(); + listener.completed(); + listener.getResult(); + } + } + + @Test + public void testGetV4() throws Exception { + try (final FlightServer server = startServer(optionV4); + final FlightClient client = connect(server); + final FlightStream stream = client.getStream(new Ticket(new byte[0]))) { + assertTrue(stream.next()); + assertEquals(optionV4.metadataVersion, stream.metadataVersion); + validateRoot(stream.getRoot()); + assertFalse(stream.next()); + } + } + + @Test + public void testExchangeV4ToV5() throws Exception { + try (final FlightServer server = startServer(optionV5); + final FlightClient client = connect(server); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + final FlightClient.ExchangeReaderWriter stream = client.doExchange(FlightDescriptor.command(new byte[0]))) { + stream.getWriter().start(root, null, optionV4); + generateData(root); + stream.getWriter().putNext(); + stream.getWriter().completed(); + assertTrue(stream.getReader().next()); + assertEquals(optionV5.metadataVersion, stream.getReader().metadataVersion); + validateRoot(stream.getReader().getRoot()); + assertFalse(stream.getReader().next()); + } + } + + @Test + public void testExchangeV5ToV4() throws Exception { + try (final FlightServer server = startServer(optionV4); + final FlightClient client = connect(server); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + final FlightClient.ExchangeReaderWriter stream = client.doExchange(FlightDescriptor.command(new byte[0]))) { + stream.getWriter().start(root, null, optionV5); + generateData(root); + stream.getWriter().putNext(); + stream.getWriter().completed(); + assertTrue(stream.getReader().next()); + assertEquals(optionV4.metadataVersion, stream.getReader().metadataVersion); + validateRoot(stream.getReader().getRoot()); + assertFalse(stream.getReader().next()); + } + } + + @Test + public void testExchangeV4ToV4() throws Exception { + try (final FlightServer server = startServer(optionV4); + final FlightClient client = connect(server); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + final FlightClient.ExchangeReaderWriter stream = client.doExchange(FlightDescriptor.command(new byte[0]))) { + stream.getWriter().start(root, null, optionV4); + generateData(root); + stream.getWriter().putNext(); + stream.getWriter().completed(); + assertTrue(stream.getReader().next()); + assertEquals(optionV4.metadataVersion, stream.getReader().metadataVersion); + validateRoot(stream.getReader().getRoot()); + assertFalse(stream.getReader().next()); + } + } + + private static void generateData(VectorSchemaRoot root) { + assertEquals(schema, root.getSchema()); + final IntVector vector = (IntVector) root.getVector("foo"); + vector.setSafe(0, 0); + vector.setSafe(1, 1); + vector.setSafe(2, 4); + root.setRowCount(3); + } + + private static void validateRoot(VectorSchemaRoot root) { + assertEquals(schema, root.getSchema()); + assertEquals(3, root.getRowCount()); + final IntVector vector = (IntVector) root.getVector("foo"); + assertEquals(0, vector.get(0)); + assertEquals(1, vector.get(1)); + assertEquals(4, vector.get(2)); + } + + FlightServer startServer(IpcOption option) throws Exception { + Location location = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, 0); + VersionFlightProducer producer = new VersionFlightProducer(allocator, option); + final FlightServer server = FlightServer.builder(allocator, location, producer).build(); + server.start(); + return server; + } + + FlightClient connect(FlightServer server) { + Location location = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort()); + return FlightClient.builder(allocator, location).build(); + } + + static final class VersionFlightProducer extends NoOpFlightProducer { + private final BufferAllocator allocator; + private final IpcOption option; + + VersionFlightProducer(BufferAllocator allocator, IpcOption option) { + this.allocator = allocator; + this.option = option; + } + + @Override + public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) { + return new FlightInfo(schema, descriptor, Collections.emptyList(), -1, -1, option); + } + + @Override + public SchemaResult getSchema(CallContext context, FlightDescriptor descriptor) { + return new SchemaResult(schema, option); + } + + @Override + public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) { + try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + listener.start(root, null, option); + generateData(root); + listener.putNext(); + listener.completed(); + } + } + + @Override + public Runnable acceptPut(CallContext context, FlightStream flightStream, StreamListener ackStream) { + return () -> { + try { + assertTrue(flightStream.next()); + assertEquals(option.metadataVersion, flightStream.metadataVersion); + validateRoot(flightStream.getRoot()); + } catch (AssertionError err) { + // gRPC doesn't propagate stack traces across the wire. + err.printStackTrace(); + ackStream.onError(CallStatus.INVALID_ARGUMENT + .withCause(err) + .withDescription("Server assertion failed: " + err) + .toRuntimeException()); + return; + } catch (RuntimeException err) { + err.printStackTrace(); + ackStream.onError(CallStatus.INTERNAL + .withCause(err) + .withDescription("Server assertion failed: " + err) + .toRuntimeException()); + return; + } + ackStream.onCompleted(); + }; + } + + @Override + public void doExchange(CallContext context, FlightStream reader, ServerStreamListener writer) { + try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + try { + assertTrue(reader.next()); + validateRoot(reader.getRoot()); + assertFalse(reader.next()); + } catch (AssertionError err) { + // gRPC doesn't propagate stack traces across the wire. + err.printStackTrace(); + writer.error(CallStatus.INVALID_ARGUMENT + .withCause(err) + .withDescription("Server assertion failed: " + err) + .toRuntimeException()); + return; + } catch (RuntimeException err) { + err.printStackTrace(); + writer.error(CallStatus.INTERNAL + .withCause(err) + .withDescription("Server assertion failed: " + err) + .toRuntimeException()); + return; + } + + writer.start(root, null, option); + generateData(root); + writer.putNext(); + writer.completed(); + } + } + } +} From 238a0e5d3c2878446d664a9be02297e0e69602ef Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 8 Jul 2020 13:47:50 -0400 Subject: [PATCH 3/4] ARROW-9362: [Java] check for union/metadata version match before serializing --- .../apache/arrow/flight/DictionaryUtils.java | 2 + .../org/apache/arrow/flight/FlightInfo.java | 2 + .../flight/OutboundStreamListenerImpl.java | 3 + .../org/apache/arrow/flight/SchemaResult.java | 5 ++ .../arrow/flight/TestMetadataVersion.java | 41 ++++++++++++ .../apache/arrow/vector/ipc/ArrowWriter.java | 2 + .../validate/MetadataV4UnionChecker.java | 66 +++++++++++++++++++ .../arrow/vector/ipc/TestRoundTrip.java | 29 +++++++- 8 files changed, 149 insertions(+), 1 deletion(-) create mode 100644 java/vector/src/main/java/org/apache/arrow/vector/validate/MetadataV4UnionChecker.java diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java index fa15bee4dce..b2256cd037d 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java @@ -37,6 +37,7 @@ import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.arrow.vector.util.DictionaryUtility; +import org.apache.arrow.vector.validate.MetadataV4UnionChecker; /** * Utilities to work with dictionaries in Flight. @@ -57,6 +58,7 @@ static Schema generateSchemaMessages(final Schema originalSchema, final FlightDe final Consumer messageCallback) throws Exception { final Set dictionaryIds = new HashSet<>(); final Schema schema = generateSchema(originalSchema, provider, dictionaryIds); + MetadataV4UnionChecker.checkForUnion(schema.getFields().iterator(), option); // Send the schema message final Flight.FlightDescriptor protoDescriptor = descriptor == null ? null : descriptor.toProtocol(); try (final ArrowMessage message = new ArrowMessage(protoDescriptor, schema, option)) { diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java index 7452ba83d18..e8e4b020e0f 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java @@ -33,6 +33,7 @@ import org.apache.arrow.vector.ipc.message.IpcOption; import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.validate.MetadataV4UnionChecker; import com.fasterxml.jackson.databind.util.ByteBufferBackedInputStream; import com.google.common.collect.ImmutableList; @@ -78,6 +79,7 @@ public FlightInfo(Schema schema, FlightDescriptor descriptor, List new SchemaResult(unionSchema, optionV4)); + assertThrows(IllegalArgumentException.class, () -> + new FlightInfo(unionSchema, FlightDescriptor.command(new byte[0]), Collections.emptyList(), -1, -1, optionV4)); + try (final FlightServer server = startServer(optionV4); + final FlightClient client = connect(server); + final FlightStream stream = client.getStream(new Ticket("union".getBytes(StandardCharsets.UTF_8)))) { + final FlightRuntimeException err = assertThrows(FlightRuntimeException.class, stream::next); + assertTrue(err.getMessage(), err.getMessage().contains("Cannot write union with V4 metadata")); + } + + try (final FlightServer server = startServer(optionV4); + final FlightClient client = connect(server); + final VectorSchemaRoot root = VectorSchemaRoot.create(unionSchema, allocator)) { + final FlightDescriptor descriptor = FlightDescriptor.command(new byte[0]); + final SyncPutListener reader = new SyncPutListener(); + final FlightClient.ClientStreamListener listener = client.startPut(descriptor, reader); + final IllegalArgumentException err = assertThrows(IllegalArgumentException.class, + () -> listener.start(root, null, optionV4)); + assertTrue(err.getMessage(), err.getMessage().contains("Cannot write union with V4 metadata")); + } + } + @Test public void testPutV4() throws Exception { try (final FlightServer server = startServer(optionV4); @@ -208,6 +239,16 @@ public SchemaResult getSchema(CallContext context, FlightDescriptor descriptor) @Override public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) { + if (Arrays.equals("union".getBytes(StandardCharsets.UTF_8), ticket.getBytes())) { + try (final VectorSchemaRoot root = VectorSchemaRoot.create(unionSchema, allocator)) { + listener.start(root, null, option); + } catch (IllegalArgumentException e) { + listener.error(CallStatus.INTERNAL.withCause(e).withDescription(e.getMessage()).toRuntimeException()); + return; + } + listener.error(CallStatus.INTERNAL.withDescription("Expected exception not raised").toRuntimeException()); + return; + } try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { listener.start(root, null, option); generateData(root); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java index b3ee0afa886..8b2e19e9bac 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java @@ -39,6 +39,7 @@ import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.arrow.vector.util.DictionaryUtility; +import org.apache.arrow.vector.validate.MetadataV4UnionChecker; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -83,6 +84,7 @@ protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, Writab List fields = new ArrayList<>(root.getSchema().getFields().size()); Set dictionaryIdsUsed = new HashSet<>(); + MetadataV4UnionChecker.checkForUnion(root.getSchema().getFields().iterator(), option); // Convert fields with dictionaries to have dictionary type for (Field field : root.getSchema().getFields()) { fields.add(DictionaryUtility.toMessageFormat(field, provider, dictionaryIdsUsed)); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/validate/MetadataV4UnionChecker.java b/java/vector/src/main/java/org/apache/arrow/vector/validate/MetadataV4UnionChecker.java new file mode 100644 index 00000000000..330e83d54f7 --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/validate/MetadataV4UnionChecker.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector.validate; + +import java.util.Iterator; + +import org.apache.arrow.vector.ipc.message.IpcOption; +import org.apache.arrow.vector.types.MetadataVersion; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; + +/** + * Given a field, checks that no Union fields are present. + * + * This is intended to be used to prevent unions from being read/written with V4 metadata. + */ +public final class MetadataV4UnionChecker { + static boolean isUnion(Field field) { + return field.getType().getTypeID() == ArrowType.ArrowTypeID.Union; + } + + static Field check(Field field) { + if (isUnion(field)) { + return field; + } + // Naive recursive DFS + for (final Field child : field.getChildren()) { + final Field result = check(child); + if (result != null) { + return result; + } + } + return null; + } + + /** + * Check the schema, raising an error if an unsupported feature is used (e.g. unions with < V5 metadata). + */ + public static void checkForUnion(Iterator fields, IpcOption option) { + if (option.metadataVersion.toFlatbufID() >= MetadataVersion.V5.toFlatbufID()) { + return; + } + while (fields.hasNext()) { + Field union = check(fields.next()); + if (union != null) { + throw new IllegalArgumentException( + "Cannot write union with V4 metadata version, use V5 instead. Found field: " + union); + } + } + } +} diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestRoundTrip.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestRoundTrip.java index 3baf949f8b0..2aeefff3c4a 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestRoundTrip.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestRoundTrip.java @@ -23,6 +23,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertThrows; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -173,7 +174,33 @@ public void testMultipleRecordBatches() throws Exception { } @Test - public void testUnion() throws Exception { + public void testUnionV4() throws Exception { + Assume.assumeTrue(writeOption.metadataVersion == MetadataVersion.V4); + final File temp = File.createTempFile("arrow-test-" + name + "-", ".arrow"); + temp.deleteOnExit(); + final ByteArrayOutputStream memoryStream = new ByteArrayOutputStream(); + + try (final BufferAllocator originalVectorAllocator = + allocator.newChildAllocator("original vectors", 0, allocator.getLimit()); + final StructVector parent = StructVector.empty("parent", originalVectorAllocator)) { + writeUnionData(COUNT, parent); + final VectorSchemaRoot root = new VectorSchemaRoot(parent.getChild("root")); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> { + try (final FileOutputStream fileStream = new FileOutputStream(temp)) { + new ArrowFileWriter(root, null, fileStream.getChannel(), writeOption); + new ArrowStreamWriter(root, null, Channels.newChannel(memoryStream), writeOption); + } + }); + assertTrue(e.getMessage(), e.getMessage().contains("Cannot write union with V4 metadata")); + e = assertThrows(IllegalArgumentException.class, () -> { + new ArrowStreamWriter(root, null, Channels.newChannel(memoryStream), writeOption); + }); + assertTrue(e.getMessage(), e.getMessage().contains("Cannot write union with V4 metadata")); + } + } + + @Test + public void testUnionV5() throws Exception { Assume.assumeTrue(writeOption.metadataVersion == MetadataVersion.V5); try (final BufferAllocator originalVectorAllocator = allocator.newChildAllocator("original vectors", 0, allocator.getLimit()); From 59a46f10c6f65533ee4fc21c143020fa703f8e07 Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 8 Jul 2020 14:23:12 -0400 Subject: [PATCH 4/4] ARROW-9362: [Java] ensure FileWriter uses the correct metadata version --- .../apache/arrow/flight/DictionaryUtils.java | 2 +- .../org/apache/arrow/flight/FlightInfo.java | 2 +- .../org/apache/arrow/flight/FlightStream.java | 16 ++++++++-- .../org/apache/arrow/flight/SchemaResult.java | 2 +- .../arrow/vector/ipc/ArrowFileReader.java | 8 +++++ .../arrow/vector/ipc/ArrowFileWriter.java | 2 +- .../arrow/vector/ipc/ArrowStreamReader.java | 6 +++- .../apache/arrow/vector/ipc/ArrowWriter.java | 2 +- .../arrow/vector/ipc/message/ArrowFooter.java | 29 ++++++++++++++++++- .../arrow/vector/ipc/message/IpcOption.java | 4 +-- .../arrow/vector/types/MetadataVersion.java | 2 ++ .../validate/MetadataV4UnionChecker.java | 22 ++++++++++++-- .../arrow/vector/ipc/TestRoundTrip.java | 1 + 13 files changed, 83 insertions(+), 15 deletions(-) diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java index b2256cd037d..516dab01d8a 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java @@ -58,7 +58,7 @@ static Schema generateSchemaMessages(final Schema originalSchema, final FlightDe final Consumer messageCallback) throws Exception { final Set dictionaryIds = new HashSet<>(); final Schema schema = generateSchema(originalSchema, provider, dictionaryIds); - MetadataV4UnionChecker.checkForUnion(schema.getFields().iterator(), option); + MetadataV4UnionChecker.checkForUnion(schema.getFields().iterator(), option.metadataVersion); // Send the schema message final Flight.FlightDescriptor protoDescriptor = descriptor == null ? null : descriptor.toProtocol(); try (final ArrowMessage message = new ArrowMessage(protoDescriptor, schema, option)) { diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java index e8e4b020e0f..8eb456b0cc4 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java @@ -79,7 +79,7 @@ public FlightInfo(Schema schema, FlightDescriptor descriptor, List fields = new ArrayList<>(root.getSchema().getFields().size()); Set dictionaryIdsUsed = new HashSet<>(); - MetadataV4UnionChecker.checkForUnion(root.getSchema().getFields().iterator(), option); + MetadataV4UnionChecker.checkForUnion(root.getSchema().getFields().iterator(), option.metadataVersion); // Convert fields with dictionaries to have dictionary type for (Field field : root.getSchema().getFields()) { fields.add(DictionaryUtility.toMessageFormat(field, provider, dictionaryIdsUsed)); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/ArrowFooter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/ArrowFooter.java index 77d3b1e98ff..567fabc1d43 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/ArrowFooter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/ArrowFooter.java @@ -28,6 +28,7 @@ import org.apache.arrow.flatbuf.Block; import org.apache.arrow.flatbuf.Footer; import org.apache.arrow.flatbuf.KeyValue; +import org.apache.arrow.vector.types.MetadataVersion; import org.apache.arrow.vector.types.pojo.Schema; import com.google.flatbuffers.FlatBufferBuilder; @@ -43,6 +44,8 @@ public class ArrowFooter implements FBSerializable { private final Map metaData; + private final MetadataVersion metadataVersion; + public ArrowFooter(Schema schema, List dictionaries, List recordBatches) { this(schema, dictionaries, recordBatches, null); } @@ -60,11 +63,29 @@ public ArrowFooter( List dictionaries, List recordBatches, Map metaData) { + this(schema, dictionaries, recordBatches, metaData, MetadataVersion.DEFAULT); + } + /** + * Constructs a new instance. + * + * @param schema The schema for record batches in the file. + * @param dictionaries The dictionaries relevant to the file. + * @param recordBatches The recordBatches written to the file. + * @param metaData user-defined k-v meta data. + * @param metadataVersion The Arrow metadata version. + */ + public ArrowFooter( + Schema schema, + List dictionaries, + List recordBatches, + Map metaData, + MetadataVersion metadataVersion) { this.schema = schema; this.dictionaries = dictionaries; this.recordBatches = recordBatches; this.metaData = metaData; + this.metadataVersion = metadataVersion; } /** @@ -75,7 +96,8 @@ public ArrowFooter(Footer footer) { Schema.convertSchema(footer.schema()), dictionaries(footer), recordBatches(footer), - metaData(footer) + metaData(footer), + MetadataVersion.fromFlatbufID(footer.version()) ); } @@ -130,6 +152,10 @@ public Map getMetaData() { return metaData; } + public MetadataVersion getMetadataVersion() { + return metadataVersion; + } + @Override public int writeTo(FlatBufferBuilder builder) { int schemaIndex = schema.getSchema(builder); @@ -148,6 +174,7 @@ public int writeTo(FlatBufferBuilder builder) { Footer.addDictionaries(builder, dicsOffset); Footer.addRecordBatches(builder, rbsOffset); Footer.addCustomMetadata(builder, metaDataOffset); + Footer.addVersion(builder, metadataVersion.toFlatbufID()); return Footer.endFooter(builder); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/IpcOption.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/IpcOption.java index c1a93dcdd63..b93c3b3da2f 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/IpcOption.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/IpcOption.java @@ -28,6 +28,6 @@ public class IpcOption { // consisting of a 4-byte prefix instead of 8 byte public boolean write_legacy_ipc_format = false; - // The metadata version. Defaults to V4. - public MetadataVersion metadataVersion = MetadataVersion.V5; + // The metadata version. Defaults to V5. + public MetadataVersion metadataVersion = MetadataVersion.DEFAULT; } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/MetadataVersion.java b/java/vector/src/main/java/org/apache/arrow/vector/types/MetadataVersion.java index 9e1894052d0..a0e281960f1 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/types/MetadataVersion.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/types/MetadataVersion.java @@ -38,6 +38,8 @@ public enum MetadataVersion { ; + public static final MetadataVersion DEFAULT = V5; + private static final MetadataVersion[] valuesByFlatbufId = new MetadataVersion[MetadataVersion.values().length]; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/validate/MetadataV4UnionChecker.java b/java/vector/src/main/java/org/apache/arrow/vector/validate/MetadataV4UnionChecker.java index 330e83d54f7..2a706836567 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/validate/MetadataV4UnionChecker.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/validate/MetadataV4UnionChecker.java @@ -17,12 +17,13 @@ package org.apache.arrow.vector.validate; +import java.io.IOException; import java.util.Iterator; -import org.apache.arrow.vector.ipc.message.IpcOption; import org.apache.arrow.vector.types.MetadataVersion; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; /** * Given a field, checks that no Union fields are present. @@ -51,8 +52,8 @@ static Field check(Field field) { /** * Check the schema, raising an error if an unsupported feature is used (e.g. unions with < V5 metadata). */ - public static void checkForUnion(Iterator fields, IpcOption option) { - if (option.metadataVersion.toFlatbufID() >= MetadataVersion.V5.toFlatbufID()) { + public static void checkForUnion(Iterator fields, MetadataVersion metadataVersion) { + if (metadataVersion.toFlatbufID() >= MetadataVersion.V5.toFlatbufID()) { return; } while (fields.hasNext()) { @@ -63,4 +64,19 @@ public static void checkForUnion(Iterator fields, IpcOption option) { } } } + + /** + * Check the schema, raising an error if an unsupported feature is used (e.g. unions with < V5 metadata). + */ + public static void checkRead(Schema schema, MetadataVersion metadataVersion) throws IOException { + if (metadataVersion.toFlatbufID() >= MetadataVersion.V5.toFlatbufID()) { + return; + } + for (final Field field : schema.getFields()) { + Field union = check(field); + if (union != null) { + throw new IOException("Cannot read union with V4 metadata version. Found field: " + union); + } + } + } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestRoundTrip.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestRoundTrip.java index 2aeefff3c4a..971008e5cb5 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestRoundTrip.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestRoundTrip.java @@ -624,6 +624,7 @@ private void roundTrip(VectorSchemaRoot root, DictionaryProvider provider, ArrowStreamReader streamReader = new ArrowStreamReader(inputStream, readerAllocator)) { fileValidator.accept(fileReader); streamValidator.accept(streamReader); + assertEquals(writeOption.metadataVersion, fileReader.getFooter().getMetadataVersion()); assertEquals(metadata, fileReader.getMetaData()); } }