diff --git a/cpp/src/arrow/flight/flight_benchmark.cc b/cpp/src/arrow/flight/flight_benchmark.cc index 46c790a2f00..6649de52cd9 100644 --- a/cpp/src/arrow/flight/flight_benchmark.cc +++ b/cpp/src/arrow/flight/flight_benchmark.cc @@ -44,6 +44,8 @@ #endif DEFINE_bool(cuda, false, "Allocate results in CUDA memory"); +DEFINE_string(transport, "grpc", + "The network transport to use. Supported: \"grpc\" (default)."); DEFINE_string(server_host, "", "An existing performance server to benchmark against (leave blank to spawn " "one automatically)"); @@ -330,7 +332,8 @@ Status DoSinglePerfRun(FlightClient* client, const FlightClientOptions client_op // Check that number of rows read / written is as expected int64_t records_for_run = stats->total_records - start_total_records; if (records_for_run != static_cast(plan->total_records())) { - return Status::Invalid("Did not consume expected number of records"); + return Status::Invalid("Did not consume expected number of records, got: ", + records_for_run, " but expected: ", plan->total_records()); } } return Status::OK(); @@ -433,56 +436,63 @@ int main(int argc, char** argv) { std::unique_ptr server; std::vector server_args; + server_args.push_back("-transport"); + server_args.push_back(FLAGS_transport); arrow::flight::Location location; auto options = arrow::flight::FlightClientOptions::Defaults(); - if (FLAGS_test_unix || !FLAGS_server_unix.empty()) { - if (FLAGS_server_unix == "") { - FLAGS_server_unix = "/tmp/flight-bench-spawn.sock"; - std::cout << "Using spawned Unix server" << std::endl; - server.reset( - new arrow::flight::TestServer("arrow-flight-perf-server", FLAGS_server_unix)); + if (FLAGS_transport == "grpc") { + if (FLAGS_test_unix || !FLAGS_server_unix.empty()) { + if (FLAGS_server_unix == "") { + FLAGS_server_unix = "/tmp/flight-bench-spawn.sock"; + std::cout << "Using spawned Unix server" << std::endl; + server.reset( + new arrow::flight::TestServer("arrow-flight-perf-server", FLAGS_server_unix)); + } else { + std::cout << "Using standalone Unix server" << std::endl; + } + std::cout << "Server unix socket: " << FLAGS_server_unix << std::endl; + ABORT_NOT_OK(arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix, &location)); } else { - std::cout << "Using standalone Unix server" << std::endl; - } - std::cout << "Server unix socket: " << FLAGS_server_unix << std::endl; - ABORT_NOT_OK(arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix, &location)); - } else { - if (FLAGS_server_host == "") { - FLAGS_server_host = "localhost"; - std::cout << "Using spawned TCP server" << std::endl; - server.reset( - new arrow::flight::TestServer("arrow-flight-perf-server", FLAGS_server_port)); - if (!FLAGS_cert_file.empty() || !FLAGS_key_file.empty()) { - if (!FLAGS_cert_file.empty() && !FLAGS_key_file.empty()) { - std::cout << "Enabling TLS for spawned server" << std::endl; - server_args.push_back("-cert_file"); - server_args.push_back(FLAGS_cert_file); - server_args.push_back("-key_file"); - server_args.push_back(FLAGS_key_file); - } else { - std::cerr << "If providing TLS cert/key, must provide both" << std::endl; - return 1; + if (FLAGS_server_host == "") { + FLAGS_server_host = "localhost"; + std::cout << "Using spawned TCP server" << std::endl; + server.reset( + new arrow::flight::TestServer("arrow-flight-perf-server", FLAGS_server_port)); + if (!FLAGS_cert_file.empty() || !FLAGS_key_file.empty()) { + if (!FLAGS_cert_file.empty() && !FLAGS_key_file.empty()) { + std::cout << "Enabling TLS for spawned server" << std::endl; + server_args.push_back("-cert_file"); + server_args.push_back(FLAGS_cert_file); + server_args.push_back("-key_file"); + server_args.push_back(FLAGS_key_file); + } else { + std::cerr << "If providing TLS cert/key, must provide both" << std::endl; + return 1; + } } + } else { + std::cout << "Using standalone TCP server" << std::endl; } - } else { - std::cout << "Using standalone TCP server" << std::endl; - } - if (server) { - if (FLAGS_cuda && FLAGS_test_put) { - server_args.push_back("-cuda"); + if (server) { + if (FLAGS_cuda && FLAGS_test_put) { + server_args.push_back("-cuda"); + } + server->Start(server_args); + } + std::cout << "Server host: " << FLAGS_server_host << std::endl + << "Server port: " << FLAGS_server_port << std::endl; + if (FLAGS_cert_file.empty()) { + ABORT_NOT_OK(arrow::flight::Location::ForGrpcTcp(FLAGS_server_host, + FLAGS_server_port, &location)); + } else { + ABORT_NOT_OK(arrow::flight::Location::ForGrpcTls(FLAGS_server_host, + FLAGS_server_port, &location)); + options.disable_server_verification = true; } - server->Start(server_args); - } - std::cout << "Server host: " << FLAGS_server_host << std::endl - << "Server port: " << FLAGS_server_port << std::endl; - if (FLAGS_cert_file.empty()) { - ABORT_NOT_OK(arrow::flight::Location::ForGrpcTcp(FLAGS_server_host, - FLAGS_server_port, &location)); - } else { - ABORT_NOT_OK(arrow::flight::Location::ForGrpcTls(FLAGS_server_host, - FLAGS_server_port, &location)); - options.disable_server_verification = true; } + } else { + std::cerr << "Unknown transport: " << FLAGS_transport << std::endl; + return EXIT_FAILURE; } if (FLAGS_cuda) { diff --git a/cpp/src/arrow/flight/perf_server.cc b/cpp/src/arrow/flight/perf_server.cc index ae89f88dde1..ae2b2a485cb 100644 --- a/cpp/src/arrow/flight/perf_server.cc +++ b/cpp/src/arrow/flight/perf_server.cc @@ -45,6 +45,8 @@ #endif DEFINE_bool(cuda, false, "Allocate results in CUDA memory"); +DEFINE_string(transport, "grpc", + "The network transport to use. Supported: \"grpc\" (default)."); DEFINE_string(server_host, "localhost", "Host where the server is running on"); DEFINE_int32(port, 31337, "Server port to listen on"); DEFINE_string(server_unix, "", "Unix socket path where the server is running on"); @@ -191,7 +193,8 @@ class FlightPerfServer : public FlightServerBase { std::unique_ptr* data_stream) override { perf::Token token; CHECK_PARSE(token.ParseFromString(request.ticket)); - return GetPerfBatches(token, perf_schema_, false, data_stream); + // This must also be set in flight_benchmark.cc + return GetPerfBatches(token, perf_schema_, /*verify=*/false, data_stream); } Status DoPut(const ServerCallContext& context, @@ -241,28 +244,33 @@ int main(int argc, char** argv) { arrow::flight::Location bind_location; arrow::flight::Location connect_location; - if (FLAGS_server_unix.empty()) { - if (!FLAGS_cert_file.empty() || !FLAGS_key_file.empty()) { - if (!FLAGS_cert_file.empty() && !FLAGS_key_file.empty()) { + if (FLAGS_transport == "grpc") { + if (FLAGS_server_unix.empty()) { + if (!FLAGS_cert_file.empty() || !FLAGS_key_file.empty()) { + if (!FLAGS_cert_file.empty() && !FLAGS_key_file.empty()) { + ARROW_CHECK_OK( + arrow::flight::Location::ForGrpcTls("0.0.0.0", FLAGS_port, &bind_location)); + ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTls( + FLAGS_server_host, FLAGS_port, &connect_location)); + } else { + std::cerr << "If providing TLS cert/key, must provide both" << std::endl; + return EXIT_FAILURE; + } + } else { ARROW_CHECK_OK( - arrow::flight::Location::ForGrpcTls("0.0.0.0", FLAGS_port, &bind_location)); - ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTls(FLAGS_server_host, FLAGS_port, + arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &bind_location)); + ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp(FLAGS_server_host, FLAGS_port, &connect_location)); - } else { - std::cerr << "If providing TLS cert/key, must provide both" << std::endl; - return 1; } } else { ARROW_CHECK_OK( - arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &bind_location)); - ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp(FLAGS_server_host, FLAGS_port, - &connect_location)); + arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix, &bind_location)); + ARROW_CHECK_OK( + arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix, &connect_location)); } } else { - ARROW_CHECK_OK( - arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix, &bind_location)); - ARROW_CHECK_OK( - arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix, &connect_location)); + std::cerr << "Unknown transport: " << FLAGS_transport << std::endl; + return EXIT_FAILURE; } arrow::flight::FlightServerOptions options(bind_location); if (!FLAGS_cert_file.empty() && !FLAGS_key_file.empty()) { @@ -286,13 +294,14 @@ int main(int argc, char** argv) { options.memory_manager = device->default_memory_manager(); #else std::cerr << "-cuda requires that Arrow is built with ARROW_CUDA" << std::endl; - return 1; + return EXIT_FAILURE; #endif } ARROW_CHECK_OK(g_server->Init(options)); // Exit with a clean error code (0) on SIGTERM ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM})); + std::cout << "Server transport: " << FLAGS_transport << std::endl; if (FLAGS_server_unix.empty()) { std::cout << "Server host: " << FLAGS_server_host << std::endl; std::cout << "Server port: " << FLAGS_port << std::endl; @@ -301,5 +310,5 @@ int main(int argc, char** argv) { } g_server->SetLocation(connect_location); ARROW_CHECK_OK(g_server->Serve()); - return 0; + return EXIT_SUCCESS; }