Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cpp/cmake_modules/DefineOptions.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,10 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_CURRENT_SOURCE_DIR}")
define_option(ARROW_WITH_ZLIB "Build with zlib compression" OFF)
define_option(ARROW_WITH_ZSTD "Build with zstd compression" OFF)

define_option(ARROW_WITH_UCX
"Build with UCX transport for Arrow Flight;(only used if ARROW_FLIGHT is ON)"
OFF)

define_option(ARROW_WITH_UTF8PROC
"Build with support for Unicode properties using the utf8proc library;(only used if ARROW_COMPUTE is ON or ARROW_GANDIVA is ON)"
ON)
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,10 @@ endif()

if(ARROW_FLIGHT)
add_subdirectory(flight)

if(ARROW_WITH_UCX)
add_subdirectory(flight/transport/ucx)
endif()
endif()

if(ARROW_FLIGHT_SQL)
Expand Down
10 changes: 10 additions & 0 deletions cpp/src/arrow/flight/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -313,4 +313,14 @@ if(ARROW_BUILD_BENCHMARKS)
add_dependencies(arrow-flight-benchmark arrow-flight-perf-server)

add_dependencies(arrow_flight arrow-flight-benchmark)

if(ARROW_WITH_UCX)
if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static")
target_link_libraries(arrow-flight-benchmark arrow_flight_transport_ucx_static)
target_link_libraries(arrow-flight-perf-server arrow_flight_transport_ucx_static)
else()
target_link_libraries(arrow-flight-benchmark arrow_flight_transport_ucx_shared)
target_link_libraries(arrow-flight-perf-server arrow_flight_transport_ucx_shared)
endif()
endif()
endif(ARROW_BUILD_BENCHMARKS)
35 changes: 34 additions & 1 deletion cpp/src/arrow/flight/flight_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,20 @@
#include "arrow/flight/test_util.h"

#ifdef ARROW_CUDA
#include <cuda.h>
#include "arrow/gpu/cuda_api.h"
#endif
#ifdef ARROW_WITH_UCX
#include "arrow/flight/transport/ucx/ucx.h"
#endif

