From 460442cea3ab4a67cb113bc9a48b9f09e1748eb2 Mon Sep 17 00:00:00 2001
From: David Li
Date: Thu, 2 May 2019 12:47:19 -0400
Subject: [PATCH 1/5] Use URIs for Flight locations
---
cpp/src/arrow/flight/client.cc | 32 +++-
cpp/src/arrow/flight/client.h | 6 +-
cpp/src/arrow/flight/flight-benchmark.cc | 8 +-
cpp/src/arrow/flight/flight-test.cc | 33 ++++-
cpp/src/arrow/flight/internal.cc | 7 +-
cpp/src/arrow/flight/perf-server.cc | 7 +-
cpp/src/arrow/flight/server.cc | 25 +++-
cpp/src/arrow/flight/server.h | 7 +-
.../arrow/flight/test-integration-client.cc | 13 +-
.../arrow/flight/test-integration-server.cc | 4 +-
cpp/src/arrow/flight/test-server.cc | 4 +-
cpp/src/arrow/flight/test-util.cc | 21 ++-
cpp/src/arrow/flight/types.cc | 29 ++++
cpp/src/arrow/flight/types.h | 57 +++++++-
cpp/src/arrow/util/uri.cc | 6 +-
cpp/src/arrow/util/uri.h | 2 +
format/Flight.proto | 9 +-
.../org/apache/arrow/flight/FlightClient.java | 23 ++-
.../apache/arrow/flight/FlightEndpoint.java | 15 +-
.../org/apache/arrow/flight/FlightInfo.java | 9 +-
.../org/apache/arrow/flight/FlightServer.java | 24 ++-
.../org/apache/arrow/flight/Location.java | 50 ++++---
.../apache/arrow/flight/LocationSchemes.java | 28 ++++
.../flight/example/ExampleFlightServer.java | 4 +-
.../integration/IntegrationTestClient.java | 8 +-
.../integration/IntegrationTestServer.java | 2 +-
.../apache/arrow/flight/TestBackPressure.java | 7 +-
.../arrow/flight/TestBasicOperation.java | 20 ++-
.../apache/arrow/flight/TestCallOptions.java | 6 +-
.../apache/arrow/flight/TestLargeMessage.java | 12 +-
.../apache/arrow/flight/auth/TestAuth.java | 4 +-
.../flight/example/TestExampleServer.java | 4 +-
.../flight/perf/PerformanceTestServer.java | 2 +-
.../apache/arrow/flight/perf/TestPerf.java | 2 +-
python/examples/flight/client.py | 2 +-
python/pyarrow/_flight.pyx | 138 +++++++++++++-----
python/pyarrow/includes/libarrow_flight.pxd | 14 +-
python/pyarrow/tests/test_flight.py | 90 +++++++-----
38 files changed, 543 insertions(+), 191 deletions(-)
create mode 100644 java/flight/src/main/java/org/apache/arrow/flight/LocationSchemes.java
diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc
index 5f4d6dd8610..4a7beded950 100644
--- a/cpp/src/arrow/flight/client.cc
+++ b/cpp/src/arrow/flight/client.cc
@@ -37,6 +37,7 @@
#include "arrow/status.h"
#include "arrow/type.h"
#include "arrow/util/logging.h"
+#include "arrow/util/uri.h"
#include "arrow/flight/client_auth.h"
#include "arrow/flight/internal.h"
@@ -230,11 +231,26 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter {
class FlightClient::FlightClientImpl {
public:
- Status Connect(const std::string& host, int port) {
- // TODO(wesm): Support other kinds of GRPC ChannelCredentials
- std::stringstream ss;
- ss << host << ":" << port;
- std::string uri = ss.str();
+ Status Connect(const Location& location) {
+ const std::string& scheme = location.scheme();
+
+ std::stringstream grpc_uri;
+ std::shared_ptr creds;
+ if (scheme == kSchemeGrpc || scheme == kSchemeGrpcTcp || scheme == kSchemeGrpcTls) {
+ // TODO(wesm): Support other kinds of GRPC ChannelCredentials
+ grpc_uri << location.uri_->host() << ":" << location.uri_->port_text();
+
+ if (scheme == "grpc+tls") {
+ creds = grpc::SslCredentials(grpc::SslCredentialsOptions());
+ } else {
+ creds = grpc::InsecureChannelCredentials();
+ }
+ } else if (scheme == kSchemeGrpcUnix) {
+ grpc_uri << "unix://" << location.uri_->path();
+ creds = grpc::InsecureChannelCredentials();
+ } else {
+ return Status::NotImplemented("Flight scheme " + scheme + " is not supported.");
+ }
grpc::ChannelArguments args;
// Try to reconnect quickly at first, in case the server is still starting up
@@ -242,7 +258,7 @@ class FlightClient::FlightClientImpl {
// Receive messages of any size
args.SetMaxReceiveMessageSize(-1);
stub_ = pb::FlightService::NewStub(
- grpc::CreateCustomChannel(ss.str(), grpc::InsecureChannelCredentials(), args));
+ grpc::CreateCustomChannel(grpc_uri.str(), creds, args));
return Status::OK();
}
@@ -383,10 +399,10 @@ FlightClient::FlightClient() { impl_.reset(new FlightClientImpl); }
FlightClient::~FlightClient() {}
-Status FlightClient::Connect(const std::string& host, int port,
+Status FlightClient::Connect(const Location& location,
std::unique_ptr* client) {
client->reset(new FlightClient);
- return (*client)->impl_->Connect(host, port);
+ return (*client)->impl_->Connect(location);
}
Status FlightClient::Authenticate(const FlightCallOptions& options,
diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h
index 3886360ee5c..f48f3edf839 100644
--- a/cpp/src/arrow/flight/client.h
+++ b/cpp/src/arrow/flight/client.h
@@ -64,13 +64,11 @@ class ARROW_EXPORT FlightClient {
~FlightClient();
/// \brief Connect to an unauthenticated flight service
- /// \param[in] host the hostname or IP address
- /// \param[in] port the port on the host
+ /// \param[in] location the URI
/// \param[out] client the created FlightClient
/// \return Status OK status may not indicate that the connection was
/// successful
- static Status Connect(const std::string& host, int port,
- std::unique_ptr* client);
+ static Status Connect(const Location& location, std::unique_ptr* client);
/// \brief Authenticate to the server using the given handler.
/// \param[in] options Per-RPC options
diff --git a/cpp/src/arrow/flight/flight-benchmark.cc b/cpp/src/arrow/flight/flight-benchmark.cc
index d6318eaadef..20fe5ac179e 100644
--- a/cpp/src/arrow/flight/flight-benchmark.cc
+++ b/cpp/src/arrow/flight/flight-benchmark.cc
@@ -79,7 +79,9 @@ Status RunPerformanceTest(const std::string& hostname, const int port) {
// Construct client and plan the query
std::unique_ptr client;
- RETURN_NOT_OK(FlightClient::Connect(hostname, port, &client));
+ Location location;
+ RETURN_NOT_OK(Location::ForGrpcTcp("localhost", port, &location));
+ RETURN_NOT_OK(FlightClient::Connect(location, &client));
FlightDescriptor descriptor;
descriptor.type = FlightDescriptor::CMD;
@@ -97,7 +99,9 @@ Status RunPerformanceTest(const std::string& hostname, const int port) {
auto ConsumeStream = [&stats, &hostname, &port](const FlightEndpoint& endpoint) {
// TODO(wesm): Use location from endpoint, same host/port for now
std::unique_ptr client;
- RETURN_NOT_OK(FlightClient::Connect(hostname, port, &client));
+ Location location;
+ RETURN_NOT_OK(Location::ForGrpcTcp("localhost", port, &location));
+ RETURN_NOT_OK(FlightClient::Connect(location, &client));
perf::Token token;
token.ParseFromString(endpoint.ticket.ticket);
diff --git a/cpp/src/arrow/flight/flight-test.cc b/cpp/src/arrow/flight/flight-test.cc
index 43e6ddb1408..d99e7c4b020 100644
--- a/cpp/src/arrow/flight/flight-test.cc
+++ b/cpp/src/arrow/flight/flight-test.cc
@@ -61,8 +61,7 @@ void AssertEqual(const Ticket& expected, const Ticket& actual) {
}
void AssertEqual(const Location& expected, const Location& actual) {
- ASSERT_EQ(expected.host, actual.host);
- ASSERT_EQ(expected.port, actual.port);
+ ASSERT_EQ(expected, actual);
}
void AssertEqual(const std::vector& expected,
@@ -146,6 +145,20 @@ TEST(TestFlight, StartStopTestServer) {
ASSERT_FALSE(server.IsRunning());
}
+TEST(TestFlight, ConnectUri) {
+ TestServer server("flight-test-server", 30000);
+ server.Start();
+ ASSERT_TRUE(server.IsRunning());
+
+ std::unique_ptr client;
+ Location location1;
+ Location location2;
+ ASSERT_OK(Location::Parse("grpc://localhost:30000", &location1));
+ ASSERT_OK(Location::Parse("grpc://localhost:30000", &location2));
+ ASSERT_OK(FlightClient::Connect(location1, &client));
+ ASSERT_OK(FlightClient::Connect(location2, &client));
+}
+
// ----------------------------------------------------------------------
// Client tests
@@ -169,7 +182,11 @@ class TestFlightClient : public ::testing::Test {
void TearDown() { server_->Stop(); }
- Status ConnectClient() { return FlightClient::Connect("localhost", port_, &client_); }
+ Status ConnectClient() {
+ Location location;
+ RETURN_NOT_OK(Location::ForGrpcTcp("localhost", port_, &location));
+ return FlightClient::Connect(location, &client_);
+ }
template
void CheckDoGet(const FlightDescriptor& descr, const BatchVector& expected_batches,
@@ -248,7 +265,11 @@ class TestAuthHandler : public ::testing::Test {
void TearDown() { server_->Stop(); }
- Status ConnectClient() { return FlightClient::Connect("localhost", port_, &client_); }
+ Status ConnectClient() {
+ Location location;
+ RETURN_NOT_OK(Location::ForGrpcTcp("localhost", port_, &location));
+ return FlightClient::Connect(location, &client_);
+ }
protected:
int port_;
@@ -423,7 +444,9 @@ TEST_F(TestFlightClient, Issue5095) {
TEST_F(TestFlightClient, TimeoutFires) {
// Server does not exist on this port, so call should fail
std::unique_ptr client;
- ASSERT_OK(FlightClient::Connect("localhost", 30001, &client));
+ Location location;
+ ASSERT_OK(Location::ForGrpcTcp("localhost", 30001, &location));
+ ASSERT_OK(FlightClient::Connect(location, &client));
FlightCallOptions options;
options.timeout = TimeoutDuration{0.2};
std::unique_ptr info;
diff --git a/cpp/src/arrow/flight/internal.cc b/cpp/src/arrow/flight/internal.cc
index 2e335936e51..c60e58597f6 100644
--- a/cpp/src/arrow/flight/internal.cc
+++ b/cpp/src/arrow/flight/internal.cc
@@ -119,14 +119,11 @@ Status FromProto(const pb::Criteria& pb_criteria, Criteria* criteria) {
// Location
Status FromProto(const pb::Location& pb_location, Location* location) {
- location->host = pb_location.host();
- location->port = pb_location.port();
- return Status::OK();
+ return Location::Parse(pb_location.uri(), location);
}
void ToProto(const Location& location, pb::Location* pb_location) {
- pb_location->set_host(location.host);
- pb_location->set_port(location.port);
+ pb_location->set_uri(location.ToString());
}
// Ticket
diff --git a/cpp/src/arrow/flight/perf-server.cc b/cpp/src/arrow/flight/perf-server.cc
index 3755f3d52d1..f65cbd15abf 100644
--- a/cpp/src/arrow/flight/perf-server.cc
+++ b/cpp/src/arrow/flight/perf-server.cc
@@ -141,7 +141,8 @@ Status GetPerfBatches(const perf::Token& token, const std::shared_ptr& s
class FlightPerfServer : public FlightServerBase {
public:
- FlightPerfServer() : location_(Location{"localhost", FLAGS_port}) {
+ FlightPerfServer() : location_() {
+ DCHECK_OK(Location::ForGrpcTcp("localhost", FLAGS_port, &location_));
perf_schema_ = schema({field("a", int64()), field("b", int64()), field("c", int64()),
field("d", int64())});
}
@@ -204,8 +205,10 @@ int main(int argc, char** argv) {
g_server.reset(new arrow::flight::FlightPerfServer);
+ arrow::flight::Location location;
+ ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location));
ARROW_CHECK_OK(
- g_server->Init(std::unique_ptr(), FLAGS_port));
+ g_server->Init(std::unique_ptr(), location));
// Exit with a clean error code (0) on SIGTERM
ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM}));
std::cout << "Server port: " << FLAGS_port << std::endl;
diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc
index 5679cdeedde..d8c5525ee98 100644
--- a/cpp/src/arrow/flight/server.cc
+++ b/cpp/src/arrow/flight/server.cc
@@ -40,6 +40,7 @@
#include "arrow/status.h"
#include "arrow/util/logging.h"
#include "arrow/util/stl.h"
+#include "arrow/util/uri.h"
#include "arrow/flight/internal.h"
#include "arrow/flight/serialization-internal.h"
@@ -410,7 +411,6 @@ class FlightServiceImpl : public FlightService::Service {
#endif
struct FlightServerBase::Impl {
- std::string address_;
std::unique_ptr service_;
std::unique_ptr server_;
@@ -442,17 +442,34 @@ FlightServerBase::FlightServerBase() { impl_.reset(new Impl); }
FlightServerBase::~FlightServerBase() {}
-Status FlightServerBase::Init(std::unique_ptr auth_handler, int port) {
- impl_->address_ = "localhost:" + std::to_string(port);
+Status FlightServerBase::Init(std::unique_ptr auth_handler,
+ const Location& location) {
std::shared_ptr handler = std::move(auth_handler);
impl_->service_.reset(new FlightServiceImpl(handler, this));
grpc::ServerBuilder builder;
// Allow uploading messages of any length
builder.SetMaxReceiveMessageSize(-1);
- builder.AddListeningPort(impl_->address_, grpc::InsecureServerCredentials());
+
+ const std::string scheme = location.scheme();
+ if (scheme == kSchemeGrpc || scheme == kSchemeGrpcTcp) {
+ std::stringstream address;
+ address << location.uri_->host() << ':' << location.uri_->port_text();
+ builder.AddListeningPort(address.str(), grpc::InsecureServerCredentials());
+ } else if (scheme == kSchemeGrpcUnix) {
+ std::stringstream address;
+ address << "unix:" << location.uri_->path();
+ builder.AddListeningPort(address.str(), grpc::InsecureServerCredentials());
+ } else {
+ return Status::NotImplemented("Scheme is not supported: " + scheme);
+ }
+
builder.RegisterService(impl_->service_.get());
+ // Disable SO_REUSEPORT - it makes debugging/testing a pain as
+ // leftover processes can handle requests on accident
+ builder.AddChannelArgument(GRPC_ARG_ALLOW_REUSEPORT, 0);
+
impl_->server_ = builder.BuildAndStart();
if (!impl_->server_) {
return Status::UnknownError("Server did not start properly");
diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h
index 0a2b940f14b..752b89e85bf 100644
--- a/cpp/src/arrow/flight/server.h
+++ b/cpp/src/arrow/flight/server.h
@@ -98,13 +98,12 @@ class ARROW_EXPORT FlightServerBase {
// Lifecycle methods.
- /// \brief Initialize an insecure TCP server listening on localhost
- /// at the given port.
+ /// \brief Initialize a Flight server listening at the given location.
/// This method must be called before any other method.
- /// \param[in] port The port to serve on.
/// \param[in] auth_handler The authentication handler. May be
/// nullptr if no authentication is desired.
- Status Init(std::unique_ptr auth_handler, int port);
+ /// \param[in] location The location to serve on.
+ Status Init(std::unique_ptr auth_handler, const Location& location);
/// \brief Set the server to stop when receiving any of the given signal
/// numbers.
diff --git a/cpp/src/arrow/flight/test-integration-client.cc b/cpp/src/arrow/flight/test-integration-client.cc
index 93a1a16c67e..d0b734007c4 100644
--- a/cpp/src/arrow/flight/test-integration-client.cc
+++ b/cpp/src/arrow/flight/test-integration-client.cc
@@ -89,8 +89,7 @@ arrow::Status ConsumeFlightLocation(const arrow::flight::Location& location,
const std::shared_ptr& schema,
std::shared_ptr* retrieved_data) {
std::unique_ptr read_client;
- RETURN_NOT_OK(
- arrow::flight::FlightClient::Connect(location.host, location.port, &read_client));
+ RETURN_NOT_OK(arrow::flight::FlightClient::Connect(location, &read_client));
std::unique_ptr stream;
RETURN_NOT_OK(read_client->DoGet(ticket, &stream));
@@ -103,7 +102,10 @@ int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
std::unique_ptr client;
- ABORT_NOT_OK(arrow::flight::FlightClient::Connect(FLAGS_host, FLAGS_port, &client));
+ std::stringstream uri;
+ arrow::flight::Location location;
+ ABORT_NOT_OK(arrow::flight::Location::ForGrpcTcp(FLAGS_host, FLAGS_port, &location));
+ ABORT_NOT_OK(arrow::flight::FlightClient::Connect(location, &client));
arrow::flight::FlightDescriptor descr{
arrow::flight::FlightDescriptor::PATH, "", {FLAGS_path}};
@@ -143,12 +145,11 @@ int main(int argc, char** argv) {
auto locations = endpoint.locations;
if (locations.size() == 0) {
- locations = {arrow::flight::Location{FLAGS_host, FLAGS_port}};
+ locations = {location};
}
for (const auto location : locations) {
- std::cout << "Verifying location " << location.host << ':' << location.port
- << std::endl;
+ std::cout << "Verifying location " << location.ToString() << std::endl;
// 3. Download the data from the server.
std::shared_ptr retrieved_data;
ABORT_NOT_OK(ConsumeFlightLocation(location, ticket, schema, &retrieved_data));
diff --git a/cpp/src/arrow/flight/test-integration-server.cc b/cpp/src/arrow/flight/test-integration-server.cc
index 9a7c8787ba0..d72e838213b 100644
--- a/cpp/src/arrow/flight/test-integration-server.cc
+++ b/cpp/src/arrow/flight/test-integration-server.cc
@@ -124,7 +124,9 @@ int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
g_server.reset(new arrow::flight::FlightIntegrationTestServer);
- ARROW_CHECK_OK(g_server->Init(nullptr, FLAGS_port));
+ arrow::flight::Location location;
+ ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location));
+ ARROW_CHECK_OK(g_server->Init(nullptr, location));
// Exit with a clean error code (0) on SIGTERM
ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM}));
diff --git a/cpp/src/arrow/flight/test-server.cc b/cpp/src/arrow/flight/test-server.cc
index a9070a495a5..29af87db601 100644
--- a/cpp/src/arrow/flight/test-server.cc
+++ b/cpp/src/arrow/flight/test-server.cc
@@ -139,8 +139,10 @@ int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
g_server.reset(new arrow::flight::FlightTestServer);
+ arrow::flight::Location location;
+ ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location));
ARROW_CHECK_OK(
- g_server->Init(std::unique_ptr(), FLAGS_port));
+ g_server->Init(std::unique_ptr(), location));
// Exit with a clean error code (0) on SIGTERM
ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM}));
diff --git a/cpp/src/arrow/flight/test-util.cc b/cpp/src/arrow/flight/test-util.cc
index ac6df556f36..01577958640 100644
--- a/cpp/src/arrow/flight/test-util.cc
+++ b/cpp/src/arrow/flight/test-util.cc
@@ -119,7 +119,9 @@ bool TestServer::IsRunning() { return server_process_->running(); }
int TestServer::port() const { return port_; }
Status InProcessTestServer::Start(std::unique_ptr auth_handler) {
- RETURN_NOT_OK(server_->Init(std::move(auth_handler), port_));
+ Location location;
+ RETURN_NOT_OK(Location::ForGrpcTcp("localhost", port_, &location));
+ RETURN_NOT_OK(server_->Init(std::move(auth_handler), location));
thread_ = std::thread([this]() { ARROW_EXPECT_OK(server_->Serve()); });
return Status::OK();
}
@@ -169,12 +171,21 @@ std::shared_ptr ExampleDictSchema() {
}
std::vector ExampleFlightInfo() {
+ Location location1;
+ Location location2;
+ Location location3;
+ Location location4;
+ ARROW_EXPECT_OK(Location::ForGrpcTcp("foo1.bar.com", 12345, &location1));
+ ARROW_EXPECT_OK(Location::ForGrpcTcp("foo2.bar.com", 12345, &location2));
+ ARROW_EXPECT_OK(Location::ForGrpcTcp("foo3.bar.com", 12345, &location3));
+ ARROW_EXPECT_OK(Location::ForGrpcTcp("foo4.bar.com", 12345, &location4));
+
FlightInfo::Data flight1, flight2, flight3;
- FlightEndpoint endpoint1({{"ticket-ints-1"}, {{"foo1.bar.com", 92385}}});
- FlightEndpoint endpoint2({{"ticket-ints-2"}, {{"foo2.bar.com", 92385}}});
- FlightEndpoint endpoint3({{"ticket-cmd"}, {{"foo3.bar.com", 92385}}});
- FlightEndpoint endpoint4({{"ticket-dicts-1"}, {{"foo4.bar.com", 92385}}});
+ FlightEndpoint endpoint1({{"ticket-ints-1"}, {location1}});
+ FlightEndpoint endpoint2({{"ticket-ints-2"}, {location2}});
+ FlightEndpoint endpoint3({{"ticket-cmd"}, {location3}});
+ FlightEndpoint endpoint4({{"ticket-dicts-1"}, {location4}});
FlightDescriptor descr1{FlightDescriptor::PATH, "", {"examples", "ints"}};
FlightDescriptor descr2{FlightDescriptor::CMD, "my_command", {}};
diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc
index 77c8009e8bf..985cb5fc03a 100644
--- a/cpp/src/arrow/flight/types.cc
+++ b/cpp/src/arrow/flight/types.cc
@@ -25,6 +25,7 @@
#include "arrow/ipc/dictionary.h"
#include "arrow/ipc/reader.h"
#include "arrow/status.h"
+#include "arrow/util/uri.h"
namespace arrow {
namespace flight {
@@ -83,6 +84,34 @@ Status FlightInfo::GetSchema(ipc::DictionaryMemo* dictionary_memo,
return Status::OK();
}
+Location::Location() { uri_ = std::make_shared(); }
+
+Status Location::Parse(const std::string& uri_string, Location* location) {
+ return location->uri_->Parse(uri_string);
+}
+
+Status Location::ForGrpcTcp(const std::string& host, const int port, Location* location) {
+ std::stringstream uri_string;
+ uri_string << "grpc+tcp://" << host << ':' << port;
+ return Location::Parse(uri_string.str(), location);
+}
+
+Status Location::ForGrpcUnix(const std::string& path, Location* location) {
+ std::stringstream uri_string;
+ uri_string << "grpc+unix://" << path;
+ return Location::Parse(uri_string.str(), location);
+}
+
+std::string Location::ToString() const { return uri_->to_string(); }
+std::string Location::scheme() const {
+ std::string scheme = uri_->scheme();
+ if (scheme.empty()) {
+ // Default to grpc+tcp
+ return "grpc+tcp";
+ }
+ return scheme;
+}
+
SimpleFlightListing::SimpleFlightListing(const std::vector& flights)
: position_(0), flights_(flights) {}
diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h
index 8e85a41b3d7..9caee997ae1 100644
--- a/cpp/src/arrow/flight/types.h
+++ b/cpp/src/arrow/flight/types.h
@@ -39,7 +39,13 @@ namespace ipc {
class DictionaryMemo;
-}
+} // namespace ipc
+
+namespace internal {
+
+class Uri;
+
+} // namespace internal
namespace flight {
@@ -115,10 +121,53 @@ struct Ticket {
std::string ticket;
};
-/// \brief A host location (hostname and port)
+class FlightClient;
+class FlightServerBase;
+
+static const char* kSchemeGrpc = "grpc";
+static const char* kSchemeGrpcTcp = "grpc+tcp";
+static const char* kSchemeGrpcUnix = "grpc+unix";
+static const char* kSchemeGrpcTls = "grpc+tls";
+
+/// \brief A host location (a URI)
struct Location {
- std::string host;
- int32_t port;
+ public:
+ /// \brief Initialize a blank location.
+ Location();
+
+ /// \brief Initialize a location by parsing a URI string
+ static Status Parse(const std::string& uri_string, Location* location);
+
+ /// \brief Initialize a location for a non-TLS, gRPC-based Flight
+ /// service from a host and port
+ /// \param[in] host The hostname to connect to
+ /// \param[in] port The port
+ /// \param[out] location The resulting location
+ static Status ForGrpcTcp(const std::string& host, const int port, Location* location);
+
+ /// \brief Initialize a location for a domain socket-based Flight
+ /// service
+ /// \param[in] path The path to the domain socket
+ /// \param[out] location The resulting location
+ static Status ForGrpcUnix(const std::string& path, Location* location);
+
+ /// \brief Get a representation of this URI as a string.
+ std::string ToString() const;
+
+ /// \brief Get the scheme of this URI.
+ std::string scheme() const;
+
+ friend bool operator==(const Location& left, const Location& right) {
+ return left.ToString() == right.ToString();
+ }
+ friend bool operator!=(const Location& left, const Location& right) {
+ return !(left == right);
+ }
+
+ private:
+ friend class FlightClient;
+ friend class FlightServerBase;
+ std::shared_ptr uri_;
};
/// \brief A flight ticket and list of locations where the ticket can be
diff --git a/cpp/src/arrow/util/uri.cc b/cpp/src/arrow/util/uri.cc
index 3a90612ea67..e79c9826b9b 100644
--- a/cpp/src/arrow/util/uri.cc
+++ b/cpp/src/arrow/util/uri.cc
@@ -52,7 +52,7 @@ bool IsTextRangeSet(const UriTextRangeStructA& range) { return range.first != nu
} // namespace
struct Uri::Impl {
- Impl() : port_(-1) { memset(&uri_, 0, sizeof(uri_)); }
+ Impl() : string_rep_(""), port_(-1) { memset(&uri_, 0, sizeof(uri_)); }
~Impl() { uriFreeUriMembersA(&uri_); }
@@ -71,6 +71,7 @@ struct Uri::Impl {
UriUriA uri_;
// Keep alive strings that uriparser stores pointers to
std::vector data_;
+ std::string string_rep_;
int32_t port_;
};
@@ -119,10 +120,13 @@ std::string Uri::path() const {
return ss.str();
}
+const std::string& Uri::to_string() const { return impl_->string_rep_; }
+
Status Uri::Parse(const std::string& uri_string) {
impl_->Reset();
const auto& s = impl_->KeepString(uri_string);
+ impl_->string_rep_ = s;
const char* error_pos;
if (uriParseSingleUriExA(&impl_->uri_, s.data(), s.data() + s.size(), &error_pos) !=
URI_SUCCESS) {
diff --git a/cpp/src/arrow/util/uri.h b/cpp/src/arrow/util/uri.h
index 3d6949537ce..7327c2d7876 100644
--- a/cpp/src/arrow/util/uri.h
+++ b/cpp/src/arrow/util/uri.h
@@ -54,6 +54,8 @@ class ARROW_EXPORT Uri {
int32_t port() const;
/// The URI path component.
std::string path() const;
+ /// Get the string representation of this URI.
+ const std::string& to_string() const;
/// Factory function to parse a URI from its string representation.
Status Parse(const std::string& uri_string);
diff --git a/format/Flight.proto b/format/Flight.proto
index 1fcefe9a63e..7f0488b86c3 100644
--- a/format/Flight.proto
+++ b/format/Flight.proto
@@ -246,7 +246,7 @@ message FlightEndpoint {
Ticket ticket = 1;
/*
- * A list of locations where this ticket can be redeemed. If the list is
+ * A list of URIs where this ticket can be redeemed. If the list is
* empty, the expectation is that the ticket can only be redeemed on the
* current service where the ticket was generated.
*/
@@ -254,12 +254,11 @@ message FlightEndpoint {
}
/*
- * A location where a flight service will accept retrieval of a particular
- * stream given a ticket.
+ * A location where a Flight service will accept retrieval of a particular
+ * stream given a ticket.
*/
message Location {
- string host = 1;
- int32 port = 2;
+ string uri = 1;
}
/*
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
index c13c5659585..178897da3b3 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
@@ -20,6 +20,8 @@
import static io.grpc.stub.ClientCalls.asyncClientStreamingCall;
import static io.grpc.stub.ClientCalls.asyncServerStreamingCall;
+import java.net.URI;
+import java.net.URISyntaxException;
import java.util.Iterator;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
@@ -74,8 +76,7 @@ public class FlightClient implements AutoCloseable {
*/
public FlightClient(BufferAllocator incomingAllocator, Location location) {
final ManagedChannelBuilder> channelBuilder =
- ManagedChannelBuilder.forAddress(location.getHost(),
- location.getPort())
+ ManagedChannelBuilder.forAddress(location.getUri().getHost(), location.getUri().getPort())
.maxTraceEvents(MAX_CHANNEL_TRACE_EVENTS)
.maxInboundMessageSize(FlightServer.MAX_GRPC_MESSAGE_SIZE)
.usePlaintext();
@@ -97,7 +98,15 @@ public FlightClient(BufferAllocator incomingAllocator, Location location) {
public Iterable listFlights(Criteria criteria, CallOption... options) {
return ImmutableList.copyOf(CallOptions.wrapStub(blockingStub, options).listFlights(criteria.asCriteria()))
.stream()
- .map(FlightInfo::new)
+ .map(t -> {
+ try {
+ return new FlightInfo(t);
+ } catch (URISyntaxException e) {
+ // We don't expect this will happen for conforming Flight implementations. For instance, a Java server
+ // itself wouldn't be able to construct an invalid Location.
+ throw new RuntimeException(e);
+ }
+ })
.collect(Collectors.toList());
}
@@ -177,7 +186,13 @@ public ClientStreamListener startPut(
* @param options RPC-layer hints for this call.
*/
public FlightInfo getInfo(FlightDescriptor descriptor, CallOption... options) {
- return new FlightInfo(CallOptions.wrapStub(blockingStub, options).getFlightInfo(descriptor.toProtocol()));
+ try {
+ return new FlightInfo(CallOptions.wrapStub(blockingStub, options).getFlightInfo(descriptor.toProtocol()));
+ } catch (URISyntaxException e) {
+ // We don't expect this will happen for conforming Flight implementations. For instance, a Java server
+ // itself wouldn't be able to construct an invalid Location.
+ throw new RuntimeException(e);
+ }
}
/**
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightEndpoint.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightEndpoint.java
index b615ed6c5a5..a34c0a58aa1 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightEndpoint.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightEndpoint.java
@@ -17,6 +17,8 @@
package org.apache.arrow.flight;
+import java.net.URISyntaxException;
+import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
@@ -46,9 +48,11 @@ public FlightEndpoint(Ticket ticket, Location... locations) {
/**
* Constructs from the protocol buffer representation.
*/
- public FlightEndpoint(Flight.FlightEndpoint flt) {
- locations = flt.getLocationList().stream()
- .map(t -> new Location(t)).collect(Collectors.toList());
+ public FlightEndpoint(Flight.FlightEndpoint flt) throws URISyntaxException {
+ locations = new ArrayList<>();
+ for (final Flight.Location location : flt.getLocationList()) {
+ locations.add(new Location(location.getUri()));
+ }
ticket = new Ticket(flt.getTicket());
}
@@ -68,10 +72,7 @@ Flight.FlightEndpoint toProtocol() {
.setTicket(ticket.toProtocol());
for (Location l : locations) {
- b.addLocation(Flight.Location.newBuilder()
- .setHost(l.getHost())
- .setPort(l.getPort())
- .build());
+ b.addLocation(l.toProtocol());
}
return b.build();
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightInfo.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightInfo.java
index 053b89885ee..4159e4618f1 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightInfo.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightInfo.java
@@ -19,8 +19,10 @@
import java.io.ByteArrayOutputStream;
import java.io.IOException;
+import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
+import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
@@ -67,7 +69,7 @@ public FlightInfo(Schema schema, FlightDescriptor descriptor, List 0 ?
@@ -78,7 +80,10 @@ public FlightInfo(Schema schema, FlightDescriptor descriptor, List new FlightEndpoint(t)).collect(Collectors.toList());
+ endpoints = new ArrayList<>();
+ for (final Flight.FlightEndpoint endpoint : pbFlightInfo.getEndpointList()) {
+ endpoints.add(new FlightEndpoint(endpoint));
+ }
bytes = pbFlightInfo.getTotalBytes();
records = pbFlightInfo.getTotalRecords();
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java
index 0e93139e0e2..58afb5bafe8 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java
@@ -18,6 +18,7 @@
package org.apache.arrow.flight;
import java.io.IOException;
+import java.net.InetSocketAddress;
import java.util.concurrent.TimeUnit;
import org.apache.arrow.flight.auth.ServerAuthHandler;
@@ -26,8 +27,8 @@
import org.apache.arrow.vector.VectorSchemaRoot;
import io.grpc.Server;
-import io.grpc.ServerBuilder;
import io.grpc.ServerInterceptors;
+import io.grpc.netty.NettyServerBuilder;
public class FlightServer implements AutoCloseable {
@@ -42,16 +43,31 @@ public class FlightServer implements AutoCloseable {
* Constructs a new instance.
*
* @param allocator The allocator to use for storing/copying arrow data.
- * @param port The port to bind to.
+ * @param location The location to serve on.
* @param producer The underlying business logic for the server.
* @param authHandler The authorization handler for the server.
*/
public FlightServer(
BufferAllocator allocator,
- int port,
+ Location location,
FlightProducer producer,
ServerAuthHandler authHandler) {
- this.server = ServerBuilder.forPort(port)
+ final NettyServerBuilder builder;
+ switch (location.getUri().getScheme()) {
+ case LocationSchemes.GRPC_DOMAIN_SOCKET: {
+ // TODO: need reflection to check if domain sockets are available
+ throw new UnsupportedOperationException("Domain sockets are not available.");
+ }
+ case LocationSchemes.GRPC:
+ case LocationSchemes.GRPC_INSECURE: {
+ builder = NettyServerBuilder
+ .forAddress(new InetSocketAddress(location.getUri().getHost(), location.getUri().getPort()));
+ break;
+ }
+ default:
+ throw new IllegalArgumentException("Scheme is not supported: " + location.getUri().getScheme());
+ }
+ this.server = builder
.maxInboundMessageSize(MAX_GRPC_MESSAGE_SIZE)
.addService(
ServerInterceptors.intercept(
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/Location.java b/java/flight/src/main/java/org/apache/arrow/flight/Location.java
index 44ee865a14b..377edeba747 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/Location.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/Location.java
@@ -17,40 +17,56 @@
package org.apache.arrow.flight;
+import java.net.URI;
+import java.net.URISyntaxException;
+
import org.apache.arrow.flight.impl.Flight;
+/** A URI where a Flight stream is available. */
public class Location {
-
- private final String host;
- private final int port;
+ private final URI uri;
/**
* Constructs a new instance.
*
- * @param host the host containing a flight server
- * @param port the port the server is listening on.
+ * @param uri the URI of the Flight service
*/
- public Location(String host, int port) {
+ public Location(String uri) throws URISyntaxException {
super();
- this.host = host;
- this.port = port;
+ this.uri = new URI(uri);
}
- Location(Flight.Location location) {
- this.host = location.getHost();
- this.port = location.getPort();
+ public Location(URI uri) {
+ super();
+ this.uri = uri;
}
- public String getHost() {
- return host;
+ public URI getUri() {
+ return uri;
}
- public int getPort() {
- return port;
+ /**
+ * Convert this Location into its protocol-level representation.
+ */
+ public Flight.Location toProtocol() {
+ return Flight.Location.newBuilder().setUri(uri.toString()).build();
}
- Flight.Location toProtocol() {
- return Flight.Location.newBuilder().setHost(host).setPort(port).build();
+ /** Construct a URI for a Flight+gRPC server without transport security. */
+ public static Location forGrpcInsecure(String host, int port) {
+ try {
+ return new Location(new URI(LocationSchemes.GRPC_INSECURE, null, host, port, null, null, null));
+ } catch (URISyntaxException e) {
+ throw new IllegalArgumentException(e);
+ }
}
+ /** Construct a URI for a Flight+gRPC server with transport security. */
+ public static Location forGrpcTls(String host, int port) {
+ try {
+ return new Location(new URI(LocationSchemes.GRPC_TLS, null, host, port, null, null, null));
+ } catch (URISyntaxException e) {
+ throw new IllegalArgumentException(e);
+ }
+ }
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/LocationSchemes.java b/java/flight/src/main/java/org/apache/arrow/flight/LocationSchemes.java
new file mode 100644
index 00000000000..b652fed48a7
--- /dev/null
+++ b/java/flight/src/main/java/org/apache/arrow/flight/LocationSchemes.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+/**
+ * Constants representing well-known URI schemes for Flight services.
+ */
+public class LocationSchemes {
+ public static final String GRPC = "grpc";
+ public static final String GRPC_INSECURE = "grpc+tcp";
+ public static final String GRPC_DOMAIN_SOCKET = "grpc+unix";
+ public static final String GRPC_TLS = "grpc+tls";
+}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/ExampleFlightServer.java b/java/flight/src/main/java/org/apache/arrow/flight/example/ExampleFlightServer.java
index cbeb8fbf04f..2d71b5d490b 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/example/ExampleFlightServer.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/example/ExampleFlightServer.java
@@ -46,7 +46,7 @@ public ExampleFlightServer(BufferAllocator allocator, Location location) {
this.allocator = allocator.newChildAllocator("flight-server", 0, Long.MAX_VALUE);
this.location = location;
this.mem = new InMemoryStore(this.allocator, location);
- this.flightServer = new FlightServer(allocator, location.getPort(), mem, ServerAuthHandler.NO_OP);
+ this.flightServer = new FlightServer(allocator, location, mem, ServerAuthHandler.NO_OP);
}
public Location getLocation() {
@@ -75,7 +75,7 @@ public void close() throws Exception {
*/
public static void main(String[] args) throws Exception {
final BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
- final ExampleFlightServer efs = new ExampleFlightServer(a, new Location("localhost", 12233));
+ final ExampleFlightServer efs = new ExampleFlightServer(a, Location.forGrpcInsecure("localhost", 12233));
efs.start();
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
try {
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java
index ed450074a76..1e9f716f184 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java
@@ -19,6 +19,7 @@
import java.io.File;
import java.io.IOException;
+import java.net.URISyntaxException;
import java.util.Collections;
import java.util.List;
@@ -80,7 +81,8 @@ private void run(String[] args) throws ParseException, IOException {
final int port = Integer.parseInt(cmd.getOptionValue("port", "31337"));
final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
- final FlightClient client = new FlightClient(allocator, new Location(host, port));
+ final Location defaultLocation = Location.forGrpcInsecure(host, port);
+ final FlightClient client = new FlightClient(allocator, defaultLocation);
final String inputPath = cmd.getOptionValue("j");
@@ -114,10 +116,10 @@ private void run(String[] args) throws ParseException, IOException {
// 3. Download the data from the server.
List locations = endpoint.getLocations();
if (locations.size() == 0) {
- locations = Collections.singletonList(new Location(host, port));
+ locations = Collections.singletonList(defaultLocation);
}
for (Location location : locations) {
- System.out.println("Verifying location " + location.getHost() + ":" + location.getPort());
+ System.out.println("Verifying location " + location.getUri());
FlightClient readClient = new FlightClient(allocator, location);
FlightStream stream = readClient.getStream(endpoint.getTicket());
VectorSchemaRoot downloadedRoot;
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java
index fdad5d1a4e5..cfeaadb3da0 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java
@@ -43,7 +43,7 @@ private void run(String[] args) throws Exception {
final int port = Integer.parseInt(cmd.getOptionValue("port", "31337"));
final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
- final ExampleFlightServer efs = new ExampleFlightServer(allocator, new Location("localhost", port));
+ final ExampleFlightServer efs = new ExampleFlightServer(allocator, Location.forGrpcInsecure("localhost", port));
efs.start();
// Print out message for integration test script
System.out.println("Server listening on localhost:" + port);
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestBackPressure.java b/java/flight/src/test/java/org/apache/arrow/flight/TestBackPressure.java
index dcf8debb55b..e5b8d980f7e 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/TestBackPressure.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestBackPressure.java
@@ -47,7 +47,7 @@ public void ensureIndependentSteams() throws Exception {
try (
final BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
final PerformanceTestServer server = FlightTestUtil.getStartedServer(
- (port) -> (new PerformanceTestServer(a, new Location(FlightTestUtil.LOCALHOST, port))));
+ (port) -> (new PerformanceTestServer(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, port))));
final FlightClient client = new FlightClient(a, server.getLocation())
) {
FlightStream fs1 = client.getStream(client.getInfo(
@@ -123,10 +123,11 @@ public void getStream(CallContext context, Ticket ticket,
BufferAllocator serverAllocator = allocator.newChildAllocator("server", 0, Long.MAX_VALUE);
FlightServer server =
FlightTestUtil.getStartedServer(
- (port) -> new FlightServer(serverAllocator, port, producer, ServerAuthHandler.NO_OP));
+ (port) -> new FlightServer(serverAllocator, Location.forGrpcInsecure("localhost", port), producer,
+ ServerAuthHandler.NO_OP));
BufferAllocator clientAllocator = allocator.newChildAllocator("client", 0, Long.MAX_VALUE);
FlightClient client =
- new FlightClient(clientAllocator, new Location(FlightTestUtil.LOCALHOST, server.getPort()))
+ new FlightClient(clientAllocator, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort()))
) {
FlightStream stream = client.getStream(new Ticket(new byte[1]));
VectorSchemaRoot root = stream.getRoot();
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java b/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java
index 268580d8e62..a02cd764fc2 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java
@@ -17,6 +17,7 @@
package org.apache.arrow.flight;
+import java.net.URISyntaxException;
import java.util.concurrent.Callable;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
@@ -138,10 +139,12 @@ private void test(BiConsumer consumer) throws Exc
BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
Producer producer = new Producer(a);
FlightServer s =
- FlightTestUtil.getStartedServer((port) -> new FlightServer(a, port, producer, ServerAuthHandler.NO_OP))) {
+ FlightTestUtil.getStartedServer(
+ (port) -> new FlightServer(a, Location.forGrpcInsecure("localhost", port), producer,
+ ServerAuthHandler.NO_OP))) {
try (
- FlightClient c = new FlightClient(a, new Location(FlightTestUtil.LOCALHOST, s.getPort()));
+ FlightClient c = new FlightClient(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort()));
) {
try (BufferAllocator testAllocator = a.newChildAllocator("testcase", 0, Long.MAX_VALUE)) {
consumer.accept(c, testAllocator);
@@ -170,7 +173,12 @@ public void listFlights(CallContext context, Criteria criteria,
.setType(DescriptorType.CMD)
.setCmd(ByteString.copyFrom("cool thing", Charsets.UTF_8)))
.build();
- listener.onNext(new FlightInfo(getInfo));
+ try {
+ listener.onNext(new FlightInfo(getInfo));
+ } catch (URISyntaxException e) {
+ listener.onError(e);
+ return;
+ }
listener.onCompleted();
}
@@ -231,7 +239,11 @@ public FlightInfo getFlightInfo(CallContext context,
.setType(DescriptorType.CMD)
.setCmd(ByteString.copyFrom("cool thing", Charsets.UTF_8)))
.build();
- return new FlightInfo(getInfo);
+ try {
+ return new FlightInfo(getInfo);
+ } catch (URISyntaxException e) {
+ throw new RuntimeException(e);
+ }
}
@Override
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestCallOptions.java b/java/flight/src/test/java/org/apache/arrow/flight/TestCallOptions.java
index 21d7ecae9aa..3af36cbe5c4 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/TestCallOptions.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestCallOptions.java
@@ -70,8 +70,10 @@ void test(Consumer testFn) {
BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
Producer producer = new Producer(a);
FlightServer s =
- FlightTestUtil.getStartedServer((port) -> new FlightServer(a, port, producer, ServerAuthHandler.NO_OP));
- FlightClient client = new FlightClient(a, new Location(FlightTestUtil.LOCALHOST, s.getPort()))) {
+ FlightTestUtil.getStartedServer(
+ (port) -> new FlightServer(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, port), producer,
+ ServerAuthHandler.NO_OP));
+ FlightClient client = new FlightClient(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort()))) {
testFn.accept(client);
} catch (InterruptedException | IOException e) {
throw new RuntimeException(e);
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestLargeMessage.java b/java/flight/src/test/java/org/apache/arrow/flight/TestLargeMessage.java
index ca31b02c729..2e69c0c9219 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/TestLargeMessage.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestLargeMessage.java
@@ -44,9 +44,11 @@ public void getLargeMessage() throws Exception {
try (final BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
final Producer producer = new Producer(a);
final FlightServer s =
- FlightTestUtil.getStartedServer((port) -> new FlightServer(a, port, producer, ServerAuthHandler.NO_OP))) {
+ FlightTestUtil.getStartedServer(
+ (port) -> new FlightServer(a, Location.forGrpcInsecure("localhost", port), producer,
+ ServerAuthHandler.NO_OP))) {
- try (FlightClient client = new FlightClient(a, new Location(FlightTestUtil.LOCALHOST, s.getPort()))) {
+ try (FlightClient client = new FlightClient(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort()))) {
FlightStream stream = client.getStream(new Ticket(new byte[]{}));
try (VectorSchemaRoot root = stream.getRoot()) {
while (stream.next()) {
@@ -73,9 +75,11 @@ public void putLargeMessage() throws Exception {
try (final BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
final Producer producer = new Producer(a);
final FlightServer s =
- FlightTestUtil.getStartedServer((port) -> new FlightServer(a, port, producer, ServerAuthHandler.NO_OP))) {
+ FlightTestUtil.getStartedServer(
+ (port) -> new FlightServer(a, Location.forGrpcInsecure("localhost", port), producer,
+ ServerAuthHandler.NO_OP))) {
- try (FlightClient client = new FlightClient(a, new Location(FlightTestUtil.LOCALHOST, s.getPort()));
+ try (FlightClient client = new FlightClient(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort()));
BufferAllocator testAllocator = a.newChildAllocator("testcase", 0, Long.MAX_VALUE);
VectorSchemaRoot root = generateData(testAllocator)) {
final FlightClient.ClientStreamListener listener = client.startPut(FlightDescriptor.path("hello"), root);
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java b/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java
index 2e967b02b7f..4dfe669d5be 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java
@@ -121,7 +121,7 @@ public byte[] getToken(String username, String password) {
server = FlightTestUtil.getStartedServer((port) -> new FlightServer(
allocator,
- port,
+ Location.forGrpcInsecure("localhost", port),
new NoOpFlightProducer() {
@Override
public void listFlights(CallContext context, Criteria criteria,
@@ -151,7 +151,7 @@ public void getStream(CallContext context, Ticket ticket, ServerStreamListener l
}
},
new BasicServerAuthHandler(validator)));
- client = new FlightClient(allocator, new Location(FlightTestUtil.LOCALHOST, server.getPort()));
+ client = new FlightClient(allocator, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort()));
}
@After
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java b/java/flight/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java
index 3730a09c65a..20e04167771 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java
@@ -18,12 +18,14 @@
package org.apache.arrow.flight.example;
import java.io.IOException;
+import java.net.URISyntaxException;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightClient.ClientStreamListener;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightStream;
+import org.apache.arrow.flight.FlightTestUtil;
import org.apache.arrow.flight.Location;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
@@ -49,7 +51,7 @@ public class TestExampleServer {
public void start() throws IOException {
allocator = new RootAllocator(Long.MAX_VALUE);
- Location l = new Location("localhost", 12233);
+ Location l = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, 12233);
if (!Boolean.getBoolean("disableServer")) {
System.out.println("Starting server.");
server = new ExampleFlightServer(allocator, l);
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java b/java/flight/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java
index 78e5574457b..5d5e600ad29 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java
@@ -64,7 +64,7 @@ public PerformanceTestServer(BufferAllocator incomingAllocator, Location locatio
this.allocator = incomingAllocator.newChildAllocator("perf-server", 0, Long.MAX_VALUE);
this.location = location;
this.producer = new PerfProducer();
- this.flightServer = new FlightServer(this.allocator, location.getPort(), producer, ServerAuthHandler.NO_OP);
+ this.flightServer = new FlightServer(this.allocator, location, producer, ServerAuthHandler.NO_OP);
}
public Location getLocation() {
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/perf/TestPerf.java b/java/flight/src/test/java/org/apache/arrow/flight/perf/TestPerf.java
index d580fe20f2a..d40773ef929 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/perf/TestPerf.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/perf/TestPerf.java
@@ -82,7 +82,7 @@ public void throughput() throws Exception {
final BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
final PerformanceTestServer server =
FlightTestUtil.getStartedServer((port) -> new PerformanceTestServer(a,
- new Location(FlightTestUtil.LOCALHOST, port)));
+ Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, port)));
final FlightClient client = new FlightClient(a, server.getLocation());
) {
final FlightInfo info = client.getInfo(getPerfFlightDescriptor(50_000_000L, 4095, 2));
diff --git a/python/examples/flight/client.py b/python/examples/flight/client.py
index d1e60d2710c..0c91e3ec55f 100644
--- a/python/examples/flight/client.py
+++ b/python/examples/flight/client.py
@@ -131,7 +131,7 @@ def main():
}
host, port = args.host.split(':')
port = int(port)
- client = pyarrow.flight.FlightClient.connect(host, port)
+ client = pyarrow.flight.FlightClient.connect(f"grpc://{host}:{port}")
while True:
try:
action = pyarrow.flight.Action("healthcheck", b"")
diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx
index 474e0076a14..05d50cd629f 100644
--- a/python/pyarrow/_flight.pyx
+++ b/python/pyarrow/_flight.pyx
@@ -20,6 +20,8 @@
import collections
import enum
+import six
+
from cython.operator cimport dereference as deref
from pyarrow.compat import frombytes, tobytes
@@ -69,6 +71,13 @@ cdef class Action:
def body(self):
return pyarrow_wrap_buffer(self.action.body)
+ @staticmethod
+ cdef CAction unwrap(action):
+ if not isinstance(action, Action):
+ raise TypeError("Must provide Action, not '{}'".format(
+ type(action)))
+ return ( action).action
+
_ActionType = collections.namedtuple('_ActionType', ['type', 'description'])
@@ -158,18 +167,73 @@ cdef class FlightDescriptor:
def __repr__(self):
return "".format(self.descriptor_type)
+ @staticmethod
+ cdef FlightDescriptor unwrap(descriptor):
+ if not isinstance(descriptor, FlightDescriptor):
+ raise TypeError("Must provide a FlightDescriptor, not '{}'".format(
+ type(descriptor)))
+ return descriptor
+
-class Ticket:
+cdef class Ticket:
"""A ticket for requesting a Flight stream."""
+
+ cdef:
+ CTicket ticket
+
def __init__(self, ticket):
- self.ticket = ticket
+ self.ticket.ticket = tobytes(ticket)
def __repr__(self):
- return ''.format(self.ticket)
+ return ''.format(self.ticket.ticket)
-class Location(collections.namedtuple('Location', ['host', 'port'])):
- """A location where a Flight stream is available."""
+cdef class Location:
+ """The location of a Flight service."""
+ cdef:
+ CLocation location
+
+ def __init__(self, uri):
+ check_status(CLocation.Parse(tobytes(uri), &self.location))
+
+ def __repr__(self):
+ return ''.format(self.location.ToString())
+
+ @staticmethod
+ def for_grpc_tcp(host, port):
+ """Create a Location for a TCP-based gRPC service."""
+ cdef:
+ c_string c_host = tobytes(host)
+ int c_port = port
+ Location result = Location.__new__(Location)
+ check_status(CLocation.ForGrpcTcp(c_host, c_port, &result.location))
+ return result
+
+ @staticmethod
+ def for_grpc_unix(path):
+ """Create a Location for a domain socket-based gRPC service."""
+ cdef:
+ c_string c_path = tobytes(path)
+ Location result = Location.__new__(Location)
+ check_status(CLocation.ForGrpcUnix(c_path, &result.location))
+ return result
+
+ @staticmethod
+ cdef Location wrap(CLocation location):
+ cdef Location result = Location.__new__(Location)
+ result.location = location
+ return result
+
+ @staticmethod
+ cdef CLocation unwrap(object location):
+ cdef CLocation c_location
+ if isinstance(location, (six.text_type, six.binary_type)):
+ check_status(CLocation.Parse(tobytes(location), &c_location))
+ return c_location
+ elif not isinstance(location, Location):
+ raise TypeError("Must provide a Location, not '{}'".format(
+ type(location)))
+ return ( location).location
cdef class FlightEndpoint:
@@ -184,11 +248,16 @@ cdef class FlightEndpoint:
----------
ticket : Ticket or bytes
the ticket needed to access this flight
- locations : list of Location or tuples of (host, port)
+ locations : list of string URIs
locations where this flight is available
+
+ Raises
+ ------
+ ArrowException
+ If one of the location URIs is not a valid URI.
"""
cdef:
- CLocation c_location = CLocation()
+ CLocation c_location
if isinstance(ticket, Ticket):
self.endpoint.ticket.ticket = tobytes(ticket.ticket)
@@ -196,9 +265,8 @@ cdef class FlightEndpoint:
self.endpoint.ticket.ticket = tobytes(ticket)
for location in locations:
- # Accepts Location namedtuple or tuple
- c_location.host = tobytes(location[0])
- c_location.port = location[1]
+ c_location = CLocation()
+ check_status(CLocation.Parse(tobytes(location), &c_location))
self.endpoint.locations.push_back(c_location)
@property
@@ -207,7 +275,7 @@ cdef class FlightEndpoint:
@property
def locations(self):
- return [Location(frombytes(location.host), location.port)
+ return [Location.wrap(location)
for location in self.endpoint.locations]
@@ -313,27 +381,21 @@ cdef class FlightClient:
.format(self.__class__.__name__))
@staticmethod
- def connect(*args):
+ def connect(location):
"""Connect to a Flight service on the given host and port."""
cdef:
FlightClient result = FlightClient.__new__(FlightClient)
int c_port = 0
- c_string c_host
-
- if len(args) == 1:
- # Accept namedtuple or plain tuple
- c_host = tobytes(args[0][0])
- c_port = args[0][1]
- elif len(args) == 2:
- # Accept separate host, port
- c_host = tobytes(args[0])
- c_port = args[1]
+ CLocation c_location
+
+ if isinstance(location, Location):
+ c_location = ( location).location
else:
- raise TypeError("FlightClient.connect() takes 1 "
- "or 2 arguments ({} given)".format(len(args)))
+ c_location = CLocation()
+ check_status(CLocation.Parse(tobytes(location), &c_location))
with nogil:
- check_status(CFlightClient.Connect(c_host, c_port, &result.client))
+ check_status(CFlightClient.Connect(c_location, &result.client))
return result
@@ -375,10 +437,12 @@ cdef class FlightClient:
cdef:
unique_ptr[CResultStream] results
Result result
+ CAction c_action = Action.unwrap(action)
CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
with nogil:
- check_status(self.client.get().DoAction(deref(c_options),
- action.action, &results))
+ check_status(
+ self.client.get().DoAction(deref(c_options), c_action,
+ &results))
while True:
result = Result.__new__(Result)
@@ -414,25 +478,23 @@ cdef class FlightClient:
cdef:
FlightInfo result = FlightInfo.__new__(FlightInfo)
CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
+ FlightDescriptor c_descriptor = \
+ FlightDescriptor.unwrap(descriptor)
with nogil:
check_status(self.client.get().GetFlightInfo(
- deref(c_options), descriptor.descriptor, &result.info))
+ deref(c_options), c_descriptor.descriptor, &result.info))
return result
def do_get(self, ticket: Ticket, options: FlightCallOptions = None):
"""Request the data for a flight."""
cdef:
- # TODO: introduce unwrap
- CTicket c_ticket
unique_ptr[CRecordBatchReader] reader
CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
- c_ticket.ticket = ticket.ticket
with nogil:
- check_status(
- self.client.get().DoGet(deref(c_options), c_ticket, &reader))
+ check_status(self.client.get().DoGet(deref(c_options), ticket.ticket, &reader))
result = FlightRecordBatchReader()
result.reader.reset(reader.release())
return result
@@ -444,10 +506,12 @@ cdef class FlightClient:
shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema)
unique_ptr[CRecordBatchWriter] writer
CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
+ FlightDescriptor c_descriptor = \
+ FlightDescriptor.unwrap(descriptor)
with nogil:
check_status(self.client.get().DoPut(
- deref(c_options), descriptor.descriptor, c_schema, &writer))
+ deref(c_options), c_descriptor.descriptor, c_schema, &writer))
result = FlightRecordBatchWriter()
result.writer.reset(writer.release())
return result
@@ -921,11 +985,11 @@ cdef class FlightServerBase:
cdef:
unique_ptr[PyFlightServer] server
- def run(self, port, auth_handler=None):
+ def run(self, location, auth_handler=None):
cdef:
PyFlightServerVtable vtable = PyFlightServerVtable()
- int c_port = port
PyFlightServer* c_server
+ CLocation c_location = Location.unwrap(location)
unique_ptr[CServerAuthHandler] c_auth_handler
if auth_handler:
@@ -945,7 +1009,7 @@ cdef class FlightServerBase:
c_server = new PyFlightServer(self, vtable)
self.server.reset(c_server)
with nogil:
- check_status(c_server.Init(move(c_auth_handler), c_port))
+ check_status(c_server.Init(move(c_auth_handler), c_location))
check_status(c_server.ServeWithSignals())
def list_flights(self, context, criteria):
diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd
index bac7de54c89..fe2075e3675 100644
--- a/python/pyarrow/includes/libarrow_flight.pxd
+++ b/python/pyarrow/includes/libarrow_flight.pxd
@@ -80,9 +80,14 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
cdef cppclass CLocation" arrow::flight::Location":
CLocation()
+ c_string ToString()
- c_string host
- int32_t port
+ @staticmethod
+ CStatus Parse(c_string& uri_string, CLocation* location)
+ @staticmethod
+ CStatus ForGrpcTcp(c_string& host, int port, CLocation* location)
+ @staticmethod
+ CStatus ForGrpcUnix(c_string& path, CLocation* location)
cdef cppclass CFlightEndpoint" arrow::flight::FlightEndpoint":
CFlightEndpoint()
@@ -150,7 +155,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
cdef cppclass CFlightClient" arrow::flight::FlightClient":
@staticmethod
- CStatus Connect(const c_string& host, int port,
+ CStatus Connect(const CLocation& location,
unique_ptr[CFlightClient]* client)
CStatus Authenticate(CFlightCallOptions& options,
@@ -224,7 +229,8 @@ cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil:
cdef cppclass PyFlightServer:
PyFlightServer(object server, PyFlightServerVtable vtable)
- CStatus Init(unique_ptr[CServerAuthHandler] auth_handler, int port)
+ CStatus Init(unique_ptr[CServerAuthHandler] auth_handler,
+ const CLocation& location)
CStatus ServeWithSignals() except *
void Shutdown()
diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py
index 6872b6f77fc..03e7a40fe1f 100644
--- a/python/pyarrow/tests/test_flight.py
+++ b/python/pyarrow/tests/test_flight.py
@@ -19,6 +19,7 @@
import base64
import contextlib
import socket
+import tempfile
import threading
import time
@@ -206,26 +207,30 @@ def get_token(self):
@contextlib.contextmanager
def flight_server(server_base, *args, **kwargs):
"""Spawn a Flight server on a free port, shutting it down when done."""
- # Find a free port
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- with contextlib.closing(sock) as sock:
- sock.bind(('', 0))
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- port = sock.getsockname()[1]
-
- auth_handler = kwargs.get('auth_handler')
+ auth_handler = kwargs.pop('auth_handler', None)
+ location = kwargs.pop('location', None)
+
+ if location is None:
+ # Find a free port
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ with contextlib.closing(sock) as sock:
+ sock.bind(('', 0))
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ port = sock.getsockname()[1]
+ location = flight.Location.for_grpc_tcp("localhost", port)
+ else:
+ port = None
+
ctor_kwargs = kwargs
- if auth_handler:
- del ctor_kwargs['auth_handler']
server_instance = server_base(*args, **ctor_kwargs)
def _server_thread():
- server_instance.run(port, auth_handler=auth_handler)
+ server_instance.run(location, auth_handler=auth_handler)
thread = threading.Thread(target=_server_thread, daemon=True)
thread.start()
- yield port
+ yield location
server_instance.shutdown()
thread.join()
@@ -244,12 +249,29 @@ def test_flight_do_get_ints():
def test_flight_do_get_dicts():
table = simple_dicts_table()
- with flight_server(ConstantFlightServer) as server_port:
- client = flight.FlightClient.connect('localhost', server_port)
+ with flight_server(ConstantFlightServer) as server_location:
+ client = flight.FlightClient.connect(server_location)
data = client.do_get(flight.Ticket(b'dicts')).read_all()
assert data.equals(table)
+def test_flight_domain_socket():
+ """Try a simple do_get call over a domain socket."""
+ data = [
+ pa.array([-10, -5, 0, 5, 10])
+ ]
+ table = pa.Table.from_arrays(data, names=['a'])
+
+ with tempfile.NamedTemporaryFile() as sock:
+ sock.close()
+ location = flight.Location.for_grpc_unix(sock.name)
+ with flight_server(ConstantFlightServer,
+ location=location) as server_location:
+ client = flight.FlightClient.connect(server_location)
+ data = client.do_get(flight.Ticket(b'')).read_all()
+ assert data.equals(table)
+
+
@pytest.mark.slow
def test_flight_large_message():
"""Try sending/receiving a large message via Flight.
@@ -261,8 +283,8 @@ def test_flight_large_message():
pa.array(range(0, 10 * 1024 * 1024))
], names=['a'])
- with flight_server(EchoFlightServer) as server_port:
- client = flight.FlightClient.connect('localhost', server_port)
+ with flight_server(EchoFlightServer) as server_location:
+ client = flight.FlightClient.connect(server_location)
writer = client.do_put(flight.FlightDescriptor.for_path('test'),
data.schema)
# Write a single giant chunk
@@ -278,8 +300,8 @@ def test_flight_generator_stream():
pa.array(range(0, 10 * 1024))
], names=['a'])
- with flight_server(EchoStreamFlightServer) as server_port:
- client = flight.FlightClient.connect('localhost', server_port)
+ with flight_server(EchoStreamFlightServer) as server_location:
+ client = flight.FlightClient.connect(server_location)
writer = client.do_put(flight.FlightDescriptor.for_path('test'),
data.schema)
writer.write_table(data)
@@ -290,8 +312,8 @@ def test_flight_generator_stream():
def test_flight_invalid_generator_stream():
"""Try streaming data with mismatched schemas."""
- with flight_server(InvalidStreamFlightServer) as server_port:
- client = flight.FlightClient.connect('localhost', server_port)
+ with flight_server(InvalidStreamFlightServer) as server_location:
+ client = flight.FlightClient.connect(server_location)
with pytest.raises(pa.ArrowException):
client.do_get(flight.Ticket(b'')).read_all()
@@ -300,8 +322,8 @@ def test_timeout_fires():
"""Make sure timeouts fire on slow requests."""
# Do this in a separate thread so that if it fails, we don't hang
# the entire test process
- with flight_server(SlowFlightServer) as server_port:
- client = flight.FlightClient.connect('localhost', server_port)
+ with flight_server(SlowFlightServer) as server_location:
+ client = flight.FlightClient.connect(server_location)
action = flight.Action("", b"")
options = flight.FlightCallOptions(timeout=0.2)
with pytest.raises(pa.ArrowIOError, match="Deadline Exceeded"):
@@ -310,8 +332,8 @@ def test_timeout_fires():
def test_timeout_passes():
"""Make sure timeouts do not fire on fast requests."""
- with flight_server(ConstantFlightServer) as server_port:
- client = flight.FlightClient.connect('localhost', server_port)
+ with flight_server(ConstantFlightServer) as server_location:
+ client = flight.FlightClient.connect(server_location)
options = flight.FlightCallOptions(timeout=0.2)
client.do_get(flight.Ticket(b'ints'), options=options).read_all()
@@ -328,8 +350,8 @@ def test_timeout_passes():
def test_http_basic_unauth():
"""Test that auth fails when not authenticated."""
with flight_server(EchoStreamFlightServer,
- auth_handler=basic_auth_handler) as server_port:
- client = flight.FlightClient.connect('localhost', server_port)
+ auth_handler=basic_auth_handler) as server_location:
+ client = flight.FlightClient.connect(server_location)
action = flight.Action("who-am-i", b"")
with pytest.raises(pa.ArrowException, match=".*unauthenticated.*"):
list(client.do_action(action))
@@ -338,8 +360,8 @@ def test_http_basic_unauth():
def test_http_basic_auth():
"""Test a Python implementation of HTTP basic authentication."""
with flight_server(EchoStreamFlightServer,
- auth_handler=basic_auth_handler) as server_port:
- client = flight.FlightClient.connect('localhost', server_port)
+ auth_handler=basic_auth_handler) as server_location:
+ client = flight.FlightClient.connect(server_location)
action = flight.Action("who-am-i", b"")
client.authenticate(HttpBasicClientAuthHandler('test', 'p4ssw0rd'))
identity = next(client.do_action(action))
@@ -349,8 +371,8 @@ def test_http_basic_auth():
def test_http_basic_auth_invalid_password():
"""Test that auth fails with the wrong password."""
with flight_server(EchoStreamFlightServer,
- auth_handler=basic_auth_handler) as server_port:
- client = flight.FlightClient.connect('localhost', server_port)
+ auth_handler=basic_auth_handler) as server_location:
+ client = flight.FlightClient.connect(server_location)
action = flight.Action("who-am-i", b"")
client.authenticate(HttpBasicClientAuthHandler('test', 'wrong'))
with pytest.raises(pa.ArrowException, match=".*wrong password.*"):
@@ -360,8 +382,8 @@ def test_http_basic_auth_invalid_password():
def test_token_auth():
"""Test an auth mechanism that uses a handshake."""
with flight_server(EchoStreamFlightServer,
- auth_handler=token_auth_handler) as server_port:
- client = flight.FlightClient.connect('localhost', server_port)
+ auth_handler=token_auth_handler) as server_location:
+ client = flight.FlightClient.connect(server_location)
action = flight.Action("who-am-i", b"")
client.authenticate(TokenClientAuthHandler('test', 'p4ssw0rd'))
identity = next(client.do_action(action))
@@ -371,7 +393,7 @@ def test_token_auth():
def test_token_auth_invalid():
"""Test an auth mechanism that uses a handshake."""
with flight_server(EchoStreamFlightServer,
- auth_handler=token_auth_handler) as server_port:
- client = flight.FlightClient.connect('localhost', server_port)
+ auth_handler=token_auth_handler) as server_location:
+ client = flight.FlightClient.connect(server_location)
with pytest.raises(pa.ArrowException, match=".*unauthenticated.*"):
client.authenticate(TokenClientAuthHandler('test', 'wrong'))
From 47441966720bf611ed2482f85fdcdf32a9ca231c Mon Sep 17 00:00:00 2001
From: David Li
Date: Thu, 2 May 2019 17:31:56 -0400
Subject: [PATCH 2/5] Use builder for Flight server in Java
---
.../org/apache/arrow/flight/FlightClient.java | 75 +++++++++--
.../org/apache/arrow/flight/FlightServer.java | 127 ++++++++++++------
.../flight/example/ExampleFlightServer.java | 2 +-
.../integration/IntegrationTestClient.java | 4 +-
.../apache/arrow/flight/TestBackPressure.java | 10 +-
.../arrow/flight/TestBasicOperation.java | 7 +-
.../apache/arrow/flight/TestCallOptions.java | 7 +-
.../apache/arrow/flight/TestLargeMessage.java | 13 +-
.../apache/arrow/flight/auth/TestAuth.java | 8 +-
.../flight/example/TestExampleServer.java | 2 +-
.../flight/perf/PerformanceTestServer.java | 2 +-
.../apache/arrow/flight/perf/TestPerf.java | 2 +-
12 files changed, 177 insertions(+), 82 deletions(-)
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
index 178897da3b3..3a6fbae87bb 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
@@ -20,7 +20,6 @@
import static io.grpc.stub.ClientCalls.asyncClientStreamingCall;
import static io.grpc.stub.ClientCalls.asyncServerStreamingCall;
-import java.net.URI;
import java.net.URISyntaxException;
import java.util.Iterator;
import java.util.concurrent.TimeUnit;
@@ -50,8 +49,8 @@
import io.grpc.ClientCall;
import io.grpc.ManagedChannel;
-import io.grpc.ManagedChannelBuilder;
import io.grpc.MethodDescriptor;
+import io.grpc.netty.NettyChannelBuilder;
import io.grpc.stub.ClientCallStreamObserver;
import io.grpc.stub.ClientResponseObserver;
import io.grpc.stub.StreamObserver;
@@ -71,17 +70,9 @@ public class FlightClient implements AutoCloseable {
private final MethodDescriptor doGetDescriptor;
private final MethodDescriptor doPutDescriptor;
- /**
- * Construct client for accessing RouteGuide server using the existing channel.
- */
- public FlightClient(BufferAllocator incomingAllocator, Location location) {
- final ManagedChannelBuilder> channelBuilder =
- ManagedChannelBuilder.forAddress(location.getUri().getHost(), location.getUri().getPort())
- .maxTraceEvents(MAX_CHANNEL_TRACE_EVENTS)
- .maxInboundMessageSize(FlightServer.MAX_GRPC_MESSAGE_SIZE)
- .usePlaintext();
+ private FlightClient(BufferAllocator incomingAllocator, ManagedChannel channel) {
this.allocator = incomingAllocator.newChildAllocator("flight-client", 0, Long.MAX_VALUE);
- channel = channelBuilder.build();
+ this.channel = channel;
blockingStub = FlightServiceGrpc.newBlockingStub(channel).withInterceptors(authInterceptor);
asyncStub = FlightServiceGrpc.newStub(channel).withInterceptors(authInterceptor);
doGetDescriptor = FlightBindingService.getDoGetDescriptor(allocator);
@@ -324,4 +315,64 @@ public void close() throws InterruptedException {
allocator.close();
}
+ /**
+ * Create a builder for a Flight client.
+ * @param allocator The allocator to use for the client.
+ * @param location The location to connect to.
+ */
+ public static Builder builder(BufferAllocator allocator, Location location) {
+ return new Builder(allocator, location);
+ }
+
+ /**
+ * A builder for Flight clients.
+ */
+ public static final class Builder {
+
+ private final BufferAllocator allocator;
+ private final Location location;
+ private boolean forceTls = false;
+
+ private Builder(BufferAllocator allocator, Location location) {
+ this.allocator = allocator;
+ this.location = location;
+ }
+
+ /**
+ * Force the client to connect over TLS.
+ */
+ public Builder useTls() {
+ this.forceTls = true;
+ return this;
+ }
+
+ /**
+ * Create the client from this builder.
+ */
+ public FlightClient build() {
+ final NettyChannelBuilder builder;
+
+ switch (location.getUri().getScheme()) {
+ case LocationSchemes.GRPC:
+ case LocationSchemes.GRPC_INSECURE:
+ case LocationSchemes.GRPC_TLS: {
+ builder = NettyChannelBuilder.forAddress(location.getUri().getHost(), location.getUri().getPort());
+ break;
+ }
+ default:
+ throw new IllegalArgumentException("Scheme is not supported: " + location.getUri().getScheme());
+ }
+
+ if (this.forceTls || LocationSchemes.GRPC_TLS.equals(location.getUri().getScheme())) {
+ builder.useTransportSecurity();
+ } else {
+ builder.usePlaintext();
+ }
+
+ builder
+ .maxTraceEvents(MAX_CHANNEL_TRACE_EVENTS)
+ .maxInboundMessageSize(FlightServer.MAX_GRPC_MESSAGE_SIZE);
+ return new FlightClient(allocator, builder.build());
+ }
+ }
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java
index 58afb5bafe8..8c9ae1c2e4e 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java
@@ -17,8 +17,13 @@
package org.apache.arrow.flight;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
import java.io.IOException;
+import java.io.InputStream;
import java.net.InetSocketAddress;
+import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import org.apache.arrow.flight.auth.ServerAuthHandler;
@@ -39,41 +44,9 @@ public class FlightServer implements AutoCloseable {
/** The maximum size of an individual gRPC message. This effectively disables the limit. */
static final int MAX_GRPC_MESSAGE_SIZE = Integer.MAX_VALUE;
- /**
- * Constructs a new instance.
- *
- * @param allocator The allocator to use for storing/copying arrow data.
- * @param location The location to serve on.
- * @param producer The underlying business logic for the server.
- * @param authHandler The authorization handler for the server.
- */
- public FlightServer(
- BufferAllocator allocator,
- Location location,
- FlightProducer producer,
- ServerAuthHandler authHandler) {
- final NettyServerBuilder builder;
- switch (location.getUri().getScheme()) {
- case LocationSchemes.GRPC_DOMAIN_SOCKET: {
- // TODO: need reflection to check if domain sockets are available
- throw new UnsupportedOperationException("Domain sockets are not available.");
- }
- case LocationSchemes.GRPC:
- case LocationSchemes.GRPC_INSECURE: {
- builder = NettyServerBuilder
- .forAddress(new InetSocketAddress(location.getUri().getHost(), location.getUri().getPort()));
- break;
- }
- default:
- throw new IllegalArgumentException("Scheme is not supported: " + location.getUri().getScheme());
- }
- this.server = builder
- .maxInboundMessageSize(MAX_GRPC_MESSAGE_SIZE)
- .addService(
- ServerInterceptors.intercept(
- new FlightBindingService(allocator, producer, authHandler),
- new ServerAuthInterceptor(authHandler)))
- .build();
+ /** Create a new instance from a gRPC server. For internal use only. */
+ private FlightServer(Server server) {
+ this.server = server;
}
/** Start the server. */
@@ -115,20 +88,86 @@ public void close() throws InterruptedException {
}
}
- public interface OutputFlight {
- void sendData(int count);
+ public static Builder builder(BufferAllocator allocator, Location location, FlightProducer producer) {
+ return new Builder(allocator, location, producer);
+ }
- void done();
+ public static final class Builder {
- void fail(Throwable t);
- }
+ private final BufferAllocator allocator;
+ private final Location location;
+ private final FlightProducer producer;
+ private ServerAuthHandler authHandler = ServerAuthHandler.NO_OP;
- public interface FlightServerHandler {
+ private InputStream certChain;
+ private InputStream key;
- public FlightInfo getFlightInfo(String descriptor) throws Exception;
+ Builder(BufferAllocator allocator, Location location, FlightProducer producer) {
+ this.allocator = allocator;
+ this.location = location;
+ this.producer = producer;
+ }
- public OutputFlight setupFlight(VectorSchemaRoot root);
+ public FlightServer build() {
+ final NettyServerBuilder builder;
+ switch (location.getUri().getScheme()) {
+ case LocationSchemes.GRPC_DOMAIN_SOCKET: {
+ // TODO: need reflection to check if domain sockets are available
+ throw new UnsupportedOperationException("Domain sockets are not available.");
+ }
+ case LocationSchemes.GRPC:
+ case LocationSchemes.GRPC_INSECURE: {
+ builder = NettyServerBuilder
+ .forAddress(new InetSocketAddress(location.getUri().getHost(), location.getUri().getPort()));
+ break;
+ }
+ case LocationSchemes.GRPC_TLS: {
+ if (certChain == null) {
+ throw new IllegalArgumentException("Must provide a certificate and key to serve gRPC over TLS");
+ }
+ builder = NettyServerBuilder
+ .forAddress(new InetSocketAddress(location.getUri().getHost(), location.getUri().getPort()));
+ break;
+ }
+ default:
+ throw new IllegalArgumentException("Scheme is not supported: " + location.getUri().getScheme());
+ }
- }
+ if (certChain != null) {
+ builder.useTransportSecurity(certChain, key);
+ }
+ final Server server = builder
+ .executor(new ForkJoinPool())
+ .maxInboundMessageSize(MAX_GRPC_MESSAGE_SIZE)
+ .addService(
+ ServerInterceptors.intercept(
+ new FlightBindingService(allocator, producer, authHandler),
+ new ServerAuthInterceptor(authHandler)))
+ .build();
+ return new FlightServer(server);
+ }
+
+ /**
+ * Enable TLS on the server.
+ * @param certChain The certificate chain to use.
+ * @param key The private key to use.
+ */
+ public Builder useTls(final File certChain, final File key) throws IOException {
+ this.certChain = new FileInputStream(certChain);
+ this.key = new FileInputStream(key);
+ return this;
+ }
+
+ public Builder useTls(final InputStream certChain, final InputStream key) {
+ this.certChain = certChain;
+ this.key = key;
+ return this;
+ }
+
+ public Builder setAuthHandler(ServerAuthHandler authHandler) {
+ this.authHandler = authHandler;
+ return this;
+ }
+ }
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/ExampleFlightServer.java b/java/flight/src/main/java/org/apache/arrow/flight/example/ExampleFlightServer.java
index 2d71b5d490b..a08f74faaa9 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/example/ExampleFlightServer.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/example/ExampleFlightServer.java
@@ -46,7 +46,7 @@ public ExampleFlightServer(BufferAllocator allocator, Location location) {
this.allocator = allocator.newChildAllocator("flight-server", 0, Long.MAX_VALUE);
this.location = location;
this.mem = new InMemoryStore(this.allocator, location);
- this.flightServer = new FlightServer(allocator, location, mem, ServerAuthHandler.NO_OP);
+ this.flightServer = FlightServer.builder(allocator, location, mem).build();
}
public Location getLocation() {
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java
index 1e9f716f184..6222093c344 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java
@@ -82,7 +82,7 @@ private void run(String[] args) throws ParseException, IOException {
final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
final Location defaultLocation = Location.forGrpcInsecure(host, port);
- final FlightClient client = new FlightClient(allocator, defaultLocation);
+ final FlightClient client = FlightClient.builder(allocator, defaultLocation).build();
final String inputPath = cmd.getOptionValue("j");
@@ -120,7 +120,7 @@ private void run(String[] args) throws ParseException, IOException {
}
for (Location location : locations) {
System.out.println("Verifying location " + location.getUri());
- FlightClient readClient = new FlightClient(allocator, location);
+ FlightClient readClient = FlightClient.builder(allocator, location).build();
FlightStream stream = readClient.getStream(endpoint.getTicket());
VectorSchemaRoot downloadedRoot;
try (VectorSchemaRoot root = stream.getRoot()) {
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestBackPressure.java b/java/flight/src/test/java/org/apache/arrow/flight/TestBackPressure.java
index e5b8d980f7e..b5dde67d5b4 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/TestBackPressure.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestBackPressure.java
@@ -48,7 +48,7 @@ public void ensureIndependentSteams() throws Exception {
final BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
final PerformanceTestServer server = FlightTestUtil.getStartedServer(
(port) -> (new PerformanceTestServer(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, port))));
- final FlightClient client = new FlightClient(a, server.getLocation())
+ final FlightClient client = FlightClient.builder(a, server.getLocation()).build()
) {
FlightStream fs1 = client.getStream(client.getInfo(
TestPerf.getPerfFlightDescriptor(110L * BATCH_SIZE, BATCH_SIZE, 1))
@@ -123,11 +123,13 @@ public void getStream(CallContext context, Ticket ticket,
BufferAllocator serverAllocator = allocator.newChildAllocator("server", 0, Long.MAX_VALUE);
FlightServer server =
FlightTestUtil.getStartedServer(
- (port) -> new FlightServer(serverAllocator, Location.forGrpcInsecure("localhost", port), producer,
- ServerAuthHandler.NO_OP));
+ (port) -> FlightServer.builder(serverAllocator, Location.forGrpcInsecure("localhost", port), producer)
+ .build());
BufferAllocator clientAllocator = allocator.newChildAllocator("client", 0, Long.MAX_VALUE);
FlightClient client =
- new FlightClient(clientAllocator, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort()))
+ FlightClient
+ .builder(clientAllocator, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort()))
+ .build()
) {
FlightStream stream = client.getStream(new Ticket(new byte[1]));
VectorSchemaRoot root = stream.getRoot();
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java b/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java
index a02cd764fc2..ab31eb8df92 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java
@@ -140,11 +140,12 @@ private void test(BiConsumer consumer) throws Exc
Producer producer = new Producer(a);
FlightServer s =
FlightTestUtil.getStartedServer(
- (port) -> new FlightServer(a, Location.forGrpcInsecure("localhost", port), producer,
- ServerAuthHandler.NO_OP))) {
+ (port) -> FlightServer.builder(a, Location.forGrpcInsecure("localhost", port), producer).build()
+ )) {
try (
- FlightClient c = new FlightClient(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort()));
+ FlightClient c = FlightClient.builder(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort()))
+ .build()
) {
try (BufferAllocator testAllocator = a.newChildAllocator("testcase", 0, Long.MAX_VALUE)) {
consumer.accept(c, testAllocator);
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestCallOptions.java b/java/flight/src/test/java/org/apache/arrow/flight/TestCallOptions.java
index 3af36cbe5c4..9a24c66dceb 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/TestCallOptions.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestCallOptions.java
@@ -71,9 +71,10 @@ void test(Consumer testFn) {
Producer producer = new Producer(a);
FlightServer s =
FlightTestUtil.getStartedServer(
- (port) -> new FlightServer(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, port), producer,
- ServerAuthHandler.NO_OP));
- FlightClient client = new FlightClient(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort()))) {
+ (port) -> FlightServer.builder(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, port), producer)
+ .build());
+ FlightClient client = FlightClient.builder(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort()))
+ .build()) {
testFn.accept(client);
} catch (InterruptedException | IOException e) {
throw new RuntimeException(e);
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestLargeMessage.java b/java/flight/src/test/java/org/apache/arrow/flight/TestLargeMessage.java
index 2e69c0c9219..48c4a271be4 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/TestLargeMessage.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestLargeMessage.java
@@ -45,10 +45,10 @@ public void getLargeMessage() throws Exception {
final Producer producer = new Producer(a);
final FlightServer s =
FlightTestUtil.getStartedServer(
- (port) -> new FlightServer(a, Location.forGrpcInsecure("localhost", port), producer,
- ServerAuthHandler.NO_OP))) {
+ (port) -> FlightServer.builder(a, Location.forGrpcInsecure("localhost", port), producer).build())) {
- try (FlightClient client = new FlightClient(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort()))) {
+ try (FlightClient client = FlightClient
+ .builder(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort())).build()) {
FlightStream stream = client.getStream(new Ticket(new byte[]{}));
try (VectorSchemaRoot root = stream.getRoot()) {
while (stream.next()) {
@@ -76,10 +76,11 @@ public void putLargeMessage() throws Exception {
final Producer producer = new Producer(a);
final FlightServer s =
FlightTestUtil.getStartedServer(
- (port) -> new FlightServer(a, Location.forGrpcInsecure("localhost", port), producer,
- ServerAuthHandler.NO_OP))) {
+ (port) -> FlightServer.builder(a, Location.forGrpcInsecure("localhost", port), producer).build()
+ )) {
- try (FlightClient client = new FlightClient(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort()));
+ try (FlightClient client = FlightClient
+ .builder(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort())).build();
BufferAllocator testAllocator = a.newChildAllocator("testcase", 0, Long.MAX_VALUE);
VectorSchemaRoot root = generateData(testAllocator)) {
final FlightClient.ClientStreamListener listener = client.startPut(FlightDescriptor.path("hello"), root);
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java b/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java
index 4dfe669d5be..df16a11938a 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java
@@ -119,7 +119,7 @@ public byte[] getToken(String username, String password) {
}
};
- server = FlightTestUtil.getStartedServer((port) -> new FlightServer(
+ server = FlightTestUtil.getStartedServer((port) -> FlightServer.builder(
allocator,
Location.forGrpcInsecure("localhost", port),
new NoOpFlightProducer() {
@@ -149,9 +149,9 @@ public void getStream(CallContext context, Ticket ticket, ServerStreamListener l
root.clear();
listener.completed();
}
- },
- new BasicServerAuthHandler(validator)));
- client = new FlightClient(allocator, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort()));
+ }).setAuthHandler(new BasicServerAuthHandler(validator)).build());
+ client = FlightClient.builder(allocator, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort()))
+ .build();
}
@After
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java b/java/flight/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java
index 20e04167771..a580a6e1717 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java
@@ -59,7 +59,7 @@ public void start() throws IOException {
} else {
System.out.println("Skipping server startup.");
}
- client = new FlightClient(allocator, l);
+ client = FlightClient.builder(allocator, l).build();
caseAllocator = allocator.newChildAllocator("test-case", 0, Long.MAX_VALUE);
}
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java b/java/flight/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java
index 5d5e600ad29..bc6d202d3bd 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java
@@ -64,7 +64,7 @@ public PerformanceTestServer(BufferAllocator incomingAllocator, Location locatio
this.allocator = incomingAllocator.newChildAllocator("perf-server", 0, Long.MAX_VALUE);
this.location = location;
this.producer = new PerfProducer();
- this.flightServer = new FlightServer(this.allocator, location, producer, ServerAuthHandler.NO_OP);
+ this.flightServer = FlightServer.builder(this.allocator, location, producer).build();
}
public Location getLocation() {
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/perf/TestPerf.java b/java/flight/src/test/java/org/apache/arrow/flight/perf/TestPerf.java
index d40773ef929..a9b9d60b1b9 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/perf/TestPerf.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/perf/TestPerf.java
@@ -83,7 +83,7 @@ public void throughput() throws Exception {
final PerformanceTestServer server =
FlightTestUtil.getStartedServer((port) -> new PerformanceTestServer(a,
Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, port)));
- final FlightClient client = new FlightClient(a, server.getLocation());
+ final FlightClient client = FlightClient.builder(a, server.getLocation()).build();
) {
final FlightInfo info = client.getInfo(getPerfFlightDescriptor(50_000_000L, 4095, 2));
ListeningExecutorService pool = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(4));
From 675acc9bc2051c804d0666f6f698f476ca87e463 Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 3 May 2019 09:58:52 -0400
Subject: [PATCH 3/5] Introduce builder for C++ Flight servers
---
cpp/src/arrow/flight/client.cc | 16 +++--
cpp/src/arrow/flight/client.h | 14 +++++
cpp/src/arrow/flight/flight-test.cc | 36 ++++++-----
cpp/src/arrow/flight/perf-server.cc | 5 +-
cpp/src/arrow/flight/server.cc | 25 ++++++--
cpp/src/arrow/flight/server.h | 19 ++++--
.../arrow/flight/test-integration-server.cc | 4 +-
cpp/src/arrow/flight/test-server.cc | 6 +-
cpp/src/arrow/flight/test-util.cc | 9 +--
cpp/src/arrow/flight/test-util.h | 10 +--
cpp/src/arrow/flight/types.cc | 2 +-
cpp/src/arrow/util/uri.cc | 2 +-
cpp/src/arrow/util/uri.h | 3 +-
python/examples/flight/client.py | 12 +++-
python/examples/flight/server.py | 22 ++++++-
python/pyarrow/_flight.pyx | 62 +++++++++++--------
python/pyarrow/includes/libarrow_flight.pxd | 15 ++++-
python/pyarrow/tests/test_flight.py | 10 +++
18 files changed, 190 insertions(+), 82 deletions(-)
diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc
index 4a7beded950..d8316475550 100644
--- a/cpp/src/arrow/flight/client.cc
+++ b/cpp/src/arrow/flight/client.cc
@@ -231,17 +231,20 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter {
class FlightClient::FlightClientImpl {
public:
- Status Connect(const Location& location) {
+ Status Connect(const Location& location, const FlightClientOptions& options) {
const std::string& scheme = location.scheme();
std::stringstream grpc_uri;
std::shared_ptr creds;
if (scheme == kSchemeGrpc || scheme == kSchemeGrpcTcp || scheme == kSchemeGrpcTls) {
- // TODO(wesm): Support other kinds of GRPC ChannelCredentials
grpc_uri << location.uri_->host() << ":" << location.uri_->port_text();
if (scheme == "grpc+tls") {
- creds = grpc::SslCredentials(grpc::SslCredentialsOptions());
+ grpc::SslCredentialsOptions ssl_options;
+ if (!options.tls_root_certs.empty()) {
+ ssl_options.pem_root_certs = options.tls_root_certs;
+ }
+ creds = grpc::SslCredentials(ssl_options);
} else {
creds = grpc::InsecureChannelCredentials();
}
@@ -401,8 +404,13 @@ FlightClient::~FlightClient() {}
Status FlightClient::Connect(const Location& location,
std::unique_ptr* client) {
+ return Connect(location, {}, client);
+}
+
+Status FlightClient::Connect(const Location& location, const FlightClientOptions& options,
+ std::unique_ptr* client) {
client->reset(new FlightClient);
- return (*client)->impl_->Connect(location);
+ return (*client)->impl_->Connect(location, options);
}
Status FlightClient::Authenticate(const FlightCallOptions& options,
diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h
index f48f3edf839..276ffc70212 100644
--- a/cpp/src/arrow/flight/client.h
+++ b/cpp/src/arrow/flight/client.h
@@ -57,6 +57,11 @@ class ARROW_EXPORT FlightCallOptions {
TimeoutDuration timeout;
};
+class ARROW_EXPORT FlightClientOptions {
+ public:
+ std::string tls_root_certs;
+};
+
/// \brief Client class for Arrow Flight RPC services (gRPC-based).
/// API experimental for now
class ARROW_EXPORT FlightClient {
@@ -70,6 +75,15 @@ class ARROW_EXPORT FlightClient {
/// successful
static Status Connect(const Location& location, std::unique_ptr* client);
+ /// \brief Connect to an unauthenticated flight service
+ /// \param[in] location the URI
+ /// \param[in] options Other options for setting up the client
+ /// \param[out] client the created FlightClient
+ /// \return Status OK status may not indicate that the connection was
+ /// successful
+ static Status Connect(const Location& location, const FlightClientOptions& options,
+ std::unique_ptr* client);
+
/// \brief Authenticate to the server using the given handler.
/// \param[in] options Per-RPC options
/// \param[in] auth_handler The authentication mechanism to use
diff --git a/cpp/src/arrow/flight/flight-test.cc b/cpp/src/arrow/flight/flight-test.cc
index d99e7c4b020..e3e01df5037 100644
--- a/cpp/src/arrow/flight/flight-test.cc
+++ b/cpp/src/arrow/flight/flight-test.cc
@@ -255,24 +255,25 @@ class DoPutTestServer : public FlightServerBase {
class TestAuthHandler : public ::testing::Test {
public:
void SetUp() {
- port_ = 30000;
- server_.reset(new InProcessTestServer(
- std::unique_ptr(new AuthTestServer), port_));
- ASSERT_OK(server_->Start(std::unique_ptr(
- new TestServerAuthHandler("user", "p4ssw0rd"))));
+ Location location;
+ std::unique_ptr server(new AuthTestServer);
+
+ ASSERT_OK(Location::ForGrpcTcp("localhost", 30000, &location));
+ FlightServerOptions options(location);
+ options.auth_handler =
+ std::unique_ptr(new TestServerAuthHandler("user", "p4ssw0rd"));
+ ASSERT_OK(server->Init(options));
+
+ server_.reset(new InProcessTestServer(std::move(server), location));
+ ASSERT_OK(server_->Start());
ASSERT_OK(ConnectClient());
}
void TearDown() { server_->Stop(); }
- Status ConnectClient() {
- Location location;
- RETURN_NOT_OK(Location::ForGrpcTcp("localhost", port_, &location));
- return FlightClient::Connect(location, &client_);
- }
+ Status ConnectClient() { return FlightClient::Connect(server_->location(), &client_); }
protected:
- int port_;
std::unique_ptr client_;
std::unique_ptr server_;
};
@@ -280,17 +281,21 @@ class TestAuthHandler : public ::testing::Test {
class TestDoPut : public ::testing::Test {
public:
void SetUp() {
- port_ = 30000;
+ Location location;
+ ASSERT_OK(Location::ForGrpcTcp("localhost", 30000, &location));
+
do_put_server_ = new DoPutTestServer();
server_.reset(new InProcessTestServer(
- std::unique_ptr(do_put_server_), port_));
- ASSERT_OK(server_->Start({}));
+ std::unique_ptr(do_put_server_), location));
+ FlightServerOptions options(location);
+ ASSERT_OK(do_put_server_->Init(options));
+ ASSERT_OK(server_->Start());
ASSERT_OK(ConnectClient());
}
void TearDown() { server_->Stop(); }
- Status ConnectClient() { return FlightClient::Connect("localhost", port_, &client_); }
+ Status ConnectClient() { return FlightClient::Connect(server_->location(), &client_); }
void CheckBatches(FlightDescriptor expected_descriptor,
const BatchVector& expected_batches) {
@@ -314,7 +319,6 @@ class TestDoPut : public ::testing::Test {
}
protected:
- int port_;
std::unique_ptr client_;
std::unique_ptr server_;
DoPutTestServer* do_put_server_;
diff --git a/cpp/src/arrow/flight/perf-server.cc b/cpp/src/arrow/flight/perf-server.cc
index f65cbd15abf..d1131175422 100644
--- a/cpp/src/arrow/flight/perf-server.cc
+++ b/cpp/src/arrow/flight/perf-server.cc
@@ -207,8 +207,9 @@ int main(int argc, char** argv) {
arrow::flight::Location location;
ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location));
- ARROW_CHECK_OK(
- g_server->Init(std::unique_ptr(), location));
+ arrow::flight::FlightServerOptions options(location);
+
+ ARROW_CHECK_OK(g_server->Init(options));
// Exit with a clean error code (0) on SIGTERM
ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM}));
std::cout << "Server port: " << FLAGS_port << std::endl;
diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc
index d8c5525ee98..46dd5bf0119 100644
--- a/cpp/src/arrow/flight/server.cc
+++ b/cpp/src/arrow/flight/server.cc
@@ -21,6 +21,7 @@
#include
#include
#include
+#include
#include
#include
@@ -438,24 +439,38 @@ void FlightServerBase::Impl::HandleSignal(int signum) {
}
}
+FlightServerOptions::FlightServerOptions(const Location& location_)
+ : location(location_), auth_handler(nullptr) {}
+
FlightServerBase::FlightServerBase() { impl_.reset(new Impl); }
FlightServerBase::~FlightServerBase() {}
-Status FlightServerBase::Init(std::unique_ptr auth_handler,
- const Location& location) {
- std::shared_ptr handler = std::move(auth_handler);
+Status FlightServerBase::Init(FlightServerOptions& options) {
+ std::shared_ptr handler = std::move(options.auth_handler);
impl_->service_.reset(new FlightServiceImpl(handler, this));
grpc::ServerBuilder builder;
// Allow uploading messages of any length
builder.SetMaxReceiveMessageSize(-1);
+ const Location& location = options.location;
const std::string scheme = location.scheme();
- if (scheme == kSchemeGrpc || scheme == kSchemeGrpcTcp) {
+ if (scheme == kSchemeGrpc || scheme == kSchemeGrpcTcp || scheme == kSchemeGrpcTls) {
std::stringstream address;
address << location.uri_->host() << ':' << location.uri_->port_text();
- builder.AddListeningPort(address.str(), grpc::InsecureServerCredentials());
+
+ std::shared_ptr creds;
+ if (scheme == kSchemeGrpcTls) {
+ grpc::SslServerCredentialsOptions ssl_options;
+ ssl_options.pem_key_cert_pairs.push_back(
+ {options.tls_private_key, options.tls_cert_chain});
+ creds = grpc::SslServerCredentials(ssl_options);
+ } else {
+ creds = grpc::InsecureServerCredentials();
+ }
+
+ builder.AddListeningPort(address.str(), creds);
} else if (scheme == kSchemeGrpcUnix) {
std::stringstream address;
address << "unix:" << location.uri_->path();
diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h
index 752b89e85bf..28e87aa2fa9 100644
--- a/cpp/src/arrow/flight/server.h
+++ b/cpp/src/arrow/flight/server.h
@@ -24,6 +24,7 @@
#include
#include
+#include "arrow/flight/server_auth.h"
#include "arrow/flight/types.h" // IWYU pragma: keep
#include "arrow/ipc/dictionary.h"
#include "arrow/memory_pool.h"
@@ -38,8 +39,6 @@ class Status;
namespace flight {
-class ServerAuthHandler;
-
/// \brief Interface that produces a sequence of IPC payloads to be sent in
/// FlightData protobuf messages
class ARROW_EXPORT FlightDataStream {
@@ -89,6 +88,16 @@ class ARROW_EXPORT ServerCallContext {
virtual const std::string& peer_identity() const = 0;
};
+class ARROW_EXPORT FlightServerOptions {
+ public:
+ explicit FlightServerOptions(const Location& location_);
+
+ Location location;
+ std::unique_ptr auth_handler;
+ std::string tls_cert_chain;
+ std::string tls_private_key;
+};
+
/// \brief Skeleton RPC server implementation which can be used to create
/// custom servers by implementing its abstract methods
class ARROW_EXPORT FlightServerBase {
@@ -100,10 +109,8 @@ class ARROW_EXPORT FlightServerBase {
/// \brief Initialize a Flight server listening at the given location.
/// This method must be called before any other method.
- /// \param[in] auth_handler The authentication handler. May be
- /// nullptr if no authentication is desired.
- /// \param[in] location The location to serve on.
- Status Init(std::unique_ptr auth_handler, const Location& location);
+ /// \param[in] options The configuration for this server.
+ Status Init(FlightServerOptions& options);
/// \brief Set the server to stop when receiving any of the given signal
/// numbers.
diff --git a/cpp/src/arrow/flight/test-integration-server.cc b/cpp/src/arrow/flight/test-integration-server.cc
index d72e838213b..c5bb180663b 100644
--- a/cpp/src/arrow/flight/test-integration-server.cc
+++ b/cpp/src/arrow/flight/test-integration-server.cc
@@ -126,7 +126,9 @@ int main(int argc, char** argv) {
g_server.reset(new arrow::flight::FlightIntegrationTestServer);
arrow::flight::Location location;
ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location));
- ARROW_CHECK_OK(g_server->Init(nullptr, location));
+ arrow::flight::FlightServerOptions options(location);
+
+ ARROW_CHECK_OK(g_server->Init(options));
// Exit with a clean error code (0) on SIGTERM
ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM}));
diff --git a/cpp/src/arrow/flight/test-server.cc b/cpp/src/arrow/flight/test-server.cc
index 29af87db601..f72fd3caeea 100644
--- a/cpp/src/arrow/flight/test-server.cc
+++ b/cpp/src/arrow/flight/test-server.cc
@@ -139,10 +139,12 @@ int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
g_server.reset(new arrow::flight::FlightTestServer);
+
arrow::flight::Location location;
ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location));
- ARROW_CHECK_OK(
- g_server->Init(std::unique_ptr(), location));
+ arrow::flight::FlightServerOptions options(location);
+
+ ARROW_CHECK_OK(g_server->Init(options));
// Exit with a clean error code (0) on SIGTERM
ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM}));
diff --git a/cpp/src/arrow/flight/test-util.cc b/cpp/src/arrow/flight/test-util.cc
index 01577958640..f870dcbfa71 100644
--- a/cpp/src/arrow/flight/test-util.cc
+++ b/cpp/src/arrow/flight/test-util.cc
@@ -118,22 +118,17 @@ bool TestServer::IsRunning() { return server_process_->running(); }
int TestServer::port() const { return port_; }
-Status InProcessTestServer::Start(std::unique_ptr auth_handler) {
- Location location;
- RETURN_NOT_OK(Location::ForGrpcTcp("localhost", port_, &location));
- RETURN_NOT_OK(server_->Init(std::move(auth_handler), location));
+Status InProcessTestServer::Start() {
thread_ = std::thread([this]() { ARROW_EXPECT_OK(server_->Serve()); });
return Status::OK();
}
-Status InProcessTestServer::Start() { return Start({}); }
-
void InProcessTestServer::Stop() {
server_->Shutdown();
thread_.join();
}
-int InProcessTestServer::port() const { return port_; }
+const Location& InProcessTestServer::location() const { return location_; }
InProcessTestServer::~InProcessTestServer() {
// Make sure server shuts down properly
diff --git a/cpp/src/arrow/flight/test-util.h b/cpp/src/arrow/flight/test-util.h
index ef18e3721b6..b5bc31a4c3b 100644
--- a/cpp/src/arrow/flight/test-util.h
+++ b/cpp/src/arrow/flight/test-util.h
@@ -63,17 +63,17 @@ class ARROW_EXPORT TestServer {
class ARROW_EXPORT InProcessTestServer {
public:
- explicit InProcessTestServer(std::unique_ptr server, int port)
- : server_(std::move(server)), port_(port), thread_() {}
+ explicit InProcessTestServer(std::unique_ptr server,
+ const Location& location)
+ : server_(std::move(server)), location_(location), thread_() {}
~InProcessTestServer();
Status Start();
- Status Start(std::unique_ptr auth_handler);
void Stop();
- int port() const;
+ const Location& location() const;
private:
std::unique_ptr server_;
- int port_;
+ Location location_;
std::thread thread_;
};
diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc
index 985cb5fc03a..524b0ee716e 100644
--- a/cpp/src/arrow/flight/types.cc
+++ b/cpp/src/arrow/flight/types.cc
@@ -102,7 +102,7 @@ Status Location::ForGrpcUnix(const std::string& path, Location* location) {
return Location::Parse(uri_string.str(), location);
}
-std::string Location::ToString() const { return uri_->to_string(); }
+std::string Location::ToString() const { return uri_->ToString(); }
std::string Location::scheme() const {
std::string scheme = uri_->scheme();
if (scheme.empty()) {
diff --git a/cpp/src/arrow/util/uri.cc b/cpp/src/arrow/util/uri.cc
index e79c9826b9b..579d4c7f243 100644
--- a/cpp/src/arrow/util/uri.cc
+++ b/cpp/src/arrow/util/uri.cc
@@ -120,7 +120,7 @@ std::string Uri::path() const {
return ss.str();
}
-const std::string& Uri::to_string() const { return impl_->string_rep_; }
+const std::string& Uri::ToString() const { return impl_->string_rep_; }
Status Uri::Parse(const std::string& uri_string) {
impl_->Reset();
diff --git a/cpp/src/arrow/util/uri.h b/cpp/src/arrow/util/uri.h
index 7327c2d7876..ce082ccc8e6 100644
--- a/cpp/src/arrow/util/uri.h
+++ b/cpp/src/arrow/util/uri.h
@@ -54,8 +54,9 @@ class ARROW_EXPORT Uri {
int32_t port() const;
/// The URI path component.
std::string path() const;
+
/// Get the string representation of this URI.
- const std::string& to_string() const;
+ const std::string& ToString() const;
/// Factory function to parse a URI from its string representation.
Status Parse(const std::string& uri_string);
diff --git a/python/examples/flight/client.py b/python/examples/flight/client.py
index 0c91e3ec55f..15db5124c42 100644
--- a/python/examples/flight/client.py
+++ b/python/examples/flight/client.py
@@ -90,6 +90,8 @@ def get_flight(args, client):
def _add_common_arguments(parser):
+ parser.add_argument('--tls', action='store_true')
+ parser.add_argument('--tls-roots', default=None)
parser.add_argument('host', type=str,
help="The host to connect to.")
@@ -131,7 +133,15 @@ def main():
}
host, port = args.host.split(':')
port = int(port)
- client = pyarrow.flight.FlightClient.connect(f"grpc://{host}:{port}")
+ scheme = "grpc+tcp"
+ connection_args = {}
+ if args.tls:
+ scheme = "grpc+tls"
+ if args.tls_roots:
+ with open(args.tls_roots, "rb") as root_certs:
+ connection_args["tls_root_certs"] = root_certs.read()
+ client = pyarrow.flight.FlightClient.connect(f"{scheme}://{host}:{port}",
+ **connection_args)
while True:
try:
action = pyarrow.flight.Action("healthcheck", b"")
diff --git a/python/examples/flight/server.py b/python/examples/flight/server.py
index e6edb8d3a1b..5a3d101c3a9 100644
--- a/python/examples/flight/server.py
+++ b/python/examples/flight/server.py
@@ -17,6 +17,7 @@
"""An example Flight Python server."""
+import argparse
import ast
import threading
import time
@@ -76,7 +77,7 @@ def do_get(self, context, ticket):
return None
return pyarrow.flight.RecordBatchStream(self.flights[key])
- def list_actions(self, context):
+ def list_actions(self):
return [
("clear", "Clear the stored flights."),
("shutdown", "Shut down this server."),
@@ -100,8 +101,25 @@ def _shutdown(self):
def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--port", type=int, default=5005)
+ parser.add_argument("--tls", nargs=2, default=None)
+
+ args = parser.parse_args()
+
server = FlightServer()
- server.run(5005)
+ kwargs = {}
+ scheme = "grpc+tcp"
+ if args.tls:
+ scheme = "grpc+tls"
+ with open(args.tls[0], "rb") as cert_file:
+ kwargs["tls_cert_chain"] = cert_file.read()
+ with open(args.tls[1], "rb") as key_file:
+ kwargs["tls_private_key"] = key_file.read()
+
+ location = "{}://0.0.0.0:{}".format(scheme, args.port)
+ print("Serving on", location)
+ server.run(location, **kwargs)
if __name__ == '__main__':
diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx
index 05d50cd629f..430f2b2abc5 100644
--- a/python/pyarrow/_flight.pyx
+++ b/python/pyarrow/_flight.pyx
@@ -72,7 +72,7 @@ cdef class Action:
return pyarrow_wrap_buffer(self.action.body)
@staticmethod
- cdef CAction unwrap(action):
+ cdef CAction unwrap(action) except *:
if not isinstance(action, Action):
raise TypeError("Must provide Action, not '{}'".format(
type(action)))
@@ -168,11 +168,11 @@ cdef class FlightDescriptor:
return "".format(self.descriptor_type)
@staticmethod
- cdef FlightDescriptor unwrap(descriptor):
+ cdef CFlightDescriptor unwrap(descriptor) except *:
if not isinstance(descriptor, FlightDescriptor):
raise TypeError("Must provide a FlightDescriptor, not '{}'".format(
type(descriptor)))
- return descriptor
+ return ( descriptor).descriptor
cdef class Ticket:
@@ -225,9 +225,9 @@ cdef class Location:
return result
@staticmethod
- cdef CLocation unwrap(object location):
+ cdef CLocation unwrap(object location) except *:
cdef CLocation c_location
- if isinstance(location, (six.text_type, six.binary_type)):
+ if isinstance(location, six.text_type):
check_status(CLocation.Parse(tobytes(location), &c_location))
return c_location
elif not isinstance(location, Location):
@@ -381,21 +381,20 @@ cdef class FlightClient:
.format(self.__class__.__name__))
@staticmethod
- def connect(location):
+ def connect(location, **kwargs):
"""Connect to a Flight service on the given host and port."""
cdef:
FlightClient result = FlightClient.__new__(FlightClient)
int c_port = 0
- CLocation c_location
+ CLocation c_location = Location.unwrap(location)
+ CFlightClientOptions c_options
- if isinstance(location, Location):
- c_location = ( location).location
- else:
- c_location = CLocation()
- check_status(CLocation.Parse(tobytes(location), &c_location))
+ if "tls_root_certs" in kwargs:
+ c_options.tls_root_certs = tobytes(kwargs["tls_root_certs"])
with nogil:
- check_status(CFlightClient.Connect(c_location, &result.client))
+ check_status(CFlightClient.Connect(c_location, c_options,
+ &result.client))
return result
@@ -478,12 +477,12 @@ cdef class FlightClient:
cdef:
FlightInfo result = FlightInfo.__new__(FlightInfo)
CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
- FlightDescriptor c_descriptor = \
+ CFlightDescriptor c_descriptor = \
FlightDescriptor.unwrap(descriptor)
with nogil:
check_status(self.client.get().GetFlightInfo(
- deref(c_options), c_descriptor.descriptor, &result.info))
+ deref(c_options), c_descriptor, &result.info))
return result
@@ -494,7 +493,8 @@ cdef class FlightClient:
CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
with nogil:
- check_status(self.client.get().DoGet(deref(c_options), ticket.ticket, &reader))
+ check_status(self.client.get().DoGet(
+ deref(c_options), ticket.ticket, &reader))
result = FlightRecordBatchReader()
result.reader.reset(reader.release())
return result
@@ -506,12 +506,12 @@ cdef class FlightClient:
shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema)
unique_ptr[CRecordBatchWriter] writer
CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
- FlightDescriptor c_descriptor = \
+ CFlightDescriptor c_descriptor = \
FlightDescriptor.unwrap(descriptor)
with nogil:
check_status(self.client.get().DoPut(
- deref(c_options), c_descriptor.descriptor, c_schema, &writer))
+ deref(c_options), c_descriptor, c_schema, &writer))
result = FlightRecordBatchWriter()
result.writer.reset(writer.release())
return result
@@ -520,7 +520,7 @@ cdef class FlightClient:
cdef class FlightDataStream:
"""Abstract base class for Flight data streams."""
- cdef CFlightDataStream* to_stream(self):
+ cdef CFlightDataStream* to_stream(self) except *:
"""Create the C++ data stream for the backing Python object.
We don't expose the C++ object to Python, so we can manage its
@@ -547,7 +547,7 @@ cdef class RecordBatchStream(FlightDataStream):
"but got: {}".format(type(data_source)))
self.data_source = data_source
- cdef CFlightDataStream* to_stream(self):
+ cdef CFlightDataStream* to_stream(self) except *:
cdef:
shared_ptr[CRecordBatchReader] reader
if isinstance(self.data_source, _CRecordBatchReader):
@@ -585,7 +585,7 @@ cdef class GeneratorStream(FlightDataStream):
self.schema = pyarrow_unwrap_schema(schema)
self.generator = iter(generator)
- cdef CFlightDataStream* to_stream(self):
+ cdef CFlightDataStream* to_stream(self) except *:
cdef:
function[cb_data_stream_next] callback = &_data_stream_next
return new CPyGeneratorFlightDataStream(self, self.schema, callback)
@@ -985,20 +985,30 @@ cdef class FlightServerBase:
cdef:
unique_ptr[PyFlightServer] server
- def run(self, location, auth_handler=None):
+ def run(self, location, auth_handler=None, **kwargs):
cdef:
PyFlightServerVtable vtable = PyFlightServerVtable()
PyFlightServer* c_server
- CLocation c_location = Location.unwrap(location)
- unique_ptr[CServerAuthHandler] c_auth_handler
+ unique_ptr[CFlightServerOptions] c_options
+
+ c_options.reset(new CFlightServerOptions(Location.unwrap(location)))
if auth_handler:
if not isinstance(auth_handler, ServerAuthHandler):
raise TypeError("auth_handler must be a ServerAuthHandler, "
"not a '{}'".format(type(auth_handler)))
- c_auth_handler.reset(
+ c_options.get().auth_handler.reset(
( auth_handler).to_handler())
+ if "tls_cert_chain" in kwargs:
+ if "tls_private_key" not in kwargs:
+ raise ValueError(
+ "Must provide both cert chain and private key")
+ c_options.get().tls_cert_chain = tobytes(
+ kwargs["tls_cert_chain"])
+ c_options.get().tls_private_key = tobytes(
+ kwargs["tls_private_key"])
+
vtable.list_flights = &_list_flights
vtable.get_flight_info = &_get_flight_info
vtable.do_put = &_do_put
@@ -1009,7 +1019,7 @@ cdef class FlightServerBase:
c_server = new PyFlightServer(self, vtable)
self.server.reset(c_server)
with nogil:
- check_status(c_server.Init(move(c_auth_handler), c_location))
+ check_status(c_server.Init(deref(c_options)))
check_status(c_server.ServeWithSignals())
def list_flights(self, context, criteria):
diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd
index fe2075e3675..d3ff4f259b0 100644
--- a/python/pyarrow/includes/libarrow_flight.pxd
+++ b/python/pyarrow/includes/libarrow_flight.pxd
@@ -153,9 +153,21 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
CFlightCallOptions()
CTimeoutDuration timeout
+ cdef cppclass CFlightServerOptions" arrow::flight::FlightServerOptions":
+ CFlightServerOptions(const CLocation& location)
+ CLocation location
+ unique_ptr[CServerAuthHandler] auth_handler
+ c_string tls_cert_chain
+ c_string tls_private_key
+
+ cdef cppclass CFlightClientOptions" arrow::flight::FlightClientOptions":
+ CFlightClientOptions()
+ c_string tls_root_certs
+
cdef cppclass CFlightClient" arrow::flight::FlightClient":
@staticmethod
CStatus Connect(const CLocation& location,
+ const CFlightClientOptions& options,
unique_ptr[CFlightClient]* client)
CStatus Authenticate(CFlightCallOptions& options,
@@ -229,8 +241,7 @@ cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil:
cdef cppclass PyFlightServer:
PyFlightServer(object server, PyFlightServerVtable vtable)
- CStatus Init(unique_ptr[CServerAuthHandler] auth_handler,
- const CLocation& location)
+ CStatus Init(CFlightServerOptions& options)
CStatus ServeWithSignals() except *
void Shutdown()
diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py
index 03e7a40fe1f..acdaacfaafe 100644
--- a/python/pyarrow/tests/test_flight.py
+++ b/python/pyarrow/tests/test_flight.py
@@ -397,3 +397,13 @@ def test_token_auth_invalid():
client = flight.FlightClient.connect(server_location)
with pytest.raises(pa.ArrowException, match=".*unauthenticated.*"):
client.authenticate(TokenClientAuthHandler('test', 'wrong'))
+
+
+def test_location_invalid():
+ """Test constructing invalid URIs."""
+ with pytest.raises(pa.ArrowException, match=".*Cannot parse URI:.*"):
+ flight.FlightClient.connect("%")
+
+ server = ConstantFlightServer()
+ with pytest.raises(pa.ArrowException, match=".*Cannot parse URI:.*"):
+ server.run("%")
From 5c127631b41986afa5a54d55512d677f95ecf39c Mon Sep 17 00:00:00 2001
From: David Li
Date: Mon, 13 May 2019 15:57:43 -0400
Subject: [PATCH 4/5] Make Python Flight bindings more complete
---
cpp/src/arrow/flight/types.cc | 4 ++
cpp/src/arrow/flight/types.h | 4 +-
cpp/src/arrow/python/flight.cc | 2 +-
cpp/src/arrow/python/flight.h | 2 +-
python/pyarrow/_flight.pyx | 23 +++++++-
python/pyarrow/includes/libarrow_flight.pxd | 9 +--
python/pyarrow/tests/test_flight.py | 64 +++++++++++++++++++++
7 files changed, 99 insertions(+), 9 deletions(-)
diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc
index 524b0ee716e..dadb51066cf 100644
--- a/cpp/src/arrow/flight/types.cc
+++ b/cpp/src/arrow/flight/types.cc
@@ -112,6 +112,10 @@ std::string Location::scheme() const {
return scheme;
}
+bool Location::Equals(const Location& other) const {
+ return ToString() == other.ToString();
+}
+
SimpleFlightListing::SimpleFlightListing(const std::vector& flights)
: position_(0), flights_(flights) {}
diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h
index 9caee997ae1..ba07b999635 100644
--- a/cpp/src/arrow/flight/types.h
+++ b/cpp/src/arrow/flight/types.h
@@ -157,8 +157,10 @@ struct Location {
/// \brief Get the scheme of this URI.
std::string scheme() const;
+ bool Equals(const Location& other) const;
+
friend bool operator==(const Location& left, const Location& right) {
- return left.ToString() == right.ToString();
+ return left.Equals(right);
}
friend bool operator!=(const Location& left, const Location& right) {
return !(left == right);
diff --git a/cpp/src/arrow/python/flight.cc b/cpp/src/arrow/python/flight.cc
index b033b341ff2..4db31570d81 100644
--- a/cpp/src/arrow/python/flight.cc
+++ b/cpp/src/arrow/python/flight.cc
@@ -217,7 +217,7 @@ Status PyGeneratorFlightDataStream::Next(FlightPayload* payload) {
Status CreateFlightInfo(const std::shared_ptr& schema,
const arrow::flight::FlightDescriptor& descriptor,
const std::vector& endpoints,
- uint64_t total_records, uint64_t total_bytes,
+ int64_t total_records, int64_t total_bytes,
std::unique_ptr* out) {
arrow::flight::FlightInfo::Data flight_data;
RETURN_NOT_OK(arrow::flight::internal::SchemaToString(*schema, &flight_data.schema));
diff --git a/cpp/src/arrow/python/flight.h b/cpp/src/arrow/python/flight.h
index 19fbb02c592..432885cb764 100644
--- a/cpp/src/arrow/python/flight.h
+++ b/cpp/src/arrow/python/flight.h
@@ -197,7 +197,7 @@ ARROW_PYTHON_EXPORT
Status CreateFlightInfo(const std::shared_ptr& schema,
const arrow::flight::FlightDescriptor& descriptor,
const std::vector& endpoints,
- uint64_t total_records, uint64_t total_bytes,
+ int64_t total_records, int64_t total_bytes,
std::unique_ptr* out);
} // namespace flight
diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx
index 430f2b2abc5..82745d34c8c 100644
--- a/python/pyarrow/_flight.pyx
+++ b/python/pyarrow/_flight.pyx
@@ -184,6 +184,10 @@ cdef class Ticket:
def __init__(self, ticket):
self.ticket.ticket = tobytes(ticket)
+ @property
+ def ticket(self):
+ return self.ticket.ticket
+
def __repr__(self):
return ''.format(self.ticket.ticket)
@@ -199,6 +203,18 @@ cdef class Location:
def __repr__(self):
return ''.format(self.location.ToString())
+ @property
+ def uri(self):
+ return self.location.ToString()
+
+ def equals(self, Location other):
+ return self == other
+
+ def __eq__(self, other):
+ if not isinstance(other, Location):
+ return NotImplemented
+ return self.location.Equals(( other).location)
+
@staticmethod
def for_grpc_tcp(host, port):
"""Create a Location for a TCP-based gRPC service."""
@@ -265,8 +281,11 @@ cdef class FlightEndpoint:
self.endpoint.ticket.ticket = tobytes(ticket)
for location in locations:
- c_location = CLocation()
- check_status(CLocation.Parse(tobytes(location), &c_location))
+ if isinstance(location, Location):
+ c_location = ( location).location
+ else:
+ c_location = CLocation()
+ check_status(CLocation.Parse(tobytes(location), &c_location))
self.endpoint.locations.push_back(c_location)
@property
diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd
index d3ff4f259b0..4b749903d3d 100644
--- a/python/pyarrow/includes/libarrow_flight.pxd
+++ b/python/pyarrow/includes/libarrow_flight.pxd
@@ -81,6 +81,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
cdef cppclass CLocation" arrow::flight::Location":
CLocation()
c_string ToString()
+ c_bool Equals(const CLocation& other)
@staticmethod
CStatus Parse(c_string& uri_string, CLocation* location)
@@ -97,8 +98,8 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
cdef cppclass CFlightInfo" arrow::flight::FlightInfo":
CFlightInfo(CFlightInfo info)
- uint64_t total_records()
- uint64_t total_bytes()
+ int64_t total_records()
+ int64_t total_bytes()
CStatus GetSchema(CDictionaryMemo* memo, shared_ptr[CSchema]* out)
CFlightDescriptor& descriptor()
const vector[CFlightEndpoint]& endpoints()
@@ -274,8 +275,8 @@ cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil:
shared_ptr[CSchema] schema,
CFlightDescriptor& descriptor,
vector[CFlightEndpoint] endpoints,
- uint64_t total_records,
- uint64_t total_bytes,
+ int64_t total_records,
+ int64_t total_bytes,
unique_ptr[CFlightInfo]* out)
cdef extern from "" namespace "std":
diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py
index acdaacfaafe..bbe5507ffdd 100644
--- a/python/pyarrow/tests/test_flight.py
+++ b/python/pyarrow/tests/test_flight.py
@@ -103,6 +103,42 @@ def do_action(self, context, action):
raise NotImplementedError
+class GetInfoFlightServer(flight.FlightServerBase):
+ """A Flight server that tests GetFlightInfo."""
+
+ def get_flight_info(self, context, descriptor):
+ return flight.FlightInfo(
+ pa.schema([('a', pa.int32())]),
+ descriptor,
+ [
+ flight.FlightEndpoint(b'', ['grpc://test']),
+ flight.FlightEndpoint(
+ b'',
+ [flight.Location.for_grpc_tcp('localhost', 5005)],
+ ),
+ ],
+ -1,
+ -1,
+ )
+
+
+class CheckTicketFlightServer(flight.FlightServerBase):
+ """A Flight server that compares the given ticket to an expected value."""
+
+ def __init__(self, expected_ticket):
+ super(CheckTicketFlightServer, self).__init__()
+ self.expected_ticket = expected_ticket
+
+ def do_get(self, context, ticket):
+ assert self.expected_ticket == ticket.ticket
+ data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())]
+ table = pa.Table.from_arrays(data1, names=['a'])
+ return flight.RecordBatchStream(table)
+
+ def do_put(self, context, descriptor, reader):
+ self.last_message = reader.read_all()
+
+
class InvalidStreamFlightServer(flight.FlightServerBase):
"""A Flight server that tries to return messages with differing schemas."""
@@ -255,6 +291,34 @@ def test_flight_do_get_dicts():
assert data.equals(table)
+def test_flight_do_get_ticket():
+ """Make sure Tickets get passed to the server."""
+ data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())]
+ table = pa.Table.from_arrays(data1, names=['a'])
+ with flight_server(
+ CheckTicketFlightServer,
+ expected_ticket=b'the-ticket',
+ ) as server_location:
+ client = flight.FlightClient.connect(server_location)
+ data = client.do_get(flight.Ticket(b'the-ticket')).read_all()
+ assert data.equals(table)
+
+
+def test_flight_get_info():
+ """Make sure FlightEndpoint accepts string and object URIs."""
+ with flight_server(GetInfoFlightServer) as server_location:
+ client = flight.FlightClient.connect(server_location)
+ info = client.get_flight_info(flight.FlightDescriptor.for_command(b''))
+ assert info.total_records == -1
+ assert info.total_bytes == -1
+ assert info.schema == pa.schema([('a', pa.int32())])
+ assert len(info.endpoints) == 2
+ assert len(info.endpoints[0].locations) == 1
+ assert info.endpoints[0].locations[0] == flight.Location('grpc://test')
+ assert info.endpoints[1].locations[0] == \
+ flight.Location.for_grpc_tcp('localhost', 5005)
+
+
def test_flight_domain_socket():
"""Try a simple do_get call over a domain socket."""
data = [
From 870f6ebb8d07797a7a408249cc350f1baf7f14a1 Mon Sep 17 00:00:00 2001
From: David Li
Date: Wed, 15 May 2019 08:35:13 -0400
Subject: [PATCH 5/5] Add more builder options for Java Flight servers
---
.../arrow/flight/test-integration-client.cc | 1 -
java/flight/pom.xml | 56 +++++++
.../org/apache/arrow/flight/FlightClient.java | 108 ++++++++++++-
.../org/apache/arrow/flight/FlightServer.java | 144 +++++++++++++++---
.../org/apache/arrow/flight/Location.java | 66 +++++++-
.../apache/arrow/flight/LocationSchemes.java | 6 +-
.../apache/arrow/flight/FlightTestUtil.java | 22 +++
.../arrow/flight/TestBasicOperation.java | 2 +-
.../arrow/flight/TestServerOptions.java | 63 ++++++++
.../apache/arrow/flight/auth/TestAuth.java | 2 +-
python/pyarrow/_flight.pyx | 31 ++--
python/pyarrow/tests/test_flight.py | 4 +-
12 files changed, 455 insertions(+), 50 deletions(-)
create mode 100644 java/flight/src/test/java/org/apache/arrow/flight/TestServerOptions.java
diff --git a/cpp/src/arrow/flight/test-integration-client.cc b/cpp/src/arrow/flight/test-integration-client.cc
index d0b734007c4..abaa3bc4221 100644
--- a/cpp/src/arrow/flight/test-integration-client.cc
+++ b/cpp/src/arrow/flight/test-integration-client.cc
@@ -102,7 +102,6 @@ int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
std::unique_ptr client;
- std::stringstream uri;
arrow::flight::Location location;
ABORT_NOT_OK(arrow::flight::Location::ForGrpcTcp(FLAGS_host, FLAGS_port, &location));
ABORT_NOT_OK(arrow::flight::FlightClient::Connect(location, &client));
diff --git a/java/flight/pom.xml b/java/flight/pom.xml
index c9c06b497d7..7d01a6e118e 100644
--- a/java/flight/pom.xml
+++ b/java/flight/pom.xml
@@ -73,6 +73,16 @@
io.netty
netty-buffer
+
+ io.netty
+ netty-handler
+ ${dep.netty.version}
+
+
+ io.netty
+ netty-transport
+ ${dep.netty.version}
+
com.google.guava
guava
@@ -302,4 +312,50 @@
+
+
+ linux-netty-native
+
+
+ linux
+
+
+
+
+ io.netty
+ netty-transport-native-unix-common
+ ${dep.netty.version}
+ ${os.detected.name}-${os.detected.arch}
+
+
+ io.netty
+ netty-transport-native-epoll
+ ${dep.netty.version}
+ ${os.detected.name}-${os.detected.arch}
+
+
+
+
+ mac-netty-native
+
+
+ mac
+
+
+
+
+ io.netty
+ netty-transport-native-unix-common
+ ${dep.netty.version}
+ ${os.detected.name}-${os.detected.arch}
+
+
+ io.netty
+ netty-transport-native-kqueue
+ ${dep.netty.version}
+ ${os.detected.name}-${os.detected.arch}
+
+
+
+
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
index 3a6fbae87bb..e74c0eefb69 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
@@ -20,11 +20,14 @@
import static io.grpc.stub.ClientCalls.asyncClientStreamingCall;
import static io.grpc.stub.ClientCalls.asyncServerStreamingCall;
+import java.io.InputStream;
import java.net.URISyntaxException;
import java.util.Iterator;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
+import javax.net.ssl.SSLException;
+
import org.apache.arrow.flight.auth.BasicClientAuthHandler;
import org.apache.arrow.flight.auth.ClientAuthHandler;
import org.apache.arrow.flight.auth.ClientAuthInterceptor;
@@ -50,13 +53,19 @@
import io.grpc.ClientCall;
import io.grpc.ManagedChannel;
import io.grpc.MethodDescriptor;
+import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.stub.ClientCallStreamObserver;
import io.grpc.stub.ClientResponseObserver;
import io.grpc.stub.StreamObserver;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.ServerChannel;
+import io.netty.handler.ssl.SslContext;
+import io.netty.handler.ssl.SslContextBuilder;
+
/**
- * Client for flight servers.
+ * Client for Flight services.
*/
public class FlightClient implements AutoCloseable {
private static final int PENDING_REQUESTS = 5;
@@ -315,6 +324,13 @@ public void close() throws InterruptedException {
allocator.close();
}
+ /**
+ * Create a builder for a Flight client.
+ */
+ public static Builder builder() {
+ return new Builder();
+ }
+
/**
* Create a builder for a Flight client.
* @param allocator The allocator to use for the client.
@@ -329,13 +345,20 @@ public static Builder builder(BufferAllocator allocator, Location location) {
*/
public static final class Builder {
- private final BufferAllocator allocator;
- private final Location location;
+ private BufferAllocator allocator;
+ private Location location;
private boolean forceTls = false;
+ private int maxInboundMessageSize = FlightServer.MAX_GRPC_MESSAGE_SIZE;
+ private InputStream trustedCertificates = null;
+ private InputStream clientCertificate = null;
+ private InputStream clientKey = null;
+
+ private Builder() {
+ }
private Builder(BufferAllocator allocator, Location location) {
- this.allocator = allocator;
- this.location = location;
+ this.allocator = Preconditions.checkNotNull(allocator);
+ this.location = Preconditions.checkNotNull(location);
}
/**
@@ -346,6 +369,37 @@ public Builder useTls() {
return this;
}
+ /** Set the maximum inbound message size. */
+ public Builder maxInboundMessageSize(int maxSize) {
+ Preconditions.checkArgument(maxSize > 0);
+ this.maxInboundMessageSize = maxSize;
+ return this;
+ }
+
+ /** Set the trusted TLS certificates. */
+ public Builder trustedCertificates(final InputStream stream) {
+ this.trustedCertificates = Preconditions.checkNotNull(stream);
+ return this;
+ }
+
+ /** Set the trusted TLS certificates. */
+ public Builder clientCertificate(final InputStream clientCertificate, final InputStream clientKey) {
+ Preconditions.checkNotNull(clientKey);
+ this.clientCertificate = Preconditions.checkNotNull(clientCertificate);
+ this.clientKey = Preconditions.checkNotNull(clientKey);
+ return this;
+ }
+
+ public Builder allocator(BufferAllocator allocator) {
+ this.allocator = Preconditions.checkNotNull(allocator);
+ return this;
+ }
+
+ public Builder location(Location location) {
+ this.location = Preconditions.checkNotNull(location);
+ return this;
+ }
+
/**
* Create the client from this builder.
*/
@@ -356,7 +410,32 @@ public FlightClient build() {
case LocationSchemes.GRPC:
case LocationSchemes.GRPC_INSECURE:
case LocationSchemes.GRPC_TLS: {
- builder = NettyChannelBuilder.forAddress(location.getUri().getHost(), location.getUri().getPort());
+ builder = NettyChannelBuilder.forAddress(location.toSocketAddress());
+ break;
+ }
+ case LocationSchemes.GRPC_DOMAIN_SOCKET: {
+ // The implementation is platform-specific, so we have to find the classes at runtime
+ builder = NettyChannelBuilder.forAddress(location.toSocketAddress());
+ try {
+ try {
+ // Linux
+ builder.channelType(
+ (Class extends ServerChannel>) Class.forName("io.netty.channel.epoll.EpollDomainSocketChannel"));
+ final EventLoopGroup elg = (EventLoopGroup) Class.forName("io.netty.channel.epoll.EpollEventLoopGroup")
+ .newInstance();
+ builder.eventLoopGroup(elg);
+ } catch (ClassNotFoundException e) {
+ // BSD
+ builder.channelType(
+ (Class extends ServerChannel>) Class.forName("io.netty.channel.kqueue.KQueueDomainSocketChannel"));
+ final EventLoopGroup elg = (EventLoopGroup) Class.forName("io.netty.channel.kqueue.KQueueEventLoopGroup")
+ .newInstance();
+ builder.eventLoopGroup(elg);
+ }
+ } catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) {
+ throw new UnsupportedOperationException(
+ "Could not find suitable Netty native transport implementation for domain socket address.");
+ }
break;
}
default:
@@ -365,13 +444,28 @@ public FlightClient build() {
if (this.forceTls || LocationSchemes.GRPC_TLS.equals(location.getUri().getScheme())) {
builder.useTransportSecurity();
+
+ if (this.trustedCertificates != null || this.clientCertificate != null || this.clientKey != null) {
+ final SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient();
+ if (this.trustedCertificates != null) {
+ sslContextBuilder.trustManager(this.trustedCertificates);
+ }
+ if (this.clientCertificate != null && this.clientKey != null) {
+ sslContextBuilder.keyManager(this.clientCertificate, this.clientKey);
+ }
+ try {
+ builder.sslContext(sslContextBuilder.build());
+ } catch (SSLException e) {
+ throw new RuntimeException(e);
+ }
+ }
} else {
builder.usePlaintext();
}
builder
.maxTraceEvents(MAX_CHANNEL_TRACE_EVENTS)
- .maxInboundMessageSize(FlightServer.MAX_GRPC_MESSAGE_SIZE);
+ .maxInboundMessageSize(maxInboundMessageSize);
return new FlightClient(allocator, builder.build());
}
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java
index 8c9ae1c2e4e..0812a04637d 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java
@@ -19,22 +19,26 @@
import java.io.File;
import java.io.FileInputStream;
-import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
-import java.net.InetSocketAddress;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.Executor;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import org.apache.arrow.flight.auth.ServerAuthHandler;
import org.apache.arrow.flight.auth.ServerAuthInterceptor;
import org.apache.arrow.memory.BufferAllocator;
-import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.util.Preconditions;
import io.grpc.Server;
import io.grpc.ServerInterceptors;
import io.grpc.netty.NettyServerBuilder;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.ServerChannel;
+
public class FlightServer implements AutoCloseable {
private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(FlightServer.class);
@@ -88,45 +92,80 @@ public void close() throws InterruptedException {
}
}
+ /** Create a builder for a Flight server. */
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ /** Create a builder for a Flight server. */
public static Builder builder(BufferAllocator allocator, Location location, FlightProducer producer) {
return new Builder(allocator, location, producer);
}
+ /** A builder for Flight servers. */
public static final class Builder {
-
- private final BufferAllocator allocator;
- private final Location location;
- private final FlightProducer producer;
+ private BufferAllocator allocator;
+ private Location location;
+ private FlightProducer producer;
+ private final Map builderOptions;
private ServerAuthHandler authHandler = ServerAuthHandler.NO_OP;
-
+ private Executor executor = null;
+ private int maxInboundMessageSize = MAX_GRPC_MESSAGE_SIZE;
private InputStream certChain;
private InputStream key;
+ Builder() {
+ builderOptions = new HashMap<>();
+ }
+
Builder(BufferAllocator allocator, Location location, FlightProducer producer) {
- this.allocator = allocator;
- this.location = location;
- this.producer = producer;
+ this.allocator = Preconditions.checkNotNull(allocator);
+ this.location = Preconditions.checkNotNull(location);
+ this.producer = Preconditions.checkNotNull(producer);
+ builderOptions = new HashMap<>();
}
+ /** Create the server for this builder. */
public FlightServer build() {
final NettyServerBuilder builder;
switch (location.getUri().getScheme()) {
case LocationSchemes.GRPC_DOMAIN_SOCKET: {
- // TODO: need reflection to check if domain sockets are available
- throw new UnsupportedOperationException("Domain sockets are not available.");
+ // The implementation is platform-specific, so we have to find the classes at runtime
+ builder = NettyServerBuilder.forAddress(location.toSocketAddress());
+ try {
+ try {
+ // Linux
+ builder.channelType(
+ (Class extends ServerChannel>) Class
+ .forName("io.netty.channel.epoll.EpollServerDomainSocketChannel"));
+ final EventLoopGroup elg = (EventLoopGroup) Class.forName("io.netty.channel.epoll.EpollEventLoopGroup")
+ .newInstance();
+ builder.bossEventLoopGroup(elg).workerEventLoopGroup(elg);
+ } catch (ClassNotFoundException e) {
+ // BSD
+ builder.channelType(
+ (Class extends ServerChannel>) Class
+ .forName("io.netty.channel.kqueue.KQueueServerDomainSocketChannel"));
+ final EventLoopGroup elg = (EventLoopGroup) Class.forName("io.netty.channel.kqueue.KQueueEventLoopGroup")
+ .newInstance();
+ builder.bossEventLoopGroup(elg).workerEventLoopGroup(elg);
+ }
+ } catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) {
+ throw new UnsupportedOperationException(
+ "Could not find suitable Netty native transport implementation for domain socket address.");
+ }
+ break;
}
case LocationSchemes.GRPC:
case LocationSchemes.GRPC_INSECURE: {
- builder = NettyServerBuilder
- .forAddress(new InetSocketAddress(location.getUri().getHost(), location.getUri().getPort()));
+ builder = NettyServerBuilder.forAddress(location.toSocketAddress());
break;
}
case LocationSchemes.GRPC_TLS: {
if (certChain == null) {
throw new IllegalArgumentException("Must provide a certificate and key to serve gRPC over TLS");
}
- builder = NettyServerBuilder
- .forAddress(new InetSocketAddress(location.getUri().getHost(), location.getUri().getPort()));
+ builder = NettyServerBuilder.forAddress(location.toSocketAddress());
break;
}
default:
@@ -137,15 +176,33 @@ public FlightServer build() {
builder.useTransportSecurity(certChain, key);
}
- final Server server = builder
- .executor(new ForkJoinPool())
- .maxInboundMessageSize(MAX_GRPC_MESSAGE_SIZE)
+ builder
+ .executor(executor != null ? executor : new ForkJoinPool())
+ .maxInboundMessageSize(maxInboundMessageSize)
.addService(
ServerInterceptors.intercept(
new FlightBindingService(allocator, producer, authHandler),
- new ServerAuthInterceptor(authHandler)))
- .build();
- return new FlightServer(server);
+ new ServerAuthInterceptor(authHandler)));
+
+ // Allow setting some Netty-specific options
+ builderOptions.computeIfPresent("netty.bossEventLoopGroup", (key, elg) -> {
+ builder.bossEventLoopGroup((EventLoopGroup) elg);
+ return null;
+ });
+ builderOptions.computeIfPresent("netty.workerEventLoopGroup", (key, elg) -> {
+ builder.workerEventLoopGroup((EventLoopGroup) elg);
+ return null;
+ });
+
+ return new FlightServer(builder.build());
+ }
+
+ /**
+ * Set the maximum size of a message. Defaults to "unlimited", depending on the underlying transport.
+ */
+ public Builder maxInboundMessageSize(int maxMessageSize) {
+ this.maxInboundMessageSize = maxMessageSize;
+ return this;
}
/**
@@ -159,15 +216,54 @@ public Builder useTls(final File certChain, final File key) throws IOException {
return this;
}
+ /**
+ * Enable TLS on the server.
+ * @param certChain The certificate chain to use.
+ * @param key The private key to use.
+ */
public Builder useTls(final InputStream certChain, final InputStream key) {
this.certChain = certChain;
this.key = key;
return this;
}
- public Builder setAuthHandler(ServerAuthHandler authHandler) {
+ /**
+ * Set the executor used by the server.
+ */
+ public Builder executor(Executor executor) {
+ this.executor = executor;
+ return this;
+ }
+
+ /**
+ * Set the authentication handler.
+ */
+ public Builder authHandler(ServerAuthHandler authHandler) {
this.authHandler = authHandler;
return this;
}
+
+ /**
+ * Provide a transport-specific option. Not guaranteed to have any effect.
+ */
+ public Builder transportHint(final String key, Object option) {
+ builderOptions.put(key, option);
+ return this;
+ }
+
+ public Builder allocator(BufferAllocator allocator) {
+ this.allocator = Preconditions.checkNotNull(allocator);
+ return this;
+ }
+
+ public Builder location(Location location) {
+ this.location = Preconditions.checkNotNull(location);
+ return this;
+ }
+
+ public Builder producer(FlightProducer producer) {
+ this.producer = Preconditions.checkNotNull(producer);
+ return this;
+ }
}
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/Location.java b/java/flight/src/main/java/org/apache/arrow/flight/Location.java
index 377edeba747..b4169fc9030 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/Location.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/Location.java
@@ -17,6 +17,9 @@
package org.apache.arrow.flight;
+import java.lang.reflect.InvocationTargetException;
+import java.net.InetSocketAddress;
+import java.net.SocketAddress;
import java.net.URI;
import java.net.URISyntaxException;
@@ -30,21 +33,69 @@ public class Location {
* Constructs a new instance.
*
* @param uri the URI of the Flight service
+ * @throws IllegalArgumentException if the URI scheme is unsupported
*/
public Location(String uri) throws URISyntaxException {
- super();
- this.uri = new URI(uri);
+ this(new URI(uri));
}
+ /**
+ * Construct a new instance from an existing URI.
+ *
+ * @param uri the URI of the Flight service
+ * @throws IllegalArgumentException if the URI scheme is unsupported
+ */
public Location(URI uri) {
super();
this.uri = uri;
+ // Validate the scheme
+ switch (uri.getScheme()) {
+ case LocationSchemes.GRPC:
+ case LocationSchemes.GRPC_DOMAIN_SOCKET:
+ case LocationSchemes.GRPC_INSECURE:
+ case LocationSchemes.GRPC_TLS: {
+ break;
+ }
+ default:
+ throw new IllegalArgumentException("Scheme is not supported: " + this.uri);
+ }
}
public URI getUri() {
return uri;
}
+ /**
+ * Helper method to turn this Location into a SocketAddress.
+ *
+ * @return null if could not be converted
+ */
+ SocketAddress toSocketAddress() {
+ switch (uri.getScheme()) {
+ case LocationSchemes.GRPC:
+ case LocationSchemes.GRPC_TLS:
+ case LocationSchemes.GRPC_INSECURE: {
+ return new InetSocketAddress(uri.getHost(), uri.getPort());
+ }
+
+ case LocationSchemes.GRPC_DOMAIN_SOCKET: {
+ try {
+ // This dependency is not available on non-Unix platforms.
+ return (SocketAddress) Class.forName("io.netty.channel.unix.DomainSocketAddress")
+ .getConstructor(String.class)
+ .newInstance(uri.getPath());
+ } catch (InstantiationException | ClassNotFoundException | InvocationTargetException |
+ NoSuchMethodException | IllegalAccessException e) {
+ return null;
+ }
+ }
+
+ default: {
+ return null;
+ }
+ }
+ }
+
/**
* Convert this Location into its protocol-level representation.
*/
@@ -69,4 +120,15 @@ public static Location forGrpcTls(String host, int port) {
throw new IllegalArgumentException(e);
}
}
+
+ /**
+ * Construct a URI for a Flight+gRPC server over a Unix domain socket.
+ */
+ public static Location forGrpcDomainSocket(String path) {
+ try {
+ return new Location(new URI(LocationSchemes.GRPC_DOMAIN_SOCKET, null, path, null));
+ } catch (URISyntaxException e) {
+ throw new IllegalArgumentException(e);
+ }
+ }
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/LocationSchemes.java b/java/flight/src/main/java/org/apache/arrow/flight/LocationSchemes.java
index b652fed48a7..872e5b1c22d 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/LocationSchemes.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/LocationSchemes.java
@@ -20,9 +20,13 @@
/**
* Constants representing well-known URI schemes for Flight services.
*/
-public class LocationSchemes {
+public final class LocationSchemes {
public static final String GRPC = "grpc";
public static final String GRPC_INSECURE = "grpc+tcp";
public static final String GRPC_DOMAIN_SOCKET = "grpc+unix";
public static final String GRPC_TLS = "grpc+tls";
+
+ private LocationSchemes() {
+ throw new AssertionError("Do not instantiate this class.");
+ }
}
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java b/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java
index 11b476ccceb..f6b9e867807 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java
@@ -62,6 +62,28 @@ public static T getStartedServer(Function newServerFromPort) thr
return server;
}
+ static boolean isEpollAvailable() {
+ try {
+ Class> epoll = Class.forName("io.netty.channel.epoll.Epoll");
+ return (Boolean) epoll.getMethod("isAvailable").invoke(null);
+ } catch (ClassNotFoundException | NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
+ return false;
+ }
+ }
+
+ static boolean isKqueueAvailable() {
+ try {
+ Class> kqueue = Class.forName("io.netty.channel.kqueue.KQueue");
+ return (Boolean) kqueue.getMethod("isAvailable").invoke(null);
+ } catch (ClassNotFoundException | NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
+ return false;
+ }
+ }
+
+ static boolean isNativeTransportAvailable() {
+ return isEpollAvailable() || isKqueueAvailable();
+ }
+
private FlightTestUtil() {
}
}
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java b/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java
index ab31eb8df92..dfb77ada666 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java
@@ -157,7 +157,7 @@ private void test(BiConsumer consumer) throws Exc
/**
* An example FlightProducer for test purposes.
*/
- public class Producer implements FlightProducer, AutoCloseable {
+ public static class Producer implements FlightProducer, AutoCloseable {
private final BufferAllocator allocator;
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestServerOptions.java b/java/flight/src/test/java/org/apache/arrow/flight/TestServerOptions.java
new file mode 100644
index 00000000000..e3ac3908941
--- /dev/null
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestServerOptions.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.io.File;
+
+import org.apache.arrow.flight.TestBasicOperation.Producer;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.junit.Assert;
+import org.junit.Assume;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class TestServerOptions {
+
+ @Test
+ public void domainSocket() throws Exception {
+ Assume.assumeTrue("We have a native transport available", FlightTestUtil.isNativeTransportAvailable());
+ final File domainSocket = File.createTempFile("flight-unit-test-", ".sock");
+ Assert.assertTrue(domainSocket.delete());
+ final Location location = Location.forGrpcDomainSocket(domainSocket.getAbsolutePath());
+ try (
+ BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
+ Producer producer = new Producer(a);
+ FlightServer s =
+ FlightTestUtil.getStartedServer(
+ (port) -> FlightServer.builder(a, location, producer).build()
+ )) {
+ try (FlightClient c = FlightClient.builder(a, location).build()) {
+ FlightStream stream = c.getStream(new Ticket(new byte[0]));
+ VectorSchemaRoot root = stream.getRoot();
+ IntVector iv = (IntVector) root.getVector("c1");
+ int value = 0;
+ while (stream.next()) {
+ for (int i = 0; i < root.getRowCount(); i++) {
+ Assert.assertEquals(value, iv.get(i));
+ value++;
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java b/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java
index df16a11938a..39b2924c620 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java
@@ -149,7 +149,7 @@ public void getStream(CallContext context, Ticket ticket, ServerStreamListener l
root.clear();
listener.completed();
}
- }).setAuthHandler(new BasicServerAuthHandler(validator)).build());
+ }).authHandler(new BasicServerAuthHandler(validator)).build());
client = FlightClient.builder(allocator, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort()))
.build();
}
diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx
index 82745d34c8c..806796f37f1 100644
--- a/python/pyarrow/_flight.pyx
+++ b/python/pyarrow/_flight.pyx
@@ -400,16 +400,26 @@ cdef class FlightClient:
.format(self.__class__.__name__))
@staticmethod
- def connect(location, **kwargs):
- """Connect to a Flight service on the given host and port."""
+ def connect(location, tls_root_certs=None):
+ """
+ Connect to a Flight service on the given host and port.
+
+ Parameters
+ ----------
+ location : Location
+ location to connect to
+
+ tls_root_certs : bytes
+ PEM-encoded
+ """
cdef:
FlightClient result = FlightClient.__new__(FlightClient)
int c_port = 0
CLocation c_location = Location.unwrap(location)
CFlightClientOptions c_options
- if "tls_root_certs" in kwargs:
- c_options.tls_root_certs = tobytes(kwargs["tls_root_certs"])
+ if tls_root_certs:
+ c_options.tls_root_certs = tobytes(tls_root_certs)
with nogil:
check_status(CFlightClient.Connect(c_location, c_options,
@@ -1004,7 +1014,8 @@ cdef class FlightServerBase:
cdef:
unique_ptr[PyFlightServer] server
- def run(self, location, auth_handler=None, **kwargs):
+ def run(self, location, auth_handler=None,
+ tls_cert_chain=None, tls_private_key=None):
cdef:
PyFlightServerVtable vtable = PyFlightServerVtable()
PyFlightServer* c_server
@@ -1019,14 +1030,12 @@ cdef class FlightServerBase:
c_options.get().auth_handler.reset(
( auth_handler).to_handler())
- if "tls_cert_chain" in kwargs:
- if "tls_private_key" not in kwargs:
+ if tls_cert_chain:
+ if not tls_private_key:
raise ValueError(
"Must provide both cert chain and private key")
- c_options.get().tls_cert_chain = tobytes(
- kwargs["tls_cert_chain"])
- c_options.get().tls_private_key = tobytes(
- kwargs["tls_private_key"])
+ c_options.get().tls_cert_chain = tobytes(tls_cert_chain)
+ c_options.get().tls_private_key = tobytes(tls_private_key)
vtable.list_flights = &_list_flights
vtable.get_flight_info = &_get_flight_info
diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py
index bbe5507ffdd..6ab512565cd 100644
--- a/python/pyarrow/tests/test_flight.py
+++ b/python/pyarrow/tests/test_flight.py
@@ -465,9 +465,9 @@ def test_token_auth_invalid():
def test_location_invalid():
"""Test constructing invalid URIs."""
- with pytest.raises(pa.ArrowException, match=".*Cannot parse URI:.*"):
+ with pytest.raises(pa.ArrowInvalid, match=".*Cannot parse URI:.*"):
flight.FlightClient.connect("%")
server = ConstantFlightServer()
- with pytest.raises(pa.ArrowException, match=".*Cannot parse URI:.*"):
+ with pytest.raises(pa.ArrowInvalid, match=".*Cannot parse URI:.*"):
server.run("%")