Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions java/flight/flight-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@
</properties>

<dependencies>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-format</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-vector</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.ipc.message.IpcOption;
import org.apache.arrow.vector.ipc.message.MessageMetadataResult;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.MetadataVersion;
import org.apache.arrow.vector.types.pojo.Schema;

import com.google.common.collect.ImmutableList;
Expand All @@ -50,7 +52,6 @@
import com.google.protobuf.WireFormat;

import io.grpc.Drainable;
import io.grpc.MethodDescriptor;
import io.grpc.MethodDescriptor.Marshaller;
import io.grpc.protobuf.ProtoUtils;
import io.netty.buffer.ByteBuf;
Expand All @@ -75,7 +76,8 @@ class ArrowMessage implements AutoCloseable {
private static final int APP_METADATA_TAG =
(FlightData.APP_METADATA_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED;

private static Marshaller<FlightData> NO_BODY_MARSHALLER = ProtoUtils.marshaller(FlightData.getDefaultInstance());
private static final Marshaller<FlightData> NO_BODY_MARSHALLER =
ProtoUtils.marshaller(FlightData.getDefaultInstance());

/** Get the application-specific metadata in this message. The ArrowMessage retains ownership of the buffer. */
public ArrowBuf getApplicationMetadata() {
Expand Down Expand Up @@ -106,7 +108,7 @@ public static HeaderType getHeader(byte b) {
}

// Pre-allocated buffers for padding serialized ArrowMessages.
private static List<ByteBuf> PADDING_BUFFERS = Arrays.asList(
private static final List<ByteBuf> PADDING_BUFFERS = Arrays.asList(
null,
Unpooled.copiedBuffer(new byte[] { 0 }),
Unpooled.copiedBuffer(new byte[] { 0, 0 }),
Expand All @@ -117,13 +119,15 @@ public static HeaderType getHeader(byte b) {
Unpooled.copiedBuffer(new byte[] { 0, 0, 0, 0, 0, 0, 0 })
);

private final IpcOption writeOption;
private final FlightDescriptor descriptor;
private final MessageMetadataResult message;
private final ArrowBuf appMetadata;
private final List<ArrowBuf> bufs;

public ArrowMessage(FlightDescriptor descriptor, Schema schema) {
ByteBuffer serializedMessage = MessageSerializer.serializeMetadata(schema);
public ArrowMessage(FlightDescriptor descriptor, Schema schema, IpcOption option) {
this.writeOption = option;
ByteBuffer serializedMessage = MessageSerializer.serializeMetadata(schema, writeOption);
this.message = MessageMetadataResult.create(serializedMessage.slice(),
serializedMessage.remaining());
bufs = ImmutableList.of();
Expand All @@ -136,16 +140,18 @@ public ArrowMessage(FlightDescriptor descriptor, Schema schema) {
* @param batch The record batch.
* @param appMetadata The app metadata. May be null. Takes ownership of the buffer otherwise.
*/
public ArrowMessage(ArrowRecordBatch batch, ArrowBuf appMetadata) {
ByteBuffer serializedMessage = MessageSerializer.serializeMetadata(batch);
public ArrowMessage(ArrowRecordBatch batch, ArrowBuf appMetadata, IpcOption option) {
this.writeOption = option;
ByteBuffer serializedMessage = MessageSerializer.serializeMetadata(batch, writeOption);
this.message = MessageMetadataResult.create(serializedMessage.slice(), serializedMessage.remaining());
this.bufs = ImmutableList.copyOf(batch.getBuffers());
this.descriptor = null;
this.appMetadata = appMetadata;
}

public ArrowMessage(ArrowDictionaryBatch batch) {
ByteBuffer serializedMessage = MessageSerializer.serializeMetadata(batch);
public ArrowMessage(ArrowDictionaryBatch batch, IpcOption option) {
this.writeOption = new IpcOption();
ByteBuffer serializedMessage = MessageSerializer.serializeMetadata(batch, writeOption);
serializedMessage = serializedMessage.slice();
this.message = MessageMetadataResult.create(serializedMessage, serializedMessage.remaining());
// asInputStream will free the buffers implicitly, so increment the reference count
Expand All @@ -160,13 +166,17 @@ public ArrowMessage(ArrowDictionaryBatch batch) {
* @param appMetadata The application-provided metadata buffer.
*/
public ArrowMessage(ArrowBuf appMetadata) {
// No need to take IpcOption as it's not used to serialize this kind of message.
this.writeOption = new IpcOption();
this.message = null;
this.bufs = ImmutableList.of();
this.descriptor = null;
this.appMetadata = appMetadata;
}

public ArrowMessage(FlightDescriptor descriptor) {
// No need to take IpcOption as it's not used to serialize this kind of message.
this.writeOption = new IpcOption();
this.message = null;
this.bufs = ImmutableList.of();
this.descriptor = descriptor;
Expand All @@ -175,6 +185,11 @@ public ArrowMessage(FlightDescriptor descriptor) {

private ArrowMessage(FlightDescriptor descriptor, MessageMetadataResult message, ArrowBuf appMetadata,
ArrowBuf buf) {
// No need to take IpcOption as this is used for deserialized ArrowMessage coming from the wire.
this.writeOption = new IpcOption();
if (message != null) {
this.writeOption.metadataVersion = MetadataVersion.fromFlatbufID(message.getMessage().version());
}
this.message = message;
this.descriptor = descriptor;
this.appMetadata = appMetadata;
Expand Down Expand Up @@ -404,7 +419,7 @@ public static Marshaller<ArrowMessage> createMarshaller(BufferAllocator allocato
return new ArrowMessageHolderMarshaller(allocator);
}

private static class ArrowMessageHolderMarshaller implements MethodDescriptor.Marshaller<ArrowMessage> {
private static class ArrowMessageHolderMarshaller implements Marshaller<ArrowMessage> {

private final BufferAllocator allocator;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,19 @@
import java.util.function.Consumer;
import java.util.stream.Collectors;

import org.apache.arrow.flight.impl.Flight;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch;
import org.apache.arrow.vector.ipc.message.IpcOption;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.DictionaryUtility;
import org.apache.arrow.vector.validate.MetadataV4UnionChecker;

/**
* Utilities to work with dictionaries in Flight.
Expand All @@ -51,11 +54,14 @@ private DictionaryUtils() {
* @throws Exception if there was an error closing {@link ArrowMessage} objects. This is not generally expected.
*/
static Schema generateSchemaMessages(final Schema originalSchema, final FlightDescriptor descriptor,
final DictionaryProvider provider, final Consumer<ArrowMessage> messageCallback) throws Exception {
final DictionaryProvider provider, final IpcOption option,
final Consumer<ArrowMessage> messageCallback) throws Exception {
final Set<Long> dictionaryIds = new HashSet<>();
final Schema schema = generateSchema(originalSchema, provider, dictionaryIds);
MetadataV4UnionChecker.checkForUnion(schema.getFields().iterator(), option.metadataVersion);
// Send the schema message
try (final ArrowMessage message = new ArrowMessage(descriptor == null ? null : descriptor.toProtocol(), schema)) {
final Flight.FlightDescriptor protoDescriptor = descriptor == null ? null : descriptor.toProtocol();
try (final ArrowMessage message = new ArrowMessage(protoDescriptor, schema, option)) {
messageCallback.accept(message);
}
// Create and write dictionary batches
Expand All @@ -71,7 +77,7 @@ static Schema generateSchemaMessages(final Schema originalSchema, final FlightDe
final VectorUnloader unloader = new VectorUnloader(dictRoot);
try (final ArrowDictionaryBatch dictionaryBatch = new ArrowDictionaryBatch(
id, unloader.getRecordBatch());
final ArrowMessage message = new ArrowMessage(dictionaryBatch)) {
final ArrowMessage message = new ArrowMessage(dictionaryBatch, option)) {
messageCallback.accept(message);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,24 @@ public ClientStreamListener startPut(FlightDescriptor descriptor, VectorSchemaRo
*/
public ClientStreamListener startPut(FlightDescriptor descriptor, VectorSchemaRoot root, DictionaryProvider provider,
PutListener metadataListener, CallOption... options) {
Preconditions.checkNotNull(descriptor, "descriptor must not be null");
Preconditions.checkNotNull(root, "root must not be null");
Preconditions.checkNotNull(provider, "provider must not be null");
final ClientStreamListener writer = startPut(descriptor, metadataListener, options);
writer.start(root, provider);
return writer;
}

/**
* Create or append a descriptor with another stream.
* @param descriptor FlightDescriptor the descriptor for the data
* @param metadataListener A handler for metadata messages from the server.
* @param options RPC-layer hints for this call.
* @return ClientStreamListener an interface to control uploading data.
* {@link ClientStreamListener#start(VectorSchemaRoot, DictionaryProvider)} will NOT already have been called.
*/
public ClientStreamListener startPut(FlightDescriptor descriptor, PutListener metadataListener,
CallOption... options) {
Preconditions.checkNotNull(descriptor, "descriptor must not be null");
Preconditions.checkNotNull(metadataListener, "metadataListener must not be null");
final io.grpc.CallOptions callOptions = CallOptions.wrapStub(asyncStub, options).getCallOptions();

Expand All @@ -212,11 +227,8 @@ public ClientStreamListener startPut(FlightDescriptor descriptor, VectorSchemaRo
ClientCallStreamObserver<ArrowMessage> observer = (ClientCallStreamObserver<ArrowMessage>)
ClientCalls.asyncBidiStreamingCall(
interceptedChannel.newCall(doPutDescriptor, callOptions), resultObserver);
final ClientStreamListener writer = new PutObserver(
return new PutObserver(
descriptor, observer, metadataListener::isCancelled, metadataListener::getResult);
// Send the schema to start.
writer.start(root, provider);
return writer;
} catch (StatusRuntimeException sre) {
throw StatusUtils.fromGrpcRuntimeException(sre);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@
import org.apache.arrow.flight.impl.Flight;
import org.apache.arrow.vector.ipc.ReadChannel;
import org.apache.arrow.vector.ipc.WriteChannel;
import org.apache.arrow.vector.ipc.message.IpcOption;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.validate.MetadataV4UnionChecker;

import com.fasterxml.jackson.databind.util.ByteBufferBackedInputStream;
import com.google.common.collect.ImmutableList;
Expand All @@ -41,11 +43,12 @@
* A POJO representation of a FlightInfo, metadata associated with a set of data records.
*/
public class FlightInfo {
private Schema schema;
private FlightDescriptor descriptor;
private List<FlightEndpoint> endpoints;
private final Schema schema;
private final FlightDescriptor descriptor;
private final List<FlightEndpoint> endpoints;
private final long bytes;
private final long records;
private final IpcOption option;

/**
* Constructs a new instance.
Expand All @@ -58,15 +61,31 @@ public class FlightInfo {
*/
public FlightInfo(Schema schema, FlightDescriptor descriptor, List<FlightEndpoint> endpoints, long bytes,
long records) {
super();
this(schema, descriptor, endpoints, bytes, records, new IpcOption());
}

/**
* Constructs a new instance.
*
* @param schema The schema of the Flight
* @param descriptor An identifier for the Flight.
* @param endpoints A list of endpoints that have the flight available.
* @param bytes The number of bytes in the flight
* @param records The number of records in the flight.
* @param option IPC write options.
*/
public FlightInfo(Schema schema, FlightDescriptor descriptor, List<FlightEndpoint> endpoints, long bytes,
long records, IpcOption option) {
Objects.requireNonNull(schema);
Objects.requireNonNull(descriptor);
Objects.requireNonNull(endpoints);
MetadataV4UnionChecker.checkForUnion(schema.getFields().iterator(), option.metadataVersion);
this.schema = schema;
this.descriptor = descriptor;
this.endpoints = endpoints;
this.bytes = bytes;
this.records = records;
this.option = option;
}

/**
Expand All @@ -89,6 +108,7 @@ public FlightInfo(Schema schema, FlightDescriptor descriptor, List<FlightEndpoin
}
bytes = pbFlightInfo.getTotalBytes();
records = pbFlightInfo.getTotalRecords();
option = new IpcOption();
}

public Schema getSchema() {
Expand Down Expand Up @@ -118,7 +138,7 @@ Flight.FlightInfo toProtocol() {
// Encode schema in a Message payload
ByteArrayOutputStream baos = new ByteArrayOutputStream();
try {
MessageSerializer.serialize(new WriteChannel(Channels.newChannel(baos)), schema);
MessageSerializer.serialize(new WriteChannel(Channels.newChannel(baos)), schema, option);
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.arrow.flight;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
Expand All @@ -31,16 +32,19 @@
import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.util.VisibleForTesting;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorLoader;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.types.MetadataVersion;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.DictionaryUtility;
import org.apache.arrow.vector.validate.MetadataV4UnionChecker;

import com.google.common.util.concurrent.SettableFuture;

Expand Down Expand Up @@ -72,6 +76,8 @@ public class FlightStream implements AutoCloseable {
private volatile VectorLoader loader;
private volatile Throwable ex;
private volatile ArrowBuf applicationMetadata = null;
@VisibleForTesting
volatile MetadataVersion metadataVersion = null;

/**
* Constructs a new instance.
Expand Down Expand Up @@ -212,13 +218,15 @@ public boolean next() {
fulfilledRoot.clear();
}
} else if (msg.getMessageType() == HeaderType.RECORD_BATCH) {
checkMetadataVersion(msg);
// Ensure we have the root
root.get().clear();
try (ArrowRecordBatch arb = msg.asRecordBatch()) {
loader.load(arb);
}
updateMetadata(msg);
} else if (msg.getMessageType() == HeaderType.DICTIONARY_BATCH) {
checkMetadataVersion(msg);
// Ensure we have the root
root.get().clear();
try (ArrowDictionaryBatch arb = msg.asDictionaryBatch()) {
Expand Down Expand Up @@ -253,7 +261,7 @@ public boolean next() {
}
}

/** Update our metdata reference with a new one from this message. */
/** Update our metadata reference with a new one from this message. */
private void updateMetadata(ArrowMessage msg) {
if (this.applicationMetadata != null) {
this.applicationMetadata.close();
Expand All @@ -264,6 +272,18 @@ private void updateMetadata(ArrowMessage msg) {
}
}

/** Ensure the Arrow metadata version doesn't change mid-stream. */
private void checkMetadataVersion(ArrowMessage msg) {
if (msg.asSchemaMessage() == null) {
return;
}
MetadataVersion receivedVersion = MetadataVersion.fromFlatbufID(msg.asSchemaMessage().getMessage().version());
if (this.metadataVersion != receivedVersion) {
throw new IllegalStateException("Metadata version mismatch: stream started as " +
this.metadataVersion + " but got message with version " + receivedVersion);
}
}

/**
* Get the current vector data from the stream.
*
Expand Down Expand Up @@ -343,13 +363,21 @@ public void onNext(ArrowMessage msg) {
dictionaries.put(entry.getValue());
}
schema = new Schema(fields, schema.getCustomMetadata());
metadataVersion = MetadataVersion.fromFlatbufID(msg.asSchemaMessage().getMessage().version());
try {
MetadataV4UnionChecker.checkRead(schema, metadataVersion);
} catch (IOException e) {
queue.add(DONE_EX);
ex = e;
break;
}

fulfilledRoot = VectorSchemaRoot.create(schema, allocator);
loader = new VectorLoader(fulfilledRoot);
if (msg.getDescriptor() != null) {
descriptor.set(new FlightDescriptor(msg.getDescriptor()));
}
root.set(fulfilledRoot);

break;
}
case RECORD_BATCH:
Expand Down
Loading