DEFINE_bool(cuda, false, "Allocate results in CUDA memory");
DEFINE_string(transport, "grpc",
"The network transport to use. Supported: \"grpc\" (default).");
"The network transport to use. Supported: \"grpc\" (default)"
#ifdef ARROW_WITH_UCX
", \"ucx\""
#endif // ARROW_WITH_UCX
".");
DEFINE_string(server_host, "",
"An existing performance server to benchmark against (leave blank to spawn "
"one automatically)");
Expand Down Expand Up @@ -497,6 +505,21 @@ int main(int argc, char** argv) {
options.disable_server_verification = true;
}
}
} else if (FLAGS_transport == "ucx") {
#ifdef ARROW_WITH_UCX
arrow::flight::transport::ucx::InitializeFlightUcx();
if (FLAGS_test_unix || !FLAGS_server_unix.empty()) {
std::cerr << "Transport does not support domain sockets: " << FLAGS_transport
<< std::endl;
return EXIT_FAILURE;
}
ARROW_CHECK_OK(arrow::flight::Location::Parse("ucx://" + FLAGS_server_host + ":" +
std::to_string(FLAGS_server_port))
.Value(&location));
#else
std::cerr << "Not built with transport: " << FLAGS_transport << std::endl;
return EXIT_FAILURE;
#endif
} else {
std::cerr << "Unknown transport: " << FLAGS_transport << std::endl;
return EXIT_FAILURE;
Expand All @@ -514,6 +537,16 @@ int main(int argc, char** argv) {
ABORT_NOT_OK(arrow::cuda::CudaDeviceManager::Instance().Value(&manager));
ABORT_NOT_OK(manager->GetDevice(0).Value(&device));
call_options.memory_manager = device->default_memory_manager();

// Needed to prevent UCX warning
// cuda_md.c:162 UCX ERROR cuMemGetAddressRange(0x7f2ab5dc0000) error: invalid
// device context
std::shared_ptr<arrow::cuda::CudaContext> context;
ABORT_NOT_OK(device->GetContext().Value(&context));
auto cuda_status = cuCtxPushCurrent(reinterpret_cast<CUcontext>(context->handle()));
if (cuda_status != CUDA_SUCCESS) {
ARROW_LOG(WARNING) << "CUDA error " << cuda_status;
}
#else
std::cerr << "-cuda requires that Arrow is built with ARROW_CUDA" << std::endl;
return 1;
Expand Down
36 changes: 34 additions & 2 deletions cpp/src/arrow/flight/perf_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <signal.h>
#include <cstdint>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <memory>
Expand All @@ -43,10 +44,17 @@
#ifdef ARROW_CUDA
#include "arrow/gpu/cuda_api.h"
#endif
#ifdef ARROW_WITH_UCX
#include "arrow/flight/transport/ucx/ucx.h"
#endif

DEFINE_bool(cuda, false, "Allocate results in CUDA memory");
DEFINE_string(transport, "grpc",
"The network transport to use. Supported: \"grpc\" (default).");
"The network transport to use. Supported: \"grpc\" (default)"
#ifdef ARROW_WITH_UCX
", \"ucx\""
#endif // ARROW_WITH_UCX
".");
DEFINE_string(server_host, "localhost", "Host where the server is running on");
DEFINE_int32(port, 31337, "Server port to listen on");
DEFINE_string(server_unix, "", "Unix socket path where the server is running on");
Expand Down Expand Up @@ -97,7 +105,7 @@ class PerfDataStream : public FlightDataStream {
if (records_sent_ >= total_records_) {
// Signal that iteration is over
payload.ipc_message.metadata = nullptr;
return Status::OK();
return payload;
}

if (verify_) {
Expand Down Expand Up @@ -274,6 +282,29 @@ int main(int argc, char** argv) {
ARROW_CHECK_OK(arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix)
.Value(&connect_location));
}
} else if (FLAGS_transport == "ucx") {
#ifdef ARROW_WITH_UCX
arrow::flight::transport::ucx::InitializeFlightUcx();
if (FLAGS_server_unix.empty()) {
if (!FLAGS_cert_file.empty() || !FLAGS_key_file.empty()) {
std::cerr << "Transport does not support TLS: " << FLAGS_transport << std::endl;
return EXIT_FAILURE;
}
ARROW_CHECK_OK(arrow::flight::Location::Parse("ucx://" + FLAGS_server_host + ":" +
std::to_string(FLAGS_port))
.Value(&bind_location));
ARROW_CHECK_OK(arrow::flight::Location::Parse("ucx://" + FLAGS_server_host + ":" +
std::to_string(FLAGS_port))
.Value(&connect_location));
} else {
std::cerr << "Transport does not support domain sockets: " << FLAGS_transport
<< std::endl;
return EXIT_FAILURE;
}
#else
std::cerr << "Not built with transport: " << FLAGS_transport << std::endl;
return EXIT_FAILURE;
#endif
} else {
std::cerr << "Unknown transport: " << FLAGS_transport << std::endl;
return EXIT_FAILURE;
Expand Down Expand Up @@ -308,6 +339,7 @@ int main(int argc, char** argv) {
// Exit with a clean error code (0) on SIGTERM
ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM}));
std::cout << "Server transport: " << FLAGS_transport << std::endl;
std::cout << "Server location: " << connect_location.ToString() << std::endl;
if (FLAGS_server_unix.empty()) {
std::cout << "Server host: " << FLAGS_server_host << std::endl;
std::cout << "Server port: " << FLAGS_port << std::endl;
Expand Down
26 changes: 13 additions & 13 deletions cpp/src/arrow/flight/test_definitions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ using arrow::internal::checked_cast;
void ConnectivityTest::TestGetPort() {
std::unique_ptr<FlightServerBase> server = ExampleTestServer();

ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
FlightServerOptions options(location);
ASSERT_OK(server->Init(options));
ASSERT_GT(server->port(), 0);
}
void ConnectivityTest::TestBuilderHook() {
std::unique_ptr<FlightServerBase> server = ExampleTestServer();

ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
FlightServerOptions options(location);
bool builder_hook_run = false;
options.builder_hook = [&builder_hook_run](void* builder) {
Expand All @@ -68,7 +68,7 @@ void ConnectivityTest::TestBuilderHook() {
void ConnectivityTest::TestShutdown() {
// Regression test for ARROW-15181
constexpr int kIterations = 10;
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
for (int i = 0; i < kIterations; i++) {
std::unique_ptr<FlightServerBase> server = ExampleTestServer();

Expand All @@ -84,7 +84,7 @@ void ConnectivityTest::TestShutdown() {
void ConnectivityTest::TestShutdownWithDeadline() {
std::unique_ptr<FlightServerBase> server = ExampleTestServer();

ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
FlightServerOptions options(location);
ASSERT_OK(server->Init(options));
ASSERT_GT(server->port(), 0);
Expand All @@ -96,13 +96,13 @@ void ConnectivityTest::TestShutdownWithDeadline() {
}
void ConnectivityTest::TestBrokenConnection() {
std::unique_ptr<FlightServerBase> server = ExampleTestServer();
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
FlightServerOptions options(location);
ASSERT_OK(server->Init(options));

std::unique_ptr<FlightClient> client;
ASSERT_OK_AND_ASSIGN(location,
Location::ForScheme(transport(), "localhost", server->port()));
Location::ForScheme(transport(), "127.0.0.1", server->port()));
ASSERT_OK_AND_ASSIGN(client, FlightClient::Connect(location));

ASSERT_OK(server->Shutdown());
Expand All @@ -117,7 +117,7 @@ void ConnectivityTest::TestBrokenConnection() {
void DataTest::SetUp() {
server_ = ExampleTestServer();

ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
FlightServerOptions options(location);
ASSERT_OK(server_->Init(options));

Expand All @@ -129,7 +129,7 @@ void DataTest::TearDown() {
}
Status DataTest::ConnectClient() {
ARROW_ASSIGN_OR_RAISE(auto location,
Location::ForScheme(transport(), "localhost", server_->port()));
Location::ForScheme(transport(), "127.0.0.1", server_->port()));
ARROW_ASSIGN_OR_RAISE(client_, FlightClient::Connect(location));
return Status::OK();
}
Expand Down Expand Up @@ -638,7 +638,7 @@ class DoPutTestServer : public FlightServerBase {
};

void DoPutTest::SetUp() {
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
ASSERT_OK(MakeServer<DoPutTestServer>(
location, &server_, &client_,
[](FlightServerOptions* options) { return Status::OK(); },
Expand Down Expand Up @@ -766,7 +766,7 @@ void DoPutTest::TestLargeBatch() {
void DoPutTest::TestSizeLimit() {
const int64_t size_limit = 4096;
ASSERT_OK_AND_ASSIGN(auto location,
Location::ForScheme(transport(), "localhost", server_->port()));
Location::ForScheme(transport(), "127.0.0.1", server_->port()));
auto client_options = FlightClientOptions::Defaults();
client_options.write_size_limit_bytes = size_limit;
ASSERT_OK_AND_ASSIGN(auto client, FlightClient::Connect(location, client_options));
Expand Down Expand Up @@ -866,7 +866,7 @@ Status AppMetadataTestServer::DoPut(const ServerCallContext& context,
}

void AppMetadataTest::SetUp() {
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
ASSERT_OK(MakeServer<AppMetadataTestServer>(
location, &server_, &client_,
[](FlightServerOptions* options) { return Status::OK(); },
Expand Down Expand Up @@ -1045,7 +1045,7 @@ class IpcOptionsTestServer : public FlightServerBase {
};

void IpcOptionsTest::SetUp() {
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
ASSERT_OK(MakeServer<IpcOptionsTestServer>(
location, &server_, &client_,
[](FlightServerOptions* options) { return Status::OK(); },
Expand Down Expand Up @@ -1241,7 +1241,7 @@ void CudaDataTest::SetUp() {
impl_->device = std::move(device);
impl_->context = std::move(context);

ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
ASSERT_OK(MakeServer<CudaTestServer>(
location, &server_, &client_,
[this](FlightServerOptions* options) {
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/flight/test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ Status MakeServer(const Location& location, std::unique_ptr<FlightServerBase>* s
RETURN_NOT_OK(make_server_options(&server_options));
RETURN_NOT_OK((*server)->Init(server_options));
std::string uri =
location.scheme() + "://localhost:" + std::to_string((*server)->port());
location.scheme() + "://127.0.0.1:" + std::to_string((*server)->port());
ARROW_ASSIGN_OR_RAISE(auto real_location, Location::Parse(uri));
FlightClientOptions client_options = FlightClientOptions::Defaults();
RETURN_NOT_OK(make_client_options(&client_options));
Expand Down
77 changes: 77 additions & 0 deletions cpp/src/arrow/flight/transport/ucx/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

add_custom_target(arrow_flight_transport_ucx)
arrow_install_all_headers("arrow/flight/transport/ucx")

find_package(PkgConfig REQUIRED)
pkg_check_modules(UCX REQUIRED IMPORTED_TARGET ucx)

set(ARROW_FLIGHT_TRANSPORT_UCX_SRCS
ucx_client.cc
ucx_server.cc
ucx.cc
ucx_internal.cc
util_internal.cc)
set(ARROW_FLIGHT_TRANSPORT_UCX_LINK_LIBS)

include_directories(SYSTEM ${UCX_INCLUDE_DIRS})
list(APPEND ARROW_FLIGHT_TRANSPORT_UCX_LINK_LIBS PkgConfig::UCX)

add_arrow_lib(arrow_flight_transport_ucx
# CMAKE_PACKAGE_NAME
# ArrowFlightTransportUcx
# PKG_CONFIG_NAME
# arrow-flight-transport-ucx
SOURCES
${ARROW_FLIGHT_TRANSPORT_UCX_SRCS}
PRECOMPILED_HEADERS
"$<$<COMPILE_LANGUAGE:CXX>:arrow/flight/transport/ucx/pch.h>"
DEPENDENCIES
SHARED_LINK_FLAGS
${ARROW_VERSION_SCRIPT_FLAGS} # Defined in cpp/arrow/CMakeLists.txt
SHARED_LINK_LIBS
arrow_shared
arrow_flight_shared
${ARROW_FLIGHT_TRANSPORT_UCX_LINK_LIBS}
STATIC_LINK_LIBS
arrow_static
arrow_flight_static
${ARROW_FLIGHT_TRANSPORT_UCX_LINK_LIBS})

if(ARROW_BUILD_TESTS)
if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static")
set(ARROW_FLIGHT_UCX_TEST_LINK_LIBS
arrow_static
arrow_flight_static
arrow_flight_testing_static
arrow_flight_transport_ucx_static
${ARROW_TEST_LINK_LIBS})
else()
set(ARROW_FLIGHT_UCX_TEST_LINK_LIBS
arrow_shared
arrow_flight_shared
arrow_flight_testing_shared
arrow_flight_transport_ucx_shared
${ARROW_TEST_LINK_LIBS})
endif()
add_arrow_test(flight_transport_ucx_test
STATIC_LINK_LIBS
${ARROW_FLIGHT_UCX_TEST_LINK_LIBS}
LABELS
"arrow_flight")
endif()
Loading