From fc35d190c0a9e340c6cc455b5da82db2ceef5738 Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 9 Jul 2019 14:01:19 -0400
Subject: [PATCH] Wait for authentication to complete server-side
---
cpp/src/arrow/flight/client.cc | 5 +
.../arrow/flight/auth/ClientAuthWrapper.java | 46 +++++++--
.../arrow/flight/auth/ServerAuthHandler.java | 2 +
.../arrow/flight/auth/ServerAuthWrapper.java | 2 +
.../org/apache/arrow/flight/TestAuth.java | 94 +++++++++++++++++++
.../{TestAuth.java => TestBasicAuth.java} | 2 +-
6 files changed, 141 insertions(+), 10 deletions(-)
create mode 100644 java/flight/src/test/java/org/apache/arrow/flight/TestAuth.java
rename java/flight/src/test/java/org/apache/arrow/flight/auth/{TestAuth.java => TestBasicAuth.java} (99%)
diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc
index c508dca769f..cb4b92d7cb9 100644
--- a/cpp/src/arrow/flight/client.cc
+++ b/cpp/src/arrow/flight/client.cc
@@ -430,7 +430,12 @@ class FlightClient::FlightClientImpl {
GrpcClientAuthSender outgoing{stream};
GrpcClientAuthReader incoming{stream};
RETURN_NOT_OK(auth_handler_->Authenticate(&outgoing, &incoming));
+ // Explicitly close our side of the connection
+ bool finished_writes = stream->WritesDone();
RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish()));
+ if (!finished_writes) {
+ return Status::UnknownError("Could not finish writing before closing");
+ }
return Status::OK();
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/auth/ClientAuthWrapper.java b/java/flight/src/main/java/org/apache/arrow/flight/auth/ClientAuthWrapper.java
index f916c9217d0..9b8034003cd 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/auth/ClientAuthWrapper.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/auth/ClientAuthWrapper.java
@@ -18,6 +18,9 @@
package org.apache.arrow.flight.auth;
import java.util.Iterator;
+import java.util.Objects;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
import java.util.concurrent.LinkedBlockingQueue;
import org.apache.arrow.flight.auth.ClientAuthHandler.ClientAuthSender;
@@ -25,9 +28,9 @@
import org.apache.arrow.flight.impl.Flight.HandshakeResponse;
import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceStub;
-import com.google.common.base.Throwables;
import com.google.protobuf.ByteString;
+import io.grpc.StatusRuntimeException;
import io.grpc.stub.StreamObserver;
/**
@@ -45,7 +48,17 @@ public static void doClientAuth(ClientAuthHandler authHandler, FlightServiceStub
AuthObserver observer = new AuthObserver();
observer.responseObserver = stub.handshake(observer);
authHandler.authenticate(observer.sender, observer.iter);
- observer.responseObserver.onCompleted();
+ if (!observer.sender.errored) {
+ observer.responseObserver.onCompleted();
+ }
+ try {
+ if (!observer.completed.get()) {
+ // TODO: ARROW-5681
+ throw new RuntimeException("Unauthenticated");
+ }
+ } catch (InterruptedException | ExecutionException e) {
+ throw new RuntimeException(e);
+ }
}
private static class AuthObserver implements StreamObserver {
@@ -53,11 +66,11 @@ private static class AuthObserver implements StreamObserver {
private volatile StreamObserver responseObserver;
private final LinkedBlockingQueue messages = new LinkedBlockingQueue<>();
private final AuthSender sender = new AuthSender();
- private volatile boolean completed = false;
- private Throwable ex = null;
+ private CompletableFuture completed;
public AuthObserver() {
super();
+ completed = new CompletableFuture<>();
}
@Override
@@ -72,7 +85,7 @@ public void onNext(HandshakeResponse value) {
@Override
public byte[] next() {
- while (ex == null && (!completed || !messages.isEmpty())) {
+ while (!completed.isDone() || !messages.isEmpty()) {
byte[] bytes = messages.poll();
if (bytes == null) {
// busy wait.
@@ -82,8 +95,19 @@ public byte[] next() {
}
}
- if (ex != null) {
- throw Throwables.propagate(ex);
+ if (completed.isCompletedExceptionally()) {
+ // Preserve prior exception behavior
+ // TODO: with ARROW-5681, throw an appropriate Flight exception if gRPC raised an exception
+ try {
+ completed.get();
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ } catch (ExecutionException e) {
+ if (e.getCause() instanceof StatusRuntimeException) {
+ throw (StatusRuntimeException) e.getCause();
+ }
+ throw new RuntimeException(e);
+ }
}
throw new IllegalStateException("You attempted to retrieve messages after there were none.");
@@ -97,11 +121,13 @@ public boolean hasNext() {
@Override
public void onError(Throwable t) {
- ex = t;
+ completed.completeExceptionally(t);
}
private class AuthSender implements ClientAuthSender {
+ private boolean errored = false;
+
@Override
public void send(byte[] payload) {
responseObserver.onNext(HandshakeRequest.newBuilder()
@@ -111,6 +137,8 @@ public void send(byte[] payload) {
@Override
public void onError(String message, Throwable cause) {
+ this.errored = true;
+ Objects.requireNonNull(cause);
responseObserver.onError(cause);
}
@@ -118,7 +146,7 @@ public void onError(String message, Throwable cause) {
@Override
public void onCompleted() {
- completed = true;
+ completed.complete(true);
}
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthHandler.java b/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthHandler.java
index a19126b6ae9..0507d3b72fd 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthHandler.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthHandler.java
@@ -36,6 +36,8 @@ public interface ServerAuthHandler {
/**
* Handle the initial handshake with the client.
*
+ * @param outgoing A writer to send messages to the client.
+ * @param incoming An iterator of messages from the client.
* @return true if client is authenticated, false otherwise.
*/
boolean authenticate(ServerAuthSender outgoing, Iterator incoming);
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthWrapper.java b/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthWrapper.java
index f0c5dae757a..a3c698b53bf 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthWrapper.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthWrapper.java
@@ -58,6 +58,7 @@ public static StreamObserver wrapHandshake(ServerAuthHandler a
responseObserver.onError(Status.PERMISSION_DENIED.asException());
} catch (Exception ex) {
+ ex.printStackTrace();
responseObserver.onError(ex);
}
};
@@ -109,6 +110,7 @@ public boolean hasNext() {
@Override
public void onError(Throwable t) {
+ completed = true;
while (future == null) {/* busy wait */}
future.cancel(true);
}
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestAuth.java b/java/flight/src/test/java/org/apache/arrow/flight/TestAuth.java
new file mode 100644
index 00000000000..bfaf660b26b
--- /dev/null
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestAuth.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.Iterator;
+import java.util.Optional;
+
+import org.apache.arrow.flight.auth.ClientAuthHandler;
+import org.apache.arrow.flight.auth.ServerAuthHandler;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+
+import org.junit.Test;
+
+public class TestAuth {
+
+ /** An auth handler that does not send messages should not block the server forever. */
+ @Test(expected = RuntimeException.class)
+ public void noMessages() throws Exception {
+ try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
+ final FlightServer s = FlightTestUtil
+ .getStartedServer(
+ location -> FlightServer.builder(allocator, location, new NoOpFlightProducer()).authHandler(
+ new OneshotAuthHandler()).build());
+ final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) {
+ client.authenticate(new ClientAuthHandler() {
+ @Override
+ public void authenticate(ClientAuthSender outgoing, Iterator incoming) {
+ }
+
+ @Override
+ public byte[] getCallToken() {
+ return new byte[0];
+ }
+ });
+ }
+ }
+
+ /** An auth handler that sends an error should not block the server forever. */
+ @Test(expected = RuntimeException.class)
+ public void clientError() throws Exception {
+ try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
+ final FlightServer s = FlightTestUtil
+ .getStartedServer(
+ location -> FlightServer.builder(allocator, location, new NoOpFlightProducer()).authHandler(
+ new OneshotAuthHandler()).build());
+ final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) {
+ client.authenticate(new ClientAuthHandler() {
+ @Override
+ public void authenticate(ClientAuthSender outgoing, Iterator incoming) {
+ outgoing.send(new byte[0]);
+ // Ensure the server-side runs
+ incoming.next();
+ outgoing.onError("test", new RuntimeException("test"));
+ }
+
+ @Override
+ public byte[] getCallToken() {
+ return new byte[0];
+ }
+ });
+ }
+ }
+
+ private static class OneshotAuthHandler implements ServerAuthHandler {
+
+ @Override
+ public Optional isValid(byte[] token) {
+ return Optional.of("test");
+ }
+
+ @Override
+ public boolean authenticate(ServerAuthSender outgoing, Iterator incoming) {
+ incoming.next();
+ outgoing.send(new byte[0]);
+ return false;
+ }
+ }
+}
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java b/java/flight/src/test/java/org/apache/arrow/flight/auth/TestBasicAuth.java
similarity index 99%
rename from java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java
rename to java/flight/src/test/java/org/apache/arrow/flight/auth/TestBasicAuth.java
index 54bbadb0369..9fe6b04140c 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/auth/TestBasicAuth.java
@@ -48,7 +48,7 @@
import io.grpc.StatusRuntimeException;
-public class TestAuth {
+public class TestBasicAuth {
final String PERMISSION_DENIED = "PERMISSION_DENIED";
private static final String USERNAME = "flight";