From 581fc75826a3841c7ef9edd514ca8e54d8478052 Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 18 Jun 2019 09:27:58 -0400
Subject: [PATCH] Add ability to override SSL hostname checking
---
cpp/src/arrow/flight/client.cc | 5 +
cpp/src/arrow/flight/client.h | 4 +
cpp/src/arrow/flight/flight-test.cc | 18 +++
java/flight/pom.xml | 15 +-
.../org/apache/arrow/flight/FlightClient.java | 11 ++
.../org/apache/arrow/flight/FlightServer.java | 17 ++-
.../apache/arrow/flight/FlightTestUtil.java | 44 ++++++
.../java/org/apache/arrow/flight/TestTls.java | 130 ++++++++++++++++++
python/pyarrow/_flight.pyx | 8 +-
python/pyarrow/includes/libarrow_flight.pxd | 1 +
python/pyarrow/tests/test_flight.py | 15 ++
testing | 2 +-
12 files changed, 259 insertions(+), 11 deletions(-)
create mode 100644 java/flight/src/test/java/org/apache/arrow/flight/TestTls.java
diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc
index 2b7c6991976..1926928c643 100644
--- a/cpp/src/arrow/flight/client.cc
+++ b/cpp/src/arrow/flight/client.cc
@@ -259,6 +259,11 @@ class FlightClient::FlightClientImpl {
args.SetInt(GRPC_ARG_INITIAL_RECONNECT_BACKOFF_MS, 100);
// Receive messages of any size
args.SetMaxReceiveMessageSize(-1);
+
+ if (options.override_hostname != "") {
+ args.SetSslTargetNameOverride(options.override_hostname);
+ }
+
stub_ = pb::FlightService::NewStub(
grpc::CreateCustomChannel(grpc_uri.str(), creds, args));
return Status::OK();
diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h
index 689c9f8c5b5..b8a5d4f4b91 100644
--- a/cpp/src/arrow/flight/client.h
+++ b/cpp/src/arrow/flight/client.h
@@ -59,7 +59,11 @@ class ARROW_FLIGHT_EXPORT FlightCallOptions {
class ARROW_FLIGHT_EXPORT FlightClientOptions {
public:
+ /// \brief Root certificates to use for validating server
+ /// certificates.
std::string tls_root_certs;
+ /// \brief Override the hostname checked by TLS. Use with caution.
+ std::string override_hostname;
};
/// \brief Client class for Arrow Flight RPC services (gRPC-based).
diff --git a/cpp/src/arrow/flight/flight-test.cc b/cpp/src/arrow/flight/flight-test.cc
index b2958786415..3c0b67cd992 100644
--- a/cpp/src/arrow/flight/flight-test.cc
+++ b/cpp/src/arrow/flight/flight-test.cc
@@ -675,5 +675,23 @@ TEST_F(TestTls, DoAction) {
ASSERT_EQ(result->body->ToString(), "Hello, world!");
}
+TEST_F(TestTls, OverrideHostname) {
+ std::unique_ptr client;
+ auto client_options = FlightClientOptions();
+ client_options.override_hostname = "fakehostname";
+ CertKeyPair root_cert;
+ ASSERT_OK(ExampleTlsCertificateRoot(&root_cert));
+ client_options.tls_root_certs = root_cert.pem_cert;
+ ASSERT_OK(FlightClient::Connect(server_->location(), client_options, &client));
+
+ FlightCallOptions options;
+ options.timeout = TimeoutDuration{5.0};
+ Action action;
+ action.type = "test";
+ action.body = Buffer::FromString("");
+ std::unique_ptr results;
+ ASSERT_RAISES(IOError, client->DoAction(options, action, &results));
+}
+
} // namespace flight
} // namespace arrow
diff --git a/java/flight/pom.xml b/java/flight/pom.xml
index 3745207c998..b41ce2413b5 100644
--- a/java/flight/pom.xml
+++ b/java/flight/pom.xml
@@ -1,10 +1,10 @@
-
4.0.0
@@ -137,6 +137,9 @@
maven-surefire-plugin
false
+
+ ${project.basedir}/../../testing/data
+
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
index 221423e8d9a..1e44b3068e5 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
@@ -354,6 +354,7 @@ public static final class Builder {
private InputStream trustedCertificates = null;
private InputStream clientCertificate = null;
private InputStream clientKey = null;
+ private String overrideHostname = null;
private Builder() {
}
@@ -371,6 +372,12 @@ public Builder useTls() {
return this;
}
+ /** Override the hostname checked for TLS. Use with caution in production. */
+ public Builder overrideHostname(final String hostname) {
+ this.overrideHostname = hostname;
+ return this;
+ }
+
/** Set the maximum inbound message size. */
public Builder maxInboundMessageSize(int maxSize) {
Preconditions.checkArgument(maxSize > 0);
@@ -461,6 +468,10 @@ public FlightClient build() {
throw new RuntimeException(e);
}
}
+
+ if (this.overrideHostname != null) {
+ builder.overrideAuthority(this.overrideHostname);
+ }
} else {
builder.usePlaintext();
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java
index eaea0441b2c..cd59a75cbbb 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java
@@ -72,10 +72,23 @@ public void awaitTermination() throws InterruptedException {
server.awaitTermination();
}
+ /** Request that the server shut down. */
+ public void shutdown() {
+ server.shutdown();
+ }
+
+ /**
+ * Wait for the server to shut down with a timeout.
+ * @return true if the server shut down successfully.
+ */
+ public boolean awaitTermination(final long timeout, final TimeUnit unit) throws InterruptedException {
+ return server.awaitTermination(timeout, unit);
+ }
+
/** Shutdown the server, waits for up to 6 seconds for successful shutdown before returning. */
public void close() throws InterruptedException {
- server.shutdown();
- final boolean terminated = server.awaitTermination(3000, TimeUnit.MILLISECONDS);
+ shutdown();
+ final boolean terminated = awaitTermination(3000, TimeUnit.MILLISECONDS);
if (terminated) {
logger.debug("Server was terminated within 3s");
return;
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java b/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java
index f6b9e867807..3cb09ef5cd9 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java
@@ -17,8 +17,14 @@
package org.apache.arrow.flight;
+import java.io.File;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
import java.util.Random;
import java.util.function.Function;
@@ -26,9 +32,12 @@
* Utility methods and constants for testing flight servers.
*/
public class FlightTestUtil {
+
private static final Random RANDOM = new Random();
public static final String LOCALHOST = "localhost";
+ public static final String TEST_DATA_ENV_VAR = "ARROW_TEST_DATA";
+ public static final String TEST_DATA_PROPERTY = "arrow.test.dataRoot";
/**
* Returns a a FlightServer (actually anything that is startable)
@@ -62,6 +71,30 @@ public static T getStartedServer(Function newServerFromPort) thr
return server;
}
+ static Path getTestDataRoot() {
+ String path = System.getenv(TEST_DATA_ENV_VAR);
+ if (path == null) {
+ path = System.getProperty(TEST_DATA_PROPERTY);
+ }
+ return Paths.get(Objects.requireNonNull(path,
+ String.format("Could not find test data path. Set the environment variable %s or the JVM property %s.",
+ TEST_DATA_ENV_VAR, TEST_DATA_PROPERTY)));
+ }
+
+ static Path getFlightTestDataRoot() {
+ return getTestDataRoot().resolve("flight");
+ }
+
+ static Path exampleTlsRootCert() {
+ return getFlightTestDataRoot().resolve("root-ca.pem");
+ }
+
+ static List exampleTlsCerts() {
+ final Path root = getFlightTestDataRoot();
+ return Arrays.asList(new CertKeyPair(root.resolve("cert0.pem").toFile(), root.resolve("cert0.pkcs1").toFile()),
+ new CertKeyPair(root.resolve("cert1.pem").toFile(), root.resolve("cert1.pkcs1").toFile()));
+ }
+
static boolean isEpollAvailable() {
try {
Class> epoll = Class.forName("io.netty.channel.epoll.Epoll");
@@ -84,6 +117,17 @@ static boolean isNativeTransportAvailable() {
return isEpollAvailable() || isKqueueAvailable();
}
+ public static class CertKeyPair {
+
+ public final File cert;
+ public final File key;
+
+ public CertKeyPair(File cert, File key) {
+ this.cert = cert;
+ this.key = key;
+ }
+ }
+
private FlightTestUtil() {
}
}
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestTls.java b/java/flight/src/test/java/org/apache/arrow/flight/TestTls.java
new file mode 100644
index 00000000000..c22304d5647
--- /dev/null
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestTls.java
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.charset.StandardCharsets;
+import java.util.Iterator;
+import java.util.function.Consumer;
+
+import org.apache.arrow.flight.FlightClient.Builder;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Tests for TLS in Flight.
+ */
+public class TestTls {
+
+ /**
+ * Test a basic request over TLS.
+ */
+ @Test
+ public void connectTls() {
+ test((builder) -> {
+ try (final InputStream roots = new FileInputStream(FlightTestUtil.exampleTlsRootCert().toFile());
+ final FlightClient client = builder.trustedCertificates(roots).build()) {
+ final Iterator responses = client.doAction(new Action("hello-world"));
+ final byte[] response = responses.next().getBody();
+ Assert.assertEquals("Hello, world!", new String(response, StandardCharsets.UTF_8));
+ Assert.assertFalse(responses.hasNext());
+ } catch (InterruptedException | IOException e) {
+ throw new RuntimeException(e);
+ }
+ });
+ }
+
+ /**
+ * Make sure that connections are rejected when the root certificate isn't trusted.
+ */
+ @Test(expected = io.grpc.StatusRuntimeException.class)
+ public void rejectInvalidCert() {
+ test((builder) -> {
+ try (final FlightClient client = builder.build()) {
+ final Iterator responses = client.doAction(new Action("hello-world"));
+ responses.next().getBody();
+ Assert.fail("Call should have failed");
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ });
+ }
+
+ /**
+ * Make sure that connections are rejected when the hostname doesn't match.
+ */
+ @Test(expected = io.grpc.StatusRuntimeException.class)
+ public void rejectHostname() {
+ test((builder) -> {
+ try (final InputStream roots = new FileInputStream(FlightTestUtil.exampleTlsRootCert().toFile());
+ final FlightClient client = builder.trustedCertificates(roots).overrideHostname("fakehostname")
+ .build()) {
+ final Iterator responses = client.doAction(new Action("hello-world"));
+ responses.next().getBody();
+ Assert.fail("Call should have failed");
+ } catch (InterruptedException | IOException e) {
+ throw new RuntimeException(e);
+ }
+ });
+ }
+
+
+ void test(Consumer testFn) {
+ final FlightTestUtil.CertKeyPair certKey = FlightTestUtil.exampleTlsCerts().get(0);
+ try (
+ BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
+ Producer producer = new Producer();
+ FlightServer s =
+ FlightTestUtil.getStartedServer(
+ (port) -> {
+ try {
+ return FlightServer.builder(a, Location.forGrpcTls(FlightTestUtil.LOCALHOST, port), producer)
+ .useTls(certKey.cert, certKey.key)
+ .build();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ })) {
+ final Builder builder = FlightClient.builder(a, Location.forGrpcTls(FlightTestUtil.LOCALHOST, s.getPort()));
+ testFn.accept(builder);
+ } catch (InterruptedException | IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ static class Producer extends NoOpFlightProducer implements AutoCloseable {
+
+ @Override
+ public void doAction(CallContext context, Action action, StreamListener listener) {
+ if (action.getType().equals("hello-world")) {
+ listener.onNext(new Result("Hello, world!".getBytes(StandardCharsets.UTF_8)));
+ listener.onCompleted();
+ }
+ listener.onError(new UnsupportedOperationException("Invalid action " + action.getType()));
+ }
+
+ @Override
+ public void close() {
+ }
+ }
+}
diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx
index c916e6bcf56..7ca83a94994 100644
--- a/python/pyarrow/_flight.pyx
+++ b/python/pyarrow/_flight.pyx
@@ -419,7 +419,7 @@ cdef class FlightClient:
.format(self.__class__.__name__))
@staticmethod
- def connect(location, tls_root_certs=None):
+ def connect(location, tls_root_certs=None, override_hostname=None):
"""
Connect to a Flight service on the given host and port.
@@ -428,8 +428,10 @@ cdef class FlightClient:
location : Location
location to connect to
- tls_root_certs : bytes
+ tls_root_certs : bytes or None
PEM-encoded
+ unsafe_override_hostname : str or None
+ Override the hostname checked by TLS. Insecure, use with caution.
"""
cdef:
FlightClient result = FlightClient.__new__(FlightClient)
@@ -439,6 +441,8 @@ cdef class FlightClient:
if tls_root_certs:
c_options.tls_root_certs = tobytes(tls_root_certs)
+ if override_hostname:
+ c_options.override_hostname = tobytes(override_hostname)
with nogil:
check_status(CFlightClient.Connect(c_location, c_options,
diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd
index 14d1ed163d1..61e9571995d 100644
--- a/python/pyarrow/includes/libarrow_flight.pxd
+++ b/python/pyarrow/includes/libarrow_flight.pxd
@@ -170,6 +170,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
cdef cppclass CFlightClientOptions" arrow::flight::FlightClientOptions":
CFlightClientOptions()
c_string tls_root_certs
+ c_string override_hostname
cdef cppclass CFlightClient" arrow::flight::FlightClient":
@staticmethod
diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py
index f4c9cc12bee..3088a7a86f1 100644
--- a/python/pyarrow/tests/test_flight.py
+++ b/python/pyarrow/tests/test_flight.py
@@ -570,3 +570,18 @@ def test_tls_do_get():
server_location, tls_root_certs=certs["root_cert"])
data = client.do_get(flight.Ticket(b'ints')).read_all()
assert data.equals(table)
+
+
+def test_tls_override_hostname():
+ """Check that incorrectly overriding the hostname fails."""
+ certs = example_tls_certs()
+
+ with flight_server(
+ ConstantFlightServer, tls_certificates=certs["certificates"],
+ connect_args=dict(tls_root_certs=certs["root_cert"]),
+ ) as server_location:
+ client = flight.FlightClient.connect(
+ server_location, tls_root_certs=certs["root_cert"],
+ override_hostname="fakehostname")
+ with pytest.raises(pa.ArrowIOError):
+ client.do_get(flight.Ticket(b'ints'))
diff --git a/testing b/testing
index 12f9dbd2a37..a674dac190c 160000
--- a/testing
+++ b/testing
@@ -1 +1 @@
-Subproject commit 12f9dbd2a37eea6fa370e108a1d797ee1167724a
+Subproject commit a674dac190c5fc626964c9b611c67552fa2e530d