From b40bb2a0fc8436e66c39d6956837c5edf2da77e4 Mon Sep 17 00:00:00 2001
From: David Li
Date: Mon, 20 Dec 2021 14:56:41 -0500
Subject: [PATCH 01/16] ARROW-15706: [C++][FlightRPC] Implement Flight UCX
transport
---
cpp/cmake_modules/DefineOptions.cmake | 4 +
cpp/src/arrow/CMakeLists.txt | 4 +
cpp/src/arrow/flight/CMakeLists.txt | 10 +
cpp/src/arrow/flight/flight_benchmark.cc | 35 +-
cpp/src/arrow/flight/perf_server.cc | 34 +-
.../arrow/flight/transport/ucx/CMakeLists.txt | 77 ++
.../ucx/flight_transport_ucx_test.cc | 399 ++++++
cpp/src/arrow/flight/transport/ucx/ucx.cc | 43 +
cpp/src/arrow/flight/transport/ucx/ucx.h | 35 +
.../arrow/flight/transport/ucx/ucx_client.cc | 714 ++++++++++
.../flight/transport/ucx/ucx_internal.cc | 1148 +++++++++++++++++
.../arrow/flight/transport/ucx/ucx_internal.h | 352 +++++
.../arrow/flight/transport/ucx/ucx_server.cc | 660 ++++++++++
.../flight/transport/ucx/util_internal.cc | 212 +++
.../flight/transport/ucx/util_internal.h | 58 +
cpp/src/arrow/flight/transport_server.cc | 5 +-
cpp/src/arrow/util/config.h.cmake | 1 +
docs/source/cpp/flight.rst | 35 +
docs/source/status.rst | 78 +-
19 files changed, 3883 insertions(+), 21 deletions(-)
create mode 100644 cpp/src/arrow/flight/transport/ucx/CMakeLists.txt
create mode 100644 cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc
create mode 100644 cpp/src/arrow/flight/transport/ucx/ucx.cc
create mode 100644 cpp/src/arrow/flight/transport/ucx/ucx.h
create mode 100644 cpp/src/arrow/flight/transport/ucx/ucx_client.cc
create mode 100644 cpp/src/arrow/flight/transport/ucx/ucx_internal.cc
create mode 100644 cpp/src/arrow/flight/transport/ucx/ucx_internal.h
create mode 100644 cpp/src/arrow/flight/transport/ucx/ucx_server.cc
create mode 100644 cpp/src/arrow/flight/transport/ucx/util_internal.cc
create mode 100644 cpp/src/arrow/flight/transport/ucx/util_internal.h
diff --git a/cpp/cmake_modules/DefineOptions.cmake b/cpp/cmake_modules/DefineOptions.cmake
index 05fc14bbc72..ec1e0b6352a 100644
--- a/cpp/cmake_modules/DefineOptions.cmake
+++ b/cpp/cmake_modules/DefineOptions.cmake
@@ -391,6 +391,10 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_CURRENT_SOURCE_DIR}")
define_option(ARROW_WITH_ZLIB "Build with zlib compression" OFF)
define_option(ARROW_WITH_ZSTD "Build with zstd compression" OFF)
+ define_option(ARROW_WITH_UCX
+ "Build with UCX transport for Arrow Flight;(only used if ARROW_FLIGHT is ON)"
+ OFF)
+
define_option(ARROW_WITH_UTF8PROC
"Build with support for Unicode properties using the utf8proc library;(only used if ARROW_COMPUTE is ON or ARROW_GANDIVA is ON)"
ON)
diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt
index e9e826097b3..b6f1e2481fa 100644
--- a/cpp/src/arrow/CMakeLists.txt
+++ b/cpp/src/arrow/CMakeLists.txt
@@ -747,6 +747,10 @@ endif()
if(ARROW_FLIGHT)
add_subdirectory(flight)
+
+ if(ARROW_WITH_UCX)
+ add_subdirectory(flight/transport/ucx)
+ endif()
endif()
if(ARROW_FLIGHT_SQL)
diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt
index 7447e675e08..f9d135654b4 100644
--- a/cpp/src/arrow/flight/CMakeLists.txt
+++ b/cpp/src/arrow/flight/CMakeLists.txt
@@ -313,4 +313,14 @@ if(ARROW_BUILD_BENCHMARKS)
add_dependencies(arrow-flight-benchmark arrow-flight-perf-server)
add_dependencies(arrow_flight arrow-flight-benchmark)
+
+ if(ARROW_WITH_UCX)
+ if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static")
+ target_link_libraries(arrow-flight-benchmark arrow_flight_transport_ucx_static)
+ target_link_libraries(arrow-flight-perf-server arrow_flight_transport_ucx_static)
+ else()
+ target_link_libraries(arrow-flight-benchmark arrow_flight_transport_ucx_shared)
+ target_link_libraries(arrow-flight-perf-server arrow_flight_transport_ucx_shared)
+ endif()
+ endif()
endif(ARROW_BUILD_BENCHMARKS)
diff --git a/cpp/src/arrow/flight/flight_benchmark.cc b/cpp/src/arrow/flight/flight_benchmark.cc
index 872c67c80b7..b5594e2d056 100644
--- a/cpp/src/arrow/flight/flight_benchmark.cc
+++ b/cpp/src/arrow/flight/flight_benchmark.cc
@@ -40,12 +40,20 @@
#include "arrow/flight/test_util.h"
#ifdef ARROW_CUDA
+#include
#include "arrow/gpu/cuda_api.h"
#endif
+#ifdef ARROW_WITH_UCX
+#include "arrow/flight/transport/ucx/ucx.h"
+#endif
DEFINE_bool(cuda, false, "Allocate results in CUDA memory");
DEFINE_string(transport, "grpc",
- "The network transport to use. Supported: \"grpc\" (default).");
+ "The network transport to use. Supported: \"grpc\" (default)"
+#ifdef ARROW_WITH_UCX
+ ", \"ucx\""
+#endif // ARROW_WITH_UCX
+ ".");
DEFINE_string(server_host, "",
"An existing performance server to benchmark against (leave blank to spawn "
"one automatically)");
@@ -497,6 +505,21 @@ int main(int argc, char** argv) {
options.disable_server_verification = true;
}
}
+ } else if (FLAGS_transport == "ucx") {
+#ifdef ARROW_WITH_UCX
+ arrow::flight::transport::ucx::InitializeFlightUcx();
+ if (FLAGS_test_unix || !FLAGS_server_unix.empty()) {
+ std::cerr << "Transport does not support domain sockets: " << FLAGS_transport
+ << std::endl;
+ return EXIT_FAILURE;
+ }
+ ARROW_CHECK_OK(arrow::flight::Location::Parse(
+ "ucx://" + FLAGS_server_host + ":" + std::to_string(FLAGS_server_port),
+ &location));
+#else
+ std::cerr << "Not built with transport: " << FLAGS_transport << std::endl;
+ return EXIT_FAILURE;
+#endif
} else {
std::cerr << "Unknown transport: " << FLAGS_transport << std::endl;
return EXIT_FAILURE;
@@ -514,6 +537,16 @@ int main(int argc, char** argv) {
ABORT_NOT_OK(arrow::cuda::CudaDeviceManager::Instance().Value(&manager));
ABORT_NOT_OK(manager->GetDevice(0).Value(&device));
call_options.memory_manager = device->default_memory_manager();
+
+ // Needed to prevent UCX warning
+ // cuda_md.c:162 UCX ERROR cuMemGetAddressRange(0x7f2ab5dc0000) error: invalid
+ // device context
+ std::shared_ptr context;
+ ABORT_NOT_OK(device->GetContext().Value(&context));
+ auto cuda_status = cuCtxPushCurrent(reinterpret_cast(context->handle()));
+ if (cuda_status != CUDA_SUCCESS) {
+ ARROW_LOG(WARNING) << "CUDA error " << cuda_status;
+ }
#else
std::cerr << "-cuda requires that Arrow is built with ARROW_CUDA" << std::endl;
return 1;
diff --git a/cpp/src/arrow/flight/perf_server.cc b/cpp/src/arrow/flight/perf_server.cc
index cc42ffedd68..6bb30f62ddd 100644
--- a/cpp/src/arrow/flight/perf_server.cc
+++ b/cpp/src/arrow/flight/perf_server.cc
@@ -19,6 +19,7 @@
#include
#include
+#include
#include
#include
#include
@@ -43,10 +44,17 @@
#ifdef ARROW_CUDA
#include "arrow/gpu/cuda_api.h"
#endif
+#ifdef ARROW_WITH_UCX
+#include "arrow/flight/transport/ucx/ucx.h"
+#endif
DEFINE_bool(cuda, false, "Allocate results in CUDA memory");
DEFINE_string(transport, "grpc",
- "The network transport to use. Supported: \"grpc\" (default).");
+ "The network transport to use. Supported: \"grpc\" (default)"
+#ifdef ARROW_WITH_UCX
+ ", \"ucx\""
+#endif // ARROW_WITH_UCX
+ ".");
DEFINE_string(server_host, "localhost", "Host where the server is running on");
DEFINE_int32(port, 31337, "Server port to listen on");
DEFINE_string(server_unix, "", "Unix socket path where the server is running on");
@@ -274,6 +282,29 @@ int main(int argc, char** argv) {
ARROW_CHECK_OK(arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix)
.Value(&connect_location));
}
+ } else if (FLAGS_transport == "ucx") {
+#ifdef ARROW_WITH_UCX
+ arrow::flight::transport::ucx::InitializeFlightUcx();
+ if (FLAGS_server_unix.empty()) {
+ if (!FLAGS_cert_file.empty() || !FLAGS_key_file.empty()) {
+ std::cerr << "Transport does not support TLS: " << FLAGS_transport << std::endl;
+ return EXIT_FAILURE;
+ }
+ ARROW_CHECK_OK(arrow::flight::Location::Parse(
+ "ucx://" + FLAGS_server_host + ":" + std::to_string(FLAGS_port),
+ &bind_location));
+ ARROW_CHECK_OK(arrow::flight::Location::Parse(
+ "ucx://" + FLAGS_server_host + ":" + std::to_string(FLAGS_port),
+ &connect_location));
+ } else {
+ std::cerr << "Transport does not support domain sockets: " << FLAGS_transport
+ << std::endl;
+ return EXIT_FAILURE;
+ }
+#else
+ std::cerr << "Not built with transport: " << FLAGS_transport << std::endl;
+ return EXIT_FAILURE;
+#endif
} else {
std::cerr << "Unknown transport: " << FLAGS_transport << std::endl;
return EXIT_FAILURE;
@@ -308,6 +339,7 @@ int main(int argc, char** argv) {
// Exit with a clean error code (0) on SIGTERM
ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM}));
std::cout << "Server transport: " << FLAGS_transport << std::endl;
+ std::cout << "Server location: " << connect_location.ToString() << std::endl;
if (FLAGS_server_unix.empty()) {
std::cout << "Server host: " << FLAGS_server_host << std::endl;
std::cout << "Server port: " << FLAGS_port << std::endl;
diff --git a/cpp/src/arrow/flight/transport/ucx/CMakeLists.txt b/cpp/src/arrow/flight/transport/ucx/CMakeLists.txt
new file mode 100644
index 00000000000..2c0e71de59f
--- /dev/null
+++ b/cpp/src/arrow/flight/transport/ucx/CMakeLists.txt
@@ -0,0 +1,77 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+add_custom_target(arrow_flight_transport_ucx)
+arrow_install_all_headers("arrow/flight/transport/ucx")
+
+find_package(PkgConfig REQUIRED)
+pkg_check_modules(UCX REQUIRED ucx)
+
+set(ARROW_FLIGHT_TRANSPORT_UCX_SRCS
+ ucx_client.cc
+ ucx_server.cc
+ ucx.cc
+ ucx_internal.cc
+ util_internal.cc)
+set(ARROW_FLIGHT_TRANSPORT_UCX_LINK_LIBS)
+
+include_directories(SYSTEM ${UCX_INCLUDE_DIRS})
+list(APPEND ARROW_FLIGHT_TRANSPORT_UCX_LINK_LIBS ${UCX_LIBRARIES})
+
+add_arrow_lib(arrow_flight_transport_ucx
+ # CMAKE_PACKAGE_NAME
+ # ArrowFlightTransportUcx
+ # PKG_CONFIG_NAME
+ # arrow-flight-transport-ucx
+ SOURCES
+ ${ARROW_FLIGHT_TRANSPORT_UCX_SRCS}
+ PRECOMPILED_HEADERS
+ "$<$:arrow/flight/transport/ucx/pch.h>"
+ DEPENDENCIES
+ SHARED_LINK_FLAGS
+ ${ARROW_VERSION_SCRIPT_FLAGS} # Defined in cpp/arrow/CMakeLists.txt
+ SHARED_LINK_LIBS
+ arrow_shared
+ arrow_flight_shared
+ ${ARROW_FLIGHT_TRANSPORT_UCX_LINK_LIBS}
+ STATIC_LINK_LIBS
+ arrow_static
+ arrow_flight_static
+ ${ARROW_FLIGHT_TRANSPORT_UCX_LINK_LIBS})
+
+if(ARROW_BUILD_TESTS)
+ if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static")
+ set(ARROW_FLIGHT_UCX_TEST_LINK_LIBS
+ arrow_static
+ arrow_flight_static
+ arrow_flight_testing_static
+ arrow_flight_transport_ucx_static
+ ${ARROW_TEST_LINK_LIBS})
+ else()
+ set(ARROW_FLIGHT_UCX_TEST_LINK_LIBS
+ arrow_shared
+ arrow_flight_shared
+ arrow_flight_testing_shared
+ arrow_flight_transport_ucx_shared
+ ${ARROW_TEST_LINK_LIBS})
+ endif()
+ add_arrow_test(flight_transport_ucx_test
+ STATIC_LINK_LIBS
+ ${ARROW_FLIGHT_UCX_TEST_LINK_LIBS}
+ LABELS
+ "arrow_flight")
+endif()
diff --git a/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc b/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc
new file mode 100644
index 00000000000..bb9e15fe1bf
--- /dev/null
+++ b/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc
@@ -0,0 +1,399 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include
+#include
+
+#include "arrow/array/array_base.h"
+#include "arrow/flight/test_definitions.h"
+#include "arrow/flight/test_util.h"
+#include "arrow/flight/transport/ucx/ucx.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/config.h"
+
+#ifdef UCP_API_VERSION
+#error "UCX headers should not be in public API"
+#endif
+
+#include "arrow/flight/transport/ucx/ucx_internal.h"
+
+#ifdef ARROW_CUDA
+#include "arrow/gpu/cuda_api.h"
+#endif
+
+namespace arrow {
+namespace flight {
+
+class UcxEnvironment : public ::testing::Environment {
+ public:
+ void SetUp() override { transport::ucx::InitializeFlightUcx(); }
+};
+
+testing::Environment* const kUcxEnvironment =
+ testing::AddGlobalTestEnvironment(new UcxEnvironment());
+
+//------------------------------------------------------------
+// Common transport tests
+
+class UcxConnectivityTest : public ConnectivityTest {
+ protected:
+ std::string transport() const override { return "ucx"; }
+};
+ARROW_FLIGHT_TEST_CONNECTIVITY(UcxConnectivityTest);
+
+class UcxDataTest : public DataTest {
+ protected:
+ std::string transport() const override { return "ucx"; }
+};
+ARROW_FLIGHT_TEST_DATA(UcxDataTest);
+
+class UcxDoPutTest : public DoPutTest {
+ protected:
+ std::string transport() const override { return "ucx"; }
+};
+ARROW_FLIGHT_TEST_DO_PUT(UcxDoPutTest);
+
+class UcxAppMetadataTest : public AppMetadataTest {
+ protected:
+ std::string transport() const override { return "ucx"; }
+};
+ARROW_FLIGHT_TEST_APP_METADATA(UcxAppMetadataTest);
+
+class UcxIpcOptionsTest : public IpcOptionsTest {
+ protected:
+ std::string transport() const override { return "ucx"; }
+};
+ARROW_FLIGHT_TEST_IPC_OPTIONS(UcxIpcOptionsTest);
+
+class UcxCudaDataTest : public CudaDataTest {
+ protected:
+ std::string transport() const override { return "ucx"; }
+};
+ARROW_FLIGHT_TEST_CUDA_DATA(UcxCudaDataTest);
+
+//------------------------------------------------------------
+// UCX internals tests
+
+constexpr std::initializer_list kStatusCodes = {
+ StatusCode::OK,
+ StatusCode::OutOfMemory,
+ StatusCode::KeyError,
+ StatusCode::TypeError,
+ StatusCode::Invalid,
+ StatusCode::IOError,
+ StatusCode::CapacityError,
+ StatusCode::IndexError,
+ StatusCode::Cancelled,
+ StatusCode::UnknownError,
+ StatusCode::NotImplemented,
+ StatusCode::SerializationError,
+ StatusCode::RError,
+ StatusCode::CodeGenError,
+ StatusCode::ExpressionValidationError,
+ StatusCode::ExecutionError,
+ StatusCode::AlreadyExists,
+};
+
+constexpr std::initializer_list kFlightStatusCodes = {
+ FlightStatusCode::Internal, FlightStatusCode::TimedOut,
+ FlightStatusCode::Cancelled, FlightStatusCode::Unauthenticated,
+ FlightStatusCode::Unauthorized, FlightStatusCode::Unavailable,
+ FlightStatusCode::Failed,
+};
+
+class TestStatusDetail : public StatusDetail {
+ public:
+ const char* type_id() const override { return "test-status-detail"; }
+ std::string ToString() const override { return "Custom status detail"; }
+};
+
+namespace transport {
+namespace ucx {
+
+static constexpr std::initializer_list kFrameTypes = {
+ FrameType::kHeaders, FrameType::kBuffer, FrameType::kPayloadHeader,
+ FrameType::kPayloadBody, FrameType::kDisconnect,
+};
+
+TEST(FrameHeader, Basics) {
+ for (const auto frame_type : kFrameTypes) {
+ FrameHeader header;
+ ASSERT_OK(header.Set(frame_type, /*counter=*/42, /*body_size=*/65535));
+ if (frame_type == FrameType::kDisconnect) {
+ ASSERT_RAISES(Cancelled, Frame::ParseHeader(header.data(), header.size()));
+ } else {
+ ASSERT_OK_AND_ASSIGN(auto frame, Frame::ParseHeader(header.data(), header.size()));
+ ASSERT_EQ(frame->type, frame_type);
+ ASSERT_EQ(frame->counter, 42);
+ ASSERT_EQ(frame->size, 65535);
+ }
+ }
+}
+
+TEST(FrameHeader, FrameType) {
+ for (const auto frame_type : kFrameTypes) {
+ ASSERT_LE(static_cast(frame_type), static_cast(FrameType::kMaxFrameType));
+ }
+}
+
+TEST(HeadersFrame, Parse) {
+ const char* data =
+ ("\x00\x00\x00\x02\x00\x00\x00\x05\x00\x00\x00\x03x-foobar"
+ "\x00\x00\x00\x05\x00\x00\x00\x01x-bin\x01");
+ constexpr int64_t size = 34;
+
+ {
+ std::unique_ptr buffer(
+ new Buffer(reinterpret_cast(data), size));
+ ASSERT_OK_AND_ASSIGN(auto headers, HeadersFrame::Parse(std::move(buffer)));
+ ASSERT_OK_AND_ASSIGN(auto foo, headers.Get("x-foo"));
+ ASSERT_EQ(foo, "bar");
+ ASSERT_OK_AND_ASSIGN(auto bin, headers.Get("x-bin"));
+ ASSERT_EQ(bin, "\x01");
+ }
+ {
+ std::unique_ptr buffer(new Buffer(reinterpret_cast(data), 3));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("expected number of headers"),
+ HeadersFrame::Parse(std::move(buffer)));
+ }
+ {
+ std::unique_ptr buffer(new Buffer(reinterpret_cast(data), 7));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("expected length of key 1"),
+ HeadersFrame::Parse(std::move(buffer)));
+ }
+ {
+ std::unique_ptr buffer(
+ new Buffer(reinterpret_cast(data), 10));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("expected length of value 1"),
+ HeadersFrame::Parse(std::move(buffer)));
+ }
+ {
+ std::unique_ptr buffer(
+ new Buffer(reinterpret_cast(data), 12));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr("expected key 1 to have length 5, but only 0 bytes remain"),
+ HeadersFrame::Parse(std::move(buffer)));
+ }
+ {
+ std::unique_ptr buffer(
+ new Buffer(reinterpret_cast(data), 17));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr(
+ "expected value 1 to have length 3, but only 0 bytes remain"),
+ HeadersFrame::Parse(std::move(buffer)));
+ }
+}
+
+TEST(HeadersFrame, RoundTripStatus) {
+ for (const auto code : kStatusCodes) {
+ {
+ Status expected = code == StatusCode::OK ? Status() : Status(code, "foo");
+ ASSERT_OK_AND_ASSIGN(auto headers, HeadersFrame::Make(expected, {}));
+ Status status;
+ ASSERT_OK(headers.GetStatus(&status));
+ ASSERT_EQ(status, expected);
+ }
+
+ if (code == StatusCode::OK) continue;
+
+ // Attach a generic status detail
+ {
+ auto detail = std::make_shared();
+ Status original(code, "foo", detail);
+ Status expected(code, "foo",
+ std::make_shared(FlightStatusCode::Internal,
+ detail->ToString()));
+ ASSERT_OK_AND_ASSIGN(auto headers, HeadersFrame::Make(expected, {}));
+ Status status;
+ ASSERT_OK(headers.GetStatus(&status));
+ ASSERT_EQ(status, expected);
+ }
+
+ // Attach a Flight status detail
+ for (const auto flight_code : kFlightStatusCodes) {
+ Status expected(code, "foo",
+ std::make_shared(flight_code, "extra"));
+ ASSERT_OK_AND_ASSIGN(auto headers, HeadersFrame::Make(expected, {}));
+ Status status;
+ ASSERT_OK(headers.GetStatus(&status));
+ ASSERT_EQ(status, expected);
+ }
+ }
+}
+} // namespace ucx
+} // namespace transport
+
+//------------------------------------------------------------
+// Ad-hoc UCX-specific tests
+
+class SimpleTestServer : public FlightServerBase {
+ public:
+ Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request,
+ std::unique_ptr* info) override {
+ if (request.path.size() > 0 && request.path[0] == "error") {
+ return status_;
+ }
+ auto examples = ExampleFlightInfo();
+ *info = std::unique_ptr(new FlightInfo(examples[0]));
+ return Status::OK();
+ }
+
+ Status DoGet(const ServerCallContext& context, const Ticket& request,
+ std::unique_ptr* data_stream) override {
+ RecordBatchVector batches;
+ RETURN_NOT_OK(ExampleIntBatches(&batches));
+ auto batch_reader = std::make_shared(batches[0]->schema(), batches);
+ *data_stream = std::unique_ptr(new RecordBatchStream(batch_reader));
+ return Status::OK();
+ }
+
+ void set_error_status(Status st) { status_ = std::move(st); }
+
+ private:
+ Status status_;
+};
+
+class TestUcx : public ::testing::Test {
+ public:
+ void SetUp() {
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme("ucx", "localhost", 0));
+ ASSERT_OK(MakeServer(
+ location, &server_, &client_,
+ [](FlightServerOptions* options) { return Status::OK(); },
+ [](FlightClientOptions* options) { return Status::OK(); }));
+ }
+
+ void TearDown() {
+ ASSERT_OK(client_->Close());
+ ASSERT_OK(server_->Shutdown());
+ }
+
+ protected:
+ std::unique_ptr client_;
+ std::unique_ptr server_;
+};
+
+TEST_F(TestUcx, GetFlightInfo) {
+ auto descriptor = FlightDescriptor::Path({"foo", "bar"});
+ std::unique_ptr info;
+ ASSERT_OK(client_->GetFlightInfo(descriptor, &info));
+ // Test that we can reuse the connection
+ ASSERT_OK(client_->GetFlightInfo(descriptor, &info));
+}
+
+TEST_F(TestUcx, SequentialClients) {
+ std::unique_ptr client2;
+ ASSERT_OK(FlightClient::Connect(server_->location(), FlightClientOptions::Defaults(),
+ &client2));
+
+ Ticket ticket{"a"};
+
+ std::unique_ptr stream1, stream2;
+ std::shared_ptr table1, table2;
+
+ ASSERT_OK(client_->DoGet(ticket, &stream1));
+ ASSERT_OK(stream1->ReadAll(&table1));
+
+ ASSERT_OK(client_->DoGet(ticket, &stream2));
+ ASSERT_OK(stream2->ReadAll(&table2));
+
+ AssertTablesEqual(*table1, *table2);
+}
+
+TEST_F(TestUcx, ConcurrentClients) {
+ std::unique_ptr client2;
+ ASSERT_OK(FlightClient::Connect(server_->location(), FlightClientOptions::Defaults(),
+ &client2));
+
+ Ticket ticket{"a"};
+
+ std::unique_ptr stream1, stream2;
+ std::shared_ptr table1, table2;
+
+ ASSERT_OK(client_->DoGet(ticket, &stream1));
+ ASSERT_OK(client2->DoGet(ticket, &stream2));
+
+ ASSERT_OK(stream1->ReadAll(&table1));
+ ASSERT_OK(stream2->ReadAll(&table2));
+
+ AssertTablesEqual(*table1, *table2);
+}
+
+TEST_F(TestUcx, Errors) {
+ auto descriptor = FlightDescriptor::Path({"error", "bar"});
+ std::unique_ptr info;
+
+ auto* server = reinterpret_cast(server_.get());
+ for (const auto code : kStatusCodes) {
+ if (code == StatusCode::OK) continue;
+
+ Status expected(code, "Error message");
+ server->set_error_status(expected);
+ Status actual = client_->GetFlightInfo(descriptor, &info);
+ ASSERT_EQ(actual, expected);
+
+ // Attach a generic status detail
+ {
+ auto detail = std::make_shared();
+ server->set_error_status(Status(code, "foo", detail));
+ Status expected(code, "foo",
+ std::make_shared(FlightStatusCode::Internal,
+ detail->ToString()));
+ Status actual = client_->GetFlightInfo(descriptor, &info);
+ ASSERT_EQ(actual, expected);
+ }
+
+ // Attach a Flight status detail
+ for (const auto flight_code : kFlightStatusCodes) {
+ Status expected(code, "Error message",
+ std::make_shared(flight_code, "extra"));
+ server->set_error_status(expected);
+ Status actual = client_->GetFlightInfo(descriptor, &info);
+ ASSERT_EQ(actual, expected);
+ }
+ }
+}
+
+TEST(TestUcxIpV6, DISABLED_IpV6Port) {
+ // TODO(lidavidm): while we can listen on IPv6 fine, we can't
+ // actually connect to it (ucp_conn_request_h appears to point to a
+ // port where nobody is listening)
+
+ // Also, disabled in CI as machines lack an IPv6 interface
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme("ucx", "[::1]", 0));
+
+ std::unique_ptr server(new SimpleTestServer());
+ FlightServerOptions server_options(location);
+ ASSERT_OK(server->Init(server_options));
+
+ std::unique_ptr client;
+ FlightClientOptions client_options = FlightClientOptions::Defaults();
+ ASSERT_OK(FlightClient::Connect(server->location(), client_options, &client));
+
+ auto descriptor = FlightDescriptor::Path({"foo", "bar"});
+ std::unique_ptr info;
+ ASSERT_OK(client->GetFlightInfo(descriptor, &info));
+}
+
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/transport/ucx/ucx.cc b/cpp/src/arrow/flight/transport/ucx/ucx.cc
new file mode 100644
index 00000000000..0b61dbb93e9
--- /dev/null
+++ b/cpp/src/arrow/flight/transport/ucx/ucx.cc
@@ -0,0 +1,43 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/transport/ucx/ucx.h"
+
+#include
+
+#include "arrow/flight/transport.h"
+#include "arrow/flight/transport/ucx/ucx_internal.h"
+#include "arrow/flight/transport_server.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace flight {
+namespace transport {
+namespace ucx {
+
+std::once_flag kInitializeOnce;
+void InitializeFlightUcx() {
+ std::call_once(kInitializeOnce, []() {
+ auto* registry = flight::internal::GetDefaultTransportRegistry();
+ DCHECK_OK(registry->RegisterClient("ucx", MakeUcxClientImpl));
+ DCHECK_OK(registry->RegisterServer("ucx", MakeUcxServerImpl));
+ });
+}
+} // namespace ucx
+} // namespace transport
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/transport/ucx/ucx.h b/cpp/src/arrow/flight/transport/ucx/ucx.h
new file mode 100644
index 00000000000..dda2c83035c
--- /dev/null
+++ b/cpp/src/arrow/flight/transport/ucx/ucx.h
@@ -0,0 +1,35 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Experimental UCX-based transport for Flight.
+
+#pragma once
+
+#include "arrow/flight/visibility.h"
+
+namespace arrow {
+namespace flight {
+namespace transport {
+namespace ucx {
+
+ARROW_FLIGHT_EXPORT
+void InitializeFlightUcx();
+
+} // namespace ucx
+} // namespace transport
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc
new file mode 100644
index 00000000000..04122942ba5
--- /dev/null
+++ b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc
@@ -0,0 +1,714 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+/// The client-side implementation of a UCX-based transport for
+/// Flight.
+///
+/// Each UCX driver is used to support one call at a time. This gives
+/// the greatest throughput for data plane methods, but is relatively
+/// expensive in terms of other resources, both for the server and the
+/// client. (UCX drivers have multiple threading modes: single-thread
+/// access, serialized access, and multi-thread access. Testing found
+/// that multi-thread access incurred high synchronization costs.)
+/// Hence, for concurrent calls in a single client, we must maintain
+/// multiple drivers, and so unlike gRPC, there is no real difference
+/// between using one client concurrently and using multiple
+/// independent clients.
+
+#include "arrow/flight/transport/ucx/ucx_internal.h"
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+#include "arrow/buffer.h"
+#include "arrow/flight/client.h"
+#include "arrow/flight/transport.h"
+#include "arrow/flight/transport/ucx/util_internal.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/uri.h"
+
+namespace arrow {
+namespace flight {
+namespace transport {
+namespace ucx {
+
+class UcxClientImpl;
+
+namespace {
+
+Status MergeStatuses(Status server_status, Status transport_status) {
+ if (server_status.ok()) {
+ if (transport_status.ok()) return server_status;
+ return transport_status;
+ } else if (transport_status.ok()) {
+ return server_status;
+ }
+ return Status::FromDetailAndArgs(server_status.code(), server_status.detail(),
+ server_status.message(),
+ ". Transport context: ", transport_status.ToString());
+}
+
+/// \brief An individual connection to the server.
+class ClientConnection {
+ public:
+ ClientConnection() = default;
+ ARROW_DISALLOW_COPY_AND_ASSIGN(ClientConnection);
+ ARROW_DEFAULT_MOVE_AND_ASSIGN(ClientConnection);
+ ~ClientConnection() { DCHECK(!driver_) << "Connection was not closed!"; }
+
+ Status Init(std::shared_ptr ucp_context, const arrow::internal::Uri& uri) {
+ auto status = InitImpl(std::move(ucp_context), uri);
+ // Clean up after-the-fact if we fail to initialize
+ if (!status.ok()) {
+ if (driver_) {
+ status = MergeStatuses(std::move(status), driver_->Close());
+ driver_.reset();
+ remote_endpoint_ = nullptr;
+ }
+ if (ucp_worker_) ucp_worker_.reset();
+ }
+ return status;
+ }
+
+ Status InitImpl(std::shared_ptr ucp_context,
+ const arrow::internal::Uri& uri) {
+ {
+ ucs_status_t status;
+ ucp_worker_params_t worker_params;
+ std::memset(&worker_params, 0, sizeof(worker_params));
+ worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
+ worker_params.thread_mode = UCS_THREAD_MODE_SERIALIZED;
+
+ ucp_worker_h ucp_worker;
+ status = ucp_worker_create(ucp_context->get(), &worker_params, &ucp_worker);
+ RETURN_NOT_OK(FromUcsStatus("ucp_worker_create", status));
+ ucp_worker_.reset(new UcpWorker(std::move(ucp_context), ucp_worker));
+ }
+ {
+ // Create endpoint for remote worker
+ struct sockaddr_storage connect_addr;
+ ARROW_ASSIGN_OR_RAISE(auto addrlen, UriToSockaddr(uri, &connect_addr));
+
+ ucp_ep_params_t params;
+ params.field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_NAME |
+ UCP_EP_PARAM_FIELD_SOCK_ADDR;
+ params.flags = UCP_EP_PARAMS_FLAGS_CLIENT_SERVER;
+ params.name = "UcxClientImpl";
+ params.sockaddr.addr = reinterpret_cast(&connect_addr);
+ params.sockaddr.addrlen = addrlen;
+
+ auto status = ucp_ep_create(ucp_worker_->get(), ¶ms, &remote_endpoint_);
+ RETURN_NOT_OK(FromUcsStatus("ucp_ep_create", status));
+ }
+
+ driver_.reset(new UcpCallDriver(ucp_worker_, remote_endpoint_));
+
+ {
+ // Set up Active Message (AM) handler
+ ucp_am_handler_param_t handler_params;
+ handler_params.field_mask = UCP_AM_HANDLER_PARAM_FIELD_ID |
+ UCP_AM_HANDLER_PARAM_FIELD_CB |
+ UCP_AM_HANDLER_PARAM_FIELD_ARG;
+ handler_params.id = kUcpAmHandlerId;
+ handler_params.cb = HandleIncomingActiveMessage;
+ handler_params.arg = driver_.get();
+ ucs_status_t status =
+ ucp_worker_set_am_recv_handler(ucp_worker_->get(), &handler_params);
+ RETURN_NOT_OK(FromUcsStatus("ucp_worker_set_am_recv_handler", status));
+ }
+
+ return Status::OK();
+ }
+
+ Status Close() {
+ if (!driver_) return Status::OK();
+
+ auto status = driver_->SendFrame(FrameType::kDisconnect, nullptr, 0);
+ status = MergeStatuses(std::move(status), driver_->Close());
+
+ driver_.reset();
+ remote_endpoint_ = nullptr;
+ ucp_worker_.reset();
+ return status;
+ }
+
+ UcpCallDriver* driver() {
+ DCHECK(driver_);
+ return driver_.get();
+ }
+
+ private:
+ static ucs_status_t HandleIncomingActiveMessage(void* self, const void* header,
+ size_t header_length, void* data,
+ size_t data_length,
+ const ucp_am_recv_param_t* param) {
+ auto* driver = reinterpret_cast(self);
+ return driver->RecvActiveMessage(header, header_length, data, data_length, param);
+ }
+
+ std::shared_ptr ucp_worker_;
+ ucp_ep_h remote_endpoint_;
+ std::unique_ptr driver_;
+};
+
+class UcxClientStream : public internal::ClientDataStream {
+ public:
+ UcxClientStream(UcxClientImpl* impl, ClientConnection conn)
+ : impl_(impl),
+ conn_(std::move(conn)),
+ driver_(conn_.driver()),
+ writes_done_(false),
+ finished_(false) {}
+
+ protected:
+ Status DoFinish() override;
+
+ UcxClientImpl* impl_;
+ ClientConnection conn_;
+ UcpCallDriver* driver_;
+ bool writes_done_;
+ bool finished_;
+ Status io_status_;
+ Status server_status_;
+};
+
+class GetClientStream : public UcxClientStream {
+ public:
+ GetClientStream(UcxClientImpl* impl, ClientConnection conn)
+ : UcxClientStream(impl, std::move(conn)) {
+ writes_done_ = true;
+ }
+
+ bool ReadData(internal::FlightData* data) override {
+ if (finished_) return false;
+
+ bool success = true;
+ io_status_ = ReadImpl(data).Value(&success);
+
+ if (!io_status_.ok() || !success) {
+ finished_ = true;
+ }
+ return success;
+ }
+
+ private:
+ ::arrow::Result ReadImpl(internal::FlightData* data) {
+ ARROW_ASSIGN_OR_RAISE(auto frame, driver_->ReadNextFrame());
+
+ if (frame->type == FrameType::kHeaders) {
+ // Trailers, stream is over
+ ARROW_ASSIGN_OR_RAISE(auto headers, HeadersFrame::Parse(std::move(frame->buffer)));
+ RETURN_NOT_OK(headers.GetStatus(&server_status_));
+ return false;
+ }
+
+ RETURN_NOT_OK(driver_->ExpectFrameType(*frame, FrameType::kPayloadHeader));
+ PayloadHeaderFrame payload_header(std::move(frame->buffer));
+ RETURN_NOT_OK(payload_header.ToFlightData(data));
+
+ // DoGet does not support metadata-only messages, so we can always
+ // assume we have an IPC payload
+ ARROW_ASSIGN_OR_RAISE(auto message, ipc::Message::Open(data->metadata, nullptr));
+
+ if (ipc::Message::HasBody(message->type())) {
+ ARROW_ASSIGN_OR_RAISE(frame, driver_->ReadNextFrame());
+ RETURN_NOT_OK(driver_->ExpectFrameType(*frame, FrameType::kPayloadBody));
+ data->body = std::move(frame->buffer);
+ }
+ return true;
+ }
+};
+
+class WriteClientStream : public UcxClientStream {
+ public:
+ WriteClientStream(UcxClientImpl* impl, ClientConnection conn)
+ : UcxClientStream(impl, std::move(conn)) {
+ std::thread t(&WriteClientStream::DriveWorker, this);
+ driver_thread_.swap(t);
+ }
+ arrow::Result WriteData(const FlightPayload& payload) override {
+ std::unique_lock guard(driver_mutex_);
+ if (finished_ || writes_done_) return Status::Invalid("Already done writing");
+ outgoing_ = driver_->SendFlightPayload(payload);
+ working_cv_.notify_all();
+ received_cv_.wait(guard, [this] { return outgoing_.is_finished(); });
+
+ auto status = outgoing_.status();
+ outgoing_ = Future<>();
+ RETURN_NOT_OK(status);
+ return true;
+ }
+ Status WritesDone() override {
+ std::unique_lock guard(driver_mutex_);
+ if (!writes_done_) {
+ ARROW_ASSIGN_OR_RAISE(auto headers, HeadersFrame::Make({}));
+ outgoing_ =
+ driver_->SendFrameAsync(FrameType::kHeaders, std::move(headers).GetBuffer());
+ working_cv_.notify_all();
+ received_cv_.wait(guard, [this] { return outgoing_.is_finished(); });
+
+ writes_done_ = true;
+ auto status = outgoing_.status();
+ outgoing_ = Future<>();
+ RETURN_NOT_OK(status);
+ }
+ return Status::OK();
+ }
+
+ protected:
+ void JoinThread() {
+ try {
+ driver_thread_.join();
+ } catch (const std::system_error&) {
+ // Ignore
+ }
+ }
+ void DriveWorker() {
+ while (true) {
+ {
+ std::unique_lock guard(driver_mutex_);
+ working_cv_.wait(guard, [this] {
+ return finished_ || incoming_.is_valid() || outgoing_.is_valid();
+ });
+ if (finished_) return;
+ }
+
+ while (true) {
+ std::unique_lock guard(driver_mutex_);
+ if (!incoming_.is_valid() && !outgoing_.is_valid()) break;
+ if (incoming_.is_valid() && incoming_.is_finished()) {
+ if (!incoming_.status().ok()) {
+ io_status_ = incoming_.status();
+ finished_ = true;
+ } else {
+ HandleIncomingMessage(*incoming_.result());
+ }
+ incoming_ = Future>();
+ received_cv_.notify_all();
+ break;
+ }
+ if (outgoing_.is_valid() && outgoing_.is_finished()) {
+ received_cv_.notify_all();
+ break;
+ }
+ driver_->MakeProgress();
+ if (finished_) return;
+ }
+ }
+ }
+
+ virtual void HandleIncomingMessage(const std::shared_ptr& frame) {}
+
+ std::mutex driver_mutex_;
+ std::thread driver_thread_;
+ std::condition_variable received_cv_;
+ std::condition_variable working_cv_;
+ Future> incoming_;
+ Future<> outgoing_;
+};
+
+class PutClientStream : public WriteClientStream {
+ public:
+ using WriteClientStream::WriteClientStream;
+ bool ReadPutMetadata(std::shared_ptr* out) override {
+ std::unique_lock guard(driver_mutex_);
+ if (finished_) {
+ *out = nullptr;
+ guard.unlock();
+ JoinThread();
+ return false;
+ }
+ next_metadata_ = nullptr;
+ incoming_ = driver_->ReadFrameAsync();
+ working_cv_.notify_all();
+ received_cv_.wait(guard, [this] { return next_metadata_ != nullptr || finished_; });
+
+ if (finished_) {
+ *out = nullptr;
+ guard.unlock();
+ JoinThread();
+ return false;
+ }
+ *out = std::move(next_metadata_);
+ return true;
+ }
+
+ private:
+ void HandleIncomingMessage(const std::shared_ptr& frame) override {
+ // No lock here, since this is called from DriveWorker() which is
+ // holding the lock
+ if (frame->type == FrameType::kBuffer) {
+ next_metadata_ = std::move(frame->buffer);
+ } else if (frame->type == FrameType::kHeaders) {
+ // Trailers, stream is over
+ finished_ = true;
+ HeadersFrame headers;
+ io_status_ = HeadersFrame::Parse(std::move(frame->buffer)).Value(&headers);
+ if (!io_status_.ok()) {
+ finished_ = true;
+ return;
+ }
+ io_status_ = headers.GetStatus(&server_status_);
+ if (!io_status_.ok()) {
+ finished_ = true;
+ return;
+ }
+ } else {
+ finished_ = true;
+ io_status_ =
+ Status::IOError("Unexpected frame type ", static_cast(frame->type));
+ }
+ }
+ std::shared_ptr next_metadata_;
+};
+
+class ExchangeClientStream : public WriteClientStream {
+ public:
+ ExchangeClientStream(UcxClientImpl* impl, ClientConnection conn)
+ : WriteClientStream(impl, std::move(conn)), read_state_(ReadState::kFinished) {}
+
+ bool ReadData(internal::FlightData* data) override {
+ std::unique_lock guard(driver_mutex_);
+ if (finished_) {
+ guard.unlock();
+ JoinThread();
+ return false;
+ }
+
+ // Drive the read loop here. (We can't recursively call
+ // ReadFrameAsync below since the internal mutex is not
+ // recursive.)
+ read_state_ = ReadState::kExpectHeader;
+ incoming_ = driver_->ReadFrameAsync();
+ working_cv_.notify_all();
+ received_cv_.wait(guard, [this] { return read_state_ != ReadState::kExpectHeader; });
+ if (read_state_ != ReadState::kFinished) {
+ incoming_ = driver_->ReadFrameAsync();
+ working_cv_.notify_all();
+ received_cv_.wait(guard, [this] { return read_state_ == ReadState::kFinished; });
+ }
+
+ if (finished_) {
+ guard.unlock();
+ JoinThread();
+ return false;
+ }
+ *data = std::move(next_data_);
+ return true;
+ }
+
+ private:
+ enum class ReadState {
+ kFinished,
+ kExpectHeader,
+ kExpectBody,
+ };
+
+ std::string DebugExpectingString() {
+ switch (read_state_) {
+ case ReadState::kFinished:
+ return "(not expecting a frame)";
+ case ReadState::kExpectHeader:
+ return "payload header frame";
+ case ReadState::kExpectBody:
+ return "payload body frame";
+ }
+ return "(unknown or invalid state)";
+ }
+
+ void HandleIncomingMessage(const std::shared_ptr& frame) override {
+ // No lock here, since this is called from MakeProgress()
+ // which is called under the lock already
+ if (frame->type == FrameType::kPayloadHeader) {
+ if (read_state_ != ReadState::kExpectHeader) {
+ finished_ = true;
+ io_status_ = Status::IOError("Got unexpected payload header frame, expected: ",
+ DebugExpectingString());
+ return;
+ }
+
+ PayloadHeaderFrame payload_header(std::move(frame->buffer));
+ io_status_ = payload_header.ToFlightData(&next_data_);
+ if (!io_status_.ok()) {
+ finished_ = true;
+ return;
+ }
+
+ if (next_data_.metadata) {
+ std::unique_ptr message;
+ io_status_ = ipc::Message::Open(next_data_.metadata, nullptr).Value(&message);
+ if (!io_status_.ok()) {
+ finished_ = true;
+ return;
+ }
+ if (ipc::Message::HasBody(message->type())) {
+ read_state_ = ReadState::kExpectBody;
+ return;
+ }
+ }
+ read_state_ = ReadState::kFinished;
+ } else if (frame->type == FrameType::kPayloadBody) {
+ next_data_.body = std::move(frame->buffer);
+ read_state_ = ReadState::kFinished;
+ } else if (frame->type == FrameType::kHeaders) {
+ // Trailers, stream is over
+ finished_ = true;
+ read_state_ = ReadState::kFinished;
+ HeadersFrame headers;
+ io_status_ = HeadersFrame::Parse(std::move(frame->buffer)).Value(&headers);
+ if (!io_status_.ok()) {
+ finished_ = true;
+ return;
+ }
+ io_status_ = headers.GetStatus(&server_status_);
+ if (!io_status_.ok()) {
+ finished_ = true;
+ return;
+ }
+ } else {
+ finished_ = true;
+ io_status_ =
+ Status::IOError("Unexpected frame type ", static_cast(frame->type));
+ read_state_ = ReadState::kFinished;
+ }
+ }
+
+ internal::FlightData next_data_;
+ ReadState read_state_;
+};
+} // namespace
+
+class ARROW_FLIGHT_EXPORT UcxClientImpl
+ : public arrow::flight::internal::ClientTransport {
+ public:
+ UcxClientImpl() {}
+
+ virtual ~UcxClientImpl() {
+ if (!ucp_context_) return;
+ auto status = Close();
+ if (!status.ok()) {
+ ARROW_LOG(WARNING) << "UcxClientImpl errored in Close() in destructor: "
+ << status.ToString();
+ }
+ }
+
+ Status Init(const FlightClientOptions& options, const Location& location,
+ const arrow::internal::Uri& uri) override {
+ RETURN_NOT_OK(uri_.Parse(uri.ToString()));
+ {
+ ucp_config_t* ucp_config;
+ ucp_params_t ucp_params;
+ ucs_status_t status;
+
+ status = ucp_config_read(nullptr, nullptr, &ucp_config);
+ RETURN_NOT_OK(FromUcsStatus("ucp_config_read", status));
+
+ std::memset(&ucp_params, 0, sizeof(ucp_params));
+ ucp_params.field_mask = UCP_PARAM_FIELD_FEATURES;
+ ucp_params.features = UCP_FEATURE_AM | UCP_FEATURE_WAKEUP;
+
+ ucp_context_h ucp_context;
+ status = ucp_init(&ucp_params, ucp_config, &ucp_context);
+ ucp_config_release(ucp_config);
+ RETURN_NOT_OK(FromUcsStatus("ucp_init", status));
+ ucp_context_.reset(new UcpContext(ucp_context));
+ }
+
+ RETURN_NOT_OK(MakeConnection());
+ return Status::OK();
+ }
+
+ Status Close() override {
+ std::unique_lock connections_mutex_;
+ while (!connections_.empty()) {
+ RETURN_NOT_OK(connections_.front().Close());
+ connections_.pop_front();
+ }
+ return Status::OK();
+ }
+
+ Status GetFlightInfo(const FlightCallOptions& options,
+ const FlightDescriptor& descriptor,
+ std::unique_ptr* info) override {
+ ARROW_ASSIGN_OR_RAISE(auto connection, CheckoutConnection(options));
+ UcpCallDriver* driver = connection.driver();
+
+ auto impl = [&]() {
+ RETURN_NOT_OK(driver->StartCall(kMethodGetFlightInfo));
+
+ ARROW_ASSIGN_OR_RAISE(std::string payload, descriptor.SerializeToString());
+
+ RETURN_NOT_OK(driver->SendFrame(FrameType::kBuffer,
+ reinterpret_cast(payload.data()),
+ static_cast(payload.size())));
+
+ ARROW_ASSIGN_OR_RAISE(auto incoming_message, driver->ReadNextFrame());
+ if (incoming_message->type == FrameType::kBuffer) {
+ ARROW_ASSIGN_OR_RAISE(
+ *info, FlightInfo::Deserialize(util::string_view(*incoming_message->buffer)));
+ ARROW_ASSIGN_OR_RAISE(incoming_message, driver->ReadNextFrame());
+ }
+ RETURN_NOT_OK(driver->ExpectFrameType(*incoming_message, FrameType::kHeaders));
+ ARROW_ASSIGN_OR_RAISE(auto headers,
+ HeadersFrame::Parse(std::move(incoming_message->buffer)));
+ Status status;
+ RETURN_NOT_OK(headers.GetStatus(&status));
+ return status;
+ };
+ auto status = impl();
+ return MergeStatuses(std::move(status), ReturnConnection(std::move(connection)));
+ }
+
+ Status DoExchange(const FlightCallOptions& options,
+ std::unique_ptr* out) override {
+ ARROW_ASSIGN_OR_RAISE(auto connection, CheckoutConnection(options));
+ UcpCallDriver* driver = connection.driver();
+
+ auto status = driver->StartCall(kMethodDoExchange);
+ if (ARROW_PREDICT_TRUE(status.ok())) {
+ *out =
+ arrow::internal::make_unique(this, std::move(connection));
+ return Status::OK();
+ }
+ return MergeStatuses(std::move(status), ReturnConnection(std::move(connection)));
+ }
+
+ Status DoGet(const FlightCallOptions& options, const Ticket& ticket,
+ std::unique_ptr* stream) override {
+ ARROW_ASSIGN_OR_RAISE(auto connection, CheckoutConnection(options));
+ UcpCallDriver* driver = connection.driver();
+
+ auto impl = [&]() {
+ RETURN_NOT_OK(driver->StartCall(kMethodDoGet));
+ ARROW_ASSIGN_OR_RAISE(std::string payload, ticket.SerializeToString());
+ RETURN_NOT_OK(driver->SendFrame(FrameType::kBuffer,
+ reinterpret_cast(payload.data()),
+ static_cast(payload.size())));
+ *stream =
+ arrow::internal::make_unique(this, std::move(connection));
+ return Status::OK();
+ };
+
+ auto status = impl();
+ if (ARROW_PREDICT_TRUE(status.ok())) return status;
+ return MergeStatuses(std::move(status), ReturnConnection(std::move(connection)));
+ }
+
+ Status DoPut(const FlightCallOptions& options,
+ std::unique_ptr* out) override {
+ ARROW_ASSIGN_OR_RAISE(auto connection, CheckoutConnection(options));
+ UcpCallDriver* driver = connection.driver();
+
+ auto status = driver->StartCall(kMethodDoPut);
+ if (ARROW_PREDICT_TRUE(status.ok())) {
+ *out = arrow::internal::make_unique(this, std::move(connection));
+ return Status::OK();
+ }
+ return MergeStatuses(std::move(status), ReturnConnection(std::move(connection)));
+ }
+
+ Status DoAction(const FlightCallOptions& options, const Action& action,
+ std::unique_ptr* results) override {
+ // XXX: fake this for now to get the perf test to work
+ return Status::OK();
+ }
+
+ Status MakeConnection() {
+ ClientConnection conn;
+ RETURN_NOT_OK(conn.Init(ucp_context_, uri_));
+ connections_.push_back(std::move(conn));
+ return Status::OK();
+ }
+
+ arrow::Result CheckoutConnection(const FlightCallOptions& options) {
+ std::unique_lock connections_mutex_;
+ if (connections_.empty()) RETURN_NOT_OK(MakeConnection());
+ ClientConnection conn = std::move(connections_.front());
+ conn.driver()->set_memory_manager(options.memory_manager);
+ conn.driver()->set_read_memory_pool(options.read_options.memory_pool);
+ conn.driver()->set_write_memory_pool(options.write_options.memory_pool);
+ connections_.pop_front();
+ return conn;
+ }
+
+ Status ReturnConnection(ClientConnection conn) {
+ std::unique_lock connections_mutex_;
+ // TODO(lidavidm): for future improvement: reclaim clients
+ // asynchronously in the background (try to avoid issues like
+ // constantly opening/closing clients because the application is
+ // just barely over the limit of open connections)
+ if (connections_.size() >= kMaxOpenConnections) {
+ RETURN_NOT_OK(conn.Close());
+ return Status::OK();
+ }
+ connections_.push_back(std::move(conn));
+ return Status::OK();
+ }
+
+ private:
+ static constexpr size_t kMaxOpenConnections = 3;
+
+ arrow::internal::Uri uri_;
+ std::shared_ptr ucp_context_;
+ std::mutex connections_mutex_;
+ std::deque connections_;
+};
+
+Status UcxClientStream::DoFinish() {
+ RETURN_NOT_OK(WritesDone());
+ if (!finished_) {
+ internal::FlightData message;
+ std::shared_ptr metadata;
+ while (ReadData(&message)) {
+ }
+ while (ReadPutMetadata(&metadata)) {
+ }
+ finished_ = true;
+ }
+ if (impl_) {
+ auto status = impl_->ReturnConnection(std::move(conn_));
+ impl_ = nullptr;
+ driver_ = nullptr;
+ if (!status.ok()) {
+ if (io_status_.ok()) {
+ io_status_ = std::move(status);
+ } else {
+ io_status_ = Status::FromDetailAndArgs(
+ io_status_.code(), io_status_.detail(), io_status_.message(),
+ ". Transport context: ", status.ToString());
+ }
+ }
+ }
+ return MergeStatuses(server_status_, io_status_);
+}
+
+std::unique_ptr MakeUcxClientImpl() {
+ return arrow::internal::make_unique();
+}
+
+} // namespace ucx
+} // namespace transport
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc
new file mode 100644
index 00000000000..5baabc5b805
--- /dev/null
+++ b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc
@@ -0,0 +1,1148 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/transport/ucx/ucx_internal.h"
+
+#include
+#include
+#include
+#include
+
+#include "arrow/buffer.h"
+#include "arrow/flight/transport/ucx/util_internal.h"
+#include "arrow/flight/types.h"
+#include "arrow/util/base64.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/uri.h"
+
+namespace arrow {
+namespace flight {
+namespace transport {
+namespace ucx {
+
+// Defines to test different implementation strategies
+// Enable the CONTIG path for CPU-only data
+// #define ARROW_FLIGHT_UCX_SEND_CONTIG
+// Enable ucp_mem_map in IOV path
+// #define ARROW_FLIGHT_UCX_SEND_IOV_MAP
+
+constexpr char kHeaderMethod[] = ":method:";
+
+namespace {
+Status SizeToUInt32BytesBe(const int64_t in, uint8_t* out) {
+ if (ARROW_PREDICT_FALSE(in < 0)) {
+ return Status::Invalid("Length cannot be negative");
+ } else if (ARROW_PREDICT_FALSE(
+ in > static_cast(std::numeric_limits::max()))) {
+ return Status::Invalid("Length cannot exceed uint32_t");
+ }
+ UInt32ToBytesBe(static_cast(in), out);
+ return Status::OK();
+}
+ucs_memory_type InferMemoryType(const Buffer& buffer) {
+ if (!buffer.is_cpu()) {
+ return UCS_MEMORY_TYPE_CUDA;
+ }
+ return UCS_MEMORY_TYPE_UNKNOWN;
+}
+void TryMapBuffer(ucp_context_h context, const void* buffer, const size_t size,
+ ucs_memory_type memory_type, ucp_mem_h* memh_p) {
+ ucp_mem_map_params_t map_param;
+ std::memset(&map_param, 0, sizeof(map_param));
+ map_param.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS |
+ UCP_MEM_MAP_PARAM_FIELD_LENGTH |
+ UCP_MEM_MAP_PARAM_FIELD_MEMORY_TYPE;
+ map_param.address = const_cast(buffer);
+ map_param.length = size;
+ map_param.memory_type = memory_type;
+ auto ucs_status = ucp_mem_map(context, &map_param, memh_p);
+ if (ucs_status != UCS_OK) {
+ *memh_p = nullptr;
+ ARROW_LOG(WARNING) << "Could not map memory: "
+ << FromUcsStatus("ucp_mem_map", ucs_status);
+ }
+}
+void TryMapBuffer(ucp_context_h context, const Buffer& buffer, ucp_mem_h* memh_p) {
+ TryMapBuffer(context, reinterpret_cast(buffer.address()),
+ static_cast(buffer.size()), InferMemoryType(buffer), memh_p);
+}
+void TryUnmapBuffer(ucp_context_h context, ucp_mem_h memh_p) {
+ if (memh_p) {
+ auto ucs_status = ucp_mem_unmap(context, memh_p);
+ if (ucs_status != UCS_OK) {
+ ARROW_LOG(WARNING) << "Could not unmap memory: "
+ << FromUcsStatus("ucp_mem_unmap", ucs_status);
+ }
+ }
+}
+
+/// \brief Wrapper around a UCX zero copy buffer (a host memory DATA
+/// buffer).
+///
+/// Owns a reference to the associated worker to avoid undefined
+/// behavior.
+class UcxDataBuffer : public Buffer {
+ public:
+ explicit UcxDataBuffer(std::shared_ptr worker, void* data, size_t size)
+ : Buffer(const_cast(reinterpret_cast(data)),
+ static_cast(size)),
+ worker_(std::move(worker)) {}
+
+ ~UcxDataBuffer() {
+ ucp_am_data_release(worker_->get(),
+ const_cast(reinterpret_cast(data())));
+ }
+
+ private:
+ std::shared_ptr worker_;
+};
+}; // namespace
+
+constexpr size_t FrameHeader::kFrameHeaderBytes;
+constexpr uint8_t FrameHeader::kFrameVersion;
+
+Status FrameHeader::Set(FrameType frame_type, uint32_t counter, int64_t body_size) {
+ header[0] = kFrameVersion;
+ header[1] = static_cast(frame_type);
+ UInt32ToBytesBe(counter, header.data() + 4);
+ RETURN_NOT_OK(SizeToUInt32BytesBe(body_size, header.data() + 8));
+ return Status::OK();
+}
+
+arrow::Result> Frame::ParseHeader(const void* header,
+ size_t header_length) {
+ if (header_length < FrameHeader::kFrameHeaderBytes) {
+ return Status::IOError("Header is too short, must be at least ",
+ FrameHeader::kFrameHeaderBytes, " bytes, got ", header_length);
+ }
+
+ const uint8_t* frame_header = reinterpret_cast(header);
+ if (frame_header[0] != FrameHeader::kFrameVersion) {
+ return Status::IOError("Expected frame version ",
+ static_cast(FrameHeader::kFrameVersion), " but got ",
+ static_cast(frame_header[0]));
+ } else if (frame_header[1] > static_cast(FrameType::kMaxFrameType)) {
+ return Status::IOError("Unknown frame type ", static_cast(frame_header[1]));
+ }
+
+ const FrameType frame_type = static_cast(frame_header[1]);
+ const uint32_t frame_counter = BytesToUInt32Be(frame_header + 4);
+ const uint32_t frame_size = BytesToUInt32Be(frame_header + 8);
+
+ if (frame_type == FrameType::kDisconnect) {
+ return Status::Cancelled("Client initiated disconnect");
+ }
+
+ return std::make_shared(frame_type, frame_size, frame_counter, nullptr);
+}
+
+arrow::Result HeadersFrame::Parse(std::unique_ptr buffer) {
+ HeadersFrame result;
+ const uint8_t* payload = buffer->data();
+ const uint8_t* end = payload + buffer->size();
+ if (ARROW_PREDICT_FALSE((end - payload) < 4)) {
+ return Status::Invalid("Buffer underflow, expected number of headers");
+ }
+ const uint32_t num_headers = BytesToUInt32Be(payload);
+ payload += 4;
+ for (uint32_t i = 0; i < num_headers; i++) {
+ if (ARROW_PREDICT_FALSE((end - payload) < 4)) {
+ return Status::Invalid("Buffer underflow, expected length of key ", i + 1);
+ }
+ const uint32_t key_length = BytesToUInt32Be(payload);
+ payload += 4;
+
+ if (ARROW_PREDICT_FALSE((end - payload) < 4)) {
+ return Status::Invalid("Buffer underflow, expected length of value ", i + 1);
+ }
+ const uint32_t value_length = BytesToUInt32Be(payload);
+ payload += 4;
+
+ if (ARROW_PREDICT_FALSE((end - payload) < key_length)) {
+ return Status::Invalid("Buffer underflow, expected key ", i + 1, " to have length ",
+ key_length, ", but only ", (end - payload), " bytes remain");
+ }
+ const util::string_view key(reinterpret_cast(payload), key_length);
+ payload += key_length;
+
+ if (ARROW_PREDICT_FALSE((end - payload) < value_length)) {
+ return Status::Invalid("Buffer underflow, expected value ", i + 1,
+ " to have length ", value_length, ", but only ",
+ (end - payload), " bytes remain");
+ }
+ const util::string_view value(reinterpret_cast(payload), value_length);
+ payload += value_length;
+ result.headers_.emplace_back(key, value);
+ }
+
+ result.buffer_ = std::move(buffer);
+ return result;
+}
+arrow::Result HeadersFrame::Make(
+ const std::vector>& headers) {
+ int32_t total_length = 4 /* # of headers */;
+ for (const auto& header : headers) {
+ total_length += 4 /* key length */ + 4 /* value length */ +
+ header.first.size() /* key */ + header.second.size();
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateBuffer(total_length));
+ uint8_t* payload = buffer->mutable_data();
+
+ RETURN_NOT_OK(SizeToUInt32BytesBe(headers.size(), payload));
+ payload += 4;
+ for (const auto& header : headers) {
+ RETURN_NOT_OK(SizeToUInt32BytesBe(header.first.size(), payload));
+ payload += 4;
+ RETURN_NOT_OK(SizeToUInt32BytesBe(header.second.size(), payload));
+ payload += 4;
+ std::memcpy(payload, header.first.data(), header.first.size());
+ payload += header.first.size();
+ std::memcpy(payload, header.second.data(), header.second.size());
+ payload += header.second.size();
+ }
+ return Parse(std::move(buffer));
+}
+arrow::Result HeadersFrame::Make(
+ const Status& status,
+ const std::vector>& headers) {
+ auto all_headers = headers;
+ all_headers.emplace_back(kHeaderStatusCode,
+ std::to_string(static_cast(status.code())));
+ all_headers.emplace_back(kHeaderStatusMessage, status.message());
+ if (status.detail()) {
+ auto fsd = FlightStatusDetail::UnwrapStatus(status);
+ if (fsd) {
+ all_headers.emplace_back(kHeaderStatusDetailCode,
+ std::to_string(static_cast(fsd->code())));
+ all_headers.emplace_back(kHeaderStatusDetail, fsd->extra_info());
+ } else {
+ all_headers.emplace_back(kHeaderStatusDetail, status.detail()->ToString());
+ }
+ }
+ return Make(all_headers);
+}
+
+arrow::Result HeadersFrame::Get(const std::string& key) {
+ for (const auto& pair : headers_) {
+ if (pair.first == key) return pair.second;
+ }
+ return Status::KeyError(key);
+}
+
+Status HeadersFrame::GetStatus(Status* out) {
+ util::string_view code_str, message_str;
+ auto status = Get(kHeaderStatusCode).Value(&code_str);
+ if (!status.ok()) {
+ return Status::KeyError("Server did not send status code header ", kHeaderStatusCode);
+ }
+
+ StatusCode status_code = StatusCode::OK;
+ auto code = std::strtol(code_str.data(), nullptr, /*base=*/10);
+ switch (code) {
+ case 0:
+ status_code = StatusCode::OK;
+ break;
+ case 1:
+ status_code = StatusCode::OutOfMemory;
+ break;
+ case 2:
+ status_code = StatusCode::KeyError;
+ break;
+ case 3:
+ status_code = StatusCode::TypeError;
+ break;
+ case 4:
+ status_code = StatusCode::Invalid;
+ break;
+ case 5:
+ status_code = StatusCode::IOError;
+ break;
+ case 6:
+ status_code = StatusCode::CapacityError;
+ break;
+ case 7:
+ status_code = StatusCode::IndexError;
+ break;
+ case 8:
+ status_code = StatusCode::Cancelled;
+ break;
+ case 9:
+ status_code = StatusCode::UnknownError;
+ break;
+ case 10:
+ status_code = StatusCode::NotImplemented;
+ break;
+ case 11:
+ status_code = StatusCode::SerializationError;
+ break;
+ case 13:
+ status_code = StatusCode::RError;
+ break;
+ case 40:
+ status_code = StatusCode::CodeGenError;
+ break;
+ case 41:
+ status_code = StatusCode::ExpressionValidationError;
+ break;
+ case 42:
+ status_code = StatusCode::ExecutionError;
+ break;
+ case 45:
+ status_code = StatusCode::AlreadyExists;
+ break;
+ default:
+ status_code = StatusCode::UnknownError;
+ break;
+ }
+ if (status_code == StatusCode::OK) {
+ *out = Status::OK();
+ return Status::OK();
+ }
+
+ status = Get(kHeaderStatusMessage).Value(&message_str);
+ if (!status.ok()) {
+ *out = Status(status_code, "Server did not send status message header", nullptr);
+ return Status::OK();
+ }
+
+ util::string_view detail_code_str, detail_str;
+ FlightStatusCode detail_code = FlightStatusCode::Internal;
+
+ if (Get(kHeaderStatusDetailCode).Value(&detail_code_str).ok()) {
+ auto detail_code_int = std::strtol(detail_code_str.data(), nullptr, /*base=*/10);
+ switch (detail_code_int) {
+ case 1:
+ detail_code = FlightStatusCode::TimedOut;
+ break;
+ case 2:
+ detail_code = FlightStatusCode::Cancelled;
+ break;
+ case 3:
+ detail_code = FlightStatusCode::Unauthenticated;
+ break;
+ case 4:
+ detail_code = FlightStatusCode::Unauthorized;
+ break;
+ case 5:
+ detail_code = FlightStatusCode::Unavailable;
+ break;
+ case 6:
+ detail_code = FlightStatusCode::Failed;
+ break;
+ case 0:
+ default:
+ detail_code = FlightStatusCode::Internal;
+ break;
+ }
+ }
+ ARROW_UNUSED(Get(kHeaderStatusDetail).Value(&detail_str));
+
+ std::shared_ptr detail = nullptr;
+ if (!detail_str.empty()) {
+ detail = std::make_shared(detail_code, std::string(detail_str));
+ }
+ *out = Status(status_code, std::string(message_str), std::move(detail));
+ return Status::OK();
+}
+
+namespace {
+static constexpr uint32_t kMissingFieldSentinel = std::numeric_limits::max();
+static constexpr uint32_t kInt32Max =
+ static_cast(std::numeric_limits::max());
+arrow::Result PayloadHeaderFieldSize(const std::string& field,
+ const std::shared_ptr& data,
+ uint32_t* total_size) {
+ if (!data) return kMissingFieldSentinel;
+ if (data->size() > kInt32Max) {
+ return Status::Invalid(field, " must be less than 2 GiB, was: ", data->size());
+ }
+ *total_size += static_cast(data->size());
+ // Check for underflow
+ if (*total_size < 0) return Status::Invalid("Payload header must fit in a uint32_t");
+ return static_cast(data->size());
+}
+uint8_t* PackField(uint32_t size, const std::shared_ptr& data, uint8_t* out) {
+ UInt32ToBytesBe(size, out);
+ if (size != kMissingFieldSentinel) {
+ std::memcpy(out + 4, data->data(), size);
+ return out + 4 + size;
+ } else {
+ return out + 4;
+ }
+}
+} // namespace
+
+arrow::Result PayloadHeaderFrame::Make(const FlightPayload& payload,
+ MemoryPool* memory_pool) {
+ // Assemble all non-data fields here. Presumably this is much less
+ // than data size so we will pay the copy.
+
+ // Structure per field: [4 byte length][data]. If a field is not
+ // present, UINT32_MAX is used as the sentinel (since 0-sized fields
+ // are acceptable)
+ uint32_t header_size = 12;
+ ARROW_ASSIGN_OR_RAISE(
+ const uint32_t descriptor_size,
+ PayloadHeaderFieldSize("descriptor", payload.descriptor, &header_size));
+ ARROW_ASSIGN_OR_RAISE(
+ const uint32_t app_metadata_size,
+ PayloadHeaderFieldSize("app_metadata", payload.app_metadata, &header_size));
+ ARROW_ASSIGN_OR_RAISE(
+ const uint32_t ipc_metadata_size,
+ PayloadHeaderFieldSize("ipc_message.metadata", payload.ipc_message.metadata,
+ &header_size));
+
+ ARROW_ASSIGN_OR_RAISE(auto header_buffer, AllocateBuffer(header_size, memory_pool));
+ uint8_t* payload_header = header_buffer->mutable_data();
+
+ payload_header = PackField(descriptor_size, payload.descriptor, payload_header);
+ payload_header = PackField(app_metadata_size, payload.app_metadata, payload_header);
+ payload_header =
+ PackField(ipc_metadata_size, payload.ipc_message.metadata, payload_header);
+
+ return PayloadHeaderFrame(std::move(header_buffer));
+}
+Status PayloadHeaderFrame::ToFlightData(internal::FlightData* data) {
+ std::shared_ptr buffer = std::move(buffer_);
+
+ // Unpack the descriptor
+ uint32_t offset = 0;
+ uint32_t size = BytesToUInt32Be(buffer->data());
+ offset += 4;
+ if (size != kMissingFieldSentinel) {
+ if (static_cast(offset + size) > buffer->size()) {
+ return Status::Invalid("Buffer is too small: expected ", offset + size,
+ " bytes but have ", buffer->size());
+ }
+ util::string_view desc(reinterpret_cast(buffer->data() + offset), size);
+ data->descriptor.reset(new FlightDescriptor());
+ ARROW_ASSIGN_OR_RAISE(*data->descriptor, FlightDescriptor::Deserialize(desc));
+ offset += size;
+ } else {
+ data->descriptor = nullptr;
+ }
+
+ // Unpack app_metadata
+ size = BytesToUInt32Be(buffer->data() + offset);
+ offset += 4;
+ // While we properly handle zero-size vs nullptr metadata here, gRPC
+ // doesn't (Protobuf doesn't differentiate between the two)
+ if (size != kMissingFieldSentinel) {
+ if (static_cast(offset + size) > buffer->size()) {
+ return Status::Invalid("Buffer is too small: expected ", offset + size,
+ " bytes but have ", buffer->size());
+ }
+ data->app_metadata = SliceBuffer(buffer, offset, size);
+ offset += size;
+ } else {
+ data->app_metadata = nullptr;
+ }
+
+ // Unpack the IPC header
+ size = BytesToUInt32Be(buffer->data() + offset);
+ offset += 4;
+ if (size != kMissingFieldSentinel) {
+ if (static_cast(offset + size) > buffer->size()) {
+ return Status::Invalid("Buffer is too small: expected ", offset + size,
+ " bytes but have ", buffer->size());
+ }
+ data->metadata = SliceBuffer(std::move(buffer), offset, size);
+ } else {
+ data->metadata = nullptr;
+ }
+ data->body = nullptr;
+ return Status::OK();
+}
+
+// pImpl the driver since async methods require a stable address
+class UcpCallDriver::Impl {
+ public:
+#if defined(ARROW_FLIGHT_UCX_SEND_CONTIG)
+ constexpr static bool kEnableContigSend = true;
+#else
+ constexpr static bool kEnableContigSend = false;
+#endif
+
+ Impl(std::shared_ptr worker, ucp_ep_h endpoint)
+ : padding_bytes_({0, 0, 0, 0, 0, 0, 0, 0}),
+ worker_(std::move(worker)),
+ endpoint_(endpoint),
+ read_memory_pool_(default_memory_pool()),
+ write_memory_pool_(default_memory_pool()),
+ memory_manager_(CPUDevice::Instance()->default_memory_manager()),
+ counter_(0) {
+#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP)
+ TryMapBuffer(worker_->context().get(), padding_bytes_.data(), padding_bytes_.size(),
+ UCS_MEMORY_TYPE_HOST, &padding_memh_p_);
+#endif
+
+ ucp_ep_attr_t attrs;
+ std::memset(&attrs, 0, sizeof(attrs));
+ attrs.field_mask = UCP_EP_ATTR_FIELD_NAME;
+ if (ucp_ep_query(endpoint_, &attrs) == UCS_OK) {
+ name_ = attrs.name;
+ } else {
+ name_ = "(unknown remote)";
+ }
+ }
+
+ ~Impl() {
+#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP)
+ TryUnmapBuffer(worker_->context().get(), padding_memh_p_);
+#endif
+ }
+
+ arrow::Result> ReadNextFrame() {
+ auto fut = ReadFrameAsync();
+ while (!fut.is_finished()) MakeProgress();
+ RETURN_NOT_OK(fut.status());
+ return fut.MoveResult();
+ }
+
+ Future> ReadFrameAsync() {
+ RETURN_NOT_OK(CheckClosed());
+
+ std::unique_lock guard(frame_mutex_);
+ if (ARROW_PREDICT_FALSE(!status_.ok())) return status_;
+
+ const uint32_t counter_value = next_counter_++;
+ auto it = frames_.find(counter_value);
+ if (it != frames_.end()) {
+ Future> fut = it->second;
+ frames_.erase(it);
+ return fut;
+ }
+ auto pair = frames_.insert({counter_value, Future>::Make()});
+ DCHECK(pair.second);
+ return pair.first->second;
+ }
+
+ Status SendFrame(FrameType frame_type, const uint8_t* data, const int64_t size) {
+ static uint8_t kZeroes[1] = {0};
+
+ RETURN_NOT_OK(CheckClosed());
+
+ void* request = nullptr;
+ ucp_request_param_t request_param;
+ request_param.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS;
+ request_param.flags = UCP_AM_SEND_FLAG_REPLY;
+
+ // Send frame header
+ FrameHeader header;
+ RETURN_NOT_OK(header.Set(frame_type, counter_++, size));
+ if (size == 0) {
+ // UCX appears to crash on zero-byte payloads
+ request = ucp_am_send_nbx(endpoint_, kUcpAmHandlerId, header.data(), header.size(),
+ kZeroes,
+ /*size=*/1, &request_param);
+ } else {
+ request = ucp_am_send_nbx(endpoint_, kUcpAmHandlerId, header.data(), header.size(),
+ data, size, &request_param);
+ }
+ RETURN_NOT_OK(CompleteRequestBlocking("ucp_am_send_nbx", request));
+
+ return Status::OK();
+ }
+
+ Future<> SendFrameAsync(FrameType frame_type, std::unique_ptr buffer) {
+ RETURN_NOT_OK(CheckClosed());
+
+ ucp_request_param_t request_param;
+ std::memset(&request_param, 0, sizeof(request_param));
+ request_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_DATATYPE |
+ UCP_OP_ATTR_FIELD_FLAGS | UCP_OP_ATTR_FIELD_USER_DATA;
+ request_param.cb.send = AmSendCallback;
+ request_param.datatype = ucp_dt_make_contig(1);
+ request_param.flags = UCP_AM_SEND_FLAG_REPLY;
+
+ const int64_t size = buffer->size();
+ if (size == 0) {
+ // UCX appears to crash on zero-byte payloads
+ ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateBuffer(1, write_memory_pool_));
+ }
+
+ std::unique_ptr pending_send(new PendingContigSend());
+ RETURN_NOT_OK(pending_send->header.Set(frame_type, counter_++, size));
+ pending_send->ipc_message = std::move(buffer);
+ pending_send->driver = this;
+ pending_send->completed = Future<>::Make();
+ pending_send->memh_p = nullptr;
+
+ request_param.user_data = pending_send.release();
+ {
+ auto* pending_send = reinterpret_cast(request_param.user_data);
+
+ void* request = ucp_am_send_nbx(
+ endpoint_, kUcpAmHandlerId, pending_send->header.data(),
+ pending_send->header.size(),
+ reinterpret_cast(pending_send->ipc_message->mutable_data()),
+ static_cast(pending_send->ipc_message->size()), &request_param);
+ if (!request) {
+ // Request completed immediately
+ delete pending_send;
+ return Status::OK();
+ } else if (UCS_PTR_IS_ERR(request)) {
+ delete pending_send;
+ return FromUcsStatus("ucp_am_send_nbx", UCS_PTR_STATUS(request));
+ }
+ return pending_send->completed;
+ }
+ }
+
+ Future<> SendFlightPayload(const FlightPayload& payload) {
+ static const int64_t kMaxBatchSize = std::numeric_limits::max();
+ RETURN_NOT_OK(CheckClosed());
+
+ if (payload.ipc_message.body_length > kMaxBatchSize) {
+ return Status::Invalid("Cannot send record batches exceeding 2GiB yet");
+ }
+
+ {
+ ARROW_ASSIGN_OR_RAISE(auto frame,
+ PayloadHeaderFrame::Make(payload, write_memory_pool_));
+ RETURN_NOT_OK(SendFrame(FrameType::kPayloadHeader, frame.data(), frame.size()));
+ }
+
+ if (!ipc::Message::HasBody(payload.ipc_message.type)) {
+ return Status::OK();
+ }
+
+ // While IOV (scatter-gather) might seem like it avoids a memcpy,
+ // profiling shows that at least for the TCP/SHM/RDMA transports,
+ // UCX just does a memcpy internally. Furthermore, on the receiver
+ // side, a sender-side IOV send prevents optimizations based on
+ // mapped buffers (UCX will memcpy to the destination buffer
+ // regardless of whether it's mapped or not).
+
+ // If all buffers are on the CPU, concatenate them ourselves and
+ // do a regular send to avoid this. Else, use IOV and let UCX
+ // figure out what to do.
+
+ // Weirdness: UCX prefers TCP over shared memory for CONTIG? We
+ // can avoid this by setting UCX_RNDV_THRESH=inf, this will make
+ // UCX prefer shared memory again. However, we still want to avoid
+ // the CONTIG path when shared memory is available, because the
+ // total amount of time spent in memcpy is greater than using IOV
+ // and letting UCX handle it.
+
+ // Consider: if we can figure out how to make IOV always as fast
+ // as CONTIG, we can just send the metadata fields as part of the
+ // IOV payload and avoid having to send two distinct messages.
+
+ bool all_cpu = true;
+ int32_t total_buffers = 0;
+ for (const auto& buffer : payload.ipc_message.body_buffers) {
+ if (!buffer || buffer->size() == 0) continue;
+ all_cpu = all_cpu && buffer->is_cpu();
+ total_buffers++;
+
+ // Arrow IPC requires that we align buffers to 8 byte boundary
+ const auto remainder = static_cast(
+ bit_util::RoundUpToMultipleOf8(buffer->size()) - buffer->size());
+ if (remainder) total_buffers++;
+ }
+
+ ucp_request_param_t request_param;
+ std::memset(&request_param, 0, sizeof(request_param));
+ request_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_DATATYPE |
+ UCP_OP_ATTR_FIELD_FLAGS | UCP_OP_ATTR_FIELD_USER_DATA;
+ request_param.cb.send = AmSendCallback;
+ request_param.flags = UCP_AM_SEND_FLAG_REPLY;
+
+ std::unique_ptr pending_send;
+ void* send_data = nullptr;
+ size_t send_size = 0;
+
+ if (!all_cpu) {
+ request_param.op_attr_mask =
+ request_param.op_attr_mask | UCP_OP_ATTR_FIELD_MEMORY_TYPE;
+ // XXX: UCX doesn't appear to autodetect this correctly if we
+ // use UNKNOWN
+ request_param.memory_type = UCS_MEMORY_TYPE_CUDA;
+ }
+
+ if (kEnableContigSend && all_cpu) {
+ // CONTIG - concatenate buffers into one before sending
+
+ // TODO(lidavidm): this needs to be pipelined since it can be expensive.
+ // Preliminary profiling shows ~5% overhead just from mapping the buffer
+ // alone (on Infiniband; it seems to be trivial for shared memory)
+ request_param.datatype = ucp_dt_make_contig(1);
+ pending_send = arrow::internal::make_unique();
+ auto* pending_contig = reinterpret_cast(pending_send.get());
+
+ ARROW_ASSIGN_OR_RAISE(
+ pending_contig->ipc_message,
+ AllocateBuffer(payload.ipc_message.body_length, write_memory_pool_));
+ TryMapBuffer(worker_->context().get(), *pending_contig->ipc_message,
+ &pending_contig->memh_p);
+
+ uint8_t* ipc_message = pending_contig->ipc_message->mutable_data();
+ for (const auto& buffer : payload.ipc_message.body_buffers) {
+ if (!buffer || buffer->size() == 0) continue;
+
+ std::memcpy(ipc_message, buffer->data(), buffer->size());
+ ipc_message += buffer->size();
+
+ const auto remainder = static_cast(
+ bit_util::RoundUpToMultipleOf8(buffer->size()) - buffer->size());
+ if (remainder) {
+ std::memset(ipc_message, 0, remainder);
+ ipc_message += remainder;
+ }
+ }
+
+ send_data = reinterpret_cast(pending_contig->ipc_message->mutable_data());
+ send_size = static_cast(pending_contig->ipc_message->size());
+ } else {
+ // IOV - let UCX use scatter-gather path
+ request_param.datatype = UCP_DATATYPE_IOV;
+ pending_send = arrow::internal::make_unique();
+ auto* pending_iov = reinterpret_cast(pending_send.get());
+
+ pending_iov->payload = payload;
+ pending_iov->iovs.resize(total_buffers);
+ ucp_dt_iov_t* iov = pending_iov->iovs.data();
+#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP)
+ // XXX: this seems to have no benefits in tests so far
+ pending_iov->memh_ps.resize(total_buffers);
+ ucp_mem_h* memh_p = pending_iov->memh_ps.data();
+#endif
+ for (const auto& buffer : payload.ipc_message.body_buffers) {
+ if (!buffer || buffer->size() == 0) continue;
+
+ iov->buffer = const_cast(reinterpret_cast(buffer->address()));
+ iov->length = buffer->size();
+ ++iov;
+
+#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP)
+ TryMapBuffer(worker_->context().get(), *buffer, memh_p);
+ memh_p++;
+#endif
+
+ const auto remainder = static_cast(
+ bit_util::RoundUpToMultipleOf8(buffer->size()) - buffer->size());
+ if (remainder) {
+ iov->buffer =
+ const_cast(reinterpret_cast(padding_bytes_.data()));
+ iov->length = remainder;
+ ++iov;
+ }
+ }
+
+ send_data = pending_iov->iovs.data();
+ send_size = pending_iov->iovs.size();
+ }
+
+ RETURN_NOT_OK(pending_send->header.Set(FrameType::kPayloadBody, counter_++,
+ payload.ipc_message.body_length));
+ pending_send->driver = this;
+ pending_send->completed = Future<>::Make();
+
+ request_param.user_data = pending_send.release();
+ {
+ auto* pending_send = reinterpret_cast(request_param.user_data);
+
+ void* request = ucp_am_send_nbx(
+ endpoint_, kUcpAmHandlerId, pending_send->header.data(),
+ pending_send->header.size(), send_data, send_size, &request_param);
+ if (!request) {
+ // Request completed immediately
+ delete pending_send;
+ return Status::OK();
+ } else if (UCS_PTR_IS_ERR(request)) {
+ delete pending_send;
+ return FromUcsStatus("ucp_am_send_nbx", UCS_PTR_STATUS(request));
+ }
+ return pending_send->completed;
+ }
+ }
+
+ Status Close() {
+ if (!endpoint_) return Status::OK();
+
+ for (auto& item : frames_) {
+ item.second.MarkFinished(Status::Cancelled("UcpCallDriver is being closed"));
+ }
+ frames_.clear();
+
+ void* request = ucp_ep_close_nb(endpoint_, UCP_EP_CLOSE_MODE_FLUSH);
+ ucs_status_t status = UCS_OK;
+ std::string origin = "ucp_ep_close_nb";
+ if (UCS_PTR_IS_ERR(request)) {
+ status = UCS_PTR_STATUS(request);
+ } else if (UCS_PTR_IS_PTR(request)) {
+ origin = "ucp_request_check_status";
+ while ((status = ucp_request_check_status(request)) == UCS_INPROGRESS) {
+ MakeProgress();
+ }
+ ucp_request_free(request);
+ } else {
+ DCHECK(!request);
+ }
+
+ endpoint_ = nullptr;
+ if (status != UCS_OK && status != UCS_ERR_ENDPOINT_TIMEOUT &&
+ status != UCS_ERR_NOT_CONNECTED) {
+ // Ignore timeout, not connected
+ return FromUcsStatus(origin, status);
+ }
+ return Status::OK();
+ }
+
+ void MakeProgress() { ucp_worker_progress(worker_->get()); }
+
+ void Push(std::shared_ptr frame) {
+ std::unique_lock guard(frame_mutex_);
+ if (ARROW_PREDICT_FALSE(!status_.ok())) return;
+ auto pair = frames_.insert({frame->counter, frame});
+ if (!pair.second) {
+ pair.first->second.MarkFinished(std::move(frame));
+ frames_.erase(pair.first);
+ }
+ }
+
+ void Push(Status status) {
+ std::unique_lock guard(frame_mutex_);
+ status_ = std::move(status);
+ for (auto& item : frames_) {
+ item.second.MarkFinished(status_);
+ }
+ frames_.clear();
+ }
+
+ ucs_status_t RecvActiveMessage(const void* header, size_t header_length, void* data,
+ const size_t data_length,
+ const ucp_am_recv_param_t* param) {
+ auto maybe_status =
+ RecvActiveMessageImpl(header, header_length, data, data_length, param);
+ if (!maybe_status.ok()) {
+ Push(maybe_status.status());
+ return UCS_OK;
+ }
+ return maybe_status.MoveValueUnsafe();
+ }
+
+ const std::shared_ptr& memory_manager() const { return memory_manager_; }
+ void set_memory_manager(std::shared_ptr memory_manager) {
+ if (memory_manager) {
+ memory_manager_ = std::move(memory_manager);
+ } else {
+ memory_manager_ = CPUDevice::Instance()->default_memory_manager();
+ }
+ }
+ void set_read_memory_pool(MemoryPool* pool) {
+ read_memory_pool_ = pool ? pool : default_memory_pool();
+ }
+ void set_write_memory_pool(MemoryPool* pool) {
+ write_memory_pool_ = pool ? pool : default_memory_pool();
+ }
+
+ private:
+ class PendingAmSend {
+ public:
+ virtual ~PendingAmSend() = default;
+ UcpCallDriver::Impl* driver;
+ Future<> completed;
+ FrameHeader header;
+ };
+
+ class PendingContigSend : public PendingAmSend {
+ public:
+ std::unique_ptr ipc_message;
+ ucp_mem_h memh_p;
+
+ virtual ~PendingContigSend() {
+ TryUnmapBuffer(driver->worker_->context().get(), memh_p);
+ }
+ };
+
+ class PendingIovSend : public PendingAmSend {
+ public:
+ FlightPayload payload;
+ std::vector iovs;
+#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP)
+ std::vector memh_ps;
+
+ virtual ~PendingIovSend() {
+ for (ucp_mem_h memh_p : memh_ps) {
+ TryUnmapBuffer(driver->worker_->context().get(), memh_p);
+ }
+ }
+#endif
+ };
+
+ struct PendingAmRecv {
+ UcpCallDriver::Impl* driver;
+ std::shared_ptr frame;
+ ucp_mem_h memh_p;
+
+ PendingAmRecv(UcpCallDriver::Impl* driver_, std::shared_ptr frame_)
+ : driver(driver_), frame(std::move(frame_)) {}
+
+ ~PendingAmRecv() { TryUnmapBuffer(driver->worker_->context().get(), memh_p); }
+ };
+
+ static void AmSendCallback(void* request, ucs_status_t status, void* user_data) {
+ auto* pending_send = reinterpret_cast(user_data);
+ if (status == UCS_OK) {
+ pending_send->completed.MarkFinished();
+ } else {
+ pending_send->completed.MarkFinished(FromUcsStatus("ucp_am_send_nbx", status));
+ }
+ // TODO(lidavidm): delete should occur on a background thread if there's mapped
+ // buffers, since unmapping can be nontrivial and we don't want to block
+ // the thread doing UCX work. (Borrow the Rust transfer-and-drop pattern.)
+ delete pending_send;
+ ucp_request_free(request);
+ }
+
+ static void AmRecvCallback(void* request, ucs_status_t status, size_t length,
+ void* user_data) {
+ auto* pending_recv = reinterpret_cast(user_data);
+ ucp_request_free(request);
+ if (status != UCS_OK) {
+ pending_recv->driver->Push(
+ FromUcsStatus("ucp_am_recv_data_nbx (callback)", status));
+ } else {
+ pending_recv->driver->Push(std::move(pending_recv->frame));
+ }
+ delete pending_recv;
+ }
+
+ arrow::Result RecvActiveMessageImpl(const void* header,
+ size_t header_length, void* data,
+ const size_t data_length,
+ const ucp_am_recv_param_t* param) {
+ DCHECK(param->recv_attr & UCP_AM_RECV_ATTR_FIELD_REPLY_EP);
+
+ if (data_length > static_cast(std::numeric_limits::max())) {
+ return Status::Invalid(
+ "Cannot allocate buffer greater than int64_t max, requested: ", data_length);
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto frame, Frame::ParseHeader(header, header_length));
+ if (data_length < frame->size) {
+ return Status::IOError("Expected frame of ", frame->size, " bytes, but got only ",
+ data_length);
+ }
+
+ if ((param->recv_attr & UCP_AM_RECV_ATTR_FLAG_DATA) &&
+ (frame->type != FrameType::kPayloadBody || memory_manager_->is_cpu())) {
+ // Zero-copy path. UCX-allocated buffer must be freed later.
+
+ // XXX: this buffer can NOT be freed until AFTER we return from
+ // this handler. Otherwise, UCX won't have fully set up its
+ // internal data structures (allocated just before the buffer)
+ // and we'll crash when we free the buffer. Effectively: we can
+ // never use Then/AddCallback on a Future<> from ReadFrameAsync,
+ // because we might run the callback synchronously (which might
+ // free the buffer) when we call Push here.
+ frame->buffer =
+ arrow::internal::make_unique(worker_, data, data_length);
+ Push(std::move(frame));
+ return UCS_INPROGRESS;
+ }
+
+ if ((param->recv_attr & UCP_AM_RECV_ATTR_FLAG_DATA) ||
+ (param->recv_attr & UCP_AM_RECV_ATTR_FLAG_RNDV)) {
+ // Rendezvous protocol (RNDV), or unpack to destination (DATA).
+
+ // We want to map/pin/register the buffer for faster transfer
+ // where possible. (It gets unmapped in ~PendingAmRecv.)
+ // TODO(lidavidm): This takes non-trivial time, so return
+ // UCS_INPROGRESS, kick off the allocation in the background,
+ // and recv the data later (is it allowed to call
+ // ucp_am_recv_data_nbx asynchronously?).
+ if (frame->type == FrameType::kPayloadBody) {
+ ARROW_ASSIGN_OR_RAISE(frame->buffer,
+ memory_manager_->AllocateBuffer(data_length));
+ } else {
+ ARROW_ASSIGN_OR_RAISE(frame->buffer,
+ AllocateBuffer(data_length, read_memory_pool_));
+ }
+
+ PendingAmRecv* pending_recv = new PendingAmRecv(this, std::move(frame));
+ TryMapBuffer(worker_->context().get(), *pending_recv->frame->buffer,
+ &pending_recv->memh_p);
+
+ ucp_request_param_t recv_param;
+ recv_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
+ UCP_OP_ATTR_FIELD_MEMORY_TYPE |
+ UCP_OP_ATTR_FIELD_USER_DATA;
+ recv_param.cb.recv_am = AmRecvCallback;
+ recv_param.user_data = pending_recv;
+ recv_param.memory_type = InferMemoryType(*pending_recv->frame->buffer);
+
+ void* dest =
+ reinterpret_cast(pending_recv->frame->buffer->mutable_address());
+ void* request =
+ ucp_am_recv_data_nbx(worker_->get(), data, dest, data_length, &recv_param);
+ if (UCS_PTR_IS_ERR(request)) {
+ delete pending_recv;
+ return FromUcsStatus("ucp_am_recv_data_nbx", UCS_PTR_STATUS(request));
+ } else if (!request) {
+ // Request completed instantly
+ Push(std::move(pending_recv->frame));
+ delete pending_recv;
+ }
+ return UCS_OK;
+ } else {
+ // Data will be freed after callback returns - copy to buffer
+ if (frame->type != FrameType::kPayloadBody || memory_manager_->is_cpu()) {
+ ARROW_ASSIGN_OR_RAISE(frame->buffer,
+ AllocateBuffer(data_length, read_memory_pool_));
+ std::memcpy(frame->buffer->mutable_data(), data, data_length);
+ } else {
+ ARROW_ASSIGN_OR_RAISE(
+ frame->buffer,
+ MemoryManager::CopyNonOwned(Buffer(reinterpret_cast(data),
+ static_cast(data_length)),
+ memory_manager_));
+ }
+ Push(std::move(frame));
+ return UCS_OK;
+ }
+ }
+
+ Status CompleteRequestBlocking(const std::string& context, void* request) {
+ if (UCS_PTR_IS_ERR(request)) {
+ return FromUcsStatus(context, UCS_PTR_STATUS(request));
+ } else if (UCS_PTR_IS_PTR(request)) {
+ while (true) {
+ auto status = ucp_request_check_status(request);
+ if (status == UCS_OK) {
+ break;
+ } else if (status != UCS_INPROGRESS) {
+ ucp_request_release(request);
+ return FromUcsStatus("ucp_request_check_status", status);
+ }
+ MakeProgress();
+ }
+ ucp_request_free(request);
+ } else {
+ // Send was completed instantly
+ DCHECK(!request);
+ }
+ return Status::OK();
+ }
+
+ Status CheckClosed() {
+ if (!endpoint_) {
+ return Status::Invalid("UcpCallDriver is closed");
+ }
+ return Status::OK();
+ }
+
+ const std::array padding_bytes_;
+#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP)
+ ucp_mem_h padding_memh_p_;
+#endif
+
+ std::shared_ptr worker_;
+ ucp_ep_h endpoint_;
+ MemoryPool* read_memory_pool_;
+ MemoryPool* write_memory_pool_;
+ std::shared_ptr memory_manager_;
+
+ // Internal name for logging/tracing
+ std::string name_;
+ // Counter used to reorder messages
+ uint32_t counter_ = 0;
+
+ std::mutex frame_mutex_;
+ Status status_;
+ std::unordered_map>> frames_;
+ uint32_t next_counter_ = 0;
+};
+
+UcpCallDriver::UcpCallDriver(std::shared_ptr worker, ucp_ep_h endpoint)
+ : impl_(new Impl(std::move(worker), endpoint)) {}
+UcpCallDriver::UcpCallDriver(UcpCallDriver&&) = default;
+UcpCallDriver& UcpCallDriver::operator=(UcpCallDriver&&) = default;
+UcpCallDriver::~UcpCallDriver() = default;
+
+arrow::Result> UcpCallDriver::ReadNextFrame() {
+ return impl_->ReadNextFrame();
+}
+
+Future> UcpCallDriver::ReadFrameAsync() {
+ return impl_->ReadFrameAsync();
+}
+
+Status UcpCallDriver::ExpectFrameType(const Frame& frame, FrameType type) {
+ if (frame.type != type) {
+ return Status::IOError("Expected frame type ", static_cast(type),
+ ", but got frame type ", static_cast(frame.type));
+ }
+ return Status::OK();
+}
+
+Status UcpCallDriver::StartCall(const std::string& method) {
+ std::vector> headers;
+ headers.emplace_back(kHeaderMethod, method);
+ ARROW_ASSIGN_OR_RAISE(auto frame, HeadersFrame::Make(headers));
+ auto buffer = std::move(frame).GetBuffer();
+ RETURN_NOT_OK(impl_->SendFrame(FrameType::kHeaders, buffer->data(), buffer->size()));
+ return Status::OK();
+}
+
+Future<> UcpCallDriver::SendFlightPayload(const FlightPayload& payload) {
+ return impl_->SendFlightPayload(payload);
+}
+
+Status UcpCallDriver::SendFrame(FrameType frame_type, const uint8_t* data,
+ const int64_t size) {
+ return impl_->SendFrame(frame_type, data, size);
+}
+
+Future<> UcpCallDriver::SendFrameAsync(FrameType frame_type,
+ std::unique_ptr buffer) {
+ return impl_->SendFrameAsync(frame_type, std::move(buffer));
+}
+
+Status UcpCallDriver::Close() { return impl_->Close(); }
+
+void UcpCallDriver::MakeProgress() { impl_->MakeProgress(); }
+
+ucs_status_t UcpCallDriver::RecvActiveMessage(const void* header, size_t header_length,
+ void* data, const size_t data_length,
+ const ucp_am_recv_param_t* param) {
+ return impl_->RecvActiveMessage(header, header_length, data, data_length, param);
+}
+
+const std::shared_ptr& UcpCallDriver::memory_manager() const {
+ return impl_->memory_manager();
+}
+
+void UcpCallDriver::set_memory_manager(std::shared_ptr memory_manager) {
+ impl_->set_memory_manager(std::move(memory_manager));
+}
+void UcpCallDriver::set_read_memory_pool(MemoryPool* pool) {
+ impl_->set_read_memory_pool(pool);
+}
+void UcpCallDriver::set_write_memory_pool(MemoryPool* pool) {
+ impl_->set_write_memory_pool(pool);
+}
+
+} // namespace ucx
+} // namespace transport
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_internal.h b/cpp/src/arrow/flight/transport/ucx/ucx_internal.h
new file mode 100644
index 00000000000..389ab2ea2a8
--- /dev/null
+++ b/cpp/src/arrow/flight/transport/ucx/ucx_internal.h
@@ -0,0 +1,352 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Common implementation of UCX communication primitives.
+
+#pragma once
+
+#include
+#include
+#include
+#include
+
+#include
+
+#include "arrow/buffer.h"
+#include "arrow/flight/server.h"
+#include "arrow/flight/transport.h"
+#include "arrow/flight/transport/ucx/util_internal.h"
+#include "arrow/flight/visibility.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/future.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/string_view.h"
+
+namespace arrow {
+namespace flight {
+namespace transport {
+namespace ucx {
+
+//------------------------------------------------------------
+// Protocol Constants
+
+static constexpr char kMethodDoExchange[] = "DoExchange";
+static constexpr char kMethodDoGet[] = "DoGet";
+static constexpr char kMethodDoPut[] = "DoPut";
+static constexpr char kMethodGetFlightInfo[] = "GetFlightInfo";
+
+static constexpr char kHeaderStatusCode[] = "flight-status-code";
+static constexpr char kHeaderStatusMessage[] = "flight-status-message";
+static constexpr char kHeaderStatusDetail[] = "flight-status-detail";
+static constexpr char kHeaderStatusDetailCode[] = "flight-status-detail-code";
+
+//------------------------------------------------------------
+// UCX Helpers
+
+/// \brief A wrapper around a ucp_context_h.
+///
+/// Used so that multiple resources can share ownership of the
+/// context. UCX has zero-copy optimizations where an application can
+/// directly use a UCX buffer, but the lifetime of such buffers is
+/// tied to the UCX context and worker, so ownership needs to be
+/// preserved.
+class UcpContext final {
+ public:
+ UcpContext() : ucp_context_(nullptr) {}
+ explicit UcpContext(ucp_context_h context) : ucp_context_(context) {}
+ ~UcpContext() {
+ if (ucp_context_) ucp_cleanup(ucp_context_);
+ ucp_context_ = nullptr;
+ }
+ ucp_context_h get() const {
+ DCHECK(ucp_context_);
+ return ucp_context_;
+ }
+
+ private:
+ ucp_context_h ucp_context_;
+};
+
+/// \brief A wrapper around a ucp_worker_h.
+class UcpWorker final {
+ public:
+ UcpWorker() : ucp_worker_(nullptr) {}
+ UcpWorker(std::shared_ptr context, ucp_worker_h worker)
+ : ucp_context_(std::move(context)), ucp_worker_(worker) {}
+ ~UcpWorker() {
+ if (ucp_worker_) ucp_worker_destroy(ucp_worker_);
+ ucp_worker_ = nullptr;
+ }
+ ucp_worker_h get() const {
+ DCHECK(ucp_worker_);
+ return ucp_worker_;
+ }
+ const UcpContext& context() const { return *ucp_context_; }
+
+ private:
+ std::shared_ptr ucp_context_;
+ ucp_worker_h ucp_worker_;
+};
+
+//------------------------------------------------------------
+// Message Framing
+
+/// \brief The message type.
+enum class FrameType : uint8_t {
+ /// Key-value headers. Sent at the beginning (client->server) and
+ /// end (server->client) of a call. Also, for client-streaming calls
+ /// (e.g. DoPut), the client should send a headers frame to signal
+ /// end-of-stream.
+ kHeaders = 0,
+ /// Binary blob, does not contain Arrow data.
+ kBuffer,
+ /// Binary blob. Contains IPC metadata, app metadata.
+ kPayloadHeader,
+ /// Binary blob. Contains IPC body. Body is sent separately since it
+ /// may use a different memory type.
+ kPayloadBody,
+ /// Ask server to disconnect (to avoid client/server waiting on each
+ /// other and getting stuck).
+ kDisconnect,
+ /// Keep at end.
+ kMaxFrameType = kDisconnect,
+};
+
+/// \brief The header of a message frame. Used when sending only.
+///
+/// A frame is expected to be sent over UCP Active Messages and
+/// consists of a header (of kFrameHeaderBytes bytes) and a body.
+///
+/// The header is as follows:
+/// +-------+---------------------------------+
+/// | Bytes | Function |
+/// +=======+=================================+
+/// | 0 | Version tag (see kFrameVersion) |
+/// | 1 | Frame type (see FrameType) |
+/// | 2-3 | Unused, reserved |
+/// | 4-7 | Frame counter (big-endian) |
+/// | 8-11 | Body size (big-endian) |
+/// +-------+---------------------------------+
+///
+/// The frame counter lets the receiver ensure messages are processed
+/// in-order. (The message receive callback may use
+/// ucp_am_recv_data_nbx which is asynchronous.)
+///
+/// The body size reports the expected message size (UCX chokes on
+/// zero-size payloads which we occasionally want to send, so the size
+/// field in the header lets us know when a payload was meant to be
+/// empty).
+struct FrameHeader {
+ /// \brief The size of a frame header.
+ static constexpr size_t kFrameHeaderBytes = 12;
+ /// \brief The expected version tag in the header.
+ static constexpr uint8_t kFrameVersion = 0x01;
+
+ FrameHeader() = default;
+ /// \brief Initialize the frame header.
+ Status Set(FrameType frame_type, uint32_t counter, int64_t body_size);
+ void* data() const { return header.data(); }
+ size_t size() const { return kFrameHeaderBytes; }
+
+ // mutable since UCX expects void* not const void*
+ mutable std::array header = {0};
+};
+
+/// \brief A single message received via UCX. Used when receiving only.
+struct Frame {
+ /// \brief The message type.
+ FrameType type;
+ /// \brief The message length.
+ uint32_t size;
+ /// \brief An incrementing message counter (may wrap over).
+ uint32_t counter;
+ /// \brief The message contents.
+ std::unique_ptr buffer;
+
+ Frame() = default;
+ Frame(FrameType type_, uint32_t size_, uint32_t counter_,
+ std::unique_ptr buffer_)
+ : type(type_), size(size_), counter(counter_), buffer(std::move(buffer_)) {}
+
+ util::string_view view() const {
+ return util::string_view(reinterpret_cast(buffer->data()), size);
+ }
+
+ /// \brief Parse a UCX active message header. This will not
+ /// initialize the buffer field.
+ static arrow::Result> ParseHeader(const void* header,
+ size_t header_length);
+};
+
+/// \brief The active message handler callback ID.
+static constexpr uint32_t kUcpAmHandlerId = 0x1024;
+
+/// \brief A collection of key-value headers.
+///
+/// This should be stored in a frame of type kHeaders.
+///
+/// Format:
+/// +-------+----------------------------------+
+/// | Bytes | Contents |
+/// +=======+==================================+
+/// | 0-4 | # of headers (big-endian) |
+/// | 4-8 | Header key length (big-endian) |
+/// | 2-3 | Header value length (big-endian) |
+/// | (...) | Header key |
+/// | (...) | Header value |
+/// | (...) | (repeat from row 2) |
+/// +-------+----------------------------------+
+class HeadersFrame {
+ public:
+ /// \brief Get a header value (or an error if it was not found)
+ arrow::Result Get(const std::string& key);
+ /// \brief Extract the server-sent status.
+ Status GetStatus(Status* out);
+ /// \brief Parse the headers from the buffer.
+ static arrow::Result Parse(std::unique_ptr buffer);
+ /// \brief Create a new frame with the given headers.
+ static arrow::Result Make(
+ const std::vector>& headers);
+ /// \brief Create a new frame with the given headers and the given status.
+ static arrow::Result Make(
+ const Status& status,
+ const std::vector>& headers);
+
+ /// \brief Take ownership of the underlying buffer.
+ std::unique_ptr GetBuffer() && { return std::move(buffer_); }
+
+ private:
+ std::unique_ptr buffer_;
+ std::vector> headers_;
+};
+
+/// \brief A representation of a kPayloadHeader frame (i.e. all of the
+/// metadata in a FlightPayload/FlightData).
+///
+/// Data messages are sent in two parts: one containing all metadata
+/// (the Flatbuffers header, FlightDescriptor, and app_metadata
+/// fields) and one containing the actual data. This was done to avoid
+/// having to concatenate these fields with the data itself (in the
+/// cases where we are not using IOV).
+///
+/// Format:
+/// +--------+----------------------------------+
+/// | Bytes | Contents |
+/// +========+==================================+
+/// | 0-4 | Descriptor length (big-endian) |
+/// | 4..a | Descriptor bytes |
+/// | a-a+4 | app_metadata length (big-endian) |
+/// | a+4..b | app_metadata bytes |
+/// | b-b+4 | ipc_metadata length (big-endian) |
+/// | b+4..c | ipc_metadata bytes |
+/// +--------+----------------------------------+
+///
+/// If a field is not present, its length is still there, but is set
+/// to UINT32_MAX.
+class PayloadHeaderFrame {
+ public:
+ explicit PayloadHeaderFrame(std::unique_ptr buffer)
+ : buffer_(std::move(buffer)) {}
+ /// \brief Unpack the internal buffer into a FlightData.
+ Status ToFlightData(internal::FlightData* data);
+ /// \brief Pack a payload into the internal buffer.
+ static arrow::Result Make(const FlightPayload& payload,
+ MemoryPool* memory_pool);
+ const uint8_t* data() const { return buffer_->data(); }
+ int64_t size() const { return buffer_->size(); }
+
+ private:
+ std::unique_ptr buffer_;
+};
+
+/// \brief Manage the state of a UCX connection.
+class UcpCallDriver {
+ public:
+ UcpCallDriver(std::shared_ptr worker, ucp_ep_h endpoint);
+
+ UcpCallDriver(const UcpCallDriver&) = delete;
+ UcpCallDriver(UcpCallDriver&&);
+ void operator=(const UcpCallDriver&) = delete;
+ UcpCallDriver& operator=(UcpCallDriver&&);
+
+ ~UcpCallDriver();
+
+ /// \brief Start a call by sending a headers frame. Client side only.
+ ///
+ /// \param[in] method The RPC method.
+ Status StartCall(const std::string& method);
+
+ /// \brief Synchronously send a generic message with binary payload.
+ Status SendFrame(FrameType frame_type, const uint8_t* data, const int64_t size);
+ /// \brief Asynchronously send a generic message with binary payload.
+ ///
+ /// The UCP driver must be manually polled (call MakeProgress()).
+ Future<> SendFrameAsync(FrameType frame_type, std::unique_ptr buffer);
+ /// \brief Asynchronously send a data message.
+ ///
+ /// The UCP driver must be manually polled (call MakeProgress()).
+ Future<> SendFlightPayload(const FlightPayload& payload);
+
+ /// \brief Synchronously read the next frame.
+ arrow::Result> ReadNextFrame();
+ /// \brief Asynchronously read the next frame.
+ ///
+ /// The UCP driver must be manually polled (call MakeProgress()).
+ Future> ReadFrameAsync();
+
+ /// \brief Validate that the frame is of the given type.
+ Status ExpectFrameType(const Frame& frame, FrameType type);
+
+ /// \brief Disconnect the other side of the connection. Note, this
+ /// can cause deadlock.
+ Status Close();
+
+ /// \brief Synchronously make progress (to adapt async to sync APIs)
+ void MakeProgress();
+
+ /// \brief Get the associated memory manager.
+ const std::shared_ptr& memory_manager() const;
+ /// \brief Set the associated memory manager.
+ void set_memory_manager(std::shared_ptr memory_manager);
+ /// \brief Set memory pool for scratch space used during reading.
+ void set_read_memory_pool(MemoryPool* memory_pool);
+ /// \brief Set memory pool for scratch space used during writing.
+ void set_write_memory_pool(MemoryPool* memory_pool);
+
+ /// \brief Process an incoming active message. This will unblock the
+ /// corresponding call to ReadFrameAsync/ReadNextFrame.
+ ucs_status_t RecvActiveMessage(const void* header, size_t header_length, void* data,
+ const size_t data_length,
+ const ucp_am_recv_param_t* param);
+
+ private:
+ class Impl;
+ std::unique_ptr impl_;
+};
+
+ARROW_FLIGHT_EXPORT
+std::unique_ptr MakeUcxClientImpl();
+
+ARROW_FLIGHT_EXPORT
+std::unique_ptr MakeUcxServerImpl(
+ FlightServerBase* base, std::shared_ptr memory_manager);
+
+} // namespace ucx
+} // namespace transport
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc
new file mode 100644
index 00000000000..b9612a19d0a
--- /dev/null
+++ b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc
@@ -0,0 +1,660 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/transport/ucx/ucx_internal.h"
+
+#include
+#include
+#include
+#include
+#include
+
+#include