This method must be called before all others, except {@link #putMetadata(ArrowBuf)}.
*/
- void start(VectorSchemaRoot root);
+ default void start(VectorSchemaRoot root) {
+ start(root, null, new IpcOption());
+ }
/**
* Start sending data, using the schema of the given {@link VectorSchemaRoot}.
*
- * This method must be called before all others.
+ *
This method must be called before all others, except {@link #putMetadata(ArrowBuf)}.
+ */
+ default void start(VectorSchemaRoot root, DictionaryProvider dictionaries) {
+ start(root, dictionaries, new IpcOption());
+ }
+
+ /**
+ * Start sending data, using the schema of the given {@link VectorSchemaRoot}.
+ *
+ *
This method must be called before all others, except {@link #putMetadata(ArrowBuf)}.
*/
- void start(VectorSchemaRoot root, DictionaryProvider dictionaries);
+ void start(VectorSchemaRoot root, DictionaryProvider dictionaries, IpcOption option);
/**
* Send the current contents of the associated {@link VectorSchemaRoot}.
diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListenerImpl.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListenerImpl.java
index c826c8507f3..b9bd626c130 100644
--- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListenerImpl.java
+++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListenerImpl.java
@@ -23,6 +23,7 @@
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
+import org.apache.arrow.vector.ipc.message.IpcOption;
import io.grpc.stub.CallStreamObserver;
@@ -33,6 +34,7 @@ abstract class OutboundStreamListenerImpl implements OutboundStreamListener {
private final FlightDescriptor descriptor; // nullable
protected final CallStreamObserver responseObserver;
protected volatile VectorUnloader unloader; // null until stream started
+ protected IpcOption option; // null until stream started
OutboundStreamListenerImpl(FlightDescriptor descriptor, CallStreamObserver responseObserver) {
Preconditions.checkNotNull(responseObserver, "responseObserver must be provided");
@@ -47,14 +49,11 @@ public boolean isReady() {
}
@Override
- public void start(VectorSchemaRoot root) {
- start(root, new DictionaryProvider.MapDictionaryProvider());
- }
-
- @Override
- public void start(VectorSchemaRoot root, DictionaryProvider dictionaries) {
+ public void start(VectorSchemaRoot root, DictionaryProvider dictionaries, IpcOption option) {
+ this.option = option;
try {
- DictionaryUtils.generateSchemaMessages(root.getSchema(), descriptor, dictionaries, responseObserver::onNext);
+ DictionaryUtils.generateSchemaMessages(root.getSchema(), descriptor, dictionaries, option,
+ responseObserver::onNext);
} catch (Exception e) {
// Only happens if closing buffers somehow fails - indicates application is an unknown state so propagate
// the exception
@@ -86,7 +85,7 @@ public void putNext(ArrowBuf metadata) {
// close is a no-op if the message has been written to gRPC, otherwise frees the associated buffers
// in some code paths (e.g. if the call is cancelled), gRPC does not write the message, so we need to clean up
// ourselves. Normally, writing the ArrowMessage will transfer ownership of the data to gRPC/Netty.
- try (final ArrowMessage message = new ArrowMessage(unloader.getRecordBatch(), metadata)) {
+ try (final ArrowMessage message = new ArrowMessage(unloader.getRecordBatch(), metadata, option)) {
responseObserver.onNext(message);
} catch (Exception e) {
// This exception comes from ArrowMessage#close, not responseObserver#onNext.
diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SchemaResult.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SchemaResult.java
index 764f4c70f33..0ef3cbb789a 100644
--- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SchemaResult.java
+++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SchemaResult.java
@@ -25,6 +25,7 @@
import org.apache.arrow.flight.impl.Flight;
import org.apache.arrow.vector.ipc.ReadChannel;
import org.apache.arrow.vector.ipc.WriteChannel;
+import org.apache.arrow.vector.ipc.message.IpcOption;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.pojo.Schema;
@@ -40,11 +41,16 @@
public class SchemaResult {
private final Schema schema;
+ private final IpcOption option;
public SchemaResult(Schema schema) {
- this.schema = schema;
+ this(schema, new IpcOption());
}
+ public SchemaResult(Schema schema, IpcOption option) {
+ this.schema = schema;
+ this.option = option;
+ }
public Schema getSchema() {
return schema;
@@ -57,7 +63,7 @@ Flight.SchemaResult toProtocol() {
// Encode schema in a Message payload
ByteArrayOutputStream baos = new ByteArrayOutputStream();
try {
- MessageSerializer.serialize(new WriteChannel(Channels.newChannel(baos)), schema);
+ MessageSerializer.serialize(new WriteChannel(Channels.newChannel(baos)), schema, option);
} catch (IOException e) {
throw new RuntimeException(e);
}
diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestMetadataVersion.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestMetadataVersion.java
new file mode 100644
index 00000000000..ad8fda65b82
--- /dev/null
+++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestMetadataVersion.java
@@ -0,0 +1,277 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+import java.util.Collections;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.ipc.message.IpcOption;
+import org.apache.arrow.vector.types.MetadataVersion;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+/**
+ * Test clients/servers with different metadata versions.
+ */
+public class TestMetadataVersion {
+ private static BufferAllocator allocator;
+ private static Schema schema;
+ private static IpcOption optionV4;
+ private static IpcOption optionV5;
+
+ @BeforeClass
+ public static void setUpClass() {
+ allocator = new RootAllocator(Integer.MAX_VALUE);
+ schema = new Schema(Collections.singletonList(Field.nullable("foo", new ArrowType.Int(32, true))));
+ optionV4 = new IpcOption();
+ optionV4.metadataVersion = MetadataVersion.V4;
+ optionV5 = new IpcOption();
+ }
+
+ @AfterClass
+ public static void tearDownClass() {
+ allocator.close();
+ }
+
+ @Test
+ public void testGetFlightInfoV4() throws Exception {
+ try (final FlightServer server = startServer(optionV4);
+ final FlightClient client = connect(server)) {
+ final FlightInfo result = client.getInfo(FlightDescriptor.command(new byte[0]));
+ assertEquals(schema, result.getSchema());
+ }
+ }
+
+ @Test
+ public void testGetSchemaV4() throws Exception {
+ try (final FlightServer server = startServer(optionV4);
+ final FlightClient client = connect(server)) {
+ final SchemaResult result = client.getSchema(FlightDescriptor.command(new byte[0]));
+ assertEquals(schema, result.getSchema());
+ }
+ }
+
+ @Test
+ public void testPutV4() throws Exception {
+ try (final FlightServer server = startServer(optionV4);
+ final FlightClient client = connect(server);
+ final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ generateData(root);
+ final FlightDescriptor descriptor = FlightDescriptor.command(new byte[0]);
+ final SyncPutListener reader = new SyncPutListener();
+ final FlightClient.ClientStreamListener listener = client.startPut(descriptor, reader);
+ listener.start(root, null, optionV4);
+ listener.putNext();
+ listener.completed();
+ listener.getResult();
+ }
+ }
+
+ @Test
+ public void testGetV4() throws Exception {
+ try (final FlightServer server = startServer(optionV4);
+ final FlightClient client = connect(server);
+ final FlightStream stream = client.getStream(new Ticket(new byte[0]))) {
+ assertTrue(stream.next());
+ assertEquals(optionV4.metadataVersion, stream.metadataVersion);
+ validateRoot(stream.getRoot());
+ assertFalse(stream.next());
+ }
+ }
+
+ @Test
+ public void testExchangeV4ToV5() throws Exception {
+ try (final FlightServer server = startServer(optionV5);
+ final FlightClient client = connect(server);
+ final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator);
+ final FlightClient.ExchangeReaderWriter stream = client.doExchange(FlightDescriptor.command(new byte[0]))) {
+ stream.getWriter().start(root, null, optionV4);
+ generateData(root);
+ stream.getWriter().putNext();
+ stream.getWriter().completed();
+ assertTrue(stream.getReader().next());
+ assertEquals(optionV5.metadataVersion, stream.getReader().metadataVersion);
+ validateRoot(stream.getReader().getRoot());
+ assertFalse(stream.getReader().next());
+ }
+ }
+
+ @Test
+ public void testExchangeV5ToV4() throws Exception {
+ try (final FlightServer server = startServer(optionV4);
+ final FlightClient client = connect(server);
+ final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator);
+ final FlightClient.ExchangeReaderWriter stream = client.doExchange(FlightDescriptor.command(new byte[0]))) {
+ stream.getWriter().start(root, null, optionV5);
+ generateData(root);
+ stream.getWriter().putNext();
+ stream.getWriter().completed();
+ assertTrue(stream.getReader().next());
+ assertEquals(optionV4.metadataVersion, stream.getReader().metadataVersion);
+ validateRoot(stream.getReader().getRoot());
+ assertFalse(stream.getReader().next());
+ }
+ }
+
+ @Test
+ public void testExchangeV4ToV4() throws Exception {
+ try (final FlightServer server = startServer(optionV4);
+ final FlightClient client = connect(server);
+ final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator);
+ final FlightClient.ExchangeReaderWriter stream = client.doExchange(FlightDescriptor.command(new byte[0]))) {
+ stream.getWriter().start(root, null, optionV4);
+ generateData(root);
+ stream.getWriter().putNext();
+ stream.getWriter().completed();
+ assertTrue(stream.getReader().next());
+ assertEquals(optionV4.metadataVersion, stream.getReader().metadataVersion);
+ validateRoot(stream.getReader().getRoot());
+ assertFalse(stream.getReader().next());
+ }
+ }
+
+ private static void generateData(VectorSchemaRoot root) {
+ assertEquals(schema, root.getSchema());
+ final IntVector vector = (IntVector) root.getVector("foo");
+ vector.setSafe(0, 0);
+ vector.setSafe(1, 1);
+ vector.setSafe(2, 4);
+ root.setRowCount(3);
+ }
+
+ private static void validateRoot(VectorSchemaRoot root) {
+ assertEquals(schema, root.getSchema());
+ assertEquals(3, root.getRowCount());
+ final IntVector vector = (IntVector) root.getVector("foo");
+ assertEquals(0, vector.get(0));
+ assertEquals(1, vector.get(1));
+ assertEquals(4, vector.get(2));
+ }
+
+ FlightServer startServer(IpcOption option) throws Exception {
+ Location location = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, 0);
+ VersionFlightProducer producer = new VersionFlightProducer(allocator, option);
+ final FlightServer server = FlightServer.builder(allocator, location, producer).build();
+ server.start();
+ return server;
+ }
+
+ FlightClient connect(FlightServer server) {
+ Location location = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort());
+ return FlightClient.builder(allocator, location).build();
+ }
+
+ static final class VersionFlightProducer extends NoOpFlightProducer {
+ private final BufferAllocator allocator;
+ private final IpcOption option;
+
+ VersionFlightProducer(BufferAllocator allocator, IpcOption option) {
+ this.allocator = allocator;
+ this.option = option;
+ }
+
+ @Override
+ public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) {
+ return new FlightInfo(schema, descriptor, Collections.emptyList(), -1, -1, option);
+ }
+
+ @Override
+ public SchemaResult getSchema(CallContext context, FlightDescriptor descriptor) {
+ return new SchemaResult(schema, option);
+ }
+
+ @Override
+ public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
+ try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ listener.start(root, null, option);
+ generateData(root);
+ listener.putNext();
+ listener.completed();
+ }
+ }
+
+ @Override
+ public Runnable acceptPut(CallContext context, FlightStream flightStream, StreamListener ackStream) {
+ return () -> {
+ try {
+ assertTrue(flightStream.next());
+ assertEquals(option.metadataVersion, flightStream.metadataVersion);
+ validateRoot(flightStream.getRoot());
+ } catch (AssertionError err) {
+ // gRPC doesn't propagate stack traces across the wire.
+ err.printStackTrace();
+ ackStream.onError(CallStatus.INVALID_ARGUMENT
+ .withCause(err)
+ .withDescription("Server assertion failed: " + err)
+ .toRuntimeException());
+ return;
+ } catch (RuntimeException err) {
+ err.printStackTrace();
+ ackStream.onError(CallStatus.INTERNAL
+ .withCause(err)
+ .withDescription("Server assertion failed: " + err)
+ .toRuntimeException());
+ return;
+ }
+ ackStream.onCompleted();
+ };
+ }
+
+ @Override
+ public void doExchange(CallContext context, FlightStream reader, ServerStreamListener writer) {
+ try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ try {
+ assertTrue(reader.next());
+ validateRoot(reader.getRoot());
+ assertFalse(reader.next());
+ } catch (AssertionError err) {
+ // gRPC doesn't propagate stack traces across the wire.
+ err.printStackTrace();
+ writer.error(CallStatus.INVALID_ARGUMENT
+ .withCause(err)
+ .withDescription("Server assertion failed: " + err)
+ .toRuntimeException());
+ return;
+ } catch (RuntimeException err) {
+ err.printStackTrace();
+ writer.error(CallStatus.INTERNAL
+ .withCause(err)
+ .withDescription("Server assertion failed: " + err)
+ .toRuntimeException());
+ return;
+ }
+
+ writer.start(root, null, option);
+ generateData(root);
+ writer.putNext();
+ writer.completed();
+ }
+ }
+ }
+}
From 238a0e5d3c2878446d664a9be02297e0e69602ef Mon Sep 17 00:00:00 2001
From: David Li
Date: Wed, 8 Jul 2020 13:47:50 -0400
Subject: [PATCH 3/4] ARROW-9362: [Java] check for union/metadata version match
before serializing
---
.../apache/arrow/flight/DictionaryUtils.java | 2 +
.../org/apache/arrow/flight/FlightInfo.java | 2 +
.../flight/OutboundStreamListenerImpl.java | 3 +
.../org/apache/arrow/flight/SchemaResult.java | 5 ++
.../arrow/flight/TestMetadataVersion.java | 41 ++++++++++++
.../apache/arrow/vector/ipc/ArrowWriter.java | 2 +
.../validate/MetadataV4UnionChecker.java | 66 +++++++++++++++++++
.../arrow/vector/ipc/TestRoundTrip.java | 29 +++++++-
8 files changed, 149 insertions(+), 1 deletion(-)
create mode 100644 java/vector/src/main/java/org/apache/arrow/vector/validate/MetadataV4UnionChecker.java
diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java
index fa15bee4dce..b2256cd037d 100644
--- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java
+++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java
@@ -37,6 +37,7 @@
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.DictionaryUtility;
+import org.apache.arrow.vector.validate.MetadataV4UnionChecker;
/**
* Utilities to work with dictionaries in Flight.
@@ -57,6 +58,7 @@ static Schema generateSchemaMessages(final Schema originalSchema, final FlightDe
final Consumer messageCallback) throws Exception {
final Set dictionaryIds = new HashSet<>();
final Schema schema = generateSchema(originalSchema, provider, dictionaryIds);
+ MetadataV4UnionChecker.checkForUnion(schema.getFields().iterator(), option);
// Send the schema message
final Flight.FlightDescriptor protoDescriptor = descriptor == null ? null : descriptor.toProtocol();
try (final ArrowMessage message = new ArrowMessage(protoDescriptor, schema, option)) {
diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java
index 7452ba83d18..e8e4b020e0f 100644
--- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java
+++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java
@@ -33,6 +33,7 @@
import org.apache.arrow.vector.ipc.message.IpcOption;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.pojo.Schema;
+import org.apache.arrow.vector.validate.MetadataV4UnionChecker;
import com.fasterxml.jackson.databind.util.ByteBufferBackedInputStream;
import com.google.common.collect.ImmutableList;
@@ -78,6 +79,7 @@ public FlightInfo(Schema schema, FlightDescriptor descriptor, List new SchemaResult(unionSchema, optionV4));
+ assertThrows(IllegalArgumentException.class, () ->
+ new FlightInfo(unionSchema, FlightDescriptor.command(new byte[0]), Collections.emptyList(), -1, -1, optionV4));
+ try (final FlightServer server = startServer(optionV4);
+ final FlightClient client = connect(server);
+ final FlightStream stream = client.getStream(new Ticket("union".getBytes(StandardCharsets.UTF_8)))) {
+ final FlightRuntimeException err = assertThrows(FlightRuntimeException.class, stream::next);
+ assertTrue(err.getMessage(), err.getMessage().contains("Cannot write union with V4 metadata"));
+ }
+
+ try (final FlightServer server = startServer(optionV4);
+ final FlightClient client = connect(server);
+ final VectorSchemaRoot root = VectorSchemaRoot.create(unionSchema, allocator)) {
+ final FlightDescriptor descriptor = FlightDescriptor.command(new byte[0]);
+ final SyncPutListener reader = new SyncPutListener();
+ final FlightClient.ClientStreamListener listener = client.startPut(descriptor, reader);
+ final IllegalArgumentException err = assertThrows(IllegalArgumentException.class,
+ () -> listener.start(root, null, optionV4));
+ assertTrue(err.getMessage(), err.getMessage().contains("Cannot write union with V4 metadata"));
+ }
+ }
+
@Test
public void testPutV4() throws Exception {
try (final FlightServer server = startServer(optionV4);
@@ -208,6 +239,16 @@ public SchemaResult getSchema(CallContext context, FlightDescriptor descriptor)
@Override
public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
+ if (Arrays.equals("union".getBytes(StandardCharsets.UTF_8), ticket.getBytes())) {
+ try (final VectorSchemaRoot root = VectorSchemaRoot.create(unionSchema, allocator)) {
+ listener.start(root, null, option);
+ } catch (IllegalArgumentException e) {
+ listener.error(CallStatus.INTERNAL.withCause(e).withDescription(e.getMessage()).toRuntimeException());
+ return;
+ }
+ listener.error(CallStatus.INTERNAL.withDescription("Expected exception not raised").toRuntimeException());
+ return;
+ }
try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
listener.start(root, null, option);
generateData(root);
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java
index b3ee0afa886..8b2e19e9bac 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java
@@ -39,6 +39,7 @@
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.DictionaryUtility;
+import org.apache.arrow.vector.validate.MetadataV4UnionChecker;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -83,6 +84,7 @@ protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, Writab
List fields = new ArrayList<>(root.getSchema().getFields().size());
Set dictionaryIdsUsed = new HashSet<>();
+ MetadataV4UnionChecker.checkForUnion(root.getSchema().getFields().iterator(), option);
// Convert fields with dictionaries to have dictionary type
for (Field field : root.getSchema().getFields()) {
fields.add(DictionaryUtility.toMessageFormat(field, provider, dictionaryIdsUsed));
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/validate/MetadataV4UnionChecker.java b/java/vector/src/main/java/org/apache/arrow/vector/validate/MetadataV4UnionChecker.java
new file mode 100644
index 00000000000..330e83d54f7
--- /dev/null
+++ b/java/vector/src/main/java/org/apache/arrow/vector/validate/MetadataV4UnionChecker.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.vector.validate;
+
+import java.util.Iterator;
+
+import org.apache.arrow.vector.ipc.message.IpcOption;
+import org.apache.arrow.vector.types.MetadataVersion;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+
+/**
+ * Given a field, checks that no Union fields are present.
+ *
+ * This is intended to be used to prevent unions from being read/written with V4 metadata.
+ */
+public final class MetadataV4UnionChecker {
+ static boolean isUnion(Field field) {
+ return field.getType().getTypeID() == ArrowType.ArrowTypeID.Union;
+ }
+
+ static Field check(Field field) {
+ if (isUnion(field)) {
+ return field;
+ }
+ // Naive recursive DFS
+ for (final Field child : field.getChildren()) {
+ final Field result = check(child);
+ if (result != null) {
+ return result;
+ }
+ }
+ return null;
+ }
+
+ /**
+ * Check the schema, raising an error if an unsupported feature is used (e.g. unions with < V5 metadata).
+ */
+ public static void checkForUnion(Iterator fields, IpcOption option) {
+ if (option.metadataVersion.toFlatbufID() >= MetadataVersion.V5.toFlatbufID()) {
+ return;
+ }
+ while (fields.hasNext()) {
+ Field union = check(fields.next());
+ if (union != null) {
+ throw new IllegalArgumentException(
+ "Cannot write union with V4 metadata version, use V5 instead. Found field: " + union);
+ }
+ }
+ }
+}
diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestRoundTrip.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestRoundTrip.java
index 3baf949f8b0..2aeefff3c4a 100644
--- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestRoundTrip.java
+++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestRoundTrip.java
@@ -23,6 +23,7 @@
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
+import static org.junit.jupiter.api.Assertions.assertThrows;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
@@ -173,7 +174,33 @@ public void testMultipleRecordBatches() throws Exception {
}
@Test
- public void testUnion() throws Exception {
+ public void testUnionV4() throws Exception {
+ Assume.assumeTrue(writeOption.metadataVersion == MetadataVersion.V4);
+ final File temp = File.createTempFile("arrow-test-" + name + "-", ".arrow");
+ temp.deleteOnExit();
+ final ByteArrayOutputStream memoryStream = new ByteArrayOutputStream();
+
+ try (final BufferAllocator originalVectorAllocator =
+ allocator.newChildAllocator("original vectors", 0, allocator.getLimit());
+ final StructVector parent = StructVector.empty("parent", originalVectorAllocator)) {
+ writeUnionData(COUNT, parent);
+ final VectorSchemaRoot root = new VectorSchemaRoot(parent.getChild("root"));
+ IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> {
+ try (final FileOutputStream fileStream = new FileOutputStream(temp)) {
+ new ArrowFileWriter(root, null, fileStream.getChannel(), writeOption);
+ new ArrowStreamWriter(root, null, Channels.newChannel(memoryStream), writeOption);
+ }
+ });
+ assertTrue(e.getMessage(), e.getMessage().contains("Cannot write union with V4 metadata"));
+ e = assertThrows(IllegalArgumentException.class, () -> {
+ new ArrowStreamWriter(root, null, Channels.newChannel(memoryStream), writeOption);
+ });
+ assertTrue(e.getMessage(), e.getMessage().contains("Cannot write union with V4 metadata"));
+ }
+ }
+
+ @Test
+ public void testUnionV5() throws Exception {
Assume.assumeTrue(writeOption.metadataVersion == MetadataVersion.V5);
try (final BufferAllocator originalVectorAllocator =
allocator.newChildAllocator("original vectors", 0, allocator.getLimit());
From 59a46f10c6f65533ee4fc21c143020fa703f8e07 Mon Sep 17 00:00:00 2001
From: David Li
Date: Wed, 8 Jul 2020 14:23:12 -0400
Subject: [PATCH 4/4] ARROW-9362: [Java] ensure FileWriter uses the correct
metadata version
---
.../apache/arrow/flight/DictionaryUtils.java | 2 +-
.../org/apache/arrow/flight/FlightInfo.java | 2 +-
.../org/apache/arrow/flight/FlightStream.java | 16 ++++++++--
.../org/apache/arrow/flight/SchemaResult.java | 2 +-
.../arrow/vector/ipc/ArrowFileReader.java | 8 +++++
.../arrow/vector/ipc/ArrowFileWriter.java | 2 +-
.../arrow/vector/ipc/ArrowStreamReader.java | 6 +++-
.../apache/arrow/vector/ipc/ArrowWriter.java | 2 +-
.../arrow/vector/ipc/message/ArrowFooter.java | 29 ++++++++++++++++++-
.../arrow/vector/ipc/message/IpcOption.java | 4 +--
.../arrow/vector/types/MetadataVersion.java | 2 ++
.../validate/MetadataV4UnionChecker.java | 22 ++++++++++++--
.../arrow/vector/ipc/TestRoundTrip.java | 1 +
13 files changed, 83 insertions(+), 15 deletions(-)
diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java
index b2256cd037d..516dab01d8a 100644
--- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java
+++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java
@@ -58,7 +58,7 @@ static Schema generateSchemaMessages(final Schema originalSchema, final FlightDe
final Consumer messageCallback) throws Exception {
final Set dictionaryIds = new HashSet<>();
final Schema schema = generateSchema(originalSchema, provider, dictionaryIds);
- MetadataV4UnionChecker.checkForUnion(schema.getFields().iterator(), option);
+ MetadataV4UnionChecker.checkForUnion(schema.getFields().iterator(), option.metadataVersion);
// Send the schema message
final Flight.FlightDescriptor protoDescriptor = descriptor == null ? null : descriptor.toProtocol();
try (final ArrowMessage message = new ArrowMessage(protoDescriptor, schema, option)) {
diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java
index e8e4b020e0f..8eb456b0cc4 100644
--- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java
+++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java
@@ -79,7 +79,7 @@ public FlightInfo(Schema schema, FlightDescriptor descriptor, List fields = new ArrayList<>(root.getSchema().getFields().size());
Set dictionaryIdsUsed = new HashSet<>();
- MetadataV4UnionChecker.checkForUnion(root.getSchema().getFields().iterator(), option);
+ MetadataV4UnionChecker.checkForUnion(root.getSchema().getFields().iterator(), option.metadataVersion);
// Convert fields with dictionaries to have dictionary type
for (Field field : root.getSchema().getFields()) {
fields.add(DictionaryUtility.toMessageFormat(field, provider, dictionaryIdsUsed));
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/ArrowFooter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/ArrowFooter.java
index 77d3b1e98ff..567fabc1d43 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/ArrowFooter.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/ArrowFooter.java
@@ -28,6 +28,7 @@
import org.apache.arrow.flatbuf.Block;
import org.apache.arrow.flatbuf.Footer;
import org.apache.arrow.flatbuf.KeyValue;
+import org.apache.arrow.vector.types.MetadataVersion;
import org.apache.arrow.vector.types.pojo.Schema;
import com.google.flatbuffers.FlatBufferBuilder;
@@ -43,6 +44,8 @@ public class ArrowFooter implements FBSerializable {
private final Map metaData;
+ private final MetadataVersion metadataVersion;
+
public ArrowFooter(Schema schema, List dictionaries, List recordBatches) {
this(schema, dictionaries, recordBatches, null);
}
@@ -60,11 +63,29 @@ public ArrowFooter(
List dictionaries,
List recordBatches,
Map metaData) {
+ this(schema, dictionaries, recordBatches, metaData, MetadataVersion.DEFAULT);
+ }
+ /**
+ * Constructs a new instance.
+ *
+ * @param schema The schema for record batches in the file.
+ * @param dictionaries The dictionaries relevant to the file.
+ * @param recordBatches The recordBatches written to the file.
+ * @param metaData user-defined k-v meta data.
+ * @param metadataVersion The Arrow metadata version.
+ */
+ public ArrowFooter(
+ Schema schema,
+ List dictionaries,
+ List recordBatches,
+ Map metaData,
+ MetadataVersion metadataVersion) {
this.schema = schema;
this.dictionaries = dictionaries;
this.recordBatches = recordBatches;
this.metaData = metaData;
+ this.metadataVersion = metadataVersion;
}
/**
@@ -75,7 +96,8 @@ public ArrowFooter(Footer footer) {
Schema.convertSchema(footer.schema()),
dictionaries(footer),
recordBatches(footer),
- metaData(footer)
+ metaData(footer),
+ MetadataVersion.fromFlatbufID(footer.version())
);
}
@@ -130,6 +152,10 @@ public Map getMetaData() {
return metaData;
}
+ public MetadataVersion getMetadataVersion() {
+ return metadataVersion;
+ }
+
@Override
public int writeTo(FlatBufferBuilder builder) {
int schemaIndex = schema.getSchema(builder);
@@ -148,6 +174,7 @@ public int writeTo(FlatBufferBuilder builder) {
Footer.addDictionaries(builder, dicsOffset);
Footer.addRecordBatches(builder, rbsOffset);
Footer.addCustomMetadata(builder, metaDataOffset);
+ Footer.addVersion(builder, metadataVersion.toFlatbufID());
return Footer.endFooter(builder);
}
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/IpcOption.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/IpcOption.java
index c1a93dcdd63..b93c3b3da2f 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/IpcOption.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/IpcOption.java
@@ -28,6 +28,6 @@ public class IpcOption {
// consisting of a 4-byte prefix instead of 8 byte
public boolean write_legacy_ipc_format = false;
- // The metadata version. Defaults to V4.
- public MetadataVersion metadataVersion = MetadataVersion.V5;
+ // The metadata version. Defaults to V5.
+ public MetadataVersion metadataVersion = MetadataVersion.DEFAULT;
}
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/MetadataVersion.java b/java/vector/src/main/java/org/apache/arrow/vector/types/MetadataVersion.java
index 9e1894052d0..a0e281960f1 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/types/MetadataVersion.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/types/MetadataVersion.java
@@ -38,6 +38,8 @@ public enum MetadataVersion {
;
+ public static final MetadataVersion DEFAULT = V5;
+
private static final MetadataVersion[] valuesByFlatbufId =
new MetadataVersion[MetadataVersion.values().length];
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/validate/MetadataV4UnionChecker.java b/java/vector/src/main/java/org/apache/arrow/vector/validate/MetadataV4UnionChecker.java
index 330e83d54f7..2a706836567 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/validate/MetadataV4UnionChecker.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/validate/MetadataV4UnionChecker.java
@@ -17,12 +17,13 @@
package org.apache.arrow.vector.validate;
+import java.io.IOException;
import java.util.Iterator;
-import org.apache.arrow.vector.ipc.message.IpcOption;
import org.apache.arrow.vector.types.MetadataVersion;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
/**
* Given a field, checks that no Union fields are present.
@@ -51,8 +52,8 @@ static Field check(Field field) {
/**
* Check the schema, raising an error if an unsupported feature is used (e.g. unions with < V5 metadata).
*/
- public static void checkForUnion(Iterator fields, IpcOption option) {
- if (option.metadataVersion.toFlatbufID() >= MetadataVersion.V5.toFlatbufID()) {
+ public static void checkForUnion(Iterator fields, MetadataVersion metadataVersion) {
+ if (metadataVersion.toFlatbufID() >= MetadataVersion.V5.toFlatbufID()) {
return;
}
while (fields.hasNext()) {
@@ -63,4 +64,19 @@ public static void checkForUnion(Iterator fields, IpcOption option) {
}
}
}
+
+ /**
+ * Check the schema, raising an error if an unsupported feature is used (e.g. unions with < V5 metadata).
+ */
+ public static void checkRead(Schema schema, MetadataVersion metadataVersion) throws IOException {
+ if (metadataVersion.toFlatbufID() >= MetadataVersion.V5.toFlatbufID()) {
+ return;
+ }
+ for (final Field field : schema.getFields()) {
+ Field union = check(field);
+ if (union != null) {
+ throw new IOException("Cannot read union with V4 metadata version. Found field: " + union);
+ }
+ }
+ }
}
diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestRoundTrip.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestRoundTrip.java
index 2aeefff3c4a..971008e5cb5 100644
--- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestRoundTrip.java
+++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestRoundTrip.java
@@ -624,6 +624,7 @@ private void roundTrip(VectorSchemaRoot root, DictionaryProvider provider,
ArrowStreamReader streamReader = new ArrowStreamReader(inputStream, readerAllocator)) {
fileValidator.accept(fileReader);
streamValidator.accept(streamReader);
+ assertEquals(writeOption.metadataVersion, fileReader.getFooter().getMetadataVersion());
assertEquals(metadata, fileReader.getMetaData());
}
}