diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index c508dca769f..cb4b92d7cb9 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -430,7 +430,12 @@ class FlightClient::FlightClientImpl { GrpcClientAuthSender outgoing{stream}; GrpcClientAuthReader incoming{stream}; RETURN_NOT_OK(auth_handler_->Authenticate(&outgoing, &incoming)); + // Explicitly close our side of the connection + bool finished_writes = stream->WritesDone(); RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish())); + if (!finished_writes) { + return Status::UnknownError("Could not finish writing before closing"); + } return Status::OK(); } diff --git a/java/flight/src/main/java/org/apache/arrow/flight/auth/ClientAuthWrapper.java b/java/flight/src/main/java/org/apache/arrow/flight/auth/ClientAuthWrapper.java index f916c9217d0..9b8034003cd 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/auth/ClientAuthWrapper.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/auth/ClientAuthWrapper.java @@ -18,6 +18,9 @@ package org.apache.arrow.flight.auth; import java.util.Iterator; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.concurrent.LinkedBlockingQueue; import org.apache.arrow.flight.auth.ClientAuthHandler.ClientAuthSender; @@ -25,9 +28,9 @@ import org.apache.arrow.flight.impl.Flight.HandshakeResponse; import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceStub; -import com.google.common.base.Throwables; import com.google.protobuf.ByteString; +import io.grpc.StatusRuntimeException; import io.grpc.stub.StreamObserver; /** @@ -45,7 +48,17 @@ public static void doClientAuth(ClientAuthHandler authHandler, FlightServiceStub AuthObserver observer = new AuthObserver(); observer.responseObserver = stub.handshake(observer); authHandler.authenticate(observer.sender, observer.iter); - observer.responseObserver.onCompleted(); + if (!observer.sender.errored) { + observer.responseObserver.onCompleted(); + } + try { + if (!observer.completed.get()) { + // TODO: ARROW-5681 + throw new RuntimeException("Unauthenticated"); + } + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } } private static class AuthObserver implements StreamObserver { @@ -53,11 +66,11 @@ private static class AuthObserver implements StreamObserver { private volatile StreamObserver responseObserver; private final LinkedBlockingQueue messages = new LinkedBlockingQueue<>(); private final AuthSender sender = new AuthSender(); - private volatile boolean completed = false; - private Throwable ex = null; + private CompletableFuture completed; public AuthObserver() { super(); + completed = new CompletableFuture<>(); } @Override @@ -72,7 +85,7 @@ public void onNext(HandshakeResponse value) { @Override public byte[] next() { - while (ex == null && (!completed || !messages.isEmpty())) { + while (!completed.isDone() || !messages.isEmpty()) { byte[] bytes = messages.poll(); if (bytes == null) { // busy wait. @@ -82,8 +95,19 @@ public byte[] next() { } } - if (ex != null) { - throw Throwables.propagate(ex); + if (completed.isCompletedExceptionally()) { + // Preserve prior exception behavior + // TODO: with ARROW-5681, throw an appropriate Flight exception if gRPC raised an exception + try { + completed.get(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } catch (ExecutionException e) { + if (e.getCause() instanceof StatusRuntimeException) { + throw (StatusRuntimeException) e.getCause(); + } + throw new RuntimeException(e); + } } throw new IllegalStateException("You attempted to retrieve messages after there were none."); @@ -97,11 +121,13 @@ public boolean hasNext() { @Override public void onError(Throwable t) { - ex = t; + completed.completeExceptionally(t); } private class AuthSender implements ClientAuthSender { + private boolean errored = false; + @Override public void send(byte[] payload) { responseObserver.onNext(HandshakeRequest.newBuilder() @@ -111,6 +137,8 @@ public void send(byte[] payload) { @Override public void onError(String message, Throwable cause) { + this.errored = true; + Objects.requireNonNull(cause); responseObserver.onError(cause); } @@ -118,7 +146,7 @@ public void onError(String message, Throwable cause) { @Override public void onCompleted() { - completed = true; + completed.complete(true); } } diff --git a/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthHandler.java b/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthHandler.java index a19126b6ae9..0507d3b72fd 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthHandler.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthHandler.java @@ -36,6 +36,8 @@ public interface ServerAuthHandler { /** * Handle the initial handshake with the client. * + * @param outgoing A writer to send messages to the client. + * @param incoming An iterator of messages from the client. * @return true if client is authenticated, false otherwise. */ boolean authenticate(ServerAuthSender outgoing, Iterator incoming); diff --git a/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthWrapper.java b/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthWrapper.java index f0c5dae757a..a3c698b53bf 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthWrapper.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthWrapper.java @@ -58,6 +58,7 @@ public static StreamObserver wrapHandshake(ServerAuthHandler a responseObserver.onError(Status.PERMISSION_DENIED.asException()); } catch (Exception ex) { + ex.printStackTrace(); responseObserver.onError(ex); } }; @@ -109,6 +110,7 @@ public boolean hasNext() { @Override public void onError(Throwable t) { + completed = true; while (future == null) {/* busy wait */} future.cancel(true); } diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestAuth.java b/java/flight/src/test/java/org/apache/arrow/flight/TestAuth.java new file mode 100644 index 00000000000..bfaf660b26b --- /dev/null +++ b/java/flight/src/test/java/org/apache/arrow/flight/TestAuth.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.Iterator; +import java.util.Optional; + +import org.apache.arrow.flight.auth.ClientAuthHandler; +import org.apache.arrow.flight.auth.ServerAuthHandler; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; + +import org.junit.Test; + +public class TestAuth { + + /** An auth handler that does not send messages should not block the server forever. */ + @Test(expected = RuntimeException.class) + public void noMessages() throws Exception { + try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); + final FlightServer s = FlightTestUtil + .getStartedServer( + location -> FlightServer.builder(allocator, location, new NoOpFlightProducer()).authHandler( + new OneshotAuthHandler()).build()); + final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) { + client.authenticate(new ClientAuthHandler() { + @Override + public void authenticate(ClientAuthSender outgoing, Iterator incoming) { + } + + @Override + public byte[] getCallToken() { + return new byte[0]; + } + }); + } + } + + /** An auth handler that sends an error should not block the server forever. */ + @Test(expected = RuntimeException.class) + public void clientError() throws Exception { + try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); + final FlightServer s = FlightTestUtil + .getStartedServer( + location -> FlightServer.builder(allocator, location, new NoOpFlightProducer()).authHandler( + new OneshotAuthHandler()).build()); + final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) { + client.authenticate(new ClientAuthHandler() { + @Override + public void authenticate(ClientAuthSender outgoing, Iterator incoming) { + outgoing.send(new byte[0]); + // Ensure the server-side runs + incoming.next(); + outgoing.onError("test", new RuntimeException("test")); + } + + @Override + public byte[] getCallToken() { + return new byte[0]; + } + }); + } + } + + private static class OneshotAuthHandler implements ServerAuthHandler { + + @Override + public Optional isValid(byte[] token) { + return Optional.of("test"); + } + + @Override + public boolean authenticate(ServerAuthSender outgoing, Iterator incoming) { + incoming.next(); + outgoing.send(new byte[0]); + return false; + } + } +} diff --git a/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java b/java/flight/src/test/java/org/apache/arrow/flight/auth/TestBasicAuth.java similarity index 99% rename from java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java rename to java/flight/src/test/java/org/apache/arrow/flight/auth/TestBasicAuth.java index 54bbadb0369..9fe6b04140c 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/auth/TestBasicAuth.java @@ -48,7 +48,7 @@ import io.grpc.StatusRuntimeException; -public class TestAuth { +public class TestBasicAuth { final String PERMISSION_DENIED = "PERMISSION_DENIED"; private static final String USERNAME = "flight";