Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cpp/src/arrow/flight/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,12 @@ class FlightClient::FlightClientImpl {
GrpcClientAuthSender outgoing{stream};
GrpcClientAuthReader incoming{stream};
RETURN_NOT_OK(auth_handler_->Authenticate(&outgoing, &incoming));
// Explicitly close our side of the connection
bool finished_writes = stream->WritesDone();
RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish()));
if (!finished_writes) {
return Status::UnknownError("Could not finish writing before closing");
}
return Status::OK();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,19 @@
package org.apache.arrow.flight.auth;

import java.util.Iterator;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.LinkedBlockingQueue;

import org.apache.arrow.flight.auth.ClientAuthHandler.ClientAuthSender;
import org.apache.arrow.flight.impl.Flight.HandshakeRequest;
import org.apache.arrow.flight.impl.Flight.HandshakeResponse;
import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceStub;

import com.google.common.base.Throwables;
import com.google.protobuf.ByteString;

import io.grpc.StatusRuntimeException;
import io.grpc.stub.StreamObserver;

/**
Expand All @@ -45,19 +48,29 @@ public static void doClientAuth(ClientAuthHandler authHandler, FlightServiceStub
AuthObserver observer = new AuthObserver();
observer.responseObserver = stub.handshake(observer);
authHandler.authenticate(observer.sender, observer.iter);
observer.responseObserver.onCompleted();
if (!observer.sender.errored) {
observer.responseObserver.onCompleted();
}
try {
if (!observer.completed.get()) {
// TODO: ARROW-5681
throw new RuntimeException("Unauthenticated");
}
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
}

private static class AuthObserver implements StreamObserver<HandshakeResponse> {

private volatile StreamObserver<HandshakeRequest> responseObserver;
private final LinkedBlockingQueue<byte[]> messages = new LinkedBlockingQueue<>();
private final AuthSender sender = new AuthSender();
private volatile boolean completed = false;
private Throwable ex = null;
private CompletableFuture<Boolean> completed;

public AuthObserver() {
super();
completed = new CompletableFuture<>();
}

@Override
Expand All @@ -72,7 +85,7 @@ public void onNext(HandshakeResponse value) {

@Override
public byte[] next() {
while (ex == null && (!completed || !messages.isEmpty())) {
while (!completed.isDone() || !messages.isEmpty()) {
byte[] bytes = messages.poll();
if (bytes == null) {
// busy wait.
Expand All @@ -82,8 +95,19 @@ public byte[] next() {
}
}

if (ex != null) {
throw Throwables.propagate(ex);
if (completed.isCompletedExceptionally()) {
// Preserve prior exception behavior
// TODO: with ARROW-5681, throw an appropriate Flight exception if gRPC raised an exception
try {
completed.get();
} catch (InterruptedException e) {
throw new RuntimeException(e);
} catch (ExecutionException e) {
if (e.getCause() instanceof StatusRuntimeException) {
throw (StatusRuntimeException) e.getCause();
}
throw new RuntimeException(e);
}
}

throw new IllegalStateException("You attempted to retrieve messages after there were none.");
Expand All @@ -97,11 +121,13 @@ public boolean hasNext() {

@Override
public void onError(Throwable t) {
ex = t;
completed.completeExceptionally(t);
}

private class AuthSender implements ClientAuthSender {

private boolean errored = false;

@Override
public void send(byte[] payload) {
responseObserver.onNext(HandshakeRequest.newBuilder()
Expand All @@ -111,14 +137,16 @@ public void send(byte[] payload) {

@Override
public void onError(String message, Throwable cause) {
this.errored = true;
Objects.requireNonNull(cause);
responseObserver.onError(cause);
}

}

@Override
public void onCompleted() {
completed = true;
completed.complete(true);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ public interface ServerAuthHandler {
/**
* Handle the initial handshake with the client.
*
* @param outgoing A writer to send messages to the client.
* @param incoming An iterator of messages from the client.
* @return true if client is authenticated, false otherwise.
*/
boolean authenticate(ServerAuthSender outgoing, Iterator<byte[]> incoming);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public static StreamObserver<HandshakeRequest> wrapHandshake(ServerAuthHandler a

responseObserver.onError(Status.PERMISSION_DENIED.asException());
} catch (Exception ex) {
ex.printStackTrace();
responseObserver.onError(ex);
}
};
Expand Down Expand Up @@ -109,6 +110,7 @@ public boolean hasNext() {

@Override
public void onError(Throwable t) {
completed = true;
while (future == null) {/* busy wait */}
future.cancel(true);
}
Expand Down
94 changes: 94 additions & 0 deletions java/flight/src/test/java/org/apache/arrow/flight/TestAuth.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.arrow.flight;

import java.util.Iterator;
import java.util.Optional;

import org.apache.arrow.flight.auth.ClientAuthHandler;
import org.apache.arrow.flight.auth.ServerAuthHandler;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;

import org.junit.Test;

public class TestAuth {

/** An auth handler that does not send messages should not block the server forever. */
@Test(expected = RuntimeException.class)
public void noMessages() throws Exception {
try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
final FlightServer s = FlightTestUtil
.getStartedServer(
location -> FlightServer.builder(allocator, location, new NoOpFlightProducer()).authHandler(
new OneshotAuthHandler()).build());
final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) {
client.authenticate(new ClientAuthHandler() {
@Override
public void authenticate(ClientAuthSender outgoing, Iterator<byte[]> incoming) {
}

@Override
public byte[] getCallToken() {
return new byte[0];
}
});
}
}

/** An auth handler that sends an error should not block the server forever. */
@Test(expected = RuntimeException.class)
public void clientError() throws Exception {
try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
final FlightServer s = FlightTestUtil
.getStartedServer(
location -> FlightServer.builder(allocator, location, new NoOpFlightProducer()).authHandler(
new OneshotAuthHandler()).build());
final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) {
client.authenticate(new ClientAuthHandler() {
@Override
public void authenticate(ClientAuthSender outgoing, Iterator<byte[]> incoming) {
outgoing.send(new byte[0]);
// Ensure the server-side runs
incoming.next();
outgoing.onError("test", new RuntimeException("test"));
}

@Override
public byte[] getCallToken() {
return new byte[0];
}
});
}
}

private static class OneshotAuthHandler implements ServerAuthHandler {

@Override
public Optional<String> isValid(byte[] token) {
return Optional.of("test");
}

@Override
public boolean authenticate(ServerAuthSender outgoing, Iterator<byte[]> incoming) {
incoming.next();
outgoing.send(new byte[0]);
return false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@

import io.grpc.StatusRuntimeException;

public class TestAuth {
public class TestBasicAuth {
final String PERMISSION_DENIED = "PERMISSION_DENIED";

private static final String USERNAME = "flight";
Expand Down