doAction(Action action, CallOption... options) {
+ // TODO: need to wrap all methods to catch exceptions
return Iterators
.transform(CallOptions.wrapStub(blockingStub, options).doAction(action.toProtocol()), Result::new);
}
@@ -241,7 +257,7 @@ public void onNext(ArrowMessage value) {
@Override
public void onError(Throwable t) {
- delegate.onError(t);
+ delegate.onError(StatusUtils.toGrpcException(t));
}
@Override
@@ -313,7 +329,7 @@ public void putNext(ArrowBuf appMetadata) {
@Override
public void error(Throwable ex) {
- observer.onError(ex);
+ observer.onError(StatusUtils.toGrpcException(ex));
}
@Override
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightRuntimeException.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightRuntimeException.java
index 30254ed2c9d..5ac8176bdc6 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightRuntimeException.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightRuntimeException.java
@@ -19,6 +19,9 @@
/**
* An exception raised from a Flight RPC.
+ *
+ * In service implementations, raising an instance of this exception will provide clients with a more detailed
+ * message and error code.
*/
public class FlightRuntimeException extends RuntimeException {
private final CallStatus status;
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java
index 7637e89f9b5..31af04a1f0b 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java
@@ -24,6 +24,7 @@
import org.apache.arrow.flight.auth.AuthConstants;
import org.apache.arrow.flight.auth.ServerAuthHandler;
import org.apache.arrow.flight.auth.ServerAuthWrapper;
+import org.apache.arrow.flight.grpc.StatusUtils;
import org.apache.arrow.flight.impl.Flight;
import org.apache.arrow.flight.impl.Flight.ActionType;
import org.apache.arrow.flight.impl.Flight.Empty;
@@ -77,7 +78,7 @@ public void listFlights(Flight.Criteria criteria, StreamObserver) responseObserver), new Criteria(criteria),
StreamPipe.wrap(responseObserver, FlightInfo::toProtocol));
} catch (Exception ex) {
- responseObserver.onError(ex);
+ responseObserver.onError(StatusUtils.toGrpcException(ex));
}
}
@@ -91,7 +92,7 @@ public void doGetCustom(Flight.Ticket ticket, StreamObserver respo
producer.getStream(makeContext((ServerCallStreamObserver>) responseObserver), new Ticket(ticket),
new GetListener(responseObserver));
} catch (Exception ex) {
- responseObserver.onError(ex);
+ responseObserver.onError(StatusUtils.toGrpcException(ex));
}
}
@@ -101,7 +102,7 @@ public void doAction(Flight.Action request, StreamObserver respon
producer.doAction(makeContext((ServerCallStreamObserver>) responseObserver), new Action(request),
StreamPipe.wrap(responseObserver, Result::toProtocol));
} catch (Exception ex) {
- responseObserver.onError(ex);
+ responseObserver.onError(StatusUtils.toGrpcException(ex));
}
}
@@ -111,7 +112,7 @@ public void listActions(Empty request, StreamObserver responseObserv
producer.listActions(makeContext((ServerCallStreamObserver>) responseObserver),
StreamPipe.wrap(responseObserver, t -> t.toProtocol()));
} catch (Exception ex) {
- responseObserver.onError(ex);
+ responseObserver.onError(StatusUtils.toGrpcException(ex));
}
}
@@ -165,7 +166,7 @@ public void putNext(ArrowBuf metadata) {
@Override
public void error(Throwable ex) {
- responseObserver.onError(ex);
+ responseObserver.onError(StatusUtils.toGrpcException(ex));
}
@Override
@@ -192,7 +193,7 @@ public StreamObserver doPutCustom(final StreamObserver wrapHandshake(ServerAuthHandler a
responseObserver.onError(Status.PERMISSION_DENIED.asException());
} catch (Exception ex) {
- responseObserver.onError(ex);
+ responseObserver.onError(StatusUtils.toGrpcException(ex));
}
};
observer.future = executors.submit(r);
@@ -128,8 +129,8 @@ public void send(byte[] payload) {
}
@Override
- public void onError(String message, Throwable cause) {
- responseObserver.onError(cause);
+ public void onError(Throwable cause) {
+ responseObserver.onError(StatusUtils.toGrpcException(cause));
}
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/grpc/StatusUtils.java b/java/flight/src/main/java/org/apache/arrow/flight/grpc/StatusUtils.java
index d010fcea2c6..84576a53c86 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/grpc/StatusUtils.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/grpc/StatusUtils.java
@@ -18,10 +18,13 @@
package org.apache.arrow.flight.grpc;
import org.apache.arrow.flight.CallStatus;
+import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.FlightStatusCode;
import io.grpc.Status;
import io.grpc.Status.Code;
+import io.grpc.StatusException;
+import io.grpc.StatusRuntimeException;
/**
* Utilities to adapt gRPC and Flight status objects.
@@ -117,4 +120,27 @@ public static CallStatus fromGrpcStatus(Status status) {
public static Status toGrpcStatus(CallStatus status) {
return toGrpcStatusCode(status.code()).toStatus().withDescription(status.description()).withCause(status.cause());
}
+
+ public static FlightRuntimeException fromGrpcRuntimeException(StatusRuntimeException sre) {
+ return fromGrpcStatus(sre.getStatus()).toRuntimeException();
+ }
+
+ /**
+ * Convert arbitrary exceptions to a {@link StatusRuntimeException} or {@link StatusException}.
+ *
+ * Such exceptions can be passed to {@link io.grpc.stub.StreamObserver#onError(Throwable)} and will give the client
+ * a reasonable error message.
+ */
+ public static Throwable toGrpcException(Throwable ex) {
+ if (ex instanceof StatusRuntimeException) {
+ return ex;
+ } else if (ex instanceof StatusException) {
+ return ex;
+ } else if (ex instanceof FlightRuntimeException) {
+ final FlightRuntimeException fre = (FlightRuntimeException) ex;
+ return toGrpcStatus(fre.status()).asRuntimeException();
+ }
+ return Status.INTERNAL.withCause(ex).withDescription("There was an error servicing your request.")
+ .asRuntimeException();
+ }
}
From 7db0e541c790ae4971f0626af478859e8fcfbcb3 Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 21 Jun 2019 12:21:17 -0400
Subject: [PATCH 3/3] Try to wrap gRPC exceptions everywhere
---
.../apache/arrow/flight/AsyncPutListener.java | 4 +-
.../org/apache/arrow/flight/CallStatus.java | 14 ++++
.../org/apache/arrow/flight/FlightClient.java | 70 +++++++++----------
.../arrow/flight/NoOpFlightProducer.java | 12 ++--
.../apache/arrow/flight/SyncPutListener.java | 4 +-
.../arrow/flight/auth/ClientAuthWrapper.java | 25 ++++---
.../flight/auth/ServerAuthInterceptor.java | 2 +-
.../arrow/flight/auth/ServerAuthWrapper.java | 4 +-
.../arrow/flight/example/InMemoryStore.java | 3 +-
.../apache/arrow/flight/grpc/StatusUtils.java | 46 ++++++++++++
.../arrow/flight/TestApplicationMetadata.java | 5 +-
.../arrow/flight/TestBasicOperation.java | 13 +++-
.../java/org/apache/arrow/flight/TestTls.java | 19 +++--
.../apache/arrow/flight/auth/TestAuth.java | 25 +++----
14 files changed, 168 insertions(+), 78 deletions(-)
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/AsyncPutListener.java b/java/flight/src/main/java/org/apache/arrow/flight/AsyncPutListener.java
index c8214e31953..c2182fc4e68 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/AsyncPutListener.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/AsyncPutListener.java
@@ -20,6 +20,8 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
+import org.apache.arrow.flight.grpc.StatusUtils;
+
/**
* A handler for server-sent application metadata messages during a Flight DoPut operation.
*
@@ -53,7 +55,7 @@ public void onNext(PutResult val) {
@Override
public final void onError(Throwable t) {
- completed.completeExceptionally(t);
+ completed.completeExceptionally(StatusUtils.fromThrowable(t));
}
@Override
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/CallStatus.java b/java/flight/src/main/java/org/apache/arrow/flight/CallStatus.java
index 3de8b8e63db..10098ccdc6c 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/CallStatus.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/CallStatus.java
@@ -72,6 +72,20 @@ public String description() {
return description;
}
+ /**
+ * Return a copy of this status with an error message.
+ */
+ public CallStatus withDescription(String message) {
+ return new CallStatus(code, cause, message);
+ }
+
+ /**
+ * Return a copy of this status with the given exception as the cause. This will not be sent over the wire.
+ */
+ public CallStatus withCause(Throwable t) {
+ return new CallStatus(code, t, description);
+ }
+
/**
* Convert the status to an equivalent exception.
*/
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
index 7a98a16d22c..4caaac1e5ca 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
@@ -21,7 +21,6 @@
import java.net.URISyntaxException;
import java.util.Iterator;
import java.util.concurrent.TimeUnit;
-import java.util.stream.Collectors;
import javax.net.ssl.SSLException;
@@ -44,8 +43,6 @@
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import com.google.common.base.Preconditions;
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.Iterators;
import io.grpc.ClientCall;
import io.grpc.ManagedChannel;
@@ -105,18 +102,15 @@ public Iterable listFlights(Criteria criteria, CallOption... options
} catch (StatusRuntimeException sre) {
throw StatusUtils.fromGrpcRuntimeException(sre);
}
- return ImmutableList.copyOf(flights)
- .stream()
- .map(t -> {
- try {
- return new FlightInfo(t);
- } catch (URISyntaxException e) {
- // We don't expect this will happen for conforming Flight implementations. For instance, a Java server
- // itself wouldn't be able to construct an invalid Location.
- throw new RuntimeException(e);
- }
- })
- .collect(Collectors.toList());
+ return () -> StatusUtils.wrapIterator(flights, t -> {
+ try {
+ return new FlightInfo(t);
+ } catch (URISyntaxException e) {
+ // We don't expect this will happen for conforming Flight implementations. For instance, a Java server
+ // itself wouldn't be able to construct an invalid Location.
+ throw new RuntimeException(e);
+ }
+ });
}
/**
@@ -132,10 +126,7 @@ public Iterable listActions(CallOption... options) {
} catch (StatusRuntimeException sre) {
throw StatusUtils.fromGrpcRuntimeException(sre);
}
- return ImmutableList.copyOf(actions)
- .stream()
- .map(ActionType::new)
- .collect(Collectors.toList());
+ return () -> StatusUtils.wrapIterator(actions, ActionType::new);
}
/**
@@ -146,9 +137,8 @@ public Iterable listActions(CallOption... options) {
* @return An iterator of results.
*/
public Iterator doAction(Action action, CallOption... options) {
- // TODO: need to wrap all methods to catch exceptions
- return Iterators
- .transform(CallOptions.wrapStub(blockingStub, options).doAction(action.toProtocol()), Result::new);
+ return StatusUtils
+ .wrapIterator(CallOptions.wrapStub(blockingStub, options).doAction(action.toProtocol()), Result::new);
}
/**
@@ -199,16 +189,20 @@ public ClientStreamListener startPut(FlightDescriptor descriptor, VectorSchemaRo
Preconditions.checkNotNull(descriptor);
Preconditions.checkNotNull(root);
- SetStreamObserver resultObserver = new SetStreamObserver(allocator, metadataListener);
- final io.grpc.CallOptions callOptions = CallOptions.wrapStub(asyncStub, options).getCallOptions();
- ClientCallStreamObserver observer = (ClientCallStreamObserver)
- ClientCalls.asyncBidiStreamingCall(
- authInterceptor.interceptCall(doPutDescriptor, callOptions, channel), resultObserver);
- // send the schema to start.
- DictionaryUtils.generateSchemaMessages(root.getSchema(), descriptor, provider, observer::onNext);
- return new PutObserver(new VectorUnloader(
- root, true /* include # of nulls in vectors */, true /* must align buffers to be C++-compatible */),
- observer, metadataListener);
+ try {
+ SetStreamObserver resultObserver = new SetStreamObserver(allocator, metadataListener);
+ final io.grpc.CallOptions callOptions = CallOptions.wrapStub(asyncStub, options).getCallOptions();
+ ClientCallStreamObserver observer = (ClientCallStreamObserver)
+ ClientCalls.asyncBidiStreamingCall(
+ authInterceptor.interceptCall(doPutDescriptor, callOptions, channel), resultObserver);
+ // send the schema to start.
+ DictionaryUtils.generateSchemaMessages(root.getSchema(), descriptor, provider, observer::onNext);
+ return new PutObserver(new VectorUnloader(
+ root, true /* include # of nulls in vectors */, true /* must align buffers to be C++-compatible */),
+ observer, metadataListener);
+ } catch (StatusRuntimeException sre) {
+ throw StatusUtils.fromGrpcRuntimeException(sre);
+ }
}
/**
@@ -223,6 +217,8 @@ public FlightInfo getInfo(FlightDescriptor descriptor, CallOption... options) {
// We don't expect this will happen for conforming Flight implementations. For instance, a Java server
// itself wouldn't be able to construct an invalid Location.
throw new RuntimeException(e);
+ } catch (StatusRuntimeException sre) {
+ throw StatusUtils.fromGrpcRuntimeException(sre);
}
}
@@ -290,7 +286,7 @@ public void onNext(Flight.PutResult value) {
@Override
public void onError(Throwable t) {
- listener.onError(t);
+ listener.onError(StatusUtils.fromThrowable(t));
}
@Override
@@ -323,8 +319,12 @@ public void putNext(ArrowBuf appMetadata) {
while (!observer.isReady()) {
/* busy wait */
}
- // Takes ownership of appMetadata
- observer.onNext(new ArrowMessage(batch, appMetadata));
+ try {
+ // Takes ownership of appMetadata
+ observer.onNext(new ArrowMessage(batch, appMetadata));
+ } catch (StatusRuntimeException sre) {
+ throw StatusUtils.fromGrpcRuntimeException(sre);
+ }
}
@Override
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/NoOpFlightProducer.java b/java/flight/src/main/java/org/apache/arrow/flight/NoOpFlightProducer.java
index eca32e1c679..d1432f514d8 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/NoOpFlightProducer.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/NoOpFlightProducer.java
@@ -25,37 +25,37 @@ public class NoOpFlightProducer implements FlightProducer {
@Override
public void getStream(CallContext context, Ticket ticket,
ServerStreamListener listener) {
- listener.error(new UnsupportedOperationException("NYI"));
+ listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException());
}
@Override
public void listFlights(CallContext context, Criteria criteria,
StreamListener listener) {
- listener.onError(new UnsupportedOperationException("NYI"));
+ listener.onError(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException());
}
@Override
public FlightInfo getFlightInfo(CallContext context,
FlightDescriptor descriptor) {
- throw new UnsupportedOperationException("NYI");
+ throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException();
}
@Override
public Runnable acceptPut(CallContext context,
FlightStream flightStream, StreamListener ackStream) {
- throw new UnsupportedOperationException("NYI");
+ throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException();
}
@Override
public void doAction(CallContext context, Action action,
StreamListener listener) {
- throw new UnsupportedOperationException("NYI");
+ listener.onError(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException());
}
@Override
public void listActions(CallContext context,
StreamListener listener) {
- listener.onError(new UnsupportedOperationException("NYI"));
+ listener.onError(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException());
}
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/SyncPutListener.java b/java/flight/src/main/java/org/apache/arrow/flight/SyncPutListener.java
index f1246a1d079..690e7742eac 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/SyncPutListener.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/SyncPutListener.java
@@ -22,6 +22,8 @@
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
+import org.apache.arrow.flight.grpc.StatusUtils;
+
import io.netty.buffer.ArrowBuf;
/**
@@ -93,7 +95,7 @@ public void onNext(PutResult val) {
@Override
public void onError(Throwable t) {
- completed.completeExceptionally(t);
+ completed.completeExceptionally(StatusUtils.fromThrowable(t));
queue.add(DONE_WITH_EXCEPTION);
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/auth/ClientAuthWrapper.java b/java/flight/src/main/java/org/apache/arrow/flight/auth/ClientAuthWrapper.java
index 42974cac81a..ed133fd3f51 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/auth/ClientAuthWrapper.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/auth/ClientAuthWrapper.java
@@ -29,6 +29,7 @@
import com.google.common.base.Throwables;
import com.google.protobuf.ByteString;
+import io.grpc.StatusRuntimeException;
import io.grpc.stub.StreamObserver;
/**
@@ -43,10 +44,14 @@ public class ClientAuthWrapper {
* @param stub The service stub.
*/
public static void doClientAuth(ClientAuthHandler authHandler, FlightServiceStub stub) {
- AuthObserver observer = new AuthObserver();
- observer.responseObserver = stub.handshake(observer);
- authHandler.authenticate(observer.sender, observer.iter);
- observer.responseObserver.onCompleted();
+ try {
+ AuthObserver observer = new AuthObserver();
+ observer.responseObserver = stub.handshake(observer);
+ authHandler.authenticate(observer.sender, observer.iter);
+ observer.responseObserver.onCompleted();
+ } catch (StatusRuntimeException sre) {
+ throw StatusUtils.fromGrpcRuntimeException(sre);
+ }
}
private static class AuthObserver implements StreamObserver {
@@ -98,16 +103,20 @@ public boolean hasNext() {
@Override
public void onError(Throwable t) {
- ex = t;
+ ex = StatusUtils.fromThrowable(t);
}
private class AuthSender implements ClientAuthSender {
@Override
public void send(byte[] payload) {
- responseObserver.onNext(HandshakeRequest.newBuilder()
- .setPayload(ByteString.copyFrom(payload))
- .build());
+ try {
+ responseObserver.onNext(HandshakeRequest.newBuilder()
+ .setPayload(ByteString.copyFrom(payload))
+ .build());
+ } catch (StatusRuntimeException sre) {
+ throw StatusUtils.fromGrpcRuntimeException(sre);
+ }
}
@Override
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthInterceptor.java b/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthInterceptor.java
index f38dee74ce7..4ebd7424cb8 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthInterceptor.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthInterceptor.java
@@ -45,7 +45,7 @@ public Listener interceptCall(ServerCall call,
if (!call.getMethodDescriptor().getFullMethodName().equals(AuthConstants.HANDSHAKE_DESCRIPTOR_NAME)) {
final Optional peerIdentity = isValid(headers);
if (!peerIdentity.isPresent()) {
- call.close(Status.PERMISSION_DENIED, new Metadata());
+ call.close(Status.UNAUTHENTICATED, new Metadata());
// TODO: we should actually terminate here instead of causing an exception below.
return new NoopServerCallListener<>();
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthWrapper.java b/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthWrapper.java
index 0e7e43fafad..d91027b5b8b 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthWrapper.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthWrapper.java
@@ -22,6 +22,7 @@
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
+import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.auth.ServerAuthHandler.ServerAuthSender;
import org.apache.arrow.flight.grpc.StatusUtils;
import org.apache.arrow.flight.impl.Flight.HandshakeRequest;
@@ -29,7 +30,6 @@
import com.google.protobuf.ByteString;
-import io.grpc.Status;
import io.grpc.stub.StreamObserver;
/**
@@ -57,7 +57,7 @@ public static StreamObserver wrapHandshake(ServerAuthHandler a
return;
}
- responseObserver.onError(Status.PERMISSION_DENIED.asException());
+ responseObserver.onError(StatusUtils.toGrpcException(CallStatus.UNAUTHENTICATED.toRuntimeException()));
} catch (Exception ex) {
responseObserver.onError(StatusUtils.toGrpcException(ex));
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java b/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java
index 59324b30397..5508399685d 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java
@@ -22,6 +22,7 @@
import org.apache.arrow.flight.Action;
import org.apache.arrow.flight.ActionType;
+import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.Criteria;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightInfo;
@@ -143,7 +144,7 @@ public void doAction(CallContext context, Action action,
break;
}
default: {
- listener.onError(new UnsupportedOperationException());
+ listener.onError(CallStatus.UNIMPLEMENTED.toRuntimeException());
}
}
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/grpc/StatusUtils.java b/java/flight/src/main/java/org/apache/arrow/flight/grpc/StatusUtils.java
index 84576a53c86..ce7b7f5a497 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/grpc/StatusUtils.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/grpc/StatusUtils.java
@@ -17,6 +17,10 @@
package org.apache.arrow.flight.grpc;
+import java.util.Iterator;
+import java.util.Objects;
+import java.util.function.Function;
+
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.FlightStatusCode;
@@ -121,10 +125,23 @@ public static Status toGrpcStatus(CallStatus status) {
return toGrpcStatusCode(status.code()).toStatus().withDescription(status.description()).withCause(status.cause());
}
+ /** Convert from a gRPC exception to a Flight exception. */
public static FlightRuntimeException fromGrpcRuntimeException(StatusRuntimeException sre) {
return fromGrpcStatus(sre.getStatus()).toRuntimeException();
}
+ /**
+ * Convert arbitrary exceptions to a {@link FlightRuntimeException}.
+ */
+ public static FlightRuntimeException fromThrowable(Throwable t) {
+ if (t instanceof StatusRuntimeException) {
+ return fromGrpcRuntimeException((StatusRuntimeException) t);
+ } else if (t instanceof FlightRuntimeException) {
+ return (FlightRuntimeException) t;
+ }
+ return CallStatus.UNKNOWN.withCause(t).withDescription(t.getMessage()).toRuntimeException();
+ }
+
/**
* Convert arbitrary exceptions to a {@link StatusRuntimeException} or {@link StatusException}.
*
@@ -143,4 +160,33 @@ public static Throwable toGrpcException(Throwable ex) {
return Status.INTERNAL.withCause(ex).withDescription("There was an error servicing your request.")
.asRuntimeException();
}
+
+ /**
+ * Maps a transformation function to the elements of an iterator, while wrapping exceptions in {@link
+ * FlightRuntimeException}.
+ */
+ public static Iterator wrapIterator(Iterator fromIterator,
+ Function super FROM, ? extends TO> transformer) {
+ Objects.requireNonNull(fromIterator);
+ Objects.requireNonNull(transformer);
+ return new Iterator() {
+ @Override
+ public boolean hasNext() {
+ try {
+ return fromIterator.hasNext();
+ } catch (StatusRuntimeException e) {
+ throw fromGrpcRuntimeException(e);
+ }
+ }
+
+ @Override
+ public TO next() {
+ try {
+ return transformer.apply(fromIterator.next());
+ } catch (StatusRuntimeException e) {
+ throw fromGrpcRuntimeException(e);
+ }
+ }
+ };
+ }
}
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java b/java/flight/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java
index ad2c58f3b78..e19bff07b25 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java
@@ -36,7 +36,6 @@
import org.junit.Ignore;
import org.junit.Test;
-import io.grpc.Status;
import io.netty.buffer.ArrowBuf;
/**
@@ -225,9 +224,9 @@ public Runnable acceptPut(CallContext context, FlightStream stream, StreamListen
while (stream.next()) {
final ArrowBuf metadata = stream.getLatestMetadata();
if (current != metadata.getByte(0)) {
- ackStream.onError(Status.INVALID_ARGUMENT.withDescription(String
+ ackStream.onError(CallStatus.INVALID_ARGUMENT.withDescription(String
.format("Metadata does not match expected value; got %d but expected %d.", metadata.getByte(0),
- current)).asRuntimeException());
+ current)).toRuntimeException());
return;
}
ackStream.onNext(PutResult.metadata(metadata));
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java b/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java
index abc5a2c321d..4a69d3fb6fe 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java
@@ -29,8 +29,10 @@
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VectorSchemaRoot;
+
import org.junit.Assert;
import org.junit.Test;
+import org.junit.jupiter.api.Assertions;
import com.google.common.base.Charsets;
import com.google.protobuf.ByteString;
@@ -125,6 +127,14 @@ public void putStream() throws Exception {
});
}
+ @Test
+ public void propagateErrors() throws Exception {
+ test(client -> {
+ final FlightRuntimeException ex = Assertions.assertThrows(FlightRuntimeException.class,
+ () -> client.doAction(new Action("invalid-action")).forEachRemaining(action -> Assert.fail()));
+ Assert.assertEquals(FlightStatusCode.UNIMPLEMENTED, ex.status().code());
+ });
+ }
@Test
public void getStream() throws Exception {
@@ -274,7 +284,8 @@ public void doAction(CallContext context, Action action,
break;
}
default:
- listener.onError(new UnsupportedOperationException());
+ listener.onError(CallStatus.UNIMPLEMENTED.withDescription("Action not implemented: " + action.getType())
+ .toRuntimeException());
}
}
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestTls.java b/java/flight/src/test/java/org/apache/arrow/flight/TestTls.java
index b9d4dea5572..8d6e91522b2 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/TestTls.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestTls.java
@@ -30,6 +30,7 @@
import org.junit.Assert;
import org.junit.Test;
+import org.junit.jupiter.api.Assertions;
/**
* Tests for TLS in Flight.
@@ -57,13 +58,14 @@ public void connectTls() {
/**
* Make sure that connections are rejected when the root certificate isn't trusted.
*/
- @Test(expected = io.grpc.StatusRuntimeException.class)
+ @Test
public void rejectInvalidCert() {
test((builder) -> {
try (final FlightClient client = builder.build()) {
final Iterator responses = client.doAction(new Action("hello-world"));
- responses.next().getBody();
- Assert.fail("Call should have failed");
+ final FlightRuntimeException ex = Assertions
+ .assertThrows(FlightRuntimeException.class, () -> responses.next().getBody());
+ Assert.assertEquals(FlightStatusCode.UNAVAILABLE, ex.status().code());
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
@@ -73,15 +75,16 @@ public void rejectInvalidCert() {
/**
* Make sure that connections are rejected when the hostname doesn't match.
*/
- @Test(expected = io.grpc.StatusRuntimeException.class)
+ @Test
public void rejectHostname() {
test((builder) -> {
try (final InputStream roots = new FileInputStream(FlightTestUtil.exampleTlsRootCert().toFile());
final FlightClient client = builder.trustedCertificates(roots).overrideHostname("fakehostname")
.build()) {
final Iterator responses = client.doAction(new Action("hello-world"));
- responses.next().getBody();
- Assert.fail("Call should have failed");
+ final FlightRuntimeException ex = Assertions
+ .assertThrows(FlightRuntimeException.class, () -> responses.next().getBody());
+ Assert.assertEquals(FlightStatusCode.UNAVAILABLE, ex.status().code());
} catch (InterruptedException | IOException e) {
throw new RuntimeException(e);
}
@@ -119,8 +122,10 @@ public void doAction(CallContext context, Action action, StreamListener
if (action.getType().equals("hello-world")) {
listener.onNext(new Result("Hello, world!".getBytes(StandardCharsets.UTF_8)));
listener.onCompleted();
+ return;
}
- listener.onError(new UnsupportedOperationException("Invalid action " + action.getType()));
+ listener
+ .onError(CallStatus.UNIMPLEMENTED.withDescription("Invalid action " + action.getType()).toRuntimeException());
}
@Override
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java b/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java
index 54bbadb0369..2aaeefcb91d 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java
@@ -26,7 +26,9 @@
import org.apache.arrow.flight.Criteria;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightInfo;
+import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.FlightStatusCode;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.FlightTestUtil;
import org.apache.arrow.flight.NoOpFlightProducer;
@@ -46,11 +48,7 @@
import com.google.common.collect.ImmutableList;
-import io.grpc.StatusRuntimeException;
-
public class TestAuth {
- final String PERMISSION_DENIED = "PERMISSION_DENIED";
-
private static final String USERNAME = "flight";
private static final String PASSWORD = "woohoo";
private static final byte[] VALID_TOKEN = "my_token".getBytes();
@@ -79,20 +77,23 @@ public void asyncCall() {
@Test
public void invalidAuth() {
- assertThrows(StatusRuntimeException.class, () -> {
+ FlightRuntimeException ex = assertThrows(FlightRuntimeException.class, () -> {
client.authenticateBasic(USERNAME, "WRONG");
- }, PERMISSION_DENIED);
+ });
+ Assert.assertEquals(FlightStatusCode.UNAUTHENTICATED, ex.status().code());
- assertThrows(StatusRuntimeException.class, () -> {
- client.listFlights(Criteria.ALL);
- }, PERMISSION_DENIED);
+ ex = assertThrows(FlightRuntimeException.class, () -> {
+ client.listFlights(Criteria.ALL).forEach(action -> Assert.fail());
+ });
+ Assert.assertEquals(FlightStatusCode.UNAUTHENTICATED, ex.status().code());
}
@Test
public void didntAuth() {
- assertThrows(StatusRuntimeException.class, () -> {
- client.listFlights(Criteria.ALL);
- }, PERMISSION_DENIED);
+ FlightRuntimeException ex = assertThrows(FlightRuntimeException.class, () -> {
+ client.listFlights(Criteria.ALL).forEach(action -> Assert.fail());
+ });
+ Assert.assertEquals(FlightStatusCode.UNAUTHENTICATED, ex.status().code());
}
@Before