Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions cpp/src/arrow/flight/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "arrow/status.h"
#include "arrow/type.h"
#include "arrow/util/logging.h"
#include "arrow/util/uri.h"

#include "arrow/flight/client_auth.h"
#include "arrow/flight/internal.h"
Expand Down Expand Up @@ -230,19 +231,37 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter {

class FlightClient::FlightClientImpl {
public:
Status Connect(const std::string& host, int port) {
// TODO(wesm): Support other kinds of GRPC ChannelCredentials
std::stringstream ss;
ss << host << ":" << port;
std::string uri = ss.str();
Status Connect(const Location& location, const FlightClientOptions& options) {
const std::string& scheme = location.scheme();

std::stringstream grpc_uri;
std::shared_ptr<grpc::ChannelCredentials> creds;
if (scheme == kSchemeGrpc || scheme == kSchemeGrpcTcp || scheme == kSchemeGrpcTls) {
grpc_uri << location.uri_->host() << ":" << location.uri_->port_text();

if (scheme == "grpc+tls") {
grpc::SslCredentialsOptions ssl_options;
if (!options.tls_root_certs.empty()) {
ssl_options.pem_root_certs = options.tls_root_certs;
}
creds = grpc::SslCredentials(ssl_options);
} else {
creds = grpc::InsecureChannelCredentials();
}
} else if (scheme == kSchemeGrpcUnix) {
grpc_uri << "unix://" << location.uri_->path();
creds = grpc::InsecureChannelCredentials();
} else {
return Status::NotImplemented("Flight scheme " + scheme + " is not supported.");
}

grpc::ChannelArguments args;
// Try to reconnect quickly at first, in case the server is still starting up
args.SetInt(GRPC_ARG_INITIAL_RECONNECT_BACKOFF_MS, 100);
// Receive messages of any size
args.SetMaxReceiveMessageSize(-1);
stub_ = pb::FlightService::NewStub(
grpc::CreateCustomChannel(ss.str(), grpc::InsecureChannelCredentials(), args));
grpc::CreateCustomChannel(grpc_uri.str(), creds, args));
return Status::OK();
}

Expand Down Expand Up @@ -383,10 +402,15 @@ FlightClient::FlightClient() { impl_.reset(new FlightClientImpl); }

FlightClient::~FlightClient() {}

Status FlightClient::Connect(const std::string& host, int port,
Status FlightClient::Connect(const Location& location,
std::unique_ptr<FlightClient>* client) {
return Connect(location, {}, client);
}

Status FlightClient::Connect(const Location& location, const FlightClientOptions& options,
std::unique_ptr<FlightClient>* client) {
client->reset(new FlightClient);
return (*client)->impl_->Connect(host, port);
return (*client)->impl_->Connect(location, options);
}

Status FlightClient::Authenticate(const FlightCallOptions& options,
Expand Down
18 changes: 15 additions & 3 deletions cpp/src/arrow/flight/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,31 @@ class ARROW_EXPORT FlightCallOptions {
TimeoutDuration timeout;
};

class ARROW_EXPORT FlightClientOptions {
public:
std::string tls_root_certs;
};

/// \brief Client class for Arrow Flight RPC services (gRPC-based).
/// API experimental for now
class ARROW_EXPORT FlightClient {
public:
~FlightClient();

/// \brief Connect to an unauthenticated flight service
/// \param[in] host the hostname or IP address
/// \param[in] port the port on the host
/// \param[in] location the URI
/// \param[out] client the created FlightClient
/// \return Status OK status may not indicate that the connection was
/// successful
static Status Connect(const Location& location, std::unique_ptr<FlightClient>* client);

/// \brief Connect to an unauthenticated flight service
/// \param[in] location the URI
/// \param[in] options Other options for setting up the client
/// \param[out] client the created FlightClient
/// \return Status OK status may not indicate that the connection was
/// successful
static Status Connect(const std::string& host, int port,
static Status Connect(const Location& location, const FlightClientOptions& options,
std::unique_ptr<FlightClient>* client);

/// \brief Authenticate to the server using the given handler.
Expand Down
8 changes: 6 additions & 2 deletions cpp/src/arrow/flight/flight-benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ Status RunPerformanceTest(const std::string& hostname, const int port) {

// Construct client and plan the query
std::unique_ptr<FlightClient> client;
RETURN_NOT_OK(FlightClient::Connect(hostname, port, &client));
Location location;
RETURN_NOT_OK(Location::ForGrpcTcp("localhost", port, &location));
RETURN_NOT_OK(FlightClient::Connect(location, &client));

FlightDescriptor descriptor;
descriptor.type = FlightDescriptor::CMD;
Expand All @@ -97,7 +99,9 @@ Status RunPerformanceTest(const std::string& hostname, const int port) {
auto ConsumeStream = [&stats, &hostname, &port](const FlightEndpoint& endpoint) {
// TODO(wesm): Use location from endpoint, same host/port for now
std::unique_ptr<FlightClient> client;
RETURN_NOT_OK(FlightClient::Connect(hostname, port, &client));
Location location;
RETURN_NOT_OK(Location::ForGrpcTcp("localhost", port, &location));
RETURN_NOT_OK(FlightClient::Connect(location, &client));

perf::Token token;
token.ParseFromString(endpoint.ticket.ticket);
Expand Down
59 changes: 43 additions & 16 deletions cpp/src/arrow/flight/flight-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ void AssertEqual(const Ticket& expected, const Ticket& actual) {
}

void AssertEqual(const Location& expected, const Location& actual) {
ASSERT_EQ(expected.host, actual.host);
ASSERT_EQ(expected.port, actual.port);
ASSERT_EQ(expected, actual);
}

void AssertEqual(const std::vector<FlightEndpoint>& expected,
Expand Down Expand Up @@ -146,6 +145,20 @@ TEST(TestFlight, StartStopTestServer) {
ASSERT_FALSE(server.IsRunning());
}

TEST(TestFlight, ConnectUri) {
TestServer server("flight-test-server", 30000);
server.Start();
ASSERT_TRUE(server.IsRunning());

std::unique_ptr<FlightClient> client;
Location location1;
Location location2;
ASSERT_OK(Location::Parse("grpc://localhost:30000", &location1));
ASSERT_OK(Location::Parse("grpc://localhost:30000", &location2));
ASSERT_OK(FlightClient::Connect(location1, &client));
ASSERT_OK(FlightClient::Connect(location2, &client));
}

// ----------------------------------------------------------------------
// Client tests

Expand All @@ -169,7 +182,11 @@ class TestFlightClient : public ::testing::Test {

void TearDown() { server_->Stop(); }

Status ConnectClient() { return FlightClient::Connect("localhost", port_, &client_); }
Status ConnectClient() {
Location location;
RETURN_NOT_OK(Location::ForGrpcTcp("localhost", port_, &location));
return FlightClient::Connect(location, &client_);
}

template <typename EndpointCheckFunc>
void CheckDoGet(const FlightDescriptor& descr, const BatchVector& expected_batches,
Expand Down Expand Up @@ -238,38 +255,47 @@ class DoPutTestServer : public FlightServerBase {
class TestAuthHandler : public ::testing::Test {
public:
void SetUp() {
port_ = 30000;
server_.reset(new InProcessTestServer(
std::unique_ptr<FlightServerBase>(new AuthTestServer), port_));
ASSERT_OK(server_->Start(std::unique_ptr<ServerAuthHandler>(
new TestServerAuthHandler("user", "p4ssw0rd"))));
Location location;
std::unique_ptr<FlightServerBase> server(new AuthTestServer);

ASSERT_OK(Location::ForGrpcTcp("localhost", 30000, &location));
FlightServerOptions options(location);
options.auth_handler =
std::unique_ptr<ServerAuthHandler>(new TestServerAuthHandler("user", "p4ssw0rd"));
ASSERT_OK(server->Init(options));

server_.reset(new InProcessTestServer(std::move(server), location));
ASSERT_OK(server_->Start());
ASSERT_OK(ConnectClient());
}

void TearDown() { server_->Stop(); }

Status ConnectClient() { return FlightClient::Connect("localhost", port_, &client_); }
Status ConnectClient() { return FlightClient::Connect(server_->location(), &client_); }

protected:
int port_;
std::unique_ptr<FlightClient> client_;
std::unique_ptr<InProcessTestServer> server_;
};

class TestDoPut : public ::testing::Test {
public:
void SetUp() {
port_ = 30000;
Location location;
ASSERT_OK(Location::ForGrpcTcp("localhost", 30000, &location));

do_put_server_ = new DoPutTestServer();
server_.reset(new InProcessTestServer(
std::unique_ptr<FlightServerBase>(do_put_server_), port_));
ASSERT_OK(server_->Start({}));
std::unique_ptr<FlightServerBase>(do_put_server_), location));
FlightServerOptions options(location);
ASSERT_OK(do_put_server_->Init(options));
ASSERT_OK(server_->Start());
ASSERT_OK(ConnectClient());
}

void TearDown() { server_->Stop(); }

Status ConnectClient() { return FlightClient::Connect("localhost", port_, &client_); }
Status ConnectClient() { return FlightClient::Connect(server_->location(), &client_); }

void CheckBatches(FlightDescriptor expected_descriptor,
const BatchVector& expected_batches) {
Expand All @@ -293,7 +319,6 @@ class TestDoPut : public ::testing::Test {
}

protected:
int port_;
std::unique_ptr<FlightClient> client_;
std::unique_ptr<InProcessTestServer> server_;
DoPutTestServer* do_put_server_;
Expand Down Expand Up @@ -423,7 +448,9 @@ TEST_F(TestFlightClient, Issue5095) {
TEST_F(TestFlightClient, TimeoutFires) {
// Server does not exist on this port, so call should fail
std::unique_ptr<FlightClient> client;
ASSERT_OK(FlightClient::Connect("localhost", 30001, &client));
Location location;
ASSERT_OK(Location::ForGrpcTcp("localhost", 30001, &location));
ASSERT_OK(FlightClient::Connect(location, &client));
FlightCallOptions options;
options.timeout = TimeoutDuration{0.2};
std::unique_ptr<FlightInfo> info;
Expand Down
7 changes: 2 additions & 5 deletions cpp/src/arrow/flight/internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,11 @@ Status FromProto(const pb::Criteria& pb_criteria, Criteria* criteria) {
// Location

Status FromProto(const pb::Location& pb_location, Location* location) {
location->host = pb_location.host();
location->port = pb_location.port();
return Status::OK();
return Location::Parse(pb_location.uri(), location);
}

void ToProto(const Location& location, pb::Location* pb_location) {
pb_location->set_host(location.host);
pb_location->set_port(location.port);
pb_location->set_uri(location.ToString());
}

// Ticket
Expand Down
10 changes: 7 additions & 3 deletions cpp/src/arrow/flight/perf-server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ Status GetPerfBatches(const perf::Token& token, const std::shared_ptr<Schema>& s

class FlightPerfServer : public FlightServerBase {
public:
FlightPerfServer() : location_(Location{"localhost", FLAGS_port}) {
FlightPerfServer() : location_() {
DCHECK_OK(Location::ForGrpcTcp("localhost", FLAGS_port, &location_));
perf_schema_ = schema({field("a", int64()), field("b", int64()), field("c", int64()),
field("d", int64())});
}
Expand Down Expand Up @@ -204,8 +205,11 @@ int main(int argc, char** argv) {

g_server.reset(new arrow::flight::FlightPerfServer);

ARROW_CHECK_OK(
g_server->Init(std::unique_ptr<arrow::flight::NoOpAuthHandler>(), FLAGS_port));
arrow::flight::Location location;
ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location));
arrow::flight::FlightServerOptions options(location);

ARROW_CHECK_OK(g_server->Init(options));
// Exit with a clean error code (0) on SIGTERM
ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM}));
std::cout << "Server port: " << FLAGS_port << std::endl;
Expand Down
42 changes: 37 additions & 5 deletions cpp/src/arrow/flight/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <atomic>
#include <cstdint>
#include <memory>
#include <sstream>
#include <string>
#include <utility>

Expand All @@ -40,6 +41,7 @@
#include "arrow/status.h"
#include "arrow/util/logging.h"
#include "arrow/util/stl.h"
#include "arrow/util/uri.h"

#include "arrow/flight/internal.h"
#include "arrow/flight/serialization-internal.h"
Expand Down Expand Up @@ -410,7 +412,6 @@ class FlightServiceImpl : public FlightService::Service {
#endif

struct FlightServerBase::Impl {
std::string address_;
std::unique_ptr<FlightServiceImpl> service_;
std::unique_ptr<grpc::Server> server_;

Expand Down Expand Up @@ -438,21 +439,52 @@ void FlightServerBase::Impl::HandleSignal(int signum) {
}
}

FlightServerOptions::FlightServerOptions(const Location& location_)
: location(location_), auth_handler(nullptr) {}

FlightServerBase::FlightServerBase() { impl_.reset(new Impl); }

FlightServerBase::~FlightServerBase() {}

Status FlightServerBase::Init(std::unique_ptr<ServerAuthHandler> auth_handler, int port) {
impl_->address_ = "localhost:" + std::to_string(port);
std::shared_ptr<ServerAuthHandler> handler = std::move(auth_handler);
Status FlightServerBase::Init(FlightServerOptions& options) {
std::shared_ptr<ServerAuthHandler> handler = std::move(options.auth_handler);
impl_->service_.reset(new FlightServiceImpl(handler, this));

grpc::ServerBuilder builder;
// Allow uploading messages of any length
builder.SetMaxReceiveMessageSize(-1);
builder.AddListeningPort(impl_->address_, grpc::InsecureServerCredentials());

const Location& location = options.location;
const std::string scheme = location.scheme();
if (scheme == kSchemeGrpc || scheme == kSchemeGrpcTcp || scheme == kSchemeGrpcTls) {
std::stringstream address;
address << location.uri_->host() << ':' << location.uri_->port_text();

std::shared_ptr<grpc::ServerCredentials> creds;
if (scheme == kSchemeGrpcTls) {
grpc::SslServerCredentialsOptions ssl_options;
ssl_options.pem_key_cert_pairs.push_back(
{options.tls_private_key, options.tls_cert_chain});
creds = grpc::SslServerCredentials(ssl_options);
} else {
creds = grpc::InsecureServerCredentials();
}

builder.AddListeningPort(address.str(), creds);
} else if (scheme == kSchemeGrpcUnix) {
std::stringstream address;
address << "unix:" << location.uri_->path();
builder.AddListeningPort(address.str(), grpc::InsecureServerCredentials());
} else {
return Status::NotImplemented("Scheme is not supported: " + scheme);
}

builder.RegisterService(impl_->service_.get());

// Disable SO_REUSEPORT - it makes debugging/testing a pain as
// leftover processes can handle requests on accident
builder.AddChannelArgument(GRPC_ARG_ALLOW_REUSEPORT, 0);

impl_->server_ = builder.BuildAndStart();
if (!impl_->server_) {
return Status::UnknownError("Server did not start properly");
Expand Down
Loading