From 460442cea3ab4a67cb113bc9a48b9f09e1748eb2 Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 2 May 2019 12:47:19 -0400 Subject: [PATCH 1/5] Use URIs for Flight locations --- cpp/src/arrow/flight/client.cc | 32 +++- cpp/src/arrow/flight/client.h | 6 +- cpp/src/arrow/flight/flight-benchmark.cc | 8 +- cpp/src/arrow/flight/flight-test.cc | 33 ++++- cpp/src/arrow/flight/internal.cc | 7 +- cpp/src/arrow/flight/perf-server.cc | 7 +- cpp/src/arrow/flight/server.cc | 25 +++- cpp/src/arrow/flight/server.h | 7 +- .../arrow/flight/test-integration-client.cc | 13 +- .../arrow/flight/test-integration-server.cc | 4 +- cpp/src/arrow/flight/test-server.cc | 4 +- cpp/src/arrow/flight/test-util.cc | 21 ++- cpp/src/arrow/flight/types.cc | 29 ++++ cpp/src/arrow/flight/types.h | 57 +++++++- cpp/src/arrow/util/uri.cc | 6 +- cpp/src/arrow/util/uri.h | 2 + format/Flight.proto | 9 +- .../org/apache/arrow/flight/FlightClient.java | 23 ++- .../apache/arrow/flight/FlightEndpoint.java | 15 +- .../org/apache/arrow/flight/FlightInfo.java | 9 +- .../org/apache/arrow/flight/FlightServer.java | 24 ++- .../org/apache/arrow/flight/Location.java | 50 ++++--- .../apache/arrow/flight/LocationSchemes.java | 28 ++++ .../flight/example/ExampleFlightServer.java | 4 +- .../integration/IntegrationTestClient.java | 8 +- .../integration/IntegrationTestServer.java | 2 +- .../apache/arrow/flight/TestBackPressure.java | 7 +- .../arrow/flight/TestBasicOperation.java | 20 ++- .../apache/arrow/flight/TestCallOptions.java | 6 +- .../apache/arrow/flight/TestLargeMessage.java | 12 +- .../apache/arrow/flight/auth/TestAuth.java | 4 +- .../flight/example/TestExampleServer.java | 4 +- .../flight/perf/PerformanceTestServer.java | 2 +- .../apache/arrow/flight/perf/TestPerf.java | 2 +- python/examples/flight/client.py | 2 +- python/pyarrow/_flight.pyx | 138 +++++++++++++----- python/pyarrow/includes/libarrow_flight.pxd | 14 +- python/pyarrow/tests/test_flight.py | 90 +++++++----- 38 files changed, 543 insertions(+), 191 deletions(-) create mode 100644 java/flight/src/main/java/org/apache/arrow/flight/LocationSchemes.java diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 5f4d6dd8610..4a7beded950 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -37,6 +37,7 @@ #include "arrow/status.h" #include "arrow/type.h" #include "arrow/util/logging.h" +#include "arrow/util/uri.h" #include "arrow/flight/client_auth.h" #include "arrow/flight/internal.h" @@ -230,11 +231,26 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter { class FlightClient::FlightClientImpl { public: - Status Connect(const std::string& host, int port) { - // TODO(wesm): Support other kinds of GRPC ChannelCredentials - std::stringstream ss; - ss << host << ":" << port; - std::string uri = ss.str(); + Status Connect(const Location& location) { + const std::string& scheme = location.scheme(); + + std::stringstream grpc_uri; + std::shared_ptr creds; + if (scheme == kSchemeGrpc || scheme == kSchemeGrpcTcp || scheme == kSchemeGrpcTls) { + // TODO(wesm): Support other kinds of GRPC ChannelCredentials + grpc_uri << location.uri_->host() << ":" << location.uri_->port_text(); + + if (scheme == "grpc+tls") { + creds = grpc::SslCredentials(grpc::SslCredentialsOptions()); + } else { + creds = grpc::InsecureChannelCredentials(); + } + } else if (scheme == kSchemeGrpcUnix) { + grpc_uri << "unix://" << location.uri_->path(); + creds = grpc::InsecureChannelCredentials(); + } else { + return Status::NotImplemented("Flight scheme " + scheme + " is not supported."); + } grpc::ChannelArguments args; // Try to reconnect quickly at first, in case the server is still starting up @@ -242,7 +258,7 @@ class FlightClient::FlightClientImpl { // Receive messages of any size args.SetMaxReceiveMessageSize(-1); stub_ = pb::FlightService::NewStub( - grpc::CreateCustomChannel(ss.str(), grpc::InsecureChannelCredentials(), args)); + grpc::CreateCustomChannel(grpc_uri.str(), creds, args)); return Status::OK(); } @@ -383,10 +399,10 @@ FlightClient::FlightClient() { impl_.reset(new FlightClientImpl); } FlightClient::~FlightClient() {} -Status FlightClient::Connect(const std::string& host, int port, +Status FlightClient::Connect(const Location& location, std::unique_ptr* client) { client->reset(new FlightClient); - return (*client)->impl_->Connect(host, port); + return (*client)->impl_->Connect(location); } Status FlightClient::Authenticate(const FlightCallOptions& options, diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index 3886360ee5c..f48f3edf839 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -64,13 +64,11 @@ class ARROW_EXPORT FlightClient { ~FlightClient(); /// \brief Connect to an unauthenticated flight service - /// \param[in] host the hostname or IP address - /// \param[in] port the port on the host + /// \param[in] location the URI /// \param[out] client the created FlightClient /// \return Status OK status may not indicate that the connection was /// successful - static Status Connect(const std::string& host, int port, - std::unique_ptr* client); + static Status Connect(const Location& location, std::unique_ptr* client); /// \brief Authenticate to the server using the given handler. /// \param[in] options Per-RPC options diff --git a/cpp/src/arrow/flight/flight-benchmark.cc b/cpp/src/arrow/flight/flight-benchmark.cc index d6318eaadef..20fe5ac179e 100644 --- a/cpp/src/arrow/flight/flight-benchmark.cc +++ b/cpp/src/arrow/flight/flight-benchmark.cc @@ -79,7 +79,9 @@ Status RunPerformanceTest(const std::string& hostname, const int port) { // Construct client and plan the query std::unique_ptr client; - RETURN_NOT_OK(FlightClient::Connect(hostname, port, &client)); + Location location; + RETURN_NOT_OK(Location::ForGrpcTcp("localhost", port, &location)); + RETURN_NOT_OK(FlightClient::Connect(location, &client)); FlightDescriptor descriptor; descriptor.type = FlightDescriptor::CMD; @@ -97,7 +99,9 @@ Status RunPerformanceTest(const std::string& hostname, const int port) { auto ConsumeStream = [&stats, &hostname, &port](const FlightEndpoint& endpoint) { // TODO(wesm): Use location from endpoint, same host/port for now std::unique_ptr client; - RETURN_NOT_OK(FlightClient::Connect(hostname, port, &client)); + Location location; + RETURN_NOT_OK(Location::ForGrpcTcp("localhost", port, &location)); + RETURN_NOT_OK(FlightClient::Connect(location, &client)); perf::Token token; token.ParseFromString(endpoint.ticket.ticket); diff --git a/cpp/src/arrow/flight/flight-test.cc b/cpp/src/arrow/flight/flight-test.cc index 43e6ddb1408..d99e7c4b020 100644 --- a/cpp/src/arrow/flight/flight-test.cc +++ b/cpp/src/arrow/flight/flight-test.cc @@ -61,8 +61,7 @@ void AssertEqual(const Ticket& expected, const Ticket& actual) { } void AssertEqual(const Location& expected, const Location& actual) { - ASSERT_EQ(expected.host, actual.host); - ASSERT_EQ(expected.port, actual.port); + ASSERT_EQ(expected, actual); } void AssertEqual(const std::vector& expected, @@ -146,6 +145,20 @@ TEST(TestFlight, StartStopTestServer) { ASSERT_FALSE(server.IsRunning()); } +TEST(TestFlight, ConnectUri) { + TestServer server("flight-test-server", 30000); + server.Start(); + ASSERT_TRUE(server.IsRunning()); + + std::unique_ptr client; + Location location1; + Location location2; + ASSERT_OK(Location::Parse("grpc://localhost:30000", &location1)); + ASSERT_OK(Location::Parse("grpc://localhost:30000", &location2)); + ASSERT_OK(FlightClient::Connect(location1, &client)); + ASSERT_OK(FlightClient::Connect(location2, &client)); +} + // ---------------------------------------------------------------------- // Client tests @@ -169,7 +182,11 @@ class TestFlightClient : public ::testing::Test { void TearDown() { server_->Stop(); } - Status ConnectClient() { return FlightClient::Connect("localhost", port_, &client_); } + Status ConnectClient() { + Location location; + RETURN_NOT_OK(Location::ForGrpcTcp("localhost", port_, &location)); + return FlightClient::Connect(location, &client_); + } template void CheckDoGet(const FlightDescriptor& descr, const BatchVector& expected_batches, @@ -248,7 +265,11 @@ class TestAuthHandler : public ::testing::Test { void TearDown() { server_->Stop(); } - Status ConnectClient() { return FlightClient::Connect("localhost", port_, &client_); } + Status ConnectClient() { + Location location; + RETURN_NOT_OK(Location::ForGrpcTcp("localhost", port_, &location)); + return FlightClient::Connect(location, &client_); + } protected: int port_; @@ -423,7 +444,9 @@ TEST_F(TestFlightClient, Issue5095) { TEST_F(TestFlightClient, TimeoutFires) { // Server does not exist on this port, so call should fail std::unique_ptr client; - ASSERT_OK(FlightClient::Connect("localhost", 30001, &client)); + Location location; + ASSERT_OK(Location::ForGrpcTcp("localhost", 30001, &location)); + ASSERT_OK(FlightClient::Connect(location, &client)); FlightCallOptions options; options.timeout = TimeoutDuration{0.2}; std::unique_ptr info; diff --git a/cpp/src/arrow/flight/internal.cc b/cpp/src/arrow/flight/internal.cc index 2e335936e51..c60e58597f6 100644 --- a/cpp/src/arrow/flight/internal.cc +++ b/cpp/src/arrow/flight/internal.cc @@ -119,14 +119,11 @@ Status FromProto(const pb::Criteria& pb_criteria, Criteria* criteria) { // Location Status FromProto(const pb::Location& pb_location, Location* location) { - location->host = pb_location.host(); - location->port = pb_location.port(); - return Status::OK(); + return Location::Parse(pb_location.uri(), location); } void ToProto(const Location& location, pb::Location* pb_location) { - pb_location->set_host(location.host); - pb_location->set_port(location.port); + pb_location->set_uri(location.ToString()); } // Ticket diff --git a/cpp/src/arrow/flight/perf-server.cc b/cpp/src/arrow/flight/perf-server.cc index 3755f3d52d1..f65cbd15abf 100644 --- a/cpp/src/arrow/flight/perf-server.cc +++ b/cpp/src/arrow/flight/perf-server.cc @@ -141,7 +141,8 @@ Status GetPerfBatches(const perf::Token& token, const std::shared_ptr& s class FlightPerfServer : public FlightServerBase { public: - FlightPerfServer() : location_(Location{"localhost", FLAGS_port}) { + FlightPerfServer() : location_() { + DCHECK_OK(Location::ForGrpcTcp("localhost", FLAGS_port, &location_)); perf_schema_ = schema({field("a", int64()), field("b", int64()), field("c", int64()), field("d", int64())}); } @@ -204,8 +205,10 @@ int main(int argc, char** argv) { g_server.reset(new arrow::flight::FlightPerfServer); + arrow::flight::Location location; + ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location)); ARROW_CHECK_OK( - g_server->Init(std::unique_ptr(), FLAGS_port)); + g_server->Init(std::unique_ptr(), location)); // Exit with a clean error code (0) on SIGTERM ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM})); std::cout << "Server port: " << FLAGS_port << std::endl; diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index 5679cdeedde..d8c5525ee98 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -40,6 +40,7 @@ #include "arrow/status.h" #include "arrow/util/logging.h" #include "arrow/util/stl.h" +#include "arrow/util/uri.h" #include "arrow/flight/internal.h" #include "arrow/flight/serialization-internal.h" @@ -410,7 +411,6 @@ class FlightServiceImpl : public FlightService::Service { #endif struct FlightServerBase::Impl { - std::string address_; std::unique_ptr service_; std::unique_ptr server_; @@ -442,17 +442,34 @@ FlightServerBase::FlightServerBase() { impl_.reset(new Impl); } FlightServerBase::~FlightServerBase() {} -Status FlightServerBase::Init(std::unique_ptr auth_handler, int port) { - impl_->address_ = "localhost:" + std::to_string(port); +Status FlightServerBase::Init(std::unique_ptr auth_handler, + const Location& location) { std::shared_ptr handler = std::move(auth_handler); impl_->service_.reset(new FlightServiceImpl(handler, this)); grpc::ServerBuilder builder; // Allow uploading messages of any length builder.SetMaxReceiveMessageSize(-1); - builder.AddListeningPort(impl_->address_, grpc::InsecureServerCredentials()); + + const std::string scheme = location.scheme(); + if (scheme == kSchemeGrpc || scheme == kSchemeGrpcTcp) { + std::stringstream address; + address << location.uri_->host() << ':' << location.uri_->port_text(); + builder.AddListeningPort(address.str(), grpc::InsecureServerCredentials()); + } else if (scheme == kSchemeGrpcUnix) { + std::stringstream address; + address << "unix:" << location.uri_->path(); + builder.AddListeningPort(address.str(), grpc::InsecureServerCredentials()); + } else { + return Status::NotImplemented("Scheme is not supported: " + scheme); + } + builder.RegisterService(impl_->service_.get()); + // Disable SO_REUSEPORT - it makes debugging/testing a pain as + // leftover processes can handle requests on accident + builder.AddChannelArgument(GRPC_ARG_ALLOW_REUSEPORT, 0); + impl_->server_ = builder.BuildAndStart(); if (!impl_->server_) { return Status::UnknownError("Server did not start properly"); diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h index 0a2b940f14b..752b89e85bf 100644 --- a/cpp/src/arrow/flight/server.h +++ b/cpp/src/arrow/flight/server.h @@ -98,13 +98,12 @@ class ARROW_EXPORT FlightServerBase { // Lifecycle methods. - /// \brief Initialize an insecure TCP server listening on localhost - /// at the given port. + /// \brief Initialize a Flight server listening at the given location. /// This method must be called before any other method. - /// \param[in] port The port to serve on. /// \param[in] auth_handler The authentication handler. May be /// nullptr if no authentication is desired. - Status Init(std::unique_ptr auth_handler, int port); + /// \param[in] location The location to serve on. + Status Init(std::unique_ptr auth_handler, const Location& location); /// \brief Set the server to stop when receiving any of the given signal /// numbers. diff --git a/cpp/src/arrow/flight/test-integration-client.cc b/cpp/src/arrow/flight/test-integration-client.cc index 93a1a16c67e..d0b734007c4 100644 --- a/cpp/src/arrow/flight/test-integration-client.cc +++ b/cpp/src/arrow/flight/test-integration-client.cc @@ -89,8 +89,7 @@ arrow::Status ConsumeFlightLocation(const arrow::flight::Location& location, const std::shared_ptr& schema, std::shared_ptr* retrieved_data) { std::unique_ptr read_client; - RETURN_NOT_OK( - arrow::flight::FlightClient::Connect(location.host, location.port, &read_client)); + RETURN_NOT_OK(arrow::flight::FlightClient::Connect(location, &read_client)); std::unique_ptr stream; RETURN_NOT_OK(read_client->DoGet(ticket, &stream)); @@ -103,7 +102,10 @@ int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); std::unique_ptr client; - ABORT_NOT_OK(arrow::flight::FlightClient::Connect(FLAGS_host, FLAGS_port, &client)); + std::stringstream uri; + arrow::flight::Location location; + ABORT_NOT_OK(arrow::flight::Location::ForGrpcTcp(FLAGS_host, FLAGS_port, &location)); + ABORT_NOT_OK(arrow::flight::FlightClient::Connect(location, &client)); arrow::flight::FlightDescriptor descr{ arrow::flight::FlightDescriptor::PATH, "", {FLAGS_path}}; @@ -143,12 +145,11 @@ int main(int argc, char** argv) { auto locations = endpoint.locations; if (locations.size() == 0) { - locations = {arrow::flight::Location{FLAGS_host, FLAGS_port}}; + locations = {location}; } for (const auto location : locations) { - std::cout << "Verifying location " << location.host << ':' << location.port - << std::endl; + std::cout << "Verifying location " << location.ToString() << std::endl; // 3. Download the data from the server. std::shared_ptr retrieved_data; ABORT_NOT_OK(ConsumeFlightLocation(location, ticket, schema, &retrieved_data)); diff --git a/cpp/src/arrow/flight/test-integration-server.cc b/cpp/src/arrow/flight/test-integration-server.cc index 9a7c8787ba0..d72e838213b 100644 --- a/cpp/src/arrow/flight/test-integration-server.cc +++ b/cpp/src/arrow/flight/test-integration-server.cc @@ -124,7 +124,9 @@ int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); g_server.reset(new arrow::flight::FlightIntegrationTestServer); - ARROW_CHECK_OK(g_server->Init(nullptr, FLAGS_port)); + arrow::flight::Location location; + ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location)); + ARROW_CHECK_OK(g_server->Init(nullptr, location)); // Exit with a clean error code (0) on SIGTERM ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM})); diff --git a/cpp/src/arrow/flight/test-server.cc b/cpp/src/arrow/flight/test-server.cc index a9070a495a5..29af87db601 100644 --- a/cpp/src/arrow/flight/test-server.cc +++ b/cpp/src/arrow/flight/test-server.cc @@ -139,8 +139,10 @@ int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); g_server.reset(new arrow::flight::FlightTestServer); + arrow::flight::Location location; + ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location)); ARROW_CHECK_OK( - g_server->Init(std::unique_ptr(), FLAGS_port)); + g_server->Init(std::unique_ptr(), location)); // Exit with a clean error code (0) on SIGTERM ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM})); diff --git a/cpp/src/arrow/flight/test-util.cc b/cpp/src/arrow/flight/test-util.cc index ac6df556f36..01577958640 100644 --- a/cpp/src/arrow/flight/test-util.cc +++ b/cpp/src/arrow/flight/test-util.cc @@ -119,7 +119,9 @@ bool TestServer::IsRunning() { return server_process_->running(); } int TestServer::port() const { return port_; } Status InProcessTestServer::Start(std::unique_ptr auth_handler) { - RETURN_NOT_OK(server_->Init(std::move(auth_handler), port_)); + Location location; + RETURN_NOT_OK(Location::ForGrpcTcp("localhost", port_, &location)); + RETURN_NOT_OK(server_->Init(std::move(auth_handler), location)); thread_ = std::thread([this]() { ARROW_EXPECT_OK(server_->Serve()); }); return Status::OK(); } @@ -169,12 +171,21 @@ std::shared_ptr ExampleDictSchema() { } std::vector ExampleFlightInfo() { + Location location1; + Location location2; + Location location3; + Location location4; + ARROW_EXPECT_OK(Location::ForGrpcTcp("foo1.bar.com", 12345, &location1)); + ARROW_EXPECT_OK(Location::ForGrpcTcp("foo2.bar.com", 12345, &location2)); + ARROW_EXPECT_OK(Location::ForGrpcTcp("foo3.bar.com", 12345, &location3)); + ARROW_EXPECT_OK(Location::ForGrpcTcp("foo4.bar.com", 12345, &location4)); + FlightInfo::Data flight1, flight2, flight3; - FlightEndpoint endpoint1({{"ticket-ints-1"}, {{"foo1.bar.com", 92385}}}); - FlightEndpoint endpoint2({{"ticket-ints-2"}, {{"foo2.bar.com", 92385}}}); - FlightEndpoint endpoint3({{"ticket-cmd"}, {{"foo3.bar.com", 92385}}}); - FlightEndpoint endpoint4({{"ticket-dicts-1"}, {{"foo4.bar.com", 92385}}}); + FlightEndpoint endpoint1({{"ticket-ints-1"}, {location1}}); + FlightEndpoint endpoint2({{"ticket-ints-2"}, {location2}}); + FlightEndpoint endpoint3({{"ticket-cmd"}, {location3}}); + FlightEndpoint endpoint4({{"ticket-dicts-1"}, {location4}}); FlightDescriptor descr1{FlightDescriptor::PATH, "", {"examples", "ints"}}; FlightDescriptor descr2{FlightDescriptor::CMD, "my_command", {}}; diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc index 77c8009e8bf..985cb5fc03a 100644 --- a/cpp/src/arrow/flight/types.cc +++ b/cpp/src/arrow/flight/types.cc @@ -25,6 +25,7 @@ #include "arrow/ipc/dictionary.h" #include "arrow/ipc/reader.h" #include "arrow/status.h" +#include "arrow/util/uri.h" namespace arrow { namespace flight { @@ -83,6 +84,34 @@ Status FlightInfo::GetSchema(ipc::DictionaryMemo* dictionary_memo, return Status::OK(); } +Location::Location() { uri_ = std::make_shared(); } + +Status Location::Parse(const std::string& uri_string, Location* location) { + return location->uri_->Parse(uri_string); +} + +Status Location::ForGrpcTcp(const std::string& host, const int port, Location* location) { + std::stringstream uri_string; + uri_string << "grpc+tcp://" << host << ':' << port; + return Location::Parse(uri_string.str(), location); +} + +Status Location::ForGrpcUnix(const std::string& path, Location* location) { + std::stringstream uri_string; + uri_string << "grpc+unix://" << path; + return Location::Parse(uri_string.str(), location); +} + +std::string Location::ToString() const { return uri_->to_string(); } +std::string Location::scheme() const { + std::string scheme = uri_->scheme(); + if (scheme.empty()) { + // Default to grpc+tcp + return "grpc+tcp"; + } + return scheme; +} + SimpleFlightListing::SimpleFlightListing(const std::vector& flights) : position_(0), flights_(flights) {} diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index 8e85a41b3d7..9caee997ae1 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -39,7 +39,13 @@ namespace ipc { class DictionaryMemo; -} +} // namespace ipc + +namespace internal { + +class Uri; + +} // namespace internal namespace flight { @@ -115,10 +121,53 @@ struct Ticket { std::string ticket; }; -/// \brief A host location (hostname and port) +class FlightClient; +class FlightServerBase; + +static const char* kSchemeGrpc = "grpc"; +static const char* kSchemeGrpcTcp = "grpc+tcp"; +static const char* kSchemeGrpcUnix = "grpc+unix"; +static const char* kSchemeGrpcTls = "grpc+tls"; + +/// \brief A host location (a URI) struct Location { - std::string host; - int32_t port; + public: + /// \brief Initialize a blank location. + Location(); + + /// \brief Initialize a location by parsing a URI string + static Status Parse(const std::string& uri_string, Location* location); + + /// \brief Initialize a location for a non-TLS, gRPC-based Flight + /// service from a host and port + /// \param[in] host The hostname to connect to + /// \param[in] port The port + /// \param[out] location The resulting location + static Status ForGrpcTcp(const std::string& host, const int port, Location* location); + + /// \brief Initialize a location for a domain socket-based Flight + /// service + /// \param[in] path The path to the domain socket + /// \param[out] location The resulting location + static Status ForGrpcUnix(const std::string& path, Location* location); + + /// \brief Get a representation of this URI as a string. + std::string ToString() const; + + /// \brief Get the scheme of this URI. + std::string scheme() const; + + friend bool operator==(const Location& left, const Location& right) { + return left.ToString() == right.ToString(); + } + friend bool operator!=(const Location& left, const Location& right) { + return !(left == right); + } + + private: + friend class FlightClient; + friend class FlightServerBase; + std::shared_ptr uri_; }; /// \brief A flight ticket and list of locations where the ticket can be diff --git a/cpp/src/arrow/util/uri.cc b/cpp/src/arrow/util/uri.cc index 3a90612ea67..e79c9826b9b 100644 --- a/cpp/src/arrow/util/uri.cc +++ b/cpp/src/arrow/util/uri.cc @@ -52,7 +52,7 @@ bool IsTextRangeSet(const UriTextRangeStructA& range) { return range.first != nu } // namespace struct Uri::Impl { - Impl() : port_(-1) { memset(&uri_, 0, sizeof(uri_)); } + Impl() : string_rep_(""), port_(-1) { memset(&uri_, 0, sizeof(uri_)); } ~Impl() { uriFreeUriMembersA(&uri_); } @@ -71,6 +71,7 @@ struct Uri::Impl { UriUriA uri_; // Keep alive strings that uriparser stores pointers to std::vector data_; + std::string string_rep_; int32_t port_; }; @@ -119,10 +120,13 @@ std::string Uri::path() const { return ss.str(); } +const std::string& Uri::to_string() const { return impl_->string_rep_; } + Status Uri::Parse(const std::string& uri_string) { impl_->Reset(); const auto& s = impl_->KeepString(uri_string); + impl_->string_rep_ = s; const char* error_pos; if (uriParseSingleUriExA(&impl_->uri_, s.data(), s.data() + s.size(), &error_pos) != URI_SUCCESS) { diff --git a/cpp/src/arrow/util/uri.h b/cpp/src/arrow/util/uri.h index 3d6949537ce..7327c2d7876 100644 --- a/cpp/src/arrow/util/uri.h +++ b/cpp/src/arrow/util/uri.h @@ -54,6 +54,8 @@ class ARROW_EXPORT Uri { int32_t port() const; /// The URI path component. std::string path() const; + /// Get the string representation of this URI. + const std::string& to_string() const; /// Factory function to parse a URI from its string representation. Status Parse(const std::string& uri_string); diff --git a/format/Flight.proto b/format/Flight.proto index 1fcefe9a63e..7f0488b86c3 100644 --- a/format/Flight.proto +++ b/format/Flight.proto @@ -246,7 +246,7 @@ message FlightEndpoint { Ticket ticket = 1; /* - * A list of locations where this ticket can be redeemed. If the list is + * A list of URIs where this ticket can be redeemed. If the list is * empty, the expectation is that the ticket can only be redeemed on the * current service where the ticket was generated. */ @@ -254,12 +254,11 @@ message FlightEndpoint { } /* - * A location where a flight service will accept retrieval of a particular - * stream given a ticket. + * A location where a Flight service will accept retrieval of a particular + * stream given a ticket. */ message Location { - string host = 1; - int32 port = 2; + string uri = 1; } /* diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java index c13c5659585..178897da3b3 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java @@ -20,6 +20,8 @@ import static io.grpc.stub.ClientCalls.asyncClientStreamingCall; import static io.grpc.stub.ClientCalls.asyncServerStreamingCall; +import java.net.URI; +import java.net.URISyntaxException; import java.util.Iterator; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; @@ -74,8 +76,7 @@ public class FlightClient implements AutoCloseable { */ public FlightClient(BufferAllocator incomingAllocator, Location location) { final ManagedChannelBuilder channelBuilder = - ManagedChannelBuilder.forAddress(location.getHost(), - location.getPort()) + ManagedChannelBuilder.forAddress(location.getUri().getHost(), location.getUri().getPort()) .maxTraceEvents(MAX_CHANNEL_TRACE_EVENTS) .maxInboundMessageSize(FlightServer.MAX_GRPC_MESSAGE_SIZE) .usePlaintext(); @@ -97,7 +98,15 @@ public FlightClient(BufferAllocator incomingAllocator, Location location) { public Iterable listFlights(Criteria criteria, CallOption... options) { return ImmutableList.copyOf(CallOptions.wrapStub(blockingStub, options).listFlights(criteria.asCriteria())) .stream() - .map(FlightInfo::new) + .map(t -> { + try { + return new FlightInfo(t); + } catch (URISyntaxException e) { + // We don't expect this will happen for conforming Flight implementations. For instance, a Java server + // itself wouldn't be able to construct an invalid Location. + throw new RuntimeException(e); + } + }) .collect(Collectors.toList()); } @@ -177,7 +186,13 @@ public ClientStreamListener startPut( * @param options RPC-layer hints for this call. */ public FlightInfo getInfo(FlightDescriptor descriptor, CallOption... options) { - return new FlightInfo(CallOptions.wrapStub(blockingStub, options).getFlightInfo(descriptor.toProtocol())); + try { + return new FlightInfo(CallOptions.wrapStub(blockingStub, options).getFlightInfo(descriptor.toProtocol())); + } catch (URISyntaxException e) { + // We don't expect this will happen for conforming Flight implementations. For instance, a Java server + // itself wouldn't be able to construct an invalid Location. + throw new RuntimeException(e); + } } /** diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightEndpoint.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightEndpoint.java index b615ed6c5a5..a34c0a58aa1 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightEndpoint.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightEndpoint.java @@ -17,6 +17,8 @@ package org.apache.arrow.flight; +import java.net.URISyntaxException; +import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; @@ -46,9 +48,11 @@ public FlightEndpoint(Ticket ticket, Location... locations) { /** * Constructs from the protocol buffer representation. */ - public FlightEndpoint(Flight.FlightEndpoint flt) { - locations = flt.getLocationList().stream() - .map(t -> new Location(t)).collect(Collectors.toList()); + public FlightEndpoint(Flight.FlightEndpoint flt) throws URISyntaxException { + locations = new ArrayList<>(); + for (final Flight.Location location : flt.getLocationList()) { + locations.add(new Location(location.getUri())); + } ticket = new Ticket(flt.getTicket()); } @@ -68,10 +72,7 @@ Flight.FlightEndpoint toProtocol() { .setTicket(ticket.toProtocol()); for (Location l : locations) { - b.addLocation(Flight.Location.newBuilder() - .setHost(l.getHost()) - .setPort(l.getPort()) - .build()); + b.addLocation(l.toProtocol()); } return b.build(); } diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightInfo.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightInfo.java index 053b89885ee..4159e4618f1 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightInfo.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightInfo.java @@ -19,8 +19,10 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.net.URISyntaxException; import java.nio.ByteBuffer; import java.nio.channels.Channels; +import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; @@ -67,7 +69,7 @@ public FlightInfo(Schema schema, FlightDescriptor descriptor, List 0 ? @@ -78,7 +80,10 @@ public FlightInfo(Schema schema, FlightDescriptor descriptor, List new FlightEndpoint(t)).collect(Collectors.toList()); + endpoints = new ArrayList<>(); + for (final Flight.FlightEndpoint endpoint : pbFlightInfo.getEndpointList()) { + endpoints.add(new FlightEndpoint(endpoint)); + } bytes = pbFlightInfo.getTotalBytes(); records = pbFlightInfo.getTotalRecords(); } diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java index 0e93139e0e2..58afb5bafe8 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java @@ -18,6 +18,7 @@ package org.apache.arrow.flight; import java.io.IOException; +import java.net.InetSocketAddress; import java.util.concurrent.TimeUnit; import org.apache.arrow.flight.auth.ServerAuthHandler; @@ -26,8 +27,8 @@ import org.apache.arrow.vector.VectorSchemaRoot; import io.grpc.Server; -import io.grpc.ServerBuilder; import io.grpc.ServerInterceptors; +import io.grpc.netty.NettyServerBuilder; public class FlightServer implements AutoCloseable { @@ -42,16 +43,31 @@ public class FlightServer implements AutoCloseable { * Constructs a new instance. * * @param allocator The allocator to use for storing/copying arrow data. - * @param port The port to bind to. + * @param location The location to serve on. * @param producer The underlying business logic for the server. * @param authHandler The authorization handler for the server. */ public FlightServer( BufferAllocator allocator, - int port, + Location location, FlightProducer producer, ServerAuthHandler authHandler) { - this.server = ServerBuilder.forPort(port) + final NettyServerBuilder builder; + switch (location.getUri().getScheme()) { + case LocationSchemes.GRPC_DOMAIN_SOCKET: { + // TODO: need reflection to check if domain sockets are available + throw new UnsupportedOperationException("Domain sockets are not available."); + } + case LocationSchemes.GRPC: + case LocationSchemes.GRPC_INSECURE: { + builder = NettyServerBuilder + .forAddress(new InetSocketAddress(location.getUri().getHost(), location.getUri().getPort())); + break; + } + default: + throw new IllegalArgumentException("Scheme is not supported: " + location.getUri().getScheme()); + } + this.server = builder .maxInboundMessageSize(MAX_GRPC_MESSAGE_SIZE) .addService( ServerInterceptors.intercept( diff --git a/java/flight/src/main/java/org/apache/arrow/flight/Location.java b/java/flight/src/main/java/org/apache/arrow/flight/Location.java index 44ee865a14b..377edeba747 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/Location.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/Location.java @@ -17,40 +17,56 @@ package org.apache.arrow.flight; +import java.net.URI; +import java.net.URISyntaxException; + import org.apache.arrow.flight.impl.Flight; +/** A URI where a Flight stream is available. */ public class Location { - - private final String host; - private final int port; + private final URI uri; /** * Constructs a new instance. * - * @param host the host containing a flight server - * @param port the port the server is listening on. + * @param uri the URI of the Flight service */ - public Location(String host, int port) { + public Location(String uri) throws URISyntaxException { super(); - this.host = host; - this.port = port; + this.uri = new URI(uri); } - Location(Flight.Location location) { - this.host = location.getHost(); - this.port = location.getPort(); + public Location(URI uri) { + super(); + this.uri = uri; } - public String getHost() { - return host; + public URI getUri() { + return uri; } - public int getPort() { - return port; + /** + * Convert this Location into its protocol-level representation. + */ + public Flight.Location toProtocol() { + return Flight.Location.newBuilder().setUri(uri.toString()).build(); } - Flight.Location toProtocol() { - return Flight.Location.newBuilder().setHost(host).setPort(port).build(); + /** Construct a URI for a Flight+gRPC server without transport security. */ + public static Location forGrpcInsecure(String host, int port) { + try { + return new Location(new URI(LocationSchemes.GRPC_INSECURE, null, host, port, null, null, null)); + } catch (URISyntaxException e) { + throw new IllegalArgumentException(e); + } } + /** Construct a URI for a Flight+gRPC server with transport security. */ + public static Location forGrpcTls(String host, int port) { + try { + return new Location(new URI(LocationSchemes.GRPC_TLS, null, host, port, null, null, null)); + } catch (URISyntaxException e) { + throw new IllegalArgumentException(e); + } + } } diff --git a/java/flight/src/main/java/org/apache/arrow/flight/LocationSchemes.java b/java/flight/src/main/java/org/apache/arrow/flight/LocationSchemes.java new file mode 100644 index 00000000000..b652fed48a7 --- /dev/null +++ b/java/flight/src/main/java/org/apache/arrow/flight/LocationSchemes.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +/** + * Constants representing well-known URI schemes for Flight services. + */ +public class LocationSchemes { + public static final String GRPC = "grpc"; + public static final String GRPC_INSECURE = "grpc+tcp"; + public static final String GRPC_DOMAIN_SOCKET = "grpc+unix"; + public static final String GRPC_TLS = "grpc+tls"; +} diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/ExampleFlightServer.java b/java/flight/src/main/java/org/apache/arrow/flight/example/ExampleFlightServer.java index cbeb8fbf04f..2d71b5d490b 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/example/ExampleFlightServer.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/ExampleFlightServer.java @@ -46,7 +46,7 @@ public ExampleFlightServer(BufferAllocator allocator, Location location) { this.allocator = allocator.newChildAllocator("flight-server", 0, Long.MAX_VALUE); this.location = location; this.mem = new InMemoryStore(this.allocator, location); - this.flightServer = new FlightServer(allocator, location.getPort(), mem, ServerAuthHandler.NO_OP); + this.flightServer = new FlightServer(allocator, location, mem, ServerAuthHandler.NO_OP); } public Location getLocation() { @@ -75,7 +75,7 @@ public void close() throws Exception { */ public static void main(String[] args) throws Exception { final BufferAllocator a = new RootAllocator(Long.MAX_VALUE); - final ExampleFlightServer efs = new ExampleFlightServer(a, new Location("localhost", 12233)); + final ExampleFlightServer efs = new ExampleFlightServer(a, Location.forGrpcInsecure("localhost", 12233)); efs.start(); Runtime.getRuntime().addShutdownHook(new Thread(() -> { try { diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java index ed450074a76..1e9f716f184 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java @@ -19,6 +19,7 @@ import java.io.File; import java.io.IOException; +import java.net.URISyntaxException; import java.util.Collections; import java.util.List; @@ -80,7 +81,8 @@ private void run(String[] args) throws ParseException, IOException { final int port = Integer.parseInt(cmd.getOptionValue("port", "31337")); final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); - final FlightClient client = new FlightClient(allocator, new Location(host, port)); + final Location defaultLocation = Location.forGrpcInsecure(host, port); + final FlightClient client = new FlightClient(allocator, defaultLocation); final String inputPath = cmd.getOptionValue("j"); @@ -114,10 +116,10 @@ private void run(String[] args) throws ParseException, IOException { // 3. Download the data from the server. List locations = endpoint.getLocations(); if (locations.size() == 0) { - locations = Collections.singletonList(new Location(host, port)); + locations = Collections.singletonList(defaultLocation); } for (Location location : locations) { - System.out.println("Verifying location " + location.getHost() + ":" + location.getPort()); + System.out.println("Verifying location " + location.getUri()); FlightClient readClient = new FlightClient(allocator, location); FlightStream stream = readClient.getStream(endpoint.getTicket()); VectorSchemaRoot downloadedRoot; diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java index fdad5d1a4e5..cfeaadb3da0 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java @@ -43,7 +43,7 @@ private void run(String[] args) throws Exception { final int port = Integer.parseInt(cmd.getOptionValue("port", "31337")); final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); - final ExampleFlightServer efs = new ExampleFlightServer(allocator, new Location("localhost", port)); + final ExampleFlightServer efs = new ExampleFlightServer(allocator, Location.forGrpcInsecure("localhost", port)); efs.start(); // Print out message for integration test script System.out.println("Server listening on localhost:" + port); diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestBackPressure.java b/java/flight/src/test/java/org/apache/arrow/flight/TestBackPressure.java index dcf8debb55b..e5b8d980f7e 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/TestBackPressure.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/TestBackPressure.java @@ -47,7 +47,7 @@ public void ensureIndependentSteams() throws Exception { try ( final BufferAllocator a = new RootAllocator(Long.MAX_VALUE); final PerformanceTestServer server = FlightTestUtil.getStartedServer( - (port) -> (new PerformanceTestServer(a, new Location(FlightTestUtil.LOCALHOST, port)))); + (port) -> (new PerformanceTestServer(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, port)))); final FlightClient client = new FlightClient(a, server.getLocation()) ) { FlightStream fs1 = client.getStream(client.getInfo( @@ -123,10 +123,11 @@ public void getStream(CallContext context, Ticket ticket, BufferAllocator serverAllocator = allocator.newChildAllocator("server", 0, Long.MAX_VALUE); FlightServer server = FlightTestUtil.getStartedServer( - (port) -> new FlightServer(serverAllocator, port, producer, ServerAuthHandler.NO_OP)); + (port) -> new FlightServer(serverAllocator, Location.forGrpcInsecure("localhost", port), producer, + ServerAuthHandler.NO_OP)); BufferAllocator clientAllocator = allocator.newChildAllocator("client", 0, Long.MAX_VALUE); FlightClient client = - new FlightClient(clientAllocator, new Location(FlightTestUtil.LOCALHOST, server.getPort())) + new FlightClient(clientAllocator, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort())) ) { FlightStream stream = client.getStream(new Ticket(new byte[1])); VectorSchemaRoot root = stream.getRoot(); diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java b/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java index 268580d8e62..a02cd764fc2 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java @@ -17,6 +17,7 @@ package org.apache.arrow.flight; +import java.net.URISyntaxException; import java.util.concurrent.Callable; import java.util.function.BiConsumer; import java.util.function.Consumer; @@ -138,10 +139,12 @@ private void test(BiConsumer consumer) throws Exc BufferAllocator a = new RootAllocator(Long.MAX_VALUE); Producer producer = new Producer(a); FlightServer s = - FlightTestUtil.getStartedServer((port) -> new FlightServer(a, port, producer, ServerAuthHandler.NO_OP))) { + FlightTestUtil.getStartedServer( + (port) -> new FlightServer(a, Location.forGrpcInsecure("localhost", port), producer, + ServerAuthHandler.NO_OP))) { try ( - FlightClient c = new FlightClient(a, new Location(FlightTestUtil.LOCALHOST, s.getPort())); + FlightClient c = new FlightClient(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort())); ) { try (BufferAllocator testAllocator = a.newChildAllocator("testcase", 0, Long.MAX_VALUE)) { consumer.accept(c, testAllocator); @@ -170,7 +173,12 @@ public void listFlights(CallContext context, Criteria criteria, .setType(DescriptorType.CMD) .setCmd(ByteString.copyFrom("cool thing", Charsets.UTF_8))) .build(); - listener.onNext(new FlightInfo(getInfo)); + try { + listener.onNext(new FlightInfo(getInfo)); + } catch (URISyntaxException e) { + listener.onError(e); + return; + } listener.onCompleted(); } @@ -231,7 +239,11 @@ public FlightInfo getFlightInfo(CallContext context, .setType(DescriptorType.CMD) .setCmd(ByteString.copyFrom("cool thing", Charsets.UTF_8))) .build(); - return new FlightInfo(getInfo); + try { + return new FlightInfo(getInfo); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } } @Override diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestCallOptions.java b/java/flight/src/test/java/org/apache/arrow/flight/TestCallOptions.java index 21d7ecae9aa..3af36cbe5c4 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/TestCallOptions.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/TestCallOptions.java @@ -70,8 +70,10 @@ void test(Consumer testFn) { BufferAllocator a = new RootAllocator(Long.MAX_VALUE); Producer producer = new Producer(a); FlightServer s = - FlightTestUtil.getStartedServer((port) -> new FlightServer(a, port, producer, ServerAuthHandler.NO_OP)); - FlightClient client = new FlightClient(a, new Location(FlightTestUtil.LOCALHOST, s.getPort()))) { + FlightTestUtil.getStartedServer( + (port) -> new FlightServer(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, port), producer, + ServerAuthHandler.NO_OP)); + FlightClient client = new FlightClient(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort()))) { testFn.accept(client); } catch (InterruptedException | IOException e) { throw new RuntimeException(e); diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestLargeMessage.java b/java/flight/src/test/java/org/apache/arrow/flight/TestLargeMessage.java index ca31b02c729..2e69c0c9219 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/TestLargeMessage.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/TestLargeMessage.java @@ -44,9 +44,11 @@ public void getLargeMessage() throws Exception { try (final BufferAllocator a = new RootAllocator(Long.MAX_VALUE); final Producer producer = new Producer(a); final FlightServer s = - FlightTestUtil.getStartedServer((port) -> new FlightServer(a, port, producer, ServerAuthHandler.NO_OP))) { + FlightTestUtil.getStartedServer( + (port) -> new FlightServer(a, Location.forGrpcInsecure("localhost", port), producer, + ServerAuthHandler.NO_OP))) { - try (FlightClient client = new FlightClient(a, new Location(FlightTestUtil.LOCALHOST, s.getPort()))) { + try (FlightClient client = new FlightClient(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort()))) { FlightStream stream = client.getStream(new Ticket(new byte[]{})); try (VectorSchemaRoot root = stream.getRoot()) { while (stream.next()) { @@ -73,9 +75,11 @@ public void putLargeMessage() throws Exception { try (final BufferAllocator a = new RootAllocator(Long.MAX_VALUE); final Producer producer = new Producer(a); final FlightServer s = - FlightTestUtil.getStartedServer((port) -> new FlightServer(a, port, producer, ServerAuthHandler.NO_OP))) { + FlightTestUtil.getStartedServer( + (port) -> new FlightServer(a, Location.forGrpcInsecure("localhost", port), producer, + ServerAuthHandler.NO_OP))) { - try (FlightClient client = new FlightClient(a, new Location(FlightTestUtil.LOCALHOST, s.getPort())); + try (FlightClient client = new FlightClient(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort())); BufferAllocator testAllocator = a.newChildAllocator("testcase", 0, Long.MAX_VALUE); VectorSchemaRoot root = generateData(testAllocator)) { final FlightClient.ClientStreamListener listener = client.startPut(FlightDescriptor.path("hello"), root); diff --git a/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java b/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java index 2e967b02b7f..4dfe669d5be 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java @@ -121,7 +121,7 @@ public byte[] getToken(String username, String password) { server = FlightTestUtil.getStartedServer((port) -> new FlightServer( allocator, - port, + Location.forGrpcInsecure("localhost", port), new NoOpFlightProducer() { @Override public void listFlights(CallContext context, Criteria criteria, @@ -151,7 +151,7 @@ public void getStream(CallContext context, Ticket ticket, ServerStreamListener l } }, new BasicServerAuthHandler(validator))); - client = new FlightClient(allocator, new Location(FlightTestUtil.LOCALHOST, server.getPort())); + client = new FlightClient(allocator, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort())); } @After diff --git a/java/flight/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java b/java/flight/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java index 3730a09c65a..20e04167771 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java @@ -18,12 +18,14 @@ package org.apache.arrow.flight.example; import java.io.IOException; +import java.net.URISyntaxException; import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightClient.ClientStreamListener; import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.FlightTestUtil; import org.apache.arrow.flight.Location; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; @@ -49,7 +51,7 @@ public class TestExampleServer { public void start() throws IOException { allocator = new RootAllocator(Long.MAX_VALUE); - Location l = new Location("localhost", 12233); + Location l = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, 12233); if (!Boolean.getBoolean("disableServer")) { System.out.println("Starting server."); server = new ExampleFlightServer(allocator, l); diff --git a/java/flight/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java b/java/flight/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java index 78e5574457b..5d5e600ad29 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java @@ -64,7 +64,7 @@ public PerformanceTestServer(BufferAllocator incomingAllocator, Location locatio this.allocator = incomingAllocator.newChildAllocator("perf-server", 0, Long.MAX_VALUE); this.location = location; this.producer = new PerfProducer(); - this.flightServer = new FlightServer(this.allocator, location.getPort(), producer, ServerAuthHandler.NO_OP); + this.flightServer = new FlightServer(this.allocator, location, producer, ServerAuthHandler.NO_OP); } public Location getLocation() { diff --git a/java/flight/src/test/java/org/apache/arrow/flight/perf/TestPerf.java b/java/flight/src/test/java/org/apache/arrow/flight/perf/TestPerf.java index d580fe20f2a..d40773ef929 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/perf/TestPerf.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/perf/TestPerf.java @@ -82,7 +82,7 @@ public void throughput() throws Exception { final BufferAllocator a = new RootAllocator(Long.MAX_VALUE); final PerformanceTestServer server = FlightTestUtil.getStartedServer((port) -> new PerformanceTestServer(a, - new Location(FlightTestUtil.LOCALHOST, port))); + Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, port))); final FlightClient client = new FlightClient(a, server.getLocation()); ) { final FlightInfo info = client.getInfo(getPerfFlightDescriptor(50_000_000L, 4095, 2)); diff --git a/python/examples/flight/client.py b/python/examples/flight/client.py index d1e60d2710c..0c91e3ec55f 100644 --- a/python/examples/flight/client.py +++ b/python/examples/flight/client.py @@ -131,7 +131,7 @@ def main(): } host, port = args.host.split(':') port = int(port) - client = pyarrow.flight.FlightClient.connect(host, port) + client = pyarrow.flight.FlightClient.connect(f"grpc://{host}:{port}") while True: try: action = pyarrow.flight.Action("healthcheck", b"") diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index 474e0076a14..05d50cd629f 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -20,6 +20,8 @@ import collections import enum +import six + from cython.operator cimport dereference as deref from pyarrow.compat import frombytes, tobytes @@ -69,6 +71,13 @@ cdef class Action: def body(self): return pyarrow_wrap_buffer(self.action.body) + @staticmethod + cdef CAction unwrap(action): + if not isinstance(action, Action): + raise TypeError("Must provide Action, not '{}'".format( + type(action))) + return ( action).action + _ActionType = collections.namedtuple('_ActionType', ['type', 'description']) @@ -158,18 +167,73 @@ cdef class FlightDescriptor: def __repr__(self): return "".format(self.descriptor_type) + @staticmethod + cdef FlightDescriptor unwrap(descriptor): + if not isinstance(descriptor, FlightDescriptor): + raise TypeError("Must provide a FlightDescriptor, not '{}'".format( + type(descriptor))) + return descriptor + -class Ticket: +cdef class Ticket: """A ticket for requesting a Flight stream.""" + + cdef: + CTicket ticket + def __init__(self, ticket): - self.ticket = ticket + self.ticket.ticket = tobytes(ticket) def __repr__(self): - return ''.format(self.ticket) + return ''.format(self.ticket.ticket) -class Location(collections.namedtuple('Location', ['host', 'port'])): - """A location where a Flight stream is available.""" +cdef class Location: + """The location of a Flight service.""" + cdef: + CLocation location + + def __init__(self, uri): + check_status(CLocation.Parse(tobytes(uri), &self.location)) + + def __repr__(self): + return ''.format(self.location.ToString()) + + @staticmethod + def for_grpc_tcp(host, port): + """Create a Location for a TCP-based gRPC service.""" + cdef: + c_string c_host = tobytes(host) + int c_port = port + Location result = Location.__new__(Location) + check_status(CLocation.ForGrpcTcp(c_host, c_port, &result.location)) + return result + + @staticmethod + def for_grpc_unix(path): + """Create a Location for a domain socket-based gRPC service.""" + cdef: + c_string c_path = tobytes(path) + Location result = Location.__new__(Location) + check_status(CLocation.ForGrpcUnix(c_path, &result.location)) + return result + + @staticmethod + cdef Location wrap(CLocation location): + cdef Location result = Location.__new__(Location) + result.location = location + return result + + @staticmethod + cdef CLocation unwrap(object location): + cdef CLocation c_location + if isinstance(location, (six.text_type, six.binary_type)): + check_status(CLocation.Parse(tobytes(location), &c_location)) + return c_location + elif not isinstance(location, Location): + raise TypeError("Must provide a Location, not '{}'".format( + type(location))) + return ( location).location cdef class FlightEndpoint: @@ -184,11 +248,16 @@ cdef class FlightEndpoint: ---------- ticket : Ticket or bytes the ticket needed to access this flight - locations : list of Location or tuples of (host, port) + locations : list of string URIs locations where this flight is available + + Raises + ------ + ArrowException + If one of the location URIs is not a valid URI. """ cdef: - CLocation c_location = CLocation() + CLocation c_location if isinstance(ticket, Ticket): self.endpoint.ticket.ticket = tobytes(ticket.ticket) @@ -196,9 +265,8 @@ cdef class FlightEndpoint: self.endpoint.ticket.ticket = tobytes(ticket) for location in locations: - # Accepts Location namedtuple or tuple - c_location.host = tobytes(location[0]) - c_location.port = location[1] + c_location = CLocation() + check_status(CLocation.Parse(tobytes(location), &c_location)) self.endpoint.locations.push_back(c_location) @property @@ -207,7 +275,7 @@ cdef class FlightEndpoint: @property def locations(self): - return [Location(frombytes(location.host), location.port) + return [Location.wrap(location) for location in self.endpoint.locations] @@ -313,27 +381,21 @@ cdef class FlightClient: .format(self.__class__.__name__)) @staticmethod - def connect(*args): + def connect(location): """Connect to a Flight service on the given host and port.""" cdef: FlightClient result = FlightClient.__new__(FlightClient) int c_port = 0 - c_string c_host - - if len(args) == 1: - # Accept namedtuple or plain tuple - c_host = tobytes(args[0][0]) - c_port = args[0][1] - elif len(args) == 2: - # Accept separate host, port - c_host = tobytes(args[0]) - c_port = args[1] + CLocation c_location + + if isinstance(location, Location): + c_location = ( location).location else: - raise TypeError("FlightClient.connect() takes 1 " - "or 2 arguments ({} given)".format(len(args))) + c_location = CLocation() + check_status(CLocation.Parse(tobytes(location), &c_location)) with nogil: - check_status(CFlightClient.Connect(c_host, c_port, &result.client)) + check_status(CFlightClient.Connect(c_location, &result.client)) return result @@ -375,10 +437,12 @@ cdef class FlightClient: cdef: unique_ptr[CResultStream] results Result result + CAction c_action = Action.unwrap(action) CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) with nogil: - check_status(self.client.get().DoAction(deref(c_options), - action.action, &results)) + check_status( + self.client.get().DoAction(deref(c_options), c_action, + &results)) while True: result = Result.__new__(Result) @@ -414,25 +478,23 @@ cdef class FlightClient: cdef: FlightInfo result = FlightInfo.__new__(FlightInfo) CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) + FlightDescriptor c_descriptor = \ + FlightDescriptor.unwrap(descriptor) with nogil: check_status(self.client.get().GetFlightInfo( - deref(c_options), descriptor.descriptor, &result.info)) + deref(c_options), c_descriptor.descriptor, &result.info)) return result def do_get(self, ticket: Ticket, options: FlightCallOptions = None): """Request the data for a flight.""" cdef: - # TODO: introduce unwrap - CTicket c_ticket unique_ptr[CRecordBatchReader] reader CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) - c_ticket.ticket = ticket.ticket with nogil: - check_status( - self.client.get().DoGet(deref(c_options), c_ticket, &reader)) + check_status(self.client.get().DoGet(deref(c_options), ticket.ticket, &reader)) result = FlightRecordBatchReader() result.reader.reset(reader.release()) return result @@ -444,10 +506,12 @@ cdef class FlightClient: shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema) unique_ptr[CRecordBatchWriter] writer CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) + FlightDescriptor c_descriptor = \ + FlightDescriptor.unwrap(descriptor) with nogil: check_status(self.client.get().DoPut( - deref(c_options), descriptor.descriptor, c_schema, &writer)) + deref(c_options), c_descriptor.descriptor, c_schema, &writer)) result = FlightRecordBatchWriter() result.writer.reset(writer.release()) return result @@ -921,11 +985,11 @@ cdef class FlightServerBase: cdef: unique_ptr[PyFlightServer] server - def run(self, port, auth_handler=None): + def run(self, location, auth_handler=None): cdef: PyFlightServerVtable vtable = PyFlightServerVtable() - int c_port = port PyFlightServer* c_server + CLocation c_location = Location.unwrap(location) unique_ptr[CServerAuthHandler] c_auth_handler if auth_handler: @@ -945,7 +1009,7 @@ cdef class FlightServerBase: c_server = new PyFlightServer(self, vtable) self.server.reset(c_server) with nogil: - check_status(c_server.Init(move(c_auth_handler), c_port)) + check_status(c_server.Init(move(c_auth_handler), c_location)) check_status(c_server.ServeWithSignals()) def list_flights(self, context, criteria): diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd index bac7de54c89..fe2075e3675 100644 --- a/python/pyarrow/includes/libarrow_flight.pxd +++ b/python/pyarrow/includes/libarrow_flight.pxd @@ -80,9 +80,14 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: cdef cppclass CLocation" arrow::flight::Location": CLocation() + c_string ToString() - c_string host - int32_t port + @staticmethod + CStatus Parse(c_string& uri_string, CLocation* location) + @staticmethod + CStatus ForGrpcTcp(c_string& host, int port, CLocation* location) + @staticmethod + CStatus ForGrpcUnix(c_string& path, CLocation* location) cdef cppclass CFlightEndpoint" arrow::flight::FlightEndpoint": CFlightEndpoint() @@ -150,7 +155,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: cdef cppclass CFlightClient" arrow::flight::FlightClient": @staticmethod - CStatus Connect(const c_string& host, int port, + CStatus Connect(const CLocation& location, unique_ptr[CFlightClient]* client) CStatus Authenticate(CFlightCallOptions& options, @@ -224,7 +229,8 @@ cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil: cdef cppclass PyFlightServer: PyFlightServer(object server, PyFlightServerVtable vtable) - CStatus Init(unique_ptr[CServerAuthHandler] auth_handler, int port) + CStatus Init(unique_ptr[CServerAuthHandler] auth_handler, + const CLocation& location) CStatus ServeWithSignals() except * void Shutdown() diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py index 6872b6f77fc..03e7a40fe1f 100644 --- a/python/pyarrow/tests/test_flight.py +++ b/python/pyarrow/tests/test_flight.py @@ -19,6 +19,7 @@ import base64 import contextlib import socket +import tempfile import threading import time @@ -206,26 +207,30 @@ def get_token(self): @contextlib.contextmanager def flight_server(server_base, *args, **kwargs): """Spawn a Flight server on a free port, shutting it down when done.""" - # Find a free port - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - with contextlib.closing(sock) as sock: - sock.bind(('', 0)) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - port = sock.getsockname()[1] - - auth_handler = kwargs.get('auth_handler') + auth_handler = kwargs.pop('auth_handler', None) + location = kwargs.pop('location', None) + + if location is None: + # Find a free port + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + with contextlib.closing(sock) as sock: + sock.bind(('', 0)) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + port = sock.getsockname()[1] + location = flight.Location.for_grpc_tcp("localhost", port) + else: + port = None + ctor_kwargs = kwargs - if auth_handler: - del ctor_kwargs['auth_handler'] server_instance = server_base(*args, **ctor_kwargs) def _server_thread(): - server_instance.run(port, auth_handler=auth_handler) + server_instance.run(location, auth_handler=auth_handler) thread = threading.Thread(target=_server_thread, daemon=True) thread.start() - yield port + yield location server_instance.shutdown() thread.join() @@ -244,12 +249,29 @@ def test_flight_do_get_ints(): def test_flight_do_get_dicts(): table = simple_dicts_table() - with flight_server(ConstantFlightServer) as server_port: - client = flight.FlightClient.connect('localhost', server_port) + with flight_server(ConstantFlightServer) as server_location: + client = flight.FlightClient.connect(server_location) data = client.do_get(flight.Ticket(b'dicts')).read_all() assert data.equals(table) +def test_flight_domain_socket(): + """Try a simple do_get call over a domain socket.""" + data = [ + pa.array([-10, -5, 0, 5, 10]) + ] + table = pa.Table.from_arrays(data, names=['a']) + + with tempfile.NamedTemporaryFile() as sock: + sock.close() + location = flight.Location.for_grpc_unix(sock.name) + with flight_server(ConstantFlightServer, + location=location) as server_location: + client = flight.FlightClient.connect(server_location) + data = client.do_get(flight.Ticket(b'')).read_all() + assert data.equals(table) + + @pytest.mark.slow def test_flight_large_message(): """Try sending/receiving a large message via Flight. @@ -261,8 +283,8 @@ def test_flight_large_message(): pa.array(range(0, 10 * 1024 * 1024)) ], names=['a']) - with flight_server(EchoFlightServer) as server_port: - client = flight.FlightClient.connect('localhost', server_port) + with flight_server(EchoFlightServer) as server_location: + client = flight.FlightClient.connect(server_location) writer = client.do_put(flight.FlightDescriptor.for_path('test'), data.schema) # Write a single giant chunk @@ -278,8 +300,8 @@ def test_flight_generator_stream(): pa.array(range(0, 10 * 1024)) ], names=['a']) - with flight_server(EchoStreamFlightServer) as server_port: - client = flight.FlightClient.connect('localhost', server_port) + with flight_server(EchoStreamFlightServer) as server_location: + client = flight.FlightClient.connect(server_location) writer = client.do_put(flight.FlightDescriptor.for_path('test'), data.schema) writer.write_table(data) @@ -290,8 +312,8 @@ def test_flight_generator_stream(): def test_flight_invalid_generator_stream(): """Try streaming data with mismatched schemas.""" - with flight_server(InvalidStreamFlightServer) as server_port: - client = flight.FlightClient.connect('localhost', server_port) + with flight_server(InvalidStreamFlightServer) as server_location: + client = flight.FlightClient.connect(server_location) with pytest.raises(pa.ArrowException): client.do_get(flight.Ticket(b'')).read_all() @@ -300,8 +322,8 @@ def test_timeout_fires(): """Make sure timeouts fire on slow requests.""" # Do this in a separate thread so that if it fails, we don't hang # the entire test process - with flight_server(SlowFlightServer) as server_port: - client = flight.FlightClient.connect('localhost', server_port) + with flight_server(SlowFlightServer) as server_location: + client = flight.FlightClient.connect(server_location) action = flight.Action("", b"") options = flight.FlightCallOptions(timeout=0.2) with pytest.raises(pa.ArrowIOError, match="Deadline Exceeded"): @@ -310,8 +332,8 @@ def test_timeout_fires(): def test_timeout_passes(): """Make sure timeouts do not fire on fast requests.""" - with flight_server(ConstantFlightServer) as server_port: - client = flight.FlightClient.connect('localhost', server_port) + with flight_server(ConstantFlightServer) as server_location: + client = flight.FlightClient.connect(server_location) options = flight.FlightCallOptions(timeout=0.2) client.do_get(flight.Ticket(b'ints'), options=options).read_all() @@ -328,8 +350,8 @@ def test_timeout_passes(): def test_http_basic_unauth(): """Test that auth fails when not authenticated.""" with flight_server(EchoStreamFlightServer, - auth_handler=basic_auth_handler) as server_port: - client = flight.FlightClient.connect('localhost', server_port) + auth_handler=basic_auth_handler) as server_location: + client = flight.FlightClient.connect(server_location) action = flight.Action("who-am-i", b"") with pytest.raises(pa.ArrowException, match=".*unauthenticated.*"): list(client.do_action(action)) @@ -338,8 +360,8 @@ def test_http_basic_unauth(): def test_http_basic_auth(): """Test a Python implementation of HTTP basic authentication.""" with flight_server(EchoStreamFlightServer, - auth_handler=basic_auth_handler) as server_port: - client = flight.FlightClient.connect('localhost', server_port) + auth_handler=basic_auth_handler) as server_location: + client = flight.FlightClient.connect(server_location) action = flight.Action("who-am-i", b"") client.authenticate(HttpBasicClientAuthHandler('test', 'p4ssw0rd')) identity = next(client.do_action(action)) @@ -349,8 +371,8 @@ def test_http_basic_auth(): def test_http_basic_auth_invalid_password(): """Test that auth fails with the wrong password.""" with flight_server(EchoStreamFlightServer, - auth_handler=basic_auth_handler) as server_port: - client = flight.FlightClient.connect('localhost', server_port) + auth_handler=basic_auth_handler) as server_location: + client = flight.FlightClient.connect(server_location) action = flight.Action("who-am-i", b"") client.authenticate(HttpBasicClientAuthHandler('test', 'wrong')) with pytest.raises(pa.ArrowException, match=".*wrong password.*"): @@ -360,8 +382,8 @@ def test_http_basic_auth_invalid_password(): def test_token_auth(): """Test an auth mechanism that uses a handshake.""" with flight_server(EchoStreamFlightServer, - auth_handler=token_auth_handler) as server_port: - client = flight.FlightClient.connect('localhost', server_port) + auth_handler=token_auth_handler) as server_location: + client = flight.FlightClient.connect(server_location) action = flight.Action("who-am-i", b"") client.authenticate(TokenClientAuthHandler('test', 'p4ssw0rd')) identity = next(client.do_action(action)) @@ -371,7 +393,7 @@ def test_token_auth(): def test_token_auth_invalid(): """Test an auth mechanism that uses a handshake.""" with flight_server(EchoStreamFlightServer, - auth_handler=token_auth_handler) as server_port: - client = flight.FlightClient.connect('localhost', server_port) + auth_handler=token_auth_handler) as server_location: + client = flight.FlightClient.connect(server_location) with pytest.raises(pa.ArrowException, match=".*unauthenticated.*"): client.authenticate(TokenClientAuthHandler('test', 'wrong')) From 47441966720bf611ed2482f85fdcdf32a9ca231c Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 2 May 2019 17:31:56 -0400 Subject: [PATCH 2/5] Use builder for Flight server in Java --- .../org/apache/arrow/flight/FlightClient.java | 75 +++++++++-- .../org/apache/arrow/flight/FlightServer.java | 127 ++++++++++++------ .../flight/example/ExampleFlightServer.java | 2 +- .../integration/IntegrationTestClient.java | 4 +- .../apache/arrow/flight/TestBackPressure.java | 10 +- .../arrow/flight/TestBasicOperation.java | 7 +- .../apache/arrow/flight/TestCallOptions.java | 7 +- .../apache/arrow/flight/TestLargeMessage.java | 13 +- .../apache/arrow/flight/auth/TestAuth.java | 8 +- .../flight/example/TestExampleServer.java | 2 +- .../flight/perf/PerformanceTestServer.java | 2 +- .../apache/arrow/flight/perf/TestPerf.java | 2 +- 12 files changed, 177 insertions(+), 82 deletions(-) diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java index 178897da3b3..3a6fbae87bb 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java @@ -20,7 +20,6 @@ import static io.grpc.stub.ClientCalls.asyncClientStreamingCall; import static io.grpc.stub.ClientCalls.asyncServerStreamingCall; -import java.net.URI; import java.net.URISyntaxException; import java.util.Iterator; import java.util.concurrent.TimeUnit; @@ -50,8 +49,8 @@ import io.grpc.ClientCall; import io.grpc.ManagedChannel; -import io.grpc.ManagedChannelBuilder; import io.grpc.MethodDescriptor; +import io.grpc.netty.NettyChannelBuilder; import io.grpc.stub.ClientCallStreamObserver; import io.grpc.stub.ClientResponseObserver; import io.grpc.stub.StreamObserver; @@ -71,17 +70,9 @@ public class FlightClient implements AutoCloseable { private final MethodDescriptor doGetDescriptor; private final MethodDescriptor doPutDescriptor; - /** - * Construct client for accessing RouteGuide server using the existing channel. - */ - public FlightClient(BufferAllocator incomingAllocator, Location location) { - final ManagedChannelBuilder channelBuilder = - ManagedChannelBuilder.forAddress(location.getUri().getHost(), location.getUri().getPort()) - .maxTraceEvents(MAX_CHANNEL_TRACE_EVENTS) - .maxInboundMessageSize(FlightServer.MAX_GRPC_MESSAGE_SIZE) - .usePlaintext(); + private FlightClient(BufferAllocator incomingAllocator, ManagedChannel channel) { this.allocator = incomingAllocator.newChildAllocator("flight-client", 0, Long.MAX_VALUE); - channel = channelBuilder.build(); + this.channel = channel; blockingStub = FlightServiceGrpc.newBlockingStub(channel).withInterceptors(authInterceptor); asyncStub = FlightServiceGrpc.newStub(channel).withInterceptors(authInterceptor); doGetDescriptor = FlightBindingService.getDoGetDescriptor(allocator); @@ -324,4 +315,64 @@ public void close() throws InterruptedException { allocator.close(); } + /** + * Create a builder for a Flight client. + * @param allocator The allocator to use for the client. + * @param location The location to connect to. + */ + public static Builder builder(BufferAllocator allocator, Location location) { + return new Builder(allocator, location); + } + + /** + * A builder for Flight clients. + */ + public static final class Builder { + + private final BufferAllocator allocator; + private final Location location; + private boolean forceTls = false; + + private Builder(BufferAllocator allocator, Location location) { + this.allocator = allocator; + this.location = location; + } + + /** + * Force the client to connect over TLS. + */ + public Builder useTls() { + this.forceTls = true; + return this; + } + + /** + * Create the client from this builder. + */ + public FlightClient build() { + final NettyChannelBuilder builder; + + switch (location.getUri().getScheme()) { + case LocationSchemes.GRPC: + case LocationSchemes.GRPC_INSECURE: + case LocationSchemes.GRPC_TLS: { + builder = NettyChannelBuilder.forAddress(location.getUri().getHost(), location.getUri().getPort()); + break; + } + default: + throw new IllegalArgumentException("Scheme is not supported: " + location.getUri().getScheme()); + } + + if (this.forceTls || LocationSchemes.GRPC_TLS.equals(location.getUri().getScheme())) { + builder.useTransportSecurity(); + } else { + builder.usePlaintext(); + } + + builder + .maxTraceEvents(MAX_CHANNEL_TRACE_EVENTS) + .maxInboundMessageSize(FlightServer.MAX_GRPC_MESSAGE_SIZE); + return new FlightClient(allocator, builder.build()); + } + } } diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java index 58afb5bafe8..8c9ae1c2e4e 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java @@ -17,8 +17,13 @@ package org.apache.arrow.flight; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; import java.io.IOException; +import java.io.InputStream; import java.net.InetSocketAddress; +import java.util.concurrent.ForkJoinPool; import java.util.concurrent.TimeUnit; import org.apache.arrow.flight.auth.ServerAuthHandler; @@ -39,41 +44,9 @@ public class FlightServer implements AutoCloseable { /** The maximum size of an individual gRPC message. This effectively disables the limit. */ static final int MAX_GRPC_MESSAGE_SIZE = Integer.MAX_VALUE; - /** - * Constructs a new instance. - * - * @param allocator The allocator to use for storing/copying arrow data. - * @param location The location to serve on. - * @param producer The underlying business logic for the server. - * @param authHandler The authorization handler for the server. - */ - public FlightServer( - BufferAllocator allocator, - Location location, - FlightProducer producer, - ServerAuthHandler authHandler) { - final NettyServerBuilder builder; - switch (location.getUri().getScheme()) { - case LocationSchemes.GRPC_DOMAIN_SOCKET: { - // TODO: need reflection to check if domain sockets are available - throw new UnsupportedOperationException("Domain sockets are not available."); - } - case LocationSchemes.GRPC: - case LocationSchemes.GRPC_INSECURE: { - builder = NettyServerBuilder - .forAddress(new InetSocketAddress(location.getUri().getHost(), location.getUri().getPort())); - break; - } - default: - throw new IllegalArgumentException("Scheme is not supported: " + location.getUri().getScheme()); - } - this.server = builder - .maxInboundMessageSize(MAX_GRPC_MESSAGE_SIZE) - .addService( - ServerInterceptors.intercept( - new FlightBindingService(allocator, producer, authHandler), - new ServerAuthInterceptor(authHandler))) - .build(); + /** Create a new instance from a gRPC server. For internal use only. */ + private FlightServer(Server server) { + this.server = server; } /** Start the server. */ @@ -115,20 +88,86 @@ public void close() throws InterruptedException { } } - public interface OutputFlight { - void sendData(int count); + public static Builder builder(BufferAllocator allocator, Location location, FlightProducer producer) { + return new Builder(allocator, location, producer); + } - void done(); + public static final class Builder { - void fail(Throwable t); - } + private final BufferAllocator allocator; + private final Location location; + private final FlightProducer producer; + private ServerAuthHandler authHandler = ServerAuthHandler.NO_OP; - public interface FlightServerHandler { + private InputStream certChain; + private InputStream key; - public FlightInfo getFlightInfo(String descriptor) throws Exception; + Builder(BufferAllocator allocator, Location location, FlightProducer producer) { + this.allocator = allocator; + this.location = location; + this.producer = producer; + } - public OutputFlight setupFlight(VectorSchemaRoot root); + public FlightServer build() { + final NettyServerBuilder builder; + switch (location.getUri().getScheme()) { + case LocationSchemes.GRPC_DOMAIN_SOCKET: { + // TODO: need reflection to check if domain sockets are available + throw new UnsupportedOperationException("Domain sockets are not available."); + } + case LocationSchemes.GRPC: + case LocationSchemes.GRPC_INSECURE: { + builder = NettyServerBuilder + .forAddress(new InetSocketAddress(location.getUri().getHost(), location.getUri().getPort())); + break; + } + case LocationSchemes.GRPC_TLS: { + if (certChain == null) { + throw new IllegalArgumentException("Must provide a certificate and key to serve gRPC over TLS"); + } + builder = NettyServerBuilder + .forAddress(new InetSocketAddress(location.getUri().getHost(), location.getUri().getPort())); + break; + } + default: + throw new IllegalArgumentException("Scheme is not supported: " + location.getUri().getScheme()); + } - } + if (certChain != null) { + builder.useTransportSecurity(certChain, key); + } + final Server server = builder + .executor(new ForkJoinPool()) + .maxInboundMessageSize(MAX_GRPC_MESSAGE_SIZE) + .addService( + ServerInterceptors.intercept( + new FlightBindingService(allocator, producer, authHandler), + new ServerAuthInterceptor(authHandler))) + .build(); + return new FlightServer(server); + } + + /** + * Enable TLS on the server. + * @param certChain The certificate chain to use. + * @param key The private key to use. + */ + public Builder useTls(final File certChain, final File key) throws IOException { + this.certChain = new FileInputStream(certChain); + this.key = new FileInputStream(key); + return this; + } + + public Builder useTls(final InputStream certChain, final InputStream key) { + this.certChain = certChain; + this.key = key; + return this; + } + + public Builder setAuthHandler(ServerAuthHandler authHandler) { + this.authHandler = authHandler; + return this; + } + } } diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/ExampleFlightServer.java b/java/flight/src/main/java/org/apache/arrow/flight/example/ExampleFlightServer.java index 2d71b5d490b..a08f74faaa9 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/example/ExampleFlightServer.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/ExampleFlightServer.java @@ -46,7 +46,7 @@ public ExampleFlightServer(BufferAllocator allocator, Location location) { this.allocator = allocator.newChildAllocator("flight-server", 0, Long.MAX_VALUE); this.location = location; this.mem = new InMemoryStore(this.allocator, location); - this.flightServer = new FlightServer(allocator, location, mem, ServerAuthHandler.NO_OP); + this.flightServer = FlightServer.builder(allocator, location, mem).build(); } public Location getLocation() { diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java index 1e9f716f184..6222093c344 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java @@ -82,7 +82,7 @@ private void run(String[] args) throws ParseException, IOException { final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); final Location defaultLocation = Location.forGrpcInsecure(host, port); - final FlightClient client = new FlightClient(allocator, defaultLocation); + final FlightClient client = FlightClient.builder(allocator, defaultLocation).build(); final String inputPath = cmd.getOptionValue("j"); @@ -120,7 +120,7 @@ private void run(String[] args) throws ParseException, IOException { } for (Location location : locations) { System.out.println("Verifying location " + location.getUri()); - FlightClient readClient = new FlightClient(allocator, location); + FlightClient readClient = FlightClient.builder(allocator, location).build(); FlightStream stream = readClient.getStream(endpoint.getTicket()); VectorSchemaRoot downloadedRoot; try (VectorSchemaRoot root = stream.getRoot()) { diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestBackPressure.java b/java/flight/src/test/java/org/apache/arrow/flight/TestBackPressure.java index e5b8d980f7e..b5dde67d5b4 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/TestBackPressure.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/TestBackPressure.java @@ -48,7 +48,7 @@ public void ensureIndependentSteams() throws Exception { final BufferAllocator a = new RootAllocator(Long.MAX_VALUE); final PerformanceTestServer server = FlightTestUtil.getStartedServer( (port) -> (new PerformanceTestServer(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, port)))); - final FlightClient client = new FlightClient(a, server.getLocation()) + final FlightClient client = FlightClient.builder(a, server.getLocation()).build() ) { FlightStream fs1 = client.getStream(client.getInfo( TestPerf.getPerfFlightDescriptor(110L * BATCH_SIZE, BATCH_SIZE, 1)) @@ -123,11 +123,13 @@ public void getStream(CallContext context, Ticket ticket, BufferAllocator serverAllocator = allocator.newChildAllocator("server", 0, Long.MAX_VALUE); FlightServer server = FlightTestUtil.getStartedServer( - (port) -> new FlightServer(serverAllocator, Location.forGrpcInsecure("localhost", port), producer, - ServerAuthHandler.NO_OP)); + (port) -> FlightServer.builder(serverAllocator, Location.forGrpcInsecure("localhost", port), producer) + .build()); BufferAllocator clientAllocator = allocator.newChildAllocator("client", 0, Long.MAX_VALUE); FlightClient client = - new FlightClient(clientAllocator, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort())) + FlightClient + .builder(clientAllocator, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort())) + .build() ) { FlightStream stream = client.getStream(new Ticket(new byte[1])); VectorSchemaRoot root = stream.getRoot(); diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java b/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java index a02cd764fc2..ab31eb8df92 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java @@ -140,11 +140,12 @@ private void test(BiConsumer consumer) throws Exc Producer producer = new Producer(a); FlightServer s = FlightTestUtil.getStartedServer( - (port) -> new FlightServer(a, Location.forGrpcInsecure("localhost", port), producer, - ServerAuthHandler.NO_OP))) { + (port) -> FlightServer.builder(a, Location.forGrpcInsecure("localhost", port), producer).build() + )) { try ( - FlightClient c = new FlightClient(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort())); + FlightClient c = FlightClient.builder(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort())) + .build() ) { try (BufferAllocator testAllocator = a.newChildAllocator("testcase", 0, Long.MAX_VALUE)) { consumer.accept(c, testAllocator); diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestCallOptions.java b/java/flight/src/test/java/org/apache/arrow/flight/TestCallOptions.java index 3af36cbe5c4..9a24c66dceb 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/TestCallOptions.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/TestCallOptions.java @@ -71,9 +71,10 @@ void test(Consumer testFn) { Producer producer = new Producer(a); FlightServer s = FlightTestUtil.getStartedServer( - (port) -> new FlightServer(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, port), producer, - ServerAuthHandler.NO_OP)); - FlightClient client = new FlightClient(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort()))) { + (port) -> FlightServer.builder(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, port), producer) + .build()); + FlightClient client = FlightClient.builder(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort())) + .build()) { testFn.accept(client); } catch (InterruptedException | IOException e) { throw new RuntimeException(e); diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestLargeMessage.java b/java/flight/src/test/java/org/apache/arrow/flight/TestLargeMessage.java index 2e69c0c9219..48c4a271be4 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/TestLargeMessage.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/TestLargeMessage.java @@ -45,10 +45,10 @@ public void getLargeMessage() throws Exception { final Producer producer = new Producer(a); final FlightServer s = FlightTestUtil.getStartedServer( - (port) -> new FlightServer(a, Location.forGrpcInsecure("localhost", port), producer, - ServerAuthHandler.NO_OP))) { + (port) -> FlightServer.builder(a, Location.forGrpcInsecure("localhost", port), producer).build())) { - try (FlightClient client = new FlightClient(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort()))) { + try (FlightClient client = FlightClient + .builder(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort())).build()) { FlightStream stream = client.getStream(new Ticket(new byte[]{})); try (VectorSchemaRoot root = stream.getRoot()) { while (stream.next()) { @@ -76,10 +76,11 @@ public void putLargeMessage() throws Exception { final Producer producer = new Producer(a); final FlightServer s = FlightTestUtil.getStartedServer( - (port) -> new FlightServer(a, Location.forGrpcInsecure("localhost", port), producer, - ServerAuthHandler.NO_OP))) { + (port) -> FlightServer.builder(a, Location.forGrpcInsecure("localhost", port), producer).build() + )) { - try (FlightClient client = new FlightClient(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort())); + try (FlightClient client = FlightClient + .builder(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort())).build(); BufferAllocator testAllocator = a.newChildAllocator("testcase", 0, Long.MAX_VALUE); VectorSchemaRoot root = generateData(testAllocator)) { final FlightClient.ClientStreamListener listener = client.startPut(FlightDescriptor.path("hello"), root); diff --git a/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java b/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java index 4dfe669d5be..df16a11938a 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java @@ -119,7 +119,7 @@ public byte[] getToken(String username, String password) { } }; - server = FlightTestUtil.getStartedServer((port) -> new FlightServer( + server = FlightTestUtil.getStartedServer((port) -> FlightServer.builder( allocator, Location.forGrpcInsecure("localhost", port), new NoOpFlightProducer() { @@ -149,9 +149,9 @@ public void getStream(CallContext context, Ticket ticket, ServerStreamListener l root.clear(); listener.completed(); } - }, - new BasicServerAuthHandler(validator))); - client = new FlightClient(allocator, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort())); + }).setAuthHandler(new BasicServerAuthHandler(validator)).build()); + client = FlightClient.builder(allocator, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort())) + .build(); } @After diff --git a/java/flight/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java b/java/flight/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java index 20e04167771..a580a6e1717 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java @@ -59,7 +59,7 @@ public void start() throws IOException { } else { System.out.println("Skipping server startup."); } - client = new FlightClient(allocator, l); + client = FlightClient.builder(allocator, l).build(); caseAllocator = allocator.newChildAllocator("test-case", 0, Long.MAX_VALUE); } diff --git a/java/flight/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java b/java/flight/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java index 5d5e600ad29..bc6d202d3bd 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java @@ -64,7 +64,7 @@ public PerformanceTestServer(BufferAllocator incomingAllocator, Location locatio this.allocator = incomingAllocator.newChildAllocator("perf-server", 0, Long.MAX_VALUE); this.location = location; this.producer = new PerfProducer(); - this.flightServer = new FlightServer(this.allocator, location, producer, ServerAuthHandler.NO_OP); + this.flightServer = FlightServer.builder(this.allocator, location, producer).build(); } public Location getLocation() { diff --git a/java/flight/src/test/java/org/apache/arrow/flight/perf/TestPerf.java b/java/flight/src/test/java/org/apache/arrow/flight/perf/TestPerf.java index d40773ef929..a9b9d60b1b9 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/perf/TestPerf.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/perf/TestPerf.java @@ -83,7 +83,7 @@ public void throughput() throws Exception { final PerformanceTestServer server = FlightTestUtil.getStartedServer((port) -> new PerformanceTestServer(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, port))); - final FlightClient client = new FlightClient(a, server.getLocation()); + final FlightClient client = FlightClient.builder(a, server.getLocation()).build(); ) { final FlightInfo info = client.getInfo(getPerfFlightDescriptor(50_000_000L, 4095, 2)); ListeningExecutorService pool = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(4)); From 675acc9bc2051c804d0666f6f698f476ca87e463 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 3 May 2019 09:58:52 -0400 Subject: [PATCH 3/5] Introduce builder for C++ Flight servers --- cpp/src/arrow/flight/client.cc | 16 +++-- cpp/src/arrow/flight/client.h | 14 +++++ cpp/src/arrow/flight/flight-test.cc | 36 ++++++----- cpp/src/arrow/flight/perf-server.cc | 5 +- cpp/src/arrow/flight/server.cc | 25 ++++++-- cpp/src/arrow/flight/server.h | 19 ++++-- .../arrow/flight/test-integration-server.cc | 4 +- cpp/src/arrow/flight/test-server.cc | 6 +- cpp/src/arrow/flight/test-util.cc | 9 +-- cpp/src/arrow/flight/test-util.h | 10 +-- cpp/src/arrow/flight/types.cc | 2 +- cpp/src/arrow/util/uri.cc | 2 +- cpp/src/arrow/util/uri.h | 3 +- python/examples/flight/client.py | 12 +++- python/examples/flight/server.py | 22 ++++++- python/pyarrow/_flight.pyx | 62 +++++++++++-------- python/pyarrow/includes/libarrow_flight.pxd | 15 ++++- python/pyarrow/tests/test_flight.py | 10 +++ 18 files changed, 190 insertions(+), 82 deletions(-) diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 4a7beded950..d8316475550 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -231,17 +231,20 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter { class FlightClient::FlightClientImpl { public: - Status Connect(const Location& location) { + Status Connect(const Location& location, const FlightClientOptions& options) { const std::string& scheme = location.scheme(); std::stringstream grpc_uri; std::shared_ptr creds; if (scheme == kSchemeGrpc || scheme == kSchemeGrpcTcp || scheme == kSchemeGrpcTls) { - // TODO(wesm): Support other kinds of GRPC ChannelCredentials grpc_uri << location.uri_->host() << ":" << location.uri_->port_text(); if (scheme == "grpc+tls") { - creds = grpc::SslCredentials(grpc::SslCredentialsOptions()); + grpc::SslCredentialsOptions ssl_options; + if (!options.tls_root_certs.empty()) { + ssl_options.pem_root_certs = options.tls_root_certs; + } + creds = grpc::SslCredentials(ssl_options); } else { creds = grpc::InsecureChannelCredentials(); } @@ -401,8 +404,13 @@ FlightClient::~FlightClient() {} Status FlightClient::Connect(const Location& location, std::unique_ptr* client) { + return Connect(location, {}, client); +} + +Status FlightClient::Connect(const Location& location, const FlightClientOptions& options, + std::unique_ptr* client) { client->reset(new FlightClient); - return (*client)->impl_->Connect(location); + return (*client)->impl_->Connect(location, options); } Status FlightClient::Authenticate(const FlightCallOptions& options, diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index f48f3edf839..276ffc70212 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -57,6 +57,11 @@ class ARROW_EXPORT FlightCallOptions { TimeoutDuration timeout; }; +class ARROW_EXPORT FlightClientOptions { + public: + std::string tls_root_certs; +}; + /// \brief Client class for Arrow Flight RPC services (gRPC-based). /// API experimental for now class ARROW_EXPORT FlightClient { @@ -70,6 +75,15 @@ class ARROW_EXPORT FlightClient { /// successful static Status Connect(const Location& location, std::unique_ptr* client); + /// \brief Connect to an unauthenticated flight service + /// \param[in] location the URI + /// \param[in] options Other options for setting up the client + /// \param[out] client the created FlightClient + /// \return Status OK status may not indicate that the connection was + /// successful + static Status Connect(const Location& location, const FlightClientOptions& options, + std::unique_ptr* client); + /// \brief Authenticate to the server using the given handler. /// \param[in] options Per-RPC options /// \param[in] auth_handler The authentication mechanism to use diff --git a/cpp/src/arrow/flight/flight-test.cc b/cpp/src/arrow/flight/flight-test.cc index d99e7c4b020..e3e01df5037 100644 --- a/cpp/src/arrow/flight/flight-test.cc +++ b/cpp/src/arrow/flight/flight-test.cc @@ -255,24 +255,25 @@ class DoPutTestServer : public FlightServerBase { class TestAuthHandler : public ::testing::Test { public: void SetUp() { - port_ = 30000; - server_.reset(new InProcessTestServer( - std::unique_ptr(new AuthTestServer), port_)); - ASSERT_OK(server_->Start(std::unique_ptr( - new TestServerAuthHandler("user", "p4ssw0rd")))); + Location location; + std::unique_ptr server(new AuthTestServer); + + ASSERT_OK(Location::ForGrpcTcp("localhost", 30000, &location)); + FlightServerOptions options(location); + options.auth_handler = + std::unique_ptr(new TestServerAuthHandler("user", "p4ssw0rd")); + ASSERT_OK(server->Init(options)); + + server_.reset(new InProcessTestServer(std::move(server), location)); + ASSERT_OK(server_->Start()); ASSERT_OK(ConnectClient()); } void TearDown() { server_->Stop(); } - Status ConnectClient() { - Location location; - RETURN_NOT_OK(Location::ForGrpcTcp("localhost", port_, &location)); - return FlightClient::Connect(location, &client_); - } + Status ConnectClient() { return FlightClient::Connect(server_->location(), &client_); } protected: - int port_; std::unique_ptr client_; std::unique_ptr server_; }; @@ -280,17 +281,21 @@ class TestAuthHandler : public ::testing::Test { class TestDoPut : public ::testing::Test { public: void SetUp() { - port_ = 30000; + Location location; + ASSERT_OK(Location::ForGrpcTcp("localhost", 30000, &location)); + do_put_server_ = new DoPutTestServer(); server_.reset(new InProcessTestServer( - std::unique_ptr(do_put_server_), port_)); - ASSERT_OK(server_->Start({})); + std::unique_ptr(do_put_server_), location)); + FlightServerOptions options(location); + ASSERT_OK(do_put_server_->Init(options)); + ASSERT_OK(server_->Start()); ASSERT_OK(ConnectClient()); } void TearDown() { server_->Stop(); } - Status ConnectClient() { return FlightClient::Connect("localhost", port_, &client_); } + Status ConnectClient() { return FlightClient::Connect(server_->location(), &client_); } void CheckBatches(FlightDescriptor expected_descriptor, const BatchVector& expected_batches) { @@ -314,7 +319,6 @@ class TestDoPut : public ::testing::Test { } protected: - int port_; std::unique_ptr client_; std::unique_ptr server_; DoPutTestServer* do_put_server_; diff --git a/cpp/src/arrow/flight/perf-server.cc b/cpp/src/arrow/flight/perf-server.cc index f65cbd15abf..d1131175422 100644 --- a/cpp/src/arrow/flight/perf-server.cc +++ b/cpp/src/arrow/flight/perf-server.cc @@ -207,8 +207,9 @@ int main(int argc, char** argv) { arrow::flight::Location location; ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location)); - ARROW_CHECK_OK( - g_server->Init(std::unique_ptr(), location)); + arrow::flight::FlightServerOptions options(location); + + ARROW_CHECK_OK(g_server->Init(options)); // Exit with a clean error code (0) on SIGTERM ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM})); std::cout << "Server port: " << FLAGS_port << std::endl; diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index d8c5525ee98..46dd5bf0119 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -438,24 +439,38 @@ void FlightServerBase::Impl::HandleSignal(int signum) { } } +FlightServerOptions::FlightServerOptions(const Location& location_) + : location(location_), auth_handler(nullptr) {} + FlightServerBase::FlightServerBase() { impl_.reset(new Impl); } FlightServerBase::~FlightServerBase() {} -Status FlightServerBase::Init(std::unique_ptr auth_handler, - const Location& location) { - std::shared_ptr handler = std::move(auth_handler); +Status FlightServerBase::Init(FlightServerOptions& options) { + std::shared_ptr handler = std::move(options.auth_handler); impl_->service_.reset(new FlightServiceImpl(handler, this)); grpc::ServerBuilder builder; // Allow uploading messages of any length builder.SetMaxReceiveMessageSize(-1); + const Location& location = options.location; const std::string scheme = location.scheme(); - if (scheme == kSchemeGrpc || scheme == kSchemeGrpcTcp) { + if (scheme == kSchemeGrpc || scheme == kSchemeGrpcTcp || scheme == kSchemeGrpcTls) { std::stringstream address; address << location.uri_->host() << ':' << location.uri_->port_text(); - builder.AddListeningPort(address.str(), grpc::InsecureServerCredentials()); + + std::shared_ptr creds; + if (scheme == kSchemeGrpcTls) { + grpc::SslServerCredentialsOptions ssl_options; + ssl_options.pem_key_cert_pairs.push_back( + {options.tls_private_key, options.tls_cert_chain}); + creds = grpc::SslServerCredentials(ssl_options); + } else { + creds = grpc::InsecureServerCredentials(); + } + + builder.AddListeningPort(address.str(), creds); } else if (scheme == kSchemeGrpcUnix) { std::stringstream address; address << "unix:" << location.uri_->path(); diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h index 752b89e85bf..28e87aa2fa9 100644 --- a/cpp/src/arrow/flight/server.h +++ b/cpp/src/arrow/flight/server.h @@ -24,6 +24,7 @@ #include #include +#include "arrow/flight/server_auth.h" #include "arrow/flight/types.h" // IWYU pragma: keep #include "arrow/ipc/dictionary.h" #include "arrow/memory_pool.h" @@ -38,8 +39,6 @@ class Status; namespace flight { -class ServerAuthHandler; - /// \brief Interface that produces a sequence of IPC payloads to be sent in /// FlightData protobuf messages class ARROW_EXPORT FlightDataStream { @@ -89,6 +88,16 @@ class ARROW_EXPORT ServerCallContext { virtual const std::string& peer_identity() const = 0; }; +class ARROW_EXPORT FlightServerOptions { + public: + explicit FlightServerOptions(const Location& location_); + + Location location; + std::unique_ptr auth_handler; + std::string tls_cert_chain; + std::string tls_private_key; +}; + /// \brief Skeleton RPC server implementation which can be used to create /// custom servers by implementing its abstract methods class ARROW_EXPORT FlightServerBase { @@ -100,10 +109,8 @@ class ARROW_EXPORT FlightServerBase { /// \brief Initialize a Flight server listening at the given location. /// This method must be called before any other method. - /// \param[in] auth_handler The authentication handler. May be - /// nullptr if no authentication is desired. - /// \param[in] location The location to serve on. - Status Init(std::unique_ptr auth_handler, const Location& location); + /// \param[in] options The configuration for this server. + Status Init(FlightServerOptions& options); /// \brief Set the server to stop when receiving any of the given signal /// numbers. diff --git a/cpp/src/arrow/flight/test-integration-server.cc b/cpp/src/arrow/flight/test-integration-server.cc index d72e838213b..c5bb180663b 100644 --- a/cpp/src/arrow/flight/test-integration-server.cc +++ b/cpp/src/arrow/flight/test-integration-server.cc @@ -126,7 +126,9 @@ int main(int argc, char** argv) { g_server.reset(new arrow::flight::FlightIntegrationTestServer); arrow::flight::Location location; ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location)); - ARROW_CHECK_OK(g_server->Init(nullptr, location)); + arrow::flight::FlightServerOptions options(location); + + ARROW_CHECK_OK(g_server->Init(options)); // Exit with a clean error code (0) on SIGTERM ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM})); diff --git a/cpp/src/arrow/flight/test-server.cc b/cpp/src/arrow/flight/test-server.cc index 29af87db601..f72fd3caeea 100644 --- a/cpp/src/arrow/flight/test-server.cc +++ b/cpp/src/arrow/flight/test-server.cc @@ -139,10 +139,12 @@ int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); g_server.reset(new arrow::flight::FlightTestServer); + arrow::flight::Location location; ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location)); - ARROW_CHECK_OK( - g_server->Init(std::unique_ptr(), location)); + arrow::flight::FlightServerOptions options(location); + + ARROW_CHECK_OK(g_server->Init(options)); // Exit with a clean error code (0) on SIGTERM ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM})); diff --git a/cpp/src/arrow/flight/test-util.cc b/cpp/src/arrow/flight/test-util.cc index 01577958640..f870dcbfa71 100644 --- a/cpp/src/arrow/flight/test-util.cc +++ b/cpp/src/arrow/flight/test-util.cc @@ -118,22 +118,17 @@ bool TestServer::IsRunning() { return server_process_->running(); } int TestServer::port() const { return port_; } -Status InProcessTestServer::Start(std::unique_ptr auth_handler) { - Location location; - RETURN_NOT_OK(Location::ForGrpcTcp("localhost", port_, &location)); - RETURN_NOT_OK(server_->Init(std::move(auth_handler), location)); +Status InProcessTestServer::Start() { thread_ = std::thread([this]() { ARROW_EXPECT_OK(server_->Serve()); }); return Status::OK(); } -Status InProcessTestServer::Start() { return Start({}); } - void InProcessTestServer::Stop() { server_->Shutdown(); thread_.join(); } -int InProcessTestServer::port() const { return port_; } +const Location& InProcessTestServer::location() const { return location_; } InProcessTestServer::~InProcessTestServer() { // Make sure server shuts down properly diff --git a/cpp/src/arrow/flight/test-util.h b/cpp/src/arrow/flight/test-util.h index ef18e3721b6..b5bc31a4c3b 100644 --- a/cpp/src/arrow/flight/test-util.h +++ b/cpp/src/arrow/flight/test-util.h @@ -63,17 +63,17 @@ class ARROW_EXPORT TestServer { class ARROW_EXPORT InProcessTestServer { public: - explicit InProcessTestServer(std::unique_ptr server, int port) - : server_(std::move(server)), port_(port), thread_() {} + explicit InProcessTestServer(std::unique_ptr server, + const Location& location) + : server_(std::move(server)), location_(location), thread_() {} ~InProcessTestServer(); Status Start(); - Status Start(std::unique_ptr auth_handler); void Stop(); - int port() const; + const Location& location() const; private: std::unique_ptr server_; - int port_; + Location location_; std::thread thread_; }; diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc index 985cb5fc03a..524b0ee716e 100644 --- a/cpp/src/arrow/flight/types.cc +++ b/cpp/src/arrow/flight/types.cc @@ -102,7 +102,7 @@ Status Location::ForGrpcUnix(const std::string& path, Location* location) { return Location::Parse(uri_string.str(), location); } -std::string Location::ToString() const { return uri_->to_string(); } +std::string Location::ToString() const { return uri_->ToString(); } std::string Location::scheme() const { std::string scheme = uri_->scheme(); if (scheme.empty()) { diff --git a/cpp/src/arrow/util/uri.cc b/cpp/src/arrow/util/uri.cc index e79c9826b9b..579d4c7f243 100644 --- a/cpp/src/arrow/util/uri.cc +++ b/cpp/src/arrow/util/uri.cc @@ -120,7 +120,7 @@ std::string Uri::path() const { return ss.str(); } -const std::string& Uri::to_string() const { return impl_->string_rep_; } +const std::string& Uri::ToString() const { return impl_->string_rep_; } Status Uri::Parse(const std::string& uri_string) { impl_->Reset(); diff --git a/cpp/src/arrow/util/uri.h b/cpp/src/arrow/util/uri.h index 7327c2d7876..ce082ccc8e6 100644 --- a/cpp/src/arrow/util/uri.h +++ b/cpp/src/arrow/util/uri.h @@ -54,8 +54,9 @@ class ARROW_EXPORT Uri { int32_t port() const; /// The URI path component. std::string path() const; + /// Get the string representation of this URI. - const std::string& to_string() const; + const std::string& ToString() const; /// Factory function to parse a URI from its string representation. Status Parse(const std::string& uri_string); diff --git a/python/examples/flight/client.py b/python/examples/flight/client.py index 0c91e3ec55f..15db5124c42 100644 --- a/python/examples/flight/client.py +++ b/python/examples/flight/client.py @@ -90,6 +90,8 @@ def get_flight(args, client): def _add_common_arguments(parser): + parser.add_argument('--tls', action='store_true') + parser.add_argument('--tls-roots', default=None) parser.add_argument('host', type=str, help="The host to connect to.") @@ -131,7 +133,15 @@ def main(): } host, port = args.host.split(':') port = int(port) - client = pyarrow.flight.FlightClient.connect(f"grpc://{host}:{port}") + scheme = "grpc+tcp" + connection_args = {} + if args.tls: + scheme = "grpc+tls" + if args.tls_roots: + with open(args.tls_roots, "rb") as root_certs: + connection_args["tls_root_certs"] = root_certs.read() + client = pyarrow.flight.FlightClient.connect(f"{scheme}://{host}:{port}", + **connection_args) while True: try: action = pyarrow.flight.Action("healthcheck", b"") diff --git a/python/examples/flight/server.py b/python/examples/flight/server.py index e6edb8d3a1b..5a3d101c3a9 100644 --- a/python/examples/flight/server.py +++ b/python/examples/flight/server.py @@ -17,6 +17,7 @@ """An example Flight Python server.""" +import argparse import ast import threading import time @@ -76,7 +77,7 @@ def do_get(self, context, ticket): return None return pyarrow.flight.RecordBatchStream(self.flights[key]) - def list_actions(self, context): + def list_actions(self): return [ ("clear", "Clear the stored flights."), ("shutdown", "Shut down this server."), @@ -100,8 +101,25 @@ def _shutdown(self): def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--port", type=int, default=5005) + parser.add_argument("--tls", nargs=2, default=None) + + args = parser.parse_args() + server = FlightServer() - server.run(5005) + kwargs = {} + scheme = "grpc+tcp" + if args.tls: + scheme = "grpc+tls" + with open(args.tls[0], "rb") as cert_file: + kwargs["tls_cert_chain"] = cert_file.read() + with open(args.tls[1], "rb") as key_file: + kwargs["tls_private_key"] = key_file.read() + + location = "{}://0.0.0.0:{}".format(scheme, args.port) + print("Serving on", location) + server.run(location, **kwargs) if __name__ == '__main__': diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index 05d50cd629f..430f2b2abc5 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -72,7 +72,7 @@ cdef class Action: return pyarrow_wrap_buffer(self.action.body) @staticmethod - cdef CAction unwrap(action): + cdef CAction unwrap(action) except *: if not isinstance(action, Action): raise TypeError("Must provide Action, not '{}'".format( type(action))) @@ -168,11 +168,11 @@ cdef class FlightDescriptor: return "".format(self.descriptor_type) @staticmethod - cdef FlightDescriptor unwrap(descriptor): + cdef CFlightDescriptor unwrap(descriptor) except *: if not isinstance(descriptor, FlightDescriptor): raise TypeError("Must provide a FlightDescriptor, not '{}'".format( type(descriptor))) - return descriptor + return ( descriptor).descriptor cdef class Ticket: @@ -225,9 +225,9 @@ cdef class Location: return result @staticmethod - cdef CLocation unwrap(object location): + cdef CLocation unwrap(object location) except *: cdef CLocation c_location - if isinstance(location, (six.text_type, six.binary_type)): + if isinstance(location, six.text_type): check_status(CLocation.Parse(tobytes(location), &c_location)) return c_location elif not isinstance(location, Location): @@ -381,21 +381,20 @@ cdef class FlightClient: .format(self.__class__.__name__)) @staticmethod - def connect(location): + def connect(location, **kwargs): """Connect to a Flight service on the given host and port.""" cdef: FlightClient result = FlightClient.__new__(FlightClient) int c_port = 0 - CLocation c_location + CLocation c_location = Location.unwrap(location) + CFlightClientOptions c_options - if isinstance(location, Location): - c_location = ( location).location - else: - c_location = CLocation() - check_status(CLocation.Parse(tobytes(location), &c_location)) + if "tls_root_certs" in kwargs: + c_options.tls_root_certs = tobytes(kwargs["tls_root_certs"]) with nogil: - check_status(CFlightClient.Connect(c_location, &result.client)) + check_status(CFlightClient.Connect(c_location, c_options, + &result.client)) return result @@ -478,12 +477,12 @@ cdef class FlightClient: cdef: FlightInfo result = FlightInfo.__new__(FlightInfo) CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) - FlightDescriptor c_descriptor = \ + CFlightDescriptor c_descriptor = \ FlightDescriptor.unwrap(descriptor) with nogil: check_status(self.client.get().GetFlightInfo( - deref(c_options), c_descriptor.descriptor, &result.info)) + deref(c_options), c_descriptor, &result.info)) return result @@ -494,7 +493,8 @@ cdef class FlightClient: CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) with nogil: - check_status(self.client.get().DoGet(deref(c_options), ticket.ticket, &reader)) + check_status(self.client.get().DoGet( + deref(c_options), ticket.ticket, &reader)) result = FlightRecordBatchReader() result.reader.reset(reader.release()) return result @@ -506,12 +506,12 @@ cdef class FlightClient: shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema) unique_ptr[CRecordBatchWriter] writer CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) - FlightDescriptor c_descriptor = \ + CFlightDescriptor c_descriptor = \ FlightDescriptor.unwrap(descriptor) with nogil: check_status(self.client.get().DoPut( - deref(c_options), c_descriptor.descriptor, c_schema, &writer)) + deref(c_options), c_descriptor, c_schema, &writer)) result = FlightRecordBatchWriter() result.writer.reset(writer.release()) return result @@ -520,7 +520,7 @@ cdef class FlightClient: cdef class FlightDataStream: """Abstract base class for Flight data streams.""" - cdef CFlightDataStream* to_stream(self): + cdef CFlightDataStream* to_stream(self) except *: """Create the C++ data stream for the backing Python object. We don't expose the C++ object to Python, so we can manage its @@ -547,7 +547,7 @@ cdef class RecordBatchStream(FlightDataStream): "but got: {}".format(type(data_source))) self.data_source = data_source - cdef CFlightDataStream* to_stream(self): + cdef CFlightDataStream* to_stream(self) except *: cdef: shared_ptr[CRecordBatchReader] reader if isinstance(self.data_source, _CRecordBatchReader): @@ -585,7 +585,7 @@ cdef class GeneratorStream(FlightDataStream): self.schema = pyarrow_unwrap_schema(schema) self.generator = iter(generator) - cdef CFlightDataStream* to_stream(self): + cdef CFlightDataStream* to_stream(self) except *: cdef: function[cb_data_stream_next] callback = &_data_stream_next return new CPyGeneratorFlightDataStream(self, self.schema, callback) @@ -985,20 +985,30 @@ cdef class FlightServerBase: cdef: unique_ptr[PyFlightServer] server - def run(self, location, auth_handler=None): + def run(self, location, auth_handler=None, **kwargs): cdef: PyFlightServerVtable vtable = PyFlightServerVtable() PyFlightServer* c_server - CLocation c_location = Location.unwrap(location) - unique_ptr[CServerAuthHandler] c_auth_handler + unique_ptr[CFlightServerOptions] c_options + + c_options.reset(new CFlightServerOptions(Location.unwrap(location))) if auth_handler: if not isinstance(auth_handler, ServerAuthHandler): raise TypeError("auth_handler must be a ServerAuthHandler, " "not a '{}'".format(type(auth_handler))) - c_auth_handler.reset( + c_options.get().auth_handler.reset( ( auth_handler).to_handler()) + if "tls_cert_chain" in kwargs: + if "tls_private_key" not in kwargs: + raise ValueError( + "Must provide both cert chain and private key") + c_options.get().tls_cert_chain = tobytes( + kwargs["tls_cert_chain"]) + c_options.get().tls_private_key = tobytes( + kwargs["tls_private_key"]) + vtable.list_flights = &_list_flights vtable.get_flight_info = &_get_flight_info vtable.do_put = &_do_put @@ -1009,7 +1019,7 @@ cdef class FlightServerBase: c_server = new PyFlightServer(self, vtable) self.server.reset(c_server) with nogil: - check_status(c_server.Init(move(c_auth_handler), c_location)) + check_status(c_server.Init(deref(c_options))) check_status(c_server.ServeWithSignals()) def list_flights(self, context, criteria): diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd index fe2075e3675..d3ff4f259b0 100644 --- a/python/pyarrow/includes/libarrow_flight.pxd +++ b/python/pyarrow/includes/libarrow_flight.pxd @@ -153,9 +153,21 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: CFlightCallOptions() CTimeoutDuration timeout + cdef cppclass CFlightServerOptions" arrow::flight::FlightServerOptions": + CFlightServerOptions(const CLocation& location) + CLocation location + unique_ptr[CServerAuthHandler] auth_handler + c_string tls_cert_chain + c_string tls_private_key + + cdef cppclass CFlightClientOptions" arrow::flight::FlightClientOptions": + CFlightClientOptions() + c_string tls_root_certs + cdef cppclass CFlightClient" arrow::flight::FlightClient": @staticmethod CStatus Connect(const CLocation& location, + const CFlightClientOptions& options, unique_ptr[CFlightClient]* client) CStatus Authenticate(CFlightCallOptions& options, @@ -229,8 +241,7 @@ cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil: cdef cppclass PyFlightServer: PyFlightServer(object server, PyFlightServerVtable vtable) - CStatus Init(unique_ptr[CServerAuthHandler] auth_handler, - const CLocation& location) + CStatus Init(CFlightServerOptions& options) CStatus ServeWithSignals() except * void Shutdown() diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py index 03e7a40fe1f..acdaacfaafe 100644 --- a/python/pyarrow/tests/test_flight.py +++ b/python/pyarrow/tests/test_flight.py @@ -397,3 +397,13 @@ def test_token_auth_invalid(): client = flight.FlightClient.connect(server_location) with pytest.raises(pa.ArrowException, match=".*unauthenticated.*"): client.authenticate(TokenClientAuthHandler('test', 'wrong')) + + +def test_location_invalid(): + """Test constructing invalid URIs.""" + with pytest.raises(pa.ArrowException, match=".*Cannot parse URI:.*"): + flight.FlightClient.connect("%") + + server = ConstantFlightServer() + with pytest.raises(pa.ArrowException, match=".*Cannot parse URI:.*"): + server.run("%") From 5c127631b41986afa5a54d55512d677f95ecf39c Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 13 May 2019 15:57:43 -0400 Subject: [PATCH 4/5] Make Python Flight bindings more complete --- cpp/src/arrow/flight/types.cc | 4 ++ cpp/src/arrow/flight/types.h | 4 +- cpp/src/arrow/python/flight.cc | 2 +- cpp/src/arrow/python/flight.h | 2 +- python/pyarrow/_flight.pyx | 23 +++++++- python/pyarrow/includes/libarrow_flight.pxd | 9 +-- python/pyarrow/tests/test_flight.py | 64 +++++++++++++++++++++ 7 files changed, 99 insertions(+), 9 deletions(-) diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc index 524b0ee716e..dadb51066cf 100644 --- a/cpp/src/arrow/flight/types.cc +++ b/cpp/src/arrow/flight/types.cc @@ -112,6 +112,10 @@ std::string Location::scheme() const { return scheme; } +bool Location::Equals(const Location& other) const { + return ToString() == other.ToString(); +} + SimpleFlightListing::SimpleFlightListing(const std::vector& flights) : position_(0), flights_(flights) {} diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index 9caee997ae1..ba07b999635 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -157,8 +157,10 @@ struct Location { /// \brief Get the scheme of this URI. std::string scheme() const; + bool Equals(const Location& other) const; + friend bool operator==(const Location& left, const Location& right) { - return left.ToString() == right.ToString(); + return left.Equals(right); } friend bool operator!=(const Location& left, const Location& right) { return !(left == right); diff --git a/cpp/src/arrow/python/flight.cc b/cpp/src/arrow/python/flight.cc index b033b341ff2..4db31570d81 100644 --- a/cpp/src/arrow/python/flight.cc +++ b/cpp/src/arrow/python/flight.cc @@ -217,7 +217,7 @@ Status PyGeneratorFlightDataStream::Next(FlightPayload* payload) { Status CreateFlightInfo(const std::shared_ptr& schema, const arrow::flight::FlightDescriptor& descriptor, const std::vector& endpoints, - uint64_t total_records, uint64_t total_bytes, + int64_t total_records, int64_t total_bytes, std::unique_ptr* out) { arrow::flight::FlightInfo::Data flight_data; RETURN_NOT_OK(arrow::flight::internal::SchemaToString(*schema, &flight_data.schema)); diff --git a/cpp/src/arrow/python/flight.h b/cpp/src/arrow/python/flight.h index 19fbb02c592..432885cb764 100644 --- a/cpp/src/arrow/python/flight.h +++ b/cpp/src/arrow/python/flight.h @@ -197,7 +197,7 @@ ARROW_PYTHON_EXPORT Status CreateFlightInfo(const std::shared_ptr& schema, const arrow::flight::FlightDescriptor& descriptor, const std::vector& endpoints, - uint64_t total_records, uint64_t total_bytes, + int64_t total_records, int64_t total_bytes, std::unique_ptr* out); } // namespace flight diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index 430f2b2abc5..82745d34c8c 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -184,6 +184,10 @@ cdef class Ticket: def __init__(self, ticket): self.ticket.ticket = tobytes(ticket) + @property + def ticket(self): + return self.ticket.ticket + def __repr__(self): return ''.format(self.ticket.ticket) @@ -199,6 +203,18 @@ cdef class Location: def __repr__(self): return ''.format(self.location.ToString()) + @property + def uri(self): + return self.location.ToString() + + def equals(self, Location other): + return self == other + + def __eq__(self, other): + if not isinstance(other, Location): + return NotImplemented + return self.location.Equals(( other).location) + @staticmethod def for_grpc_tcp(host, port): """Create a Location for a TCP-based gRPC service.""" @@ -265,8 +281,11 @@ cdef class FlightEndpoint: self.endpoint.ticket.ticket = tobytes(ticket) for location in locations: - c_location = CLocation() - check_status(CLocation.Parse(tobytes(location), &c_location)) + if isinstance(location, Location): + c_location = ( location).location + else: + c_location = CLocation() + check_status(CLocation.Parse(tobytes(location), &c_location)) self.endpoint.locations.push_back(c_location) @property diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd index d3ff4f259b0..4b749903d3d 100644 --- a/python/pyarrow/includes/libarrow_flight.pxd +++ b/python/pyarrow/includes/libarrow_flight.pxd @@ -81,6 +81,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: cdef cppclass CLocation" arrow::flight::Location": CLocation() c_string ToString() + c_bool Equals(const CLocation& other) @staticmethod CStatus Parse(c_string& uri_string, CLocation* location) @@ -97,8 +98,8 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: cdef cppclass CFlightInfo" arrow::flight::FlightInfo": CFlightInfo(CFlightInfo info) - uint64_t total_records() - uint64_t total_bytes() + int64_t total_records() + int64_t total_bytes() CStatus GetSchema(CDictionaryMemo* memo, shared_ptr[CSchema]* out) CFlightDescriptor& descriptor() const vector[CFlightEndpoint]& endpoints() @@ -274,8 +275,8 @@ cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil: shared_ptr[CSchema] schema, CFlightDescriptor& descriptor, vector[CFlightEndpoint] endpoints, - uint64_t total_records, - uint64_t total_bytes, + int64_t total_records, + int64_t total_bytes, unique_ptr[CFlightInfo]* out) cdef extern from "" namespace "std": diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py index acdaacfaafe..bbe5507ffdd 100644 --- a/python/pyarrow/tests/test_flight.py +++ b/python/pyarrow/tests/test_flight.py @@ -103,6 +103,42 @@ def do_action(self, context, action): raise NotImplementedError +class GetInfoFlightServer(flight.FlightServerBase): + """A Flight server that tests GetFlightInfo.""" + + def get_flight_info(self, context, descriptor): + return flight.FlightInfo( + pa.schema([('a', pa.int32())]), + descriptor, + [ + flight.FlightEndpoint(b'', ['grpc://test']), + flight.FlightEndpoint( + b'', + [flight.Location.for_grpc_tcp('localhost', 5005)], + ), + ], + -1, + -1, + ) + + +class CheckTicketFlightServer(flight.FlightServerBase): + """A Flight server that compares the given ticket to an expected value.""" + + def __init__(self, expected_ticket): + super(CheckTicketFlightServer, self).__init__() + self.expected_ticket = expected_ticket + + def do_get(self, context, ticket): + assert self.expected_ticket == ticket.ticket + data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())] + table = pa.Table.from_arrays(data1, names=['a']) + return flight.RecordBatchStream(table) + + def do_put(self, context, descriptor, reader): + self.last_message = reader.read_all() + + class InvalidStreamFlightServer(flight.FlightServerBase): """A Flight server that tries to return messages with differing schemas.""" @@ -255,6 +291,34 @@ def test_flight_do_get_dicts(): assert data.equals(table) +def test_flight_do_get_ticket(): + """Make sure Tickets get passed to the server.""" + data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())] + table = pa.Table.from_arrays(data1, names=['a']) + with flight_server( + CheckTicketFlightServer, + expected_ticket=b'the-ticket', + ) as server_location: + client = flight.FlightClient.connect(server_location) + data = client.do_get(flight.Ticket(b'the-ticket')).read_all() + assert data.equals(table) + + +def test_flight_get_info(): + """Make sure FlightEndpoint accepts string and object URIs.""" + with flight_server(GetInfoFlightServer) as server_location: + client = flight.FlightClient.connect(server_location) + info = client.get_flight_info(flight.FlightDescriptor.for_command(b'')) + assert info.total_records == -1 + assert info.total_bytes == -1 + assert info.schema == pa.schema([('a', pa.int32())]) + assert len(info.endpoints) == 2 + assert len(info.endpoints[0].locations) == 1 + assert info.endpoints[0].locations[0] == flight.Location('grpc://test') + assert info.endpoints[1].locations[0] == \ + flight.Location.for_grpc_tcp('localhost', 5005) + + def test_flight_domain_socket(): """Try a simple do_get call over a domain socket.""" data = [ From 870f6ebb8d07797a7a408249cc350f1baf7f14a1 Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 15 May 2019 08:35:13 -0400 Subject: [PATCH 5/5] Add more builder options for Java Flight servers --- .../arrow/flight/test-integration-client.cc | 1 - java/flight/pom.xml | 56 +++++++ .../org/apache/arrow/flight/FlightClient.java | 108 ++++++++++++- .../org/apache/arrow/flight/FlightServer.java | 144 +++++++++++++++--- .../org/apache/arrow/flight/Location.java | 66 +++++++- .../apache/arrow/flight/LocationSchemes.java | 6 +- .../apache/arrow/flight/FlightTestUtil.java | 22 +++ .../arrow/flight/TestBasicOperation.java | 2 +- .../arrow/flight/TestServerOptions.java | 63 ++++++++ .../apache/arrow/flight/auth/TestAuth.java | 2 +- python/pyarrow/_flight.pyx | 31 ++-- python/pyarrow/tests/test_flight.py | 4 +- 12 files changed, 455 insertions(+), 50 deletions(-) create mode 100644 java/flight/src/test/java/org/apache/arrow/flight/TestServerOptions.java diff --git a/cpp/src/arrow/flight/test-integration-client.cc b/cpp/src/arrow/flight/test-integration-client.cc index d0b734007c4..abaa3bc4221 100644 --- a/cpp/src/arrow/flight/test-integration-client.cc +++ b/cpp/src/arrow/flight/test-integration-client.cc @@ -102,7 +102,6 @@ int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); std::unique_ptr client; - std::stringstream uri; arrow::flight::Location location; ABORT_NOT_OK(arrow::flight::Location::ForGrpcTcp(FLAGS_host, FLAGS_port, &location)); ABORT_NOT_OK(arrow::flight::FlightClient::Connect(location, &client)); diff --git a/java/flight/pom.xml b/java/flight/pom.xml index c9c06b497d7..7d01a6e118e 100644 --- a/java/flight/pom.xml +++ b/java/flight/pom.xml @@ -73,6 +73,16 @@ io.netty netty-buffer + + io.netty + netty-handler + ${dep.netty.version} + + + io.netty + netty-transport + ${dep.netty.version} + com.google.guava guava @@ -302,4 +312,50 @@ + + + linux-netty-native + + + linux + + + + + io.netty + netty-transport-native-unix-common + ${dep.netty.version} + ${os.detected.name}-${os.detected.arch} + + + io.netty + netty-transport-native-epoll + ${dep.netty.version} + ${os.detected.name}-${os.detected.arch} + + + + + mac-netty-native + + + mac + + + + + io.netty + netty-transport-native-unix-common + ${dep.netty.version} + ${os.detected.name}-${os.detected.arch} + + + io.netty + netty-transport-native-kqueue + ${dep.netty.version} + ${os.detected.name}-${os.detected.arch} + + + + diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java index 3a6fbae87bb..e74c0eefb69 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java @@ -20,11 +20,14 @@ import static io.grpc.stub.ClientCalls.asyncClientStreamingCall; import static io.grpc.stub.ClientCalls.asyncServerStreamingCall; +import java.io.InputStream; import java.net.URISyntaxException; import java.util.Iterator; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; +import javax.net.ssl.SSLException; + import org.apache.arrow.flight.auth.BasicClientAuthHandler; import org.apache.arrow.flight.auth.ClientAuthHandler; import org.apache.arrow.flight.auth.ClientAuthInterceptor; @@ -50,13 +53,19 @@ import io.grpc.ClientCall; import io.grpc.ManagedChannel; import io.grpc.MethodDescriptor; +import io.grpc.netty.GrpcSslContexts; import io.grpc.netty.NettyChannelBuilder; import io.grpc.stub.ClientCallStreamObserver; import io.grpc.stub.ClientResponseObserver; import io.grpc.stub.StreamObserver; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.ServerChannel; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; + /** - * Client for flight servers. + * Client for Flight services. */ public class FlightClient implements AutoCloseable { private static final int PENDING_REQUESTS = 5; @@ -315,6 +324,13 @@ public void close() throws InterruptedException { allocator.close(); } + /** + * Create a builder for a Flight client. + */ + public static Builder builder() { + return new Builder(); + } + /** * Create a builder for a Flight client. * @param allocator The allocator to use for the client. @@ -329,13 +345,20 @@ public static Builder builder(BufferAllocator allocator, Location location) { */ public static final class Builder { - private final BufferAllocator allocator; - private final Location location; + private BufferAllocator allocator; + private Location location; private boolean forceTls = false; + private int maxInboundMessageSize = FlightServer.MAX_GRPC_MESSAGE_SIZE; + private InputStream trustedCertificates = null; + private InputStream clientCertificate = null; + private InputStream clientKey = null; + + private Builder() { + } private Builder(BufferAllocator allocator, Location location) { - this.allocator = allocator; - this.location = location; + this.allocator = Preconditions.checkNotNull(allocator); + this.location = Preconditions.checkNotNull(location); } /** @@ -346,6 +369,37 @@ public Builder useTls() { return this; } + /** Set the maximum inbound message size. */ + public Builder maxInboundMessageSize(int maxSize) { + Preconditions.checkArgument(maxSize > 0); + this.maxInboundMessageSize = maxSize; + return this; + } + + /** Set the trusted TLS certificates. */ + public Builder trustedCertificates(final InputStream stream) { + this.trustedCertificates = Preconditions.checkNotNull(stream); + return this; + } + + /** Set the trusted TLS certificates. */ + public Builder clientCertificate(final InputStream clientCertificate, final InputStream clientKey) { + Preconditions.checkNotNull(clientKey); + this.clientCertificate = Preconditions.checkNotNull(clientCertificate); + this.clientKey = Preconditions.checkNotNull(clientKey); + return this; + } + + public Builder allocator(BufferAllocator allocator) { + this.allocator = Preconditions.checkNotNull(allocator); + return this; + } + + public Builder location(Location location) { + this.location = Preconditions.checkNotNull(location); + return this; + } + /** * Create the client from this builder. */ @@ -356,7 +410,32 @@ public FlightClient build() { case LocationSchemes.GRPC: case LocationSchemes.GRPC_INSECURE: case LocationSchemes.GRPC_TLS: { - builder = NettyChannelBuilder.forAddress(location.getUri().getHost(), location.getUri().getPort()); + builder = NettyChannelBuilder.forAddress(location.toSocketAddress()); + break; + } + case LocationSchemes.GRPC_DOMAIN_SOCKET: { + // The implementation is platform-specific, so we have to find the classes at runtime + builder = NettyChannelBuilder.forAddress(location.toSocketAddress()); + try { + try { + // Linux + builder.channelType( + (Class) Class.forName("io.netty.channel.epoll.EpollDomainSocketChannel")); + final EventLoopGroup elg = (EventLoopGroup) Class.forName("io.netty.channel.epoll.EpollEventLoopGroup") + .newInstance(); + builder.eventLoopGroup(elg); + } catch (ClassNotFoundException e) { + // BSD + builder.channelType( + (Class) Class.forName("io.netty.channel.kqueue.KQueueDomainSocketChannel")); + final EventLoopGroup elg = (EventLoopGroup) Class.forName("io.netty.channel.kqueue.KQueueEventLoopGroup") + .newInstance(); + builder.eventLoopGroup(elg); + } + } catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) { + throw new UnsupportedOperationException( + "Could not find suitable Netty native transport implementation for domain socket address."); + } break; } default: @@ -365,13 +444,28 @@ public FlightClient build() { if (this.forceTls || LocationSchemes.GRPC_TLS.equals(location.getUri().getScheme())) { builder.useTransportSecurity(); + + if (this.trustedCertificates != null || this.clientCertificate != null || this.clientKey != null) { + final SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient(); + if (this.trustedCertificates != null) { + sslContextBuilder.trustManager(this.trustedCertificates); + } + if (this.clientCertificate != null && this.clientKey != null) { + sslContextBuilder.keyManager(this.clientCertificate, this.clientKey); + } + try { + builder.sslContext(sslContextBuilder.build()); + } catch (SSLException e) { + throw new RuntimeException(e); + } + } } else { builder.usePlaintext(); } builder .maxTraceEvents(MAX_CHANNEL_TRACE_EVENTS) - .maxInboundMessageSize(FlightServer.MAX_GRPC_MESSAGE_SIZE); + .maxInboundMessageSize(maxInboundMessageSize); return new FlightClient(allocator, builder.build()); } } diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java index 8c9ae1c2e4e..0812a04637d 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java @@ -19,22 +19,26 @@ import java.io.File; import java.io.FileInputStream; -import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStream; -import java.net.InetSocketAddress; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.Executor; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.TimeUnit; import org.apache.arrow.flight.auth.ServerAuthHandler; import org.apache.arrow.flight.auth.ServerAuthInterceptor; import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.util.Preconditions; import io.grpc.Server; import io.grpc.ServerInterceptors; import io.grpc.netty.NettyServerBuilder; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.ServerChannel; + public class FlightServer implements AutoCloseable { private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(FlightServer.class); @@ -88,45 +92,80 @@ public void close() throws InterruptedException { } } + /** Create a builder for a Flight server. */ + public static Builder builder() { + return new Builder(); + } + + /** Create a builder for a Flight server. */ public static Builder builder(BufferAllocator allocator, Location location, FlightProducer producer) { return new Builder(allocator, location, producer); } + /** A builder for Flight servers. */ public static final class Builder { - - private final BufferAllocator allocator; - private final Location location; - private final FlightProducer producer; + private BufferAllocator allocator; + private Location location; + private FlightProducer producer; + private final Map builderOptions; private ServerAuthHandler authHandler = ServerAuthHandler.NO_OP; - + private Executor executor = null; + private int maxInboundMessageSize = MAX_GRPC_MESSAGE_SIZE; private InputStream certChain; private InputStream key; + Builder() { + builderOptions = new HashMap<>(); + } + Builder(BufferAllocator allocator, Location location, FlightProducer producer) { - this.allocator = allocator; - this.location = location; - this.producer = producer; + this.allocator = Preconditions.checkNotNull(allocator); + this.location = Preconditions.checkNotNull(location); + this.producer = Preconditions.checkNotNull(producer); + builderOptions = new HashMap<>(); } + /** Create the server for this builder. */ public FlightServer build() { final NettyServerBuilder builder; switch (location.getUri().getScheme()) { case LocationSchemes.GRPC_DOMAIN_SOCKET: { - // TODO: need reflection to check if domain sockets are available - throw new UnsupportedOperationException("Domain sockets are not available."); + // The implementation is platform-specific, so we have to find the classes at runtime + builder = NettyServerBuilder.forAddress(location.toSocketAddress()); + try { + try { + // Linux + builder.channelType( + (Class) Class + .forName("io.netty.channel.epoll.EpollServerDomainSocketChannel")); + final EventLoopGroup elg = (EventLoopGroup) Class.forName("io.netty.channel.epoll.EpollEventLoopGroup") + .newInstance(); + builder.bossEventLoopGroup(elg).workerEventLoopGroup(elg); + } catch (ClassNotFoundException e) { + // BSD + builder.channelType( + (Class) Class + .forName("io.netty.channel.kqueue.KQueueServerDomainSocketChannel")); + final EventLoopGroup elg = (EventLoopGroup) Class.forName("io.netty.channel.kqueue.KQueueEventLoopGroup") + .newInstance(); + builder.bossEventLoopGroup(elg).workerEventLoopGroup(elg); + } + } catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) { + throw new UnsupportedOperationException( + "Could not find suitable Netty native transport implementation for domain socket address."); + } + break; } case LocationSchemes.GRPC: case LocationSchemes.GRPC_INSECURE: { - builder = NettyServerBuilder - .forAddress(new InetSocketAddress(location.getUri().getHost(), location.getUri().getPort())); + builder = NettyServerBuilder.forAddress(location.toSocketAddress()); break; } case LocationSchemes.GRPC_TLS: { if (certChain == null) { throw new IllegalArgumentException("Must provide a certificate and key to serve gRPC over TLS"); } - builder = NettyServerBuilder - .forAddress(new InetSocketAddress(location.getUri().getHost(), location.getUri().getPort())); + builder = NettyServerBuilder.forAddress(location.toSocketAddress()); break; } default: @@ -137,15 +176,33 @@ public FlightServer build() { builder.useTransportSecurity(certChain, key); } - final Server server = builder - .executor(new ForkJoinPool()) - .maxInboundMessageSize(MAX_GRPC_MESSAGE_SIZE) + builder + .executor(executor != null ? executor : new ForkJoinPool()) + .maxInboundMessageSize(maxInboundMessageSize) .addService( ServerInterceptors.intercept( new FlightBindingService(allocator, producer, authHandler), - new ServerAuthInterceptor(authHandler))) - .build(); - return new FlightServer(server); + new ServerAuthInterceptor(authHandler))); + + // Allow setting some Netty-specific options + builderOptions.computeIfPresent("netty.bossEventLoopGroup", (key, elg) -> { + builder.bossEventLoopGroup((EventLoopGroup) elg); + return null; + }); + builderOptions.computeIfPresent("netty.workerEventLoopGroup", (key, elg) -> { + builder.workerEventLoopGroup((EventLoopGroup) elg); + return null; + }); + + return new FlightServer(builder.build()); + } + + /** + * Set the maximum size of a message. Defaults to "unlimited", depending on the underlying transport. + */ + public Builder maxInboundMessageSize(int maxMessageSize) { + this.maxInboundMessageSize = maxMessageSize; + return this; } /** @@ -159,15 +216,54 @@ public Builder useTls(final File certChain, final File key) throws IOException { return this; } + /** + * Enable TLS on the server. + * @param certChain The certificate chain to use. + * @param key The private key to use. + */ public Builder useTls(final InputStream certChain, final InputStream key) { this.certChain = certChain; this.key = key; return this; } - public Builder setAuthHandler(ServerAuthHandler authHandler) { + /** + * Set the executor used by the server. + */ + public Builder executor(Executor executor) { + this.executor = executor; + return this; + } + + /** + * Set the authentication handler. + */ + public Builder authHandler(ServerAuthHandler authHandler) { this.authHandler = authHandler; return this; } + + /** + * Provide a transport-specific option. Not guaranteed to have any effect. + */ + public Builder transportHint(final String key, Object option) { + builderOptions.put(key, option); + return this; + } + + public Builder allocator(BufferAllocator allocator) { + this.allocator = Preconditions.checkNotNull(allocator); + return this; + } + + public Builder location(Location location) { + this.location = Preconditions.checkNotNull(location); + return this; + } + + public Builder producer(FlightProducer producer) { + this.producer = Preconditions.checkNotNull(producer); + return this; + } } } diff --git a/java/flight/src/main/java/org/apache/arrow/flight/Location.java b/java/flight/src/main/java/org/apache/arrow/flight/Location.java index 377edeba747..b4169fc9030 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/Location.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/Location.java @@ -17,6 +17,9 @@ package org.apache.arrow.flight; +import java.lang.reflect.InvocationTargetException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.net.URI; import java.net.URISyntaxException; @@ -30,21 +33,69 @@ public class Location { * Constructs a new instance. * * @param uri the URI of the Flight service + * @throws IllegalArgumentException if the URI scheme is unsupported */ public Location(String uri) throws URISyntaxException { - super(); - this.uri = new URI(uri); + this(new URI(uri)); } + /** + * Construct a new instance from an existing URI. + * + * @param uri the URI of the Flight service + * @throws IllegalArgumentException if the URI scheme is unsupported + */ public Location(URI uri) { super(); this.uri = uri; + // Validate the scheme + switch (uri.getScheme()) { + case LocationSchemes.GRPC: + case LocationSchemes.GRPC_DOMAIN_SOCKET: + case LocationSchemes.GRPC_INSECURE: + case LocationSchemes.GRPC_TLS: { + break; + } + default: + throw new IllegalArgumentException("Scheme is not supported: " + this.uri); + } } public URI getUri() { return uri; } + /** + * Helper method to turn this Location into a SocketAddress. + * + * @return null if could not be converted + */ + SocketAddress toSocketAddress() { + switch (uri.getScheme()) { + case LocationSchemes.GRPC: + case LocationSchemes.GRPC_TLS: + case LocationSchemes.GRPC_INSECURE: { + return new InetSocketAddress(uri.getHost(), uri.getPort()); + } + + case LocationSchemes.GRPC_DOMAIN_SOCKET: { + try { + // This dependency is not available on non-Unix platforms. + return (SocketAddress) Class.forName("io.netty.channel.unix.DomainSocketAddress") + .getConstructor(String.class) + .newInstance(uri.getPath()); + } catch (InstantiationException | ClassNotFoundException | InvocationTargetException | + NoSuchMethodException | IllegalAccessException e) { + return null; + } + } + + default: { + return null; + } + } + } + /** * Convert this Location into its protocol-level representation. */ @@ -69,4 +120,15 @@ public static Location forGrpcTls(String host, int port) { throw new IllegalArgumentException(e); } } + + /** + * Construct a URI for a Flight+gRPC server over a Unix domain socket. + */ + public static Location forGrpcDomainSocket(String path) { + try { + return new Location(new URI(LocationSchemes.GRPC_DOMAIN_SOCKET, null, path, null)); + } catch (URISyntaxException e) { + throw new IllegalArgumentException(e); + } + } } diff --git a/java/flight/src/main/java/org/apache/arrow/flight/LocationSchemes.java b/java/flight/src/main/java/org/apache/arrow/flight/LocationSchemes.java index b652fed48a7..872e5b1c22d 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/LocationSchemes.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/LocationSchemes.java @@ -20,9 +20,13 @@ /** * Constants representing well-known URI schemes for Flight services. */ -public class LocationSchemes { +public final class LocationSchemes { public static final String GRPC = "grpc"; public static final String GRPC_INSECURE = "grpc+tcp"; public static final String GRPC_DOMAIN_SOCKET = "grpc+unix"; public static final String GRPC_TLS = "grpc+tls"; + + private LocationSchemes() { + throw new AssertionError("Do not instantiate this class."); + } } diff --git a/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java b/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java index 11b476ccceb..f6b9e867807 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java @@ -62,6 +62,28 @@ public static T getStartedServer(Function newServerFromPort) thr return server; } + static boolean isEpollAvailable() { + try { + Class epoll = Class.forName("io.netty.channel.epoll.Epoll"); + return (Boolean) epoll.getMethod("isAvailable").invoke(null); + } catch (ClassNotFoundException | NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { + return false; + } + } + + static boolean isKqueueAvailable() { + try { + Class kqueue = Class.forName("io.netty.channel.kqueue.KQueue"); + return (Boolean) kqueue.getMethod("isAvailable").invoke(null); + } catch (ClassNotFoundException | NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { + return false; + } + } + + static boolean isNativeTransportAvailable() { + return isEpollAvailable() || isKqueueAvailable(); + } + private FlightTestUtil() { } } diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java b/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java index ab31eb8df92..dfb77ada666 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java @@ -157,7 +157,7 @@ private void test(BiConsumer consumer) throws Exc /** * An example FlightProducer for test purposes. */ - public class Producer implements FlightProducer, AutoCloseable { + public static class Producer implements FlightProducer, AutoCloseable { private final BufferAllocator allocator; diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestServerOptions.java b/java/flight/src/test/java/org/apache/arrow/flight/TestServerOptions.java new file mode 100644 index 00000000000..e3ac3908941 --- /dev/null +++ b/java/flight/src/test/java/org/apache/arrow/flight/TestServerOptions.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.io.File; + +import org.apache.arrow.flight.TestBasicOperation.Producer; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class TestServerOptions { + + @Test + public void domainSocket() throws Exception { + Assume.assumeTrue("We have a native transport available", FlightTestUtil.isNativeTransportAvailable()); + final File domainSocket = File.createTempFile("flight-unit-test-", ".sock"); + Assert.assertTrue(domainSocket.delete()); + final Location location = Location.forGrpcDomainSocket(domainSocket.getAbsolutePath()); + try ( + BufferAllocator a = new RootAllocator(Long.MAX_VALUE); + Producer producer = new Producer(a); + FlightServer s = + FlightTestUtil.getStartedServer( + (port) -> FlightServer.builder(a, location, producer).build() + )) { + try (FlightClient c = FlightClient.builder(a, location).build()) { + FlightStream stream = c.getStream(new Ticket(new byte[0])); + VectorSchemaRoot root = stream.getRoot(); + IntVector iv = (IntVector) root.getVector("c1"); + int value = 0; + while (stream.next()) { + for (int i = 0; i < root.getRowCount(); i++) { + Assert.assertEquals(value, iv.get(i)); + value++; + } + } + } + } + } +} diff --git a/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java b/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java index df16a11938a..39b2924c620 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java @@ -149,7 +149,7 @@ public void getStream(CallContext context, Ticket ticket, ServerStreamListener l root.clear(); listener.completed(); } - }).setAuthHandler(new BasicServerAuthHandler(validator)).build()); + }).authHandler(new BasicServerAuthHandler(validator)).build()); client = FlightClient.builder(allocator, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort())) .build(); } diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index 82745d34c8c..806796f37f1 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -400,16 +400,26 @@ cdef class FlightClient: .format(self.__class__.__name__)) @staticmethod - def connect(location, **kwargs): - """Connect to a Flight service on the given host and port.""" + def connect(location, tls_root_certs=None): + """ + Connect to a Flight service on the given host and port. + + Parameters + ---------- + location : Location + location to connect to + + tls_root_certs : bytes + PEM-encoded + """ cdef: FlightClient result = FlightClient.__new__(FlightClient) int c_port = 0 CLocation c_location = Location.unwrap(location) CFlightClientOptions c_options - if "tls_root_certs" in kwargs: - c_options.tls_root_certs = tobytes(kwargs["tls_root_certs"]) + if tls_root_certs: + c_options.tls_root_certs = tobytes(tls_root_certs) with nogil: check_status(CFlightClient.Connect(c_location, c_options, @@ -1004,7 +1014,8 @@ cdef class FlightServerBase: cdef: unique_ptr[PyFlightServer] server - def run(self, location, auth_handler=None, **kwargs): + def run(self, location, auth_handler=None, + tls_cert_chain=None, tls_private_key=None): cdef: PyFlightServerVtable vtable = PyFlightServerVtable() PyFlightServer* c_server @@ -1019,14 +1030,12 @@ cdef class FlightServerBase: c_options.get().auth_handler.reset( ( auth_handler).to_handler()) - if "tls_cert_chain" in kwargs: - if "tls_private_key" not in kwargs: + if tls_cert_chain: + if not tls_private_key: raise ValueError( "Must provide both cert chain and private key") - c_options.get().tls_cert_chain = tobytes( - kwargs["tls_cert_chain"]) - c_options.get().tls_private_key = tobytes( - kwargs["tls_private_key"]) + c_options.get().tls_cert_chain = tobytes(tls_cert_chain) + c_options.get().tls_private_key = tobytes(tls_private_key) vtable.list_flights = &_list_flights vtable.get_flight_info = &_get_flight_info diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py index bbe5507ffdd..6ab512565cd 100644 --- a/python/pyarrow/tests/test_flight.py +++ b/python/pyarrow/tests/test_flight.py @@ -465,9 +465,9 @@ def test_token_auth_invalid(): def test_location_invalid(): """Test constructing invalid URIs.""" - with pytest.raises(pa.ArrowException, match=".*Cannot parse URI:.*"): + with pytest.raises(pa.ArrowInvalid, match=".*Cannot parse URI:.*"): flight.FlightClient.connect("%") server = ConstantFlightServer() - with pytest.raises(pa.ArrowException, match=".*Cannot parse URI:.*"): + with pytest.raises(pa.ArrowInvalid, match=".*Cannot parse URI:.*"): server.run("%")