Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 54 additions & 44 deletions cpp/src/arrow/flight/flight_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
#endif

DEFINE_bool(cuda, false, "Allocate results in CUDA memory");
DEFINE_string(transport, "grpc",
"The network transport to use. Supported: \"grpc\" (default).");
DEFINE_string(server_host, "",
"An existing performance server to benchmark against (leave blank to spawn "
"one automatically)");
Expand Down Expand Up @@ -330,7 +332,8 @@ Status DoSinglePerfRun(FlightClient* client, const FlightClientOptions client_op
// Check that number of rows read / written is as expected
int64_t records_for_run = stats->total_records - start_total_records;
if (records_for_run != static_cast<int64_t>(plan->total_records())) {
return Status::Invalid("Did not consume expected number of records");
return Status::Invalid("Did not consume expected number of records, got: ",
records_for_run, " but expected: ", plan->total_records());
}
}
return Status::OK();
Expand Down Expand Up @@ -433,56 +436,63 @@ int main(int argc, char** argv) {

std::unique_ptr<arrow::flight::TestServer> server;
std::vector<std::string> server_args;
server_args.push_back("-transport");
server_args.push_back(FLAGS_transport);
arrow::flight::Location location;
auto options = arrow::flight::FlightClientOptions::Defaults();
if (FLAGS_test_unix || !FLAGS_server_unix.empty()) {
if (FLAGS_server_unix == "") {
FLAGS_server_unix = "/tmp/flight-bench-spawn.sock";
std::cout << "Using spawned Unix server" << std::endl;
server.reset(
new arrow::flight::TestServer("arrow-flight-perf-server", FLAGS_server_unix));
if (FLAGS_transport == "grpc") {
if (FLAGS_test_unix || !FLAGS_server_unix.empty()) {
if (FLAGS_server_unix == "") {
FLAGS_server_unix = "/tmp/flight-bench-spawn.sock";
std::cout << "Using spawned Unix server" << std::endl;
server.reset(
new arrow::flight::TestServer("arrow-flight-perf-server", FLAGS_server_unix));
} else {
std::cout << "Using standalone Unix server" << std::endl;
}
std::cout << "Server unix socket: " << FLAGS_server_unix << std::endl;
ABORT_NOT_OK(arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix, &location));
} else {
std::cout << "Using standalone Unix server" << std::endl;
}
std::cout << "Server unix socket: " << FLAGS_server_unix << std::endl;
ABORT_NOT_OK(arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix, &location));
} else {
if (FLAGS_server_host == "") {
FLAGS_server_host = "localhost";
std::cout << "Using spawned TCP server" << std::endl;
server.reset(
new arrow::flight::TestServer("arrow-flight-perf-server", FLAGS_server_port));
if (!FLAGS_cert_file.empty() || !FLAGS_key_file.empty()) {
if (!FLAGS_cert_file.empty() && !FLAGS_key_file.empty()) {
std::cout << "Enabling TLS for spawned server" << std::endl;
server_args.push_back("-cert_file");
server_args.push_back(FLAGS_cert_file);
server_args.push_back("-key_file");
server_args.push_back(FLAGS_key_file);
} else {
std::cerr << "If providing TLS cert/key, must provide both" << std::endl;
return 1;
if (FLAGS_server_host == "") {
FLAGS_server_host = "localhost";
std::cout << "Using spawned TCP server" << std::endl;
server.reset(
new arrow::flight::TestServer("arrow-flight-perf-server", FLAGS_server_port));
if (!FLAGS_cert_file.empty() || !FLAGS_key_file.empty()) {
if (!FLAGS_cert_file.empty() && !FLAGS_key_file.empty()) {
std::cout << "Enabling TLS for spawned server" << std::endl;
server_args.push_back("-cert_file");
server_args.push_back(FLAGS_cert_file);
server_args.push_back("-key_file");
server_args.push_back(FLAGS_key_file);
} else {
std::cerr << "If providing TLS cert/key, must provide both" << std::endl;
return 1;
}
}
} else {
std::cout << "Using standalone TCP server" << std::endl;
}
} else {
std::cout << "Using standalone TCP server" << std::endl;
}
if (server) {
if (FLAGS_cuda && FLAGS_test_put) {
server_args.push_back("-cuda");
if (server) {
if (FLAGS_cuda && FLAGS_test_put) {
server_args.push_back("-cuda");
}
server->Start(server_args);
}
std::cout << "Server host: " << FLAGS_server_host << std::endl
<< "Server port: " << FLAGS_server_port << std::endl;
if (FLAGS_cert_file.empty()) {
ABORT_NOT_OK(arrow::flight::Location::ForGrpcTcp(FLAGS_server_host,
FLAGS_server_port, &location));
} else {
ABORT_NOT_OK(arrow::flight::Location::ForGrpcTls(FLAGS_server_host,
FLAGS_server_port, &location));
options.disable_server_verification = true;
}
server->Start(server_args);
}
std::cout << "Server host: " << FLAGS_server_host << std::endl
<< "Server port: " << FLAGS_server_port << std::endl;
if (FLAGS_cert_file.empty()) {
ABORT_NOT_OK(arrow::flight::Location::ForGrpcTcp(FLAGS_server_host,
FLAGS_server_port, &location));
} else {
ABORT_NOT_OK(arrow::flight::Location::ForGrpcTls(FLAGS_server_host,
FLAGS_server_port, &location));
options.disable_server_verification = true;
}
} else {
std::cerr << "Unknown transport: " << FLAGS_transport << std::endl;
return EXIT_FAILURE;
}

if (FLAGS_cuda) {
Expand Down
45 changes: 27 additions & 18 deletions cpp/src/arrow/flight/perf_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
#endif

DEFINE_bool(cuda, false, "Allocate results in CUDA memory");
DEFINE_string(transport, "grpc",
"The network transport to use. Supported: \"grpc\" (default).");
DEFINE_string(server_host, "localhost", "Host where the server is running on");
DEFINE_int32(port, 31337, "Server port to listen on");
DEFINE_string(server_unix, "", "Unix socket path where the server is running on");
Expand Down Expand Up @@ -191,7 +193,8 @@ class FlightPerfServer : public FlightServerBase {
std::unique_ptr<FlightDataStream>* data_stream) override {
perf::Token token;
CHECK_PARSE(token.ParseFromString(request.ticket));
return GetPerfBatches(token, perf_schema_, false, data_stream);
// This must also be set in flight_benchmark.cc
return GetPerfBatches(token, perf_schema_, /*verify=*/false, data_stream);
}

Status DoPut(const ServerCallContext& context,
Expand Down Expand Up @@ -241,28 +244,33 @@ int main(int argc, char** argv) {

arrow::flight::Location bind_location;
arrow::flight::Location connect_location;
if (FLAGS_server_unix.empty()) {
if (!FLAGS_cert_file.empty() || !FLAGS_key_file.empty()) {
if (!FLAGS_cert_file.empty() && !FLAGS_key_file.empty()) {
if (FLAGS_transport == "grpc") {
if (FLAGS_server_unix.empty()) {
if (!FLAGS_cert_file.empty() || !FLAGS_key_file.empty()) {
if (!FLAGS_cert_file.empty() && !FLAGS_key_file.empty()) {
ARROW_CHECK_OK(
arrow::flight::Location::ForGrpcTls("0.0.0.0", FLAGS_port, &bind_location));
ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTls(
FLAGS_server_host, FLAGS_port, &connect_location));
} else {
std::cerr << "If providing TLS cert/key, must provide both" << std::endl;
return EXIT_FAILURE;
}
} else {
ARROW_CHECK_OK(
arrow::flight::Location::ForGrpcTls("0.0.0.0", FLAGS_port, &bind_location));
ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTls(FLAGS_server_host, FLAGS_port,
arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &bind_location));
ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp(FLAGS_server_host, FLAGS_port,
&connect_location));
} else {
std::cerr << "If providing TLS cert/key, must provide both" << std::endl;
return 1;
}
} else {
ARROW_CHECK_OK(
arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &bind_location));
ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp(FLAGS_server_host, FLAGS_port,
&connect_location));
arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix, &bind_location));
ARROW_CHECK_OK(
arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix, &connect_location));
}
} else {
ARROW_CHECK_OK(
arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix, &bind_location));
ARROW_CHECK_OK(
arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix, &connect_location));
std::cerr << "Unknown transport: " << FLAGS_transport << std::endl;
return EXIT_FAILURE;
}
arrow::flight::FlightServerOptions options(bind_location);
if (!FLAGS_cert_file.empty() && !FLAGS_key_file.empty()) {
Expand All @@ -286,13 +294,14 @@ int main(int argc, char** argv) {
options.memory_manager = device->default_memory_manager();
#else
std::cerr << "-cuda requires that Arrow is built with ARROW_CUDA" << std::endl;
return 1;
return EXIT_FAILURE;
#endif
}

ARROW_CHECK_OK(g_server->Init(options));
// Exit with a clean error code (0) on SIGTERM
ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM}));
std::cout << "Server transport: " << FLAGS_transport << std::endl;
if (FLAGS_server_unix.empty()) {
std::cout << "Server host: " << FLAGS_server_host << std::endl;
std::cout << "Server port: " << FLAGS_port << std::endl;
Expand All @@ -301,5 +310,5 @@ int main(int argc, char** argv) {
}
g_server->SetLocation(connect_location);
ARROW_CHECK_OK(g_server->Serve());
return 0;
return EXIT_SUCCESS;
